diff --git a/CHANGELOG.md b/CHANGELOG.md index 1206657d44..77cf2b37c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ Please mark all change in change log and use the issue from GitHub - \#1548 Move store/Directory to storage/Operation and add FSHandler - \#1619 Improve compact performance - \#1649 Fix Milvus crash on old CPU -- \#1653 IndexFlat performance improvement for NQ less than thread_number +- \#1653 IndexFlat (SSE) and IndexBinaryFlat performance improvement for small NQ ## Task diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp index e89d6aa7ee..ff20a21277 100644 --- a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp @@ -15,175 +15,196 @@ namespace faiss { - size_t batch_size = 65536; +static const size_t size_1M = 1 * 1024 * 1024; +static const size_t batch_size = 65536; - template - static - void binary_distence_knn_hc( - int bytes_per_code, - float_maxheap_array_t * ha, - const uint8_t * bs1, - const uint8_t * bs2, - size_t n2, - bool order = true, - bool init_heap = true, - ConcurrentBitsetPtr bitset = nullptr) - { - size_t k = ha->k; +template +static +void binary_distence_knn_hc( + int bytes_per_code, + float_maxheap_array_t * ha, + const uint8_t * bs1, + const uint8_t * bs2, + size_t n2, + bool order = true, + bool init_heap = true, + ConcurrentBitsetPtr bitset = nullptr) +{ + size_t k = ha->k; + + if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) { + int thread_max_num = omp_get_max_threads(); + // init hash + size_t thread_hash_size = ha->nh * k; + size_t all_hash_size = thread_hash_size * thread_max_num; + float *value = new float[all_hash_size]; + int64_t *labels = new int64_t[all_hash_size]; + for (int i = 0; i < all_hash_size; i++) { + value[i] = 1.0 / 0.0; + labels[i] = -1; + } + + T *hc = new T[ha->nh]; + for (size_t i = 0; i < ha->nh; i++) { + hc[i].set(bs1 + i * bytes_per_code, bytes_per_code); + } + +#pragma omp parallel for + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + int thread_no = omp_get_thread_num(); + + const uint8_t * bs2_ = bs2 + j * bytes_per_code; + for (size_t i = 0; i < ha->nh; i++) { + tadis_t dis = hc[i].compute (bs2_); + + float * val_ = value + thread_no * thread_hash_size + i * k; + int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + if (dis < val_[0]) { + faiss::maxheap_pop (k, val_, ids_); + faiss::maxheap_push (k, val_, ids_, dis, j); + } + } + } + } + + for (size_t t = 1; t < thread_max_num; t++) { + // merge hash + for (size_t i = 0; i < ha->nh; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + float *value_x_t = value_x + t * thread_hash_size; + int64_t *labels_x_t = labels_x + t * thread_hash_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] < value_x[0]) { + faiss::maxheap_pop (k, value_x, labels_x); + faiss::maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + // copy result + memcpy(ha->val, value, thread_hash_size * sizeof(float)); + memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t)); + + delete[] hc; + delete[] value; + delete[] labels; + + } else { if (init_heap) ha->heapify (); - int thread_max_num = omp_get_max_threads(); - if (ha->nh < 4) { - // omp for n2 - int all_hash_size = thread_max_num * k; - float *value = new float[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; - - for (int i = 0; i < ha->nh; i++) { - T hc (bs1 + i * bytes_per_code, bytes_per_code); - // init hash - for (int i = 0; i < all_hash_size; i++) { - value[i] = 1.0 / 0.0; - } + const size_t block_size = batch_size; + for (size_t j0 = 0; j0 < n2; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, n2); #pragma omp parallel for - for (size_t j = 0; j < n2; j++) { - if(!bitset || !bitset->test(j)) { - const uint8_t * bs2_ = bs2 + j * bytes_per_code; - tadis_t dis = hc.compute (bs2_); + for (size_t i = 0; i < ha->nh; i++) { + T hc (bs1 + i * bytes_per_code, bytes_per_code); - int thread_no = omp_get_thread_num(); - float * __restrict val_ = value + thread_no * k; - int64_t * __restrict ids_ = labels + thread_no * k; - if (dis < val_[0]) { - faiss::maxheap_pop (k, val_, ids_); - faiss::maxheap_push (k, val_, ids_, dis, j); - } - } - } - // merge hash + const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; + tadis_t dis; tadis_t * __restrict bh_val_ = ha->val + i * k; int64_t * __restrict bh_ids_ = ha->ids + i * k; - for (int i = 0; i < all_hash_size; i++) { - if (value[i] < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, value[i], labels[i]); - } - } - } - delete[] value; - delete[] labels; - - } else { - const size_t block_size = batch_size; - for (size_t j0 = 0; j0 < n2; j0 += block_size) { - const size_t j1 = std::min(j0 + block_size, n2); -#pragma omp parallel for - for (size_t i = 0; i < ha->nh; i++) { - T hc (bs1 + i * bytes_per_code, bytes_per_code); - - const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; - tadis_t dis; - tadis_t * __restrict bh_val_ = ha->val + i * k; - int64_t * __restrict bh_ids_ = ha->ids + i * k; - size_t j; - for (j = j0; j < j1; j++, bs2_+= bytes_per_code) { - if(!bitset || !bitset->test(j)){ - dis = hc.compute (bs2_); - if (dis < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); - } + size_t j; + for (j = j0; j < j1; j++, bs2_+= bytes_per_code) { + if(!bitset || !bitset->test(j)){ + dis = hc.compute (bs2_); + if (dis < bh_val_[0]) { + faiss::maxheap_pop (k, bh_val_, bh_ids_); + faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); } } - } + } } - - if (order) ha->reorder (); } - void binary_distence_knn_hc ( - MetricType metric_type, - float_maxheap_array_t * ha, - const uint8_t * a, - const uint8_t * b, - size_t nb, - size_t ncodes, - int order, - ConcurrentBitsetPtr bitset) - { - switch (metric_type) { - case METRIC_Jaccard: - case METRIC_Tanimoto: - switch (ncodes) { + if (order) ha->reorder (); +} + +void binary_distence_knn_hc ( + MetricType metric_type, + float_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int order, + ConcurrentBitsetPtr bitset) +{ + switch (metric_type) { + case METRIC_Jaccard: + case METRIC_Tanimoto: + switch (ncodes) { #define binary_distence_knn_hc_jaccard(ncodes) \ - case ncodes: \ - binary_distence_knn_hc \ - (ncodes, ha, a, b, nb, order, true, bitset); \ - break; - binary_distence_knn_hc_jaccard(8); - binary_distence_knn_hc_jaccard(16); - binary_distence_knn_hc_jaccard(32); - binary_distence_knn_hc_jaccard(64); - binary_distence_knn_hc_jaccard(128); - binary_distence_knn_hc_jaccard(256); - binary_distence_knn_hc_jaccard(512); + case ncodes: \ + binary_distence_knn_hc \ + (ncodes, ha, a, b, nb, order, true, bitset); \ + break; + binary_distence_knn_hc_jaccard(8); + binary_distence_knn_hc_jaccard(16); + binary_distence_knn_hc_jaccard(32); + binary_distence_knn_hc_jaccard(64); + binary_distence_knn_hc_jaccard(128); + binary_distence_knn_hc_jaccard(256); + binary_distence_knn_hc_jaccard(512); #undef binary_distence_knn_hc_jaccard - default: - binary_distence_knn_hc - (ncodes, ha, a, b, nb, order, true, bitset); - break; - } - break; - - case METRIC_Substructure: - switch (ncodes) { -#define binary_distence_knn_hc_Substructure(ncodes) \ - case ncodes: \ - binary_distence_knn_hc \ - (ncodes, ha, a, b, nb, order, true, bitset); \ - break; - binary_distence_knn_hc_Substructure(8); - binary_distence_knn_hc_Substructure(16); - binary_distence_knn_hc_Substructure(32); - binary_distence_knn_hc_Substructure(64); - binary_distence_knn_hc_Substructure(128); - binary_distence_knn_hc_Substructure(256); - binary_distence_knn_hc_Substructure(512); -#undef binary_distence_knn_hc_Substructure - default: - binary_distence_knn_hc - (ncodes, ha, a, b, nb, order, true, bitset); - break; - } - break; - - case METRIC_Superstructure: - switch (ncodes) { -#define binary_distence_knn_hc_Superstructure(ncodes) \ - case ncodes: \ - binary_distence_knn_hc \ - (ncodes, ha, a, b, nb, order, true, bitset); \ - break; - binary_distence_knn_hc_Superstructure(8); - binary_distence_knn_hc_Superstructure(16); - binary_distence_knn_hc_Superstructure(32); - binary_distence_knn_hc_Superstructure(64); - binary_distence_knn_hc_Superstructure(128); - binary_distence_knn_hc_Superstructure(256); - binary_distence_knn_hc_Superstructure(512); -#undef binary_distence_knn_hc_Superstructure - default: - binary_distence_knn_hc - (ncodes, ha, a, b, nb, order, true, bitset); - break; - } - break; - default: + binary_distence_knn_hc + (ncodes, ha, a, b, nb, order, true, bitset); break; } + break; + + case METRIC_Substructure: + switch (ncodes) { +#define binary_distence_knn_hc_Substructure(ncodes) \ + case ncodes: \ + binary_distence_knn_hc \ + (ncodes, ha, a, b, nb, order, true, bitset); \ + break; + binary_distence_knn_hc_Substructure(8); + binary_distence_knn_hc_Substructure(16); + binary_distence_knn_hc_Substructure(32); + binary_distence_knn_hc_Substructure(64); + binary_distence_knn_hc_Substructure(128); + binary_distence_knn_hc_Substructure(256); + binary_distence_knn_hc_Substructure(512); +#undef binary_distence_knn_hc_Substructure + default: + binary_distence_knn_hc + (ncodes, ha, a, b, nb, order, true, bitset); + break; + } + break; + + case METRIC_Superstructure: + switch (ncodes) { +#define binary_distence_knn_hc_Superstructure(ncodes) \ + case ncodes: \ + binary_distence_knn_hc \ + (ncodes, ha, a, b, nb, order, true, bitset); \ + break; + binary_distence_knn_hc_Superstructure(8); + binary_distence_knn_hc_Superstructure(16); + binary_distence_knn_hc_Superstructure(32); + binary_distence_knn_hc_Superstructure(64); + binary_distence_knn_hc_Superstructure(128); + binary_distence_knn_hc_Superstructure(256); + binary_distence_knn_hc_Superstructure(512); +#undef binary_distence_knn_hc_Superstructure + default: + binary_distence_knn_hc + (ncodes, ha, a, b, nb, order, true, bitset); + break; + } + break; + + default: + break; } +} } diff --git a/core/src/index/thirdparty/faiss/utils/distances.cpp b/core/src/index/thirdparty/faiss/utils/distances.cpp index 68bc6d9b82..ac0e3b8321 100644 --- a/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -154,50 +154,68 @@ static void knn_inner_product_sse (const float * x, size_t k = res->k; size_t thread_max_num = omp_get_max_threads(); - if (nx < 4) { - // omp for ny - size_t all_hash_size = thread_max_num * k; - float *value = new float[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; + + size_t thread_hash_size = nx * k; + size_t all_hash_size = thread_hash_size * thread_max_num; + float *value = new float[all_hash_size]; + int64_t *labels = new int64_t[all_hash_size]; + + // init hash + for (size_t i = 0; i < all_hash_size; i++) { + value[i] = -1.0 / 0.0; + labels[i] = -1; + } - for (size_t i = 0; i < nx; i++) { - // init hash - for (size_t i = 0; i < all_hash_size; i++) { - value[i] = -1.0 / 0.0; - } - const float *x_i = x + i * d; #pragma omp parallel for - for (size_t j = 0; j < ny; j++) { - if(!bitset || !bitset->test(j)) { - const float *y_j = y + j * d; - float ip = fvec_inner_product (x_i, y_j, d); + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)) { + size_t thread_no = omp_get_thread_num(); + const float *y_j = y + j * d; + for (size_t i = 0; i < nx; i++) { + const float *x_i = x + i * d; + float ip = fvec_inner_product (x_i, y_j, d); - size_t thread_no = omp_get_thread_num(); - float * __restrict val_ = value + thread_no * k; - int64_t * __restrict ids_ = labels + thread_no * k; - if (ip > val_[0]) { - minheap_pop (k, val_, ids_); - minheap_push (k, val_, ids_, ip, j); - } + float * val_ = value + thread_no * thread_hash_size + i * k; + int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + if (ip > val_[0]) { + minheap_pop (k, val_, ids_); + minheap_push (k, val_, ids_, ip, j); } } - - // merge hash - float * __restrict simi = res->get_val(i); - int64_t * __restrict idxi = res->get_ids (i); - minheap_heapify (k, simi, idxi); - for (size_t i = 0; i < all_hash_size; i++) { - if (value[i] > simi[0]) { - minheap_pop (k, simi, idxi); - minheap_push (k, simi, idxi, value[i], labels[i]); - } - } - minheap_reorder (k, simi, idxi); } - delete[] value; - delete[] labels; + } - } else { + for (size_t t = 1; t < thread_max_num; t++) { + // merge hash + for (size_t i = 0; i < nx; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + float *value_x_t = value_x + t * thread_hash_size; + int64_t *labels_x_t = labels_x + t * thread_hash_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] > value_x[0]) { + minheap_pop (k, value_x, labels_x); + minheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + for (size_t i = 0; i < nx; i++) { + float * value_x = value + i * k; + int64_t * labels_x = labels + i * k; + minheap_reorder (k, value_x, labels_x); + } + + // copy result + memcpy(res->val, value, thread_hash_size * sizeof(float)); + memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t)); + + delete[] value; + delete[] labels; + +/* + else { size_t check_period = InterruptCallback::get_period_hint (ny * d); check_period *= thread_max_num; @@ -230,6 +248,7 @@ static void knn_inner_product_sse (const float * x, InterruptCallback::check (); } } + */ } static void knn_L2sqr_sse ( @@ -242,55 +261,68 @@ static void knn_L2sqr_sse ( size_t k = res->k; size_t thread_max_num = omp_get_max_threads(); - if (nx < 4) { - // omp for ny - size_t all_hash_size = thread_max_num * k; - float *value = new float[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; - for (size_t i = 0; i < nx; i++) { - // init hash - for (size_t i = 0; i < all_hash_size; i++) { - value[i] = 1.0 / 0.0; - } - for (size_t i = 0; i < k; i++) { - labels[i] = -1; - } - const float *x_i = x + i * d; + size_t thread_hash_size = nx * k; + size_t all_hash_size = thread_hash_size * thread_max_num; + float *value = new float[all_hash_size]; + int64_t *labels = new int64_t[all_hash_size]; + + // init hash + for (size_t i = 0; i < all_hash_size; i++) { + value[i] = 1.0 / 0.0; + labels[i] = -1; + } + #pragma omp parallel for - for (size_t j = 0; j < ny; j++) { - if(!bitset || !bitset->test(j)) { - const float *y_j = y + j * d; - float disij = fvec_L2sqr (x_i, y_j, d); + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)) { + size_t thread_no = omp_get_thread_num(); + const float *y_j = y + j * d; + for (size_t i = 0; i < nx; i++) { + const float *x_i = x + i * d; + float disij = fvec_L2sqr (x_i, y_j, d); - size_t thread_no = omp_get_thread_num(); - float * __restrict val_ = value + thread_no * k; - int64_t * __restrict ids_ = labels + thread_no * k; - if (disij < val_[0]) { - maxheap_pop (k, val_, ids_); - maxheap_push (k, val_, ids_, disij, j); - } + float * val_ = value + thread_no * thread_hash_size + i * k; + int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + if (disij < val_[0]) { + maxheap_pop (k, val_, ids_); + maxheap_push (k, val_, ids_, disij, j); } } - - // merge hash - float * __restrict simi = res->get_val(i); - int64_t * __restrict idxi = res->get_ids (i); - memcpy(simi, value, k * sizeof(float)); - memcpy(idxi, labels, k * sizeof(int64_t)); - maxheap_heapify (k, simi, idxi, value, labels, k); - for (size_t i = k; i < all_hash_size; i++) { - if (value[i] < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, value[i], labels[i]); - } - } - maxheap_reorder (k, simi, idxi); } - delete[] value; - delete[] labels; + } - } else { + for (size_t t = 1; t < thread_max_num; t++) { + // merge hash + for (size_t i = 0; i < nx; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + float *value_x_t = value_x + t * thread_hash_size; + int64_t *labels_x_t = labels_x + t * thread_hash_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] < value_x[0]) { + maxheap_pop (k, value_x, labels_x); + maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + for (size_t i = 0; i < nx; i++) { + float * value_x = value + i * k; + int64_t * labels_x = labels + i * k; + maxheap_reorder (k, value_x, labels_x); + } + + // copy result + memcpy(res->val, value, thread_hash_size * sizeof(float)); + memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t)); + + delete[] value; + delete[] labels; + + /* + else { size_t check_period = InterruptCallback::get_period_hint (ny * d); check_period *= thread_max_num; @@ -322,6 +354,7 @@ static void knn_L2sqr_sse ( InterruptCallback::check (); } } + */ } diff --git a/core/src/index/thirdparty/faiss/utils/hamming.cpp b/core/src/index/thirdparty/faiss/utils/hamming.cpp index 6fb4c95d56..0760bc0dd3 100644 --- a/core/src/index/thirdparty/faiss/utils/hamming.cpp +++ b/core/src/index/thirdparty/faiss/utils/hamming.cpp @@ -40,7 +40,7 @@ #include static const size_t BLOCKSIZE_QUERY = 8192; - +static const size_t size_1M = 1 * 1024 * 1024; namespace faiss { @@ -278,50 +278,69 @@ void hammings_knn_hc ( ConcurrentBitsetPtr bitset = nullptr) { size_t k = ha->k; - if (init_heap) ha->heapify (); - int thread_max_num = omp_get_max_threads(); - if (ha->nh < 4) { - // omp for n2 - int all_hash_size = thread_max_num * k; + if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) { + int thread_max_num = omp_get_max_threads(); + // init hash + size_t thread_hash_size = ha->nh * k; + size_t all_hash_size = thread_hash_size * thread_max_num; hamdis_t *value = new hamdis_t[all_hash_size]; int64_t *labels = new int64_t[all_hash_size]; + for (int i = 0; i < all_hash_size; i++) { + value[i] = 0x7fffffff; + labels[i] = -1; + } + + HammingComputer *hc = new HammingComputer[ha->nh]; + for (size_t i = 0; i < ha->nh; i++) { + hc[i].set(bs1 + i * bytes_per_code, bytes_per_code); + } - for (int i = 0; i < ha->nh; i++) { - HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code); - // init hash - for (int i = 0; i < all_hash_size; i++) { - value[i] = 0x7fffffff; - } #pragma omp parallel for - for (size_t j = 0; j < n2; j++) { - if(!bitset || !bitset->test(j)) { - const uint8_t * bs2_ = bs2 + j * bytes_per_code; - hamdis_t dis = hc.hamming (bs2_); + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + int thread_no = omp_get_thread_num(); - int thread_no = omp_get_thread_num(); - hamdis_t * __restrict val_ = value + thread_no * k; - int64_t * __restrict ids_ = labels + thread_no * k; + const uint8_t * bs2_ = bs2 + j * bytes_per_code; + for (size_t i = 0; i < ha->nh; i++) { + hamdis_t dis = hc[i].hamming (bs2_); + + hamdis_t * val_ = value + thread_no * thread_hash_size + i * k; + int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; if (dis < val_[0]) { faiss::maxheap_pop (k, val_, ids_); faiss::maxheap_push (k, val_, ids_, dis, j); } } } + } + + for (size_t t = 1; t < thread_max_num; t++) { // merge hash - hamdis_t * __restrict bh_val_ = ha->val + i * k; - int64_t * __restrict bh_ids_ = ha->ids + i * k; - for (int i = 0; i < all_hash_size; i++) { - if (value[i] < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, value[i], labels[i]); + for (size_t i = 0; i < ha->nh; i++) { + hamdis_t * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + hamdis_t *value_x_t = value_x + t * thread_hash_size; + int64_t *labels_x_t = labels_x + t * thread_hash_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] < value_x[0]) { + faiss::maxheap_pop (k, value_x, labels_x); + faiss::maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } } } } + + // copy result + memcpy(ha->val, value, thread_hash_size * sizeof(hamdis_t)); + memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t)); + + delete[] hc; delete[] value; delete[] labels; } else { + if (init_heap) ha->heapify (); const size_t block_size = hamming_batch_size; for (size_t j0 = 0; j0 < n2; j0 += block_size) { const size_t j1 = std::min(j0 + block_size, n2); @@ -426,48 +445,46 @@ void hammings_knn_hc_1 ( const size_t nwords = 1; size_t k = ha->k; - if (init_heap) { ha->heapify (); } int thread_max_num = omp_get_max_threads(); - if (ha->nh < 4) { + if (ha->nh == 1) { // omp for n2 int all_hash_size = thread_max_num * k; hamdis_t *value = new hamdis_t[all_hash_size]; int64_t *labels = new int64_t[all_hash_size]; - for (int i = 0; i < ha->nh; i++) { - // init hash - for (int i = 0; i < all_hash_size; i++) { - value[i] = 0x7fffffff; - } - const uint64_t bs1_ = bs1 [i]; + // init hash + for (int i = 0; i < all_hash_size; i++) { + value[i] = 0x7fffffff; + } + const uint64_t bs1_ = bs1[0]; #pragma omp parallel for - for (size_t j = 0; j < n2; j++) { - if(!bitset || !bitset->test(j)) { - hamdis_t dis = popcount64 (bs1_ ^ bs2[j]); + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + hamdis_t dis = popcount64 (bs1_ ^ bs2[j]); - int thread_no = omp_get_thread_num(); - hamdis_t * __restrict val_ = value + thread_no * k; - int64_t * __restrict ids_ = labels + thread_no * k; - if (dis < val_[0]) { - faiss::maxheap_pop (k, val_, ids_); - faiss::maxheap_push (k, val_, ids_, dis, j); - } - } - } - // merge hash - hamdis_t * __restrict bh_val_ = ha->val + i * k; - int64_t * __restrict bh_ids_ = ha->ids + i * k; - for (int i = 0; i < all_hash_size; i++) { - if (value[i] < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, value[i], labels[i]); + int thread_no = omp_get_thread_num(); + hamdis_t * __restrict val_ = value + thread_no * k; + int64_t * __restrict ids_ = labels + thread_no * k; + if (dis < val_[0]) { + faiss::maxheap_pop (k, val_, ids_); + faiss::maxheap_push (k, val_, ids_, dis, j); } } } + // merge hash + hamdis_t * __restrict bh_val_ = ha->val; + int64_t * __restrict bh_ids_ = ha->ids; + for (int i = 0; i < all_hash_size; i++) { + if (value[i] < bh_val_[0]) { + faiss::maxheap_pop (k, bh_val_, bh_ids_); + faiss::maxheap_push (k, bh_val_, bh_ids_, value[i], labels[i]); + } + } + delete[] value; delete[] labels; diff --git a/core/src/index/thirdparty/faiss/utils/superstructure-inl.h b/core/src/index/thirdparty/faiss/utils/superstructure-inl.h index 8df39992ae..1ebf8946a5 100644 --- a/core/src/index/thirdparty/faiss/utils/superstructure-inl.h +++ b/core/src/index/thirdparty/faiss/utils/superstructure-inl.h @@ -8,13 +8,13 @@ namespace faiss { SuperstructureComputer8 (const uint8_t *a8, int code_size) { set (a8, code_size); - accu_den = (float)(popcount64 (a0)); } void set (const uint8_t *a8, int code_size) { assert (code_size == 8); const uint64_t *a = (uint64_t *)a8; a0 = a[0]; + accu_den = (float)(popcount64 (a0)); } inline float compute (const uint8_t *b8) const { @@ -35,13 +35,13 @@ namespace faiss { SuperstructureComputer16 (const uint8_t *a8, int code_size) { set (a8, code_size); - accu_den = (float)(popcount64 (a0) + popcount64 (a1)); } void set (const uint8_t *a8, int code_size) { assert (code_size == 16); const uint64_t *a = (uint64_t *)a8; a0 = a[0]; a1 = a[1]; + accu_den = (float)(popcount64 (a0) + popcount64 (a1)); } inline float compute (const uint8_t *b8) const {