From b04b3e942ce1404b34ee3311fb3487e2fa6f635f Mon Sep 17 00:00:00 2001 From: "shengjun.li" <49774184+shengjun1985@users.noreply.github.com> Date: Sat, 11 Apr 2020 12:56:32 +0800 Subject: [PATCH] #1897 add heap_swap_top (#1898) * add heap_swap_top Signed-off-by: shengjun.li * fix wrong code Signed-off-by: shengjun.li --- CHANGELOG.md | 4 +- .../index/thirdparty/faiss/IndexBinaryIVF.cpp | 6 +- .../index/thirdparty/faiss/IndexIVFFlat.cpp | 3 +- .../src/index/thirdparty/faiss/IndexIVFPQ.cpp | 3 +- .../index/thirdparty/faiss/IndexIVFPQR.cpp | 3 +- .../thirdparty/faiss/IndexIVFSpectralHash.cpp | 3 +- core/src/index/thirdparty/faiss/IndexPQ.cpp | 3 +- .../faiss/impl/ProductQuantizer.cpp | 15 +-- .../thirdparty/faiss/impl/ScalarQuantizer.cpp | 6 +- .../thirdparty/faiss/utils/BinaryDistance.cpp | 35 ++++--- .../src/index/thirdparty/faiss/utils/Heap.cpp | 6 +- core/src/index/thirdparty/faiss/utils/Heap.h | 56 ++++++++++- .../thirdparty/faiss/utils/distances.cpp | 93 +++++++------------ .../index/thirdparty/faiss/utils/hamming.cpp | 61 ++++++------ 14 files changed, 148 insertions(+), 149 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71e4efbb93..da80f9eec1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,10 @@ Please mark all change in change log and use the issue from GitHub - \#1789 Fix multi-client search cause server crash - \#1832 Fix crash in tracing module - \#1873 Fix index file serialize to incorrect path -- \#1881 Fix Annoy index search failure +- \#1881 Fix bad alloc when index files lost ## Feature - \#261 Integrate ANNOY into Milvus -- \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure - \#1655 GPU index support delete vectors - \#1660 IVF PQ CPU support deleted vectors searching - \#1661 HNSW support deleted vectors searching @@ -29,6 +28,7 @@ Please mark all change in change log and use the issue from GitHub - \#1882 Add index annoy into http module - \#1885 Optimize knowhere unittest - \#1886 Refactor log on search and insert request +- \#1897 Heap pop and push can be realized by heap_swap_top ## Task diff --git a/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp b/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp index bf97ba0f18..f853933877 100644 --- a/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp +++ b/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp @@ -420,9 +420,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner { uint32_t dis = hc.hamming (codes); if (dis < simi[0]) { - heap_pop (k, simi, idxi); idx_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - heap_push (k, simi, idxi, dis, id); + heap_swap_top (k, simi, idxi, dis, id); nup++; } } @@ -470,9 +469,8 @@ struct IVFBinaryScannerJaccard: BinaryInvertedListScanner { float dis = hc.compute (codes); if (dis < psimi[0]) { - heap_pop (k, psimi, idxi); idx_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - heap_push (k, psimi, idxi, dis, id); + heap_swap_top (k, psimi, idxi, dis, id); nup++; } } diff --git a/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp b/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp index 65bb73d135..4e531be758 100644 --- a/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp +++ b/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp @@ -159,9 +159,8 @@ struct IVFFlatScanner: InvertedListScanner { float dis = metric == METRIC_INNER_PRODUCT ? fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d); if (C::cmp (simi[0], dis)) { - heap_pop (k, simi, idxi); int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - heap_push (k, simi, idxi, dis, id); + heap_swap_top (k, simi, idxi, dis, id); nup++; } } diff --git a/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp b/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp index a50006b689..6b47cde4da 100644 --- a/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp +++ b/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp @@ -805,8 +805,7 @@ struct KnnSearchResults { idx_t id = ids ? ids[j] : (key << 32 | j); if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)id)) return; - heap_pop (k, heap_sim, heap_ids); - heap_push (k, heap_sim, heap_ids, dis, id); + heap_swap_top (k, heap_sim, heap_ids, dis, id); nup++; } } diff --git a/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp b/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp index fe832bbd36..b94e16eac0 100644 --- a/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp +++ b/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp @@ -171,9 +171,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k, float dis = fvec_L2sqr (residual_1, residual_2, d); if (dis < heap_sim[0]) { - maxheap_pop (k, heap_sim, heap_ids); idx_t id_or_pair = store_pairs ? sl : id; - maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair); + maxheap_swap_top (k, heap_sim, heap_ids, dis, id_or_pair); } n_refine ++; } diff --git a/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp b/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp index 534826517b..ddea7f6d87 100644 --- a/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp +++ b/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp @@ -270,9 +270,8 @@ struct IVFScanner: InvertedListScanner { float dis = hc.hamming (codes); if (dis < simi [0]) { - maxheap_pop (k, simi, idxi); int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - maxheap_push (k, simi, idxi, dis, id); + maxheap_swap_top (k, simi, idxi, dis, id); nup++; } } diff --git a/core/src/index/thirdparty/faiss/IndexPQ.cpp b/core/src/index/thirdparty/faiss/IndexPQ.cpp index defaff1b04..49ba8ce675 100644 --- a/core/src/index/thirdparty/faiss/IndexPQ.cpp +++ b/core/src/index/thirdparty/faiss/IndexPQ.cpp @@ -330,8 +330,7 @@ static size_t polysemous_inner_loop ( } if (dis < heap_dis[0]) { - maxheap_pop (k, heap_dis, heap_ids); - maxheap_push (k, heap_dis, heap_ids, dis, bi); + maxheap_swap_top (k, heap_dis, heap_ids, dis, bi); } } b_code += code_size; diff --git a/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp index bbd143611e..379bb78822 100644 --- a/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp +++ b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp @@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes, } if (C::cmp (heap_dis[0], dis)) { - heap_pop (k, heap_dis, heap_ids); - heap_push (k, heap_dis, heap_ids, dis, j); + heap_swap_top (k, heap_dis, heap_ids, dis, j); } } } @@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes, dis += dt[*codes++]; if (C::cmp (heap_dis[0], dis)) { - heap_pop (k, heap_dis, heap_ids); - heap_push (k, heap_dis, heap_ids, dis, j); + heap_swap_top (k, heap_dis, heap_ids, dis, j); } } } @@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq, dt += ksub; } if (C::cmp (heap_dis[0], dis)) { - heap_pop (k, heap_dis, heap_ids); - heap_push (k, heap_dis, heap_ids, dis, j); + heap_swap_top (k, heap_dis, heap_ids, dis, j); } } } @@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq, } if (C::cmp(heap_dis[0], dis)) { - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, j); + heap_swap_top(k, heap_dis, heap_ids, dis, j); } } } @@ -747,8 +743,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes, tab += ksub * ksub; } if (dis < heap_dis[0]) { - maxheap_pop (k, heap_dis, heap_ids); - maxheap_push (k, heap_dis, heap_ids, dis, j); + maxheap_swap_top (k, heap_dis, heap_ids, dis, j); } bcode += code_size; } diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp index f03604a893..53c279bc89 100644 --- a/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp @@ -231,9 +231,8 @@ struct IVFSQScannerIP: InvertedListScanner { float accu = accu0 + dc.query_to_code (codes); if (accu > simi [0]) { - minheap_pop (k, simi, idxi); int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - minheap_push (k, simi, idxi, accu, id); + minheap_swap_top (k, simi, idxi, accu, id); nup++; } } @@ -319,9 +318,8 @@ struct IVFSQScannerL2: InvertedListScanner { float dis = dc.query_to_code (codes); if (dis < simi [0]) { - maxheap_pop (k, simi, idxi); int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - maxheap_push (k, simi, idxi, dis, id); + maxheap_swap_top (k, simi, idxi, dis, id); nup++; } } diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp index 5dabcf026a..d6ebaa44f0 100644 --- a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp @@ -34,12 +34,12 @@ void binary_distence_knn_hc( if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) { int thread_max_num = omp_get_max_threads(); - // init hash - size_t thread_hash_size = ha->nh * k; - size_t all_hash_size = thread_hash_size * thread_max_num; - float *value = new float[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; - for (int i = 0; i < all_hash_size; i++) { + // init heap + size_t thread_heap_size = ha->nh * k; + size_t all_heap_size = thread_heap_size * thread_max_num; + float *value = new float[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + for (int i = 0; i < all_heap_size; i++) { value[i] = 1.0 / 0.0; labels[i] = -1; } @@ -58,35 +58,33 @@ void binary_distence_knn_hc( for (size_t i = 0; i < ha->nh; i++) { tadis_t dis = hc[i].compute (bs2_); - float * val_ = value + thread_no * thread_hash_size + i * k; - int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; if (dis < val_[0]) { - faiss::maxheap_pop (k, val_, ids_); - faiss::maxheap_push (k, val_, ids_, dis, j); + faiss::maxheap_swap_top (k, val_, ids_, dis, j); } } } } for (size_t t = 1; t < thread_max_num; t++) { - // merge hash + // merge heap for (size_t i = 0; i < ha->nh; i++) { float * __restrict value_x = value + i * k; int64_t * __restrict labels_x = labels + i * k; - float *value_x_t = value_x + t * thread_hash_size; - int64_t *labels_x_t = labels_x + t * thread_hash_size; + float *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; for (size_t j = 0; j < k; j++) { if (value_x_t[j] < value_x[0]) { - faiss::maxheap_pop (k, value_x, labels_x); - faiss::maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + faiss::maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); } } } } // copy result - memcpy(ha->val, value, thread_hash_size * sizeof(float)); - memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t)); + memcpy(ha->val, value, thread_heap_size * sizeof(float)); + memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t)); delete[] hc; delete[] value; @@ -111,8 +109,7 @@ void binary_distence_knn_hc( if(!bitset || !bitset->test(j)){ dis = hc.compute (bs2_); if (dis < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, dis, j); } } } diff --git a/core/src/index/thirdparty/faiss/utils/Heap.cpp b/core/src/index/thirdparty/faiss/utils/Heap.cpp index 4a5de5ad36..0b7cfab547 100644 --- a/core/src/index/thirdparty/faiss/utils/Heap.cpp +++ b/core/src/index/thirdparty/faiss/utils/Heap.cpp @@ -46,8 +46,7 @@ void HeapArray::addn (size_t nj, const T *vin, TI j0, for (size_t j = 0; j < nj; j++) { T ip = ip_line [j]; if (C::cmp(simi[0], ip)) { - heap_pop (k, simi, idxi); - heap_push (k, simi, idxi, ip, j + j0); + heap_swap_top (k, simi, idxi, ip, j + j0); } } } @@ -74,8 +73,7 @@ void HeapArray::addn_with_ids ( for (size_t j = 0; j < nj; j++) { T ip = ip_line [j]; if (C::cmp(simi[0], ip)) { - heap_pop (k, simi, idxi); - heap_push (k, simi, idxi, ip, id_line [j]); + heap_swap_top (k, simi, idxi, ip, id_line [j]); } } } diff --git a/core/src/index/thirdparty/faiss/utils/Heap.h b/core/src/index/thirdparty/faiss/utils/Heap.h index e691c36c7f..9962cbc112 100644 --- a/core/src/index/thirdparty/faiss/utils/Heap.h +++ b/core/src/index/thirdparty/faiss/utils/Heap.h @@ -83,6 +83,42 @@ struct CMax { * Basic heap ops: push and pop *******************************************************************/ +/** Pops the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. on output the element at k-1 is undefined. + */ +template inline +void heap_swap_top (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + typename C::T val, typename C::TI ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { + if (C::cmp(val, bh_val[i1])) + break; + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + i = i1; + } + else { + if (C::cmp(val, bh_val[i2])) + break; + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + i = i2; + } + } + bh_val[i] = val; + bh_ids[i] = ids; +} + + /** Pops the top element from the heap defined by bh_val[0..k-1] and * bh_ids[0..k-1]. on output the element at k-1 is undefined. */ @@ -146,6 +182,13 @@ void heap_push (size_t k, /* Partial instanciation for heaps with TI = int64_t */ +template inline +void minheap_swap_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_swap_top > (k, bh_val, bh_ids, val, ids); +} + + template inline void minheap_pop (size_t k, T * bh_val, int64_t * bh_ids) { @@ -160,6 +203,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) } +template inline +void maxheap_swap_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_swap_top > (k, bh_val, bh_ids, val, ids); +} + + template inline void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids) { @@ -251,15 +301,13 @@ void heap_addn (size_t k, if (ids) for (i = 0; i < n; i++) { if (C::cmp (bh_val[0], x[i])) { - heap_pop (k, bh_val, bh_ids); - heap_push (k, bh_val, bh_ids, x[i], ids[i]); + heap_swap_top (k, bh_val, bh_ids, x[i], ids[i]); } } else for (i = 0; i < n; i++) { if (C::cmp (bh_val[0], x[i])) { - heap_pop (k, bh_val, bh_ids); - heap_push (k, bh_val, bh_ids, x[i], i); + heap_swap_top (k, bh_val, bh_ids, x[i], i); } } } diff --git a/core/src/index/thirdparty/faiss/utils/distances.cpp b/core/src/index/thirdparty/faiss/utils/distances.cpp index ac0e3b8321..cf902c7ce2 100644 --- a/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -155,13 +155,13 @@ static void knn_inner_product_sse (const float * x, size_t thread_max_num = omp_get_max_threads(); - size_t thread_hash_size = nx * k; - size_t all_hash_size = thread_hash_size * thread_max_num; - float *value = new float[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; + size_t thread_heap_size = nx * k; + size_t all_heap_size = thread_heap_size * thread_max_num; + float *value = new float[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; - // init hash - for (size_t i = 0; i < all_hash_size; i++) { + // init heap + for (size_t i = 0; i < all_heap_size; i++) { value[i] = -1.0 / 0.0; labels[i] = -1; } @@ -175,27 +175,25 @@ static void knn_inner_product_sse (const float * x, const float *x_i = x + i * d; float ip = fvec_inner_product (x_i, y_j, d); - float * val_ = value + thread_no * thread_hash_size + i * k; - int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; if (ip > val_[0]) { - minheap_pop (k, val_, ids_); - minheap_push (k, val_, ids_, ip, j); + minheap_swap_top (k, val_, ids_, ip, j); } } } } for (size_t t = 1; t < thread_max_num; t++) { - // merge hash + // merge heap for (size_t i = 0; i < nx; i++) { float * __restrict value_x = value + i * k; int64_t * __restrict labels_x = labels + i * k; - float *value_x_t = value_x + t * thread_hash_size; - int64_t *labels_x_t = labels_x + t * thread_hash_size; + float *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; for (size_t j = 0; j < k; j++) { if (value_x_t[j] > value_x[0]) { - minheap_pop (k, value_x, labels_x); - minheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + minheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); } } } @@ -208,8 +206,8 @@ static void knn_inner_product_sse (const float * x, } // copy result - memcpy(res->val, value, thread_hash_size * sizeof(float)); - memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t)); + memcpy(res->val, value, thread_heap_size * sizeof(float)); + memcpy(res->ids, labels, thread_heap_size * sizeof(int64_t)); delete[] value; delete[] labels; @@ -262,13 +260,13 @@ static void knn_L2sqr_sse ( size_t thread_max_num = omp_get_max_threads(); - size_t thread_hash_size = nx * k; - size_t all_hash_size = thread_hash_size * thread_max_num; - float *value = new float[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; + size_t thread_heap_size = nx * k; + size_t all_heap_size = thread_heap_size * thread_max_num; + float *value = new float[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; - // init hash - for (size_t i = 0; i < all_hash_size; i++) { + // init heap + for (size_t i = 0; i < all_heap_size; i++) { value[i] = 1.0 / 0.0; labels[i] = -1; } @@ -282,27 +280,25 @@ static void knn_L2sqr_sse ( const float *x_i = x + i * d; float disij = fvec_L2sqr (x_i, y_j, d); - float * val_ = value + thread_no * thread_hash_size + i * k; - int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; if (disij < val_[0]) { - maxheap_pop (k, val_, ids_); - maxheap_push (k, val_, ids_, disij, j); + maxheap_swap_top (k, val_, ids_, disij, j); } } } } for (size_t t = 1; t < thread_max_num; t++) { - // merge hash + // merge heap for (size_t i = 0; i < nx; i++) { float * __restrict value_x = value + i * k; int64_t * __restrict labels_x = labels + i * k; - float *value_x_t = value_x + t * thread_hash_size; - int64_t *labels_x_t = labels_x + t * thread_hash_size; + float *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; for (size_t j = 0; j < k; j++) { if (value_x_t[j] < value_x[0]) { - maxheap_pop (k, value_x, labels_x); - maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); } } } @@ -315,8 +311,8 @@ static void knn_L2sqr_sse ( } // copy result - memcpy(res->val, value, thread_hash_size * sizeof(float)); - memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t)); + memcpy(res->val, value, thread_heap_size * sizeof(float)); + memcpy(res->ids, labels, thread_heap_size * sizeof(int64_t)); delete[] value; delete[] labels; @@ -408,8 +404,7 @@ static void knn_inner_product_blas ( float dis = *ip_line; if(dis > simi[0]){ - minheap_pop(k, simi, idxi); - minheap_push(k, simi, idxi, dis, j); + minheap_swap_top(k, simi, idxi, dis, j); } } ip_line++; @@ -486,8 +481,7 @@ static void knn_L2sqr_blas (const float * x, dis = corr (dis, i, j); if (dis < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, dis, j); + maxheap_swap_top (k, simi, idxi, dis, j); } } ip_line++; @@ -563,8 +557,7 @@ static void knn_jaccard_blas (const float * x, dis = corr (dis, i, j); if (dis < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, dis, j); + maxheap_swap_top (k, simi, idxi, dis, j); } } ip_line++; @@ -638,20 +631,6 @@ void knn_jaccard (const float * x, } } -void knn_jaccard (const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float_maxheap_array_t * res) -{ - if (d % 4 == 0 && nx < distance_compute_blas_threshold) { -// knn_jaccard_sse (x, y, d, nx, ny, res); - printf("sse_not implemented!\n"); - } else { - NopDistanceCorrection nop; - knn_jaccard_blas (x, y, d, nx, ny, res, nop); - } -} - struct BaseShiftDistanceCorrection { const float *base_shift; float operator()(float dis, size_t /*qno*/, size_t bno) const { @@ -773,8 +752,7 @@ void knn_inner_products_by_idx (const float * x, float ip = fvec_inner_product (x_, y + d * idsi[j], d); if (ip > simi[0]) { - minheap_pop (k, simi, idxi); - minheap_push (k, simi, idxi, ip, idsi[j]); + minheap_swap_top (k, simi, idxi, ip, idsi[j]); } } minheap_reorder (k, simi, idxi); @@ -801,8 +779,7 @@ void knn_L2sqr_by_idx (const float * x, float disij = fvec_L2sqr (x_, y + d * idsi[j], d); if (disij < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, disij, idsi[j]); + maxheap_swap_top (k, simi, idxi, disij, idsi[j]); } } maxheap_reorder (res->k, simi, idxi); diff --git a/core/src/index/thirdparty/faiss/utils/hamming.cpp b/core/src/index/thirdparty/faiss/utils/hamming.cpp index 0760bc0dd3..e6dca07950 100644 --- a/core/src/index/thirdparty/faiss/utils/hamming.cpp +++ b/core/src/index/thirdparty/faiss/utils/hamming.cpp @@ -281,12 +281,12 @@ void hammings_knn_hc ( if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) { int thread_max_num = omp_get_max_threads(); - // init hash - size_t thread_hash_size = ha->nh * k; - size_t all_hash_size = thread_hash_size * thread_max_num; - hamdis_t *value = new hamdis_t[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; - for (int i = 0; i < all_hash_size; i++) { + // init heap + size_t thread_heap_size = ha->nh * k; + size_t all_heap_size = thread_heap_size * thread_max_num; + hamdis_t *value = new hamdis_t[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + for (int i = 0; i < all_heap_size; i++) { value[i] = 0x7fffffff; labels[i] = -1; } @@ -305,35 +305,33 @@ void hammings_knn_hc ( for (size_t i = 0; i < ha->nh; i++) { hamdis_t dis = hc[i].hamming (bs2_); - hamdis_t * val_ = value + thread_no * thread_hash_size + i * k; - int64_t * ids_ = labels + thread_no * thread_hash_size + i * k; + hamdis_t * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; if (dis < val_[0]) { - faiss::maxheap_pop (k, val_, ids_); - faiss::maxheap_push (k, val_, ids_, dis, j); + faiss::maxheap_swap_top (k, val_, ids_, dis, j); } } } } for (size_t t = 1; t < thread_max_num; t++) { - // merge hash + // merge heap for (size_t i = 0; i < ha->nh; i++) { hamdis_t * __restrict value_x = value + i * k; int64_t * __restrict labels_x = labels + i * k; - hamdis_t *value_x_t = value_x + t * thread_hash_size; - int64_t *labels_x_t = labels_x + t * thread_hash_size; + hamdis_t *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; for (size_t j = 0; j < k; j++) { if (value_x_t[j] < value_x[0]) { - faiss::maxheap_pop (k, value_x, labels_x); - faiss::maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + faiss::maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); } } } } // copy result - memcpy(ha->val, value, thread_hash_size * sizeof(hamdis_t)); - memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t)); + memcpy(ha->val, value, thread_heap_size * sizeof(hamdis_t)); + memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t)); delete[] hc; delete[] value; @@ -357,8 +355,7 @@ void hammings_knn_hc ( if(!bitset || !bitset->test(j)){ dis = hc.hamming (bs2_); if (dis < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, dis, j); } } } @@ -452,12 +449,12 @@ void hammings_knn_hc_1 ( int thread_max_num = omp_get_max_threads(); if (ha->nh == 1) { // omp for n2 - int all_hash_size = thread_max_num * k; - hamdis_t *value = new hamdis_t[all_hash_size]; - int64_t *labels = new int64_t[all_hash_size]; + int all_heap_size = thread_max_num * k; + hamdis_t *value = new hamdis_t[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; - // init hash - for (int i = 0; i < all_hash_size; i++) { + // init heap + for (int i = 0; i < all_heap_size; i++) { value[i] = 0x7fffffff; } const uint64_t bs1_ = bs1[0]; @@ -470,18 +467,16 @@ void hammings_knn_hc_1 ( hamdis_t * __restrict val_ = value + thread_no * k; int64_t * __restrict ids_ = labels + thread_no * k; if (dis < val_[0]) { - faiss::maxheap_pop (k, val_, ids_); - faiss::maxheap_push (k, val_, ids_, dis, j); + faiss::maxheap_swap_top (k, val_, ids_, dis, j); } } } - // merge hash + // merge heap hamdis_t * __restrict bh_val_ = ha->val; int64_t * __restrict bh_ids_ = ha->ids; - for (int i = 0; i < all_hash_size; i++) { + for (int i = 0; i < all_heap_size; i++) { if (value[i] < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, value[i], labels[i]); + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, value[i], labels[i]); } } @@ -502,8 +497,7 @@ void hammings_knn_hc_1 ( if(!bitset || !bitset->test(j)){ dis = popcount64 (bs1_ ^ *bs2_); if (dis < bh_val_0) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, dis, j); bh_val_0 = bh_val_[0]; } } @@ -849,8 +843,7 @@ static void hamming_dis_inner_loop ( int ndiff = hc.hamming (cb); cb += code_size; if (ndiff < bh_val_[0]) { - maxheap_pop (k, bh_val_, bh_ids_); - maxheap_push (k, bh_val_, bh_ids_, ndiff, j); + maxheap_swap_top (k, bh_val_, bh_ids_, ndiff, j); } } }