From a2875f9d957b24b377d6e1cc6286308eeae71f1c Mon Sep 17 00:00:00 2001 From: "shengjun.li" Date: Fri, 23 Apr 2021 18:03:49 +0800 Subject: [PATCH] Update knowhere (#5006) Import performance of ivf::train and hnsw, and fix bugs Signed-off-by: shengjun.li --- .../index/vector_index/ConfAdapter.cpp | 7 +- .../knowhere/index/vector_index/IndexHNSW.cpp | 48 +- .../knowhere/index/vector_index/IndexHNSW.h | 1 - .../vector_index/adapter/SptagAdapter.cpp | 4 +- .../thirdparty/faiss/utils/distances.cpp | 30 +- .../core/src/index/thirdparty/hnswlib/LICENSE | 201 +++ .../src/index/thirdparty/hnswlib/hnswalg.h | 145 +- .../src/index/thirdparty/hnswlib/hnswalg_nm.h | 1227 ----------------- .../src/index/thirdparty/hnswlib/hnswlib.h | 8 +- .../src/index/thirdparty/hnswlib/hnswlib_nm.h | 99 -- .../core/src/index/unittest/test_hnsw.cpp | 14 + 11 files changed, 324 insertions(+), 1460 deletions(-) create mode 100644 internal/core/src/index/thirdparty/hnswlib/LICENSE delete mode 100644 internal/core/src/index/thirdparty/hnswlib/hnswalg_nm.h delete mode 100644 internal/core/src/index/thirdparty/hnswlib/hnswlib_nm.h diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp index 8befcc51ed..a2a77a679c 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp @@ -27,6 +27,7 @@ namespace knowhere { static const int64_t MIN_NBITS = 1; static const int64_t MAX_NBITS = 16; +static const int64_t DEFAULT_NBITS = 8; static const int64_t MIN_NLIST = 1; static const int64_t MAX_NLIST = 65536; static const int64_t MIN_NPROBE = 1; @@ -91,7 +92,7 @@ ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode m int64_t MatchNlist(int64_t size, int64_t nlist) { - const int64_t MIN_POINTS_PER_CENTROID = 40; + const int64_t MIN_POINTS_PER_CENTROID = 39; if (nlist * MIN_POINTS_PER_CENTROID > size) { // nlist is too large, adjust to a proper value @@ -146,9 +147,7 @@ IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod bool IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { - const int64_t DEFAULT_NBITS = 8; oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS; - return IVFConfAdapter::CheckTrain(oricfg, mode); } @@ -161,7 +160,7 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { CheckIntByRange(knowhere::IndexParams::nbits, MIN_NBITS, MAX_NBITS); auto rows = oricfg[knowhere::meta::ROWS].get(); - auto nbits = oricfg[knowhere::IndexParams::nbits].get(); + auto nbits = oricfg.count(IndexParams::nbits) ? oricfg[IndexParams::nbits].get() : DEFAULT_NBITS; oricfg[knowhere::IndexParams::nbits] = MatchNbits(rows, nbits); auto m = oricfg[knowhere::IndexParams::m].get(); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp index 8e32db9370..881b3b4bfd 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp @@ -83,8 +83,6 @@ IndexHNSW::Load(const BinarySet& index_binary) { } // LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Load finished, show statistics:"; // LOG_KNOWHERE_DEBUG_ << hnsw_stats->ToString(); - - normalize = index_->metric_type_ == 1; // 1 == InnerProduct } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } @@ -102,7 +100,6 @@ IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) { space = new hnswlib::L2Space(dim); } else if (metric_type == Metric::IP) { space = new hnswlib::InnerProductSpace(dim); - normalize = true; } else { KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type); } @@ -142,7 +139,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais if (!index_) { KNOWHERE_THROW_MSG("index not initialize or trained"); } - GET_TENSOR_DATA(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) size_t k = config[meta::TOPK].get(); size_t id_size = sizeof(int64_t) * k; @@ -159,44 +156,39 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais } index_->setEf(config[IndexParams::ef].get()); - - using P = std::pair; - auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; }; + bool transform = (index_->metric_type_ == 1); // InnerProduct: 1 std::chrono::high_resolution_clock::time_point query_start, query_end; query_start = std::chrono::high_resolution_clock::now(); #pragma omp parallel for for (unsigned int i = 0; i < rows; ++i) { - std::vector

ret; - const float* single_query = reinterpret_cast(p_data) + i * Dim(); - + auto single_query = (float*)p_data + i * dim; + std::priority_queue> rst; if (STATISTICS_LEVEL >= 3) { - ret = index_->searchKnn(single_query, k, compare, bitset, query_stats[i]); + rst = index_->searchKnn(single_query, k, bitset, query_stats[i]); } else { auto dummy_stat = hnswlib::StatisticsInfo(); - ret = index_->searchKnn(single_query, k, compare, bitset, dummy_stat); + rst = index_->searchKnn(single_query, k, bitset, dummy_stat); } + size_t rst_size = rst.size(); - while (ret.size() < k) { - ret.emplace_back(std::make_pair(-1, -1)); + auto p_single_dis = p_dist + i * k; + auto p_single_id = p_id + i * k; + size_t idx = rst_size - 1; + while (!rst.empty()) { + auto& it = rst.top(); + p_single_dis[idx] = transform ? (1 - it.first) : it.first; + p_single_id[idx] = it.second; + rst.pop(); + idx--; } - std::vector dist; - std::vector ids; + MapOffsetToUid(p_single_id, rst_size); - if (normalize) { - std::transform(ret.begin(), ret.end(), std::back_inserter(dist), - [](const std::pair& e) { return float(1 - e.first); }); - } else { - std::transform(ret.begin(), ret.end(), std::back_inserter(dist), - [](const std::pair& e) { return e.first; }); + for (idx = rst_size; idx < k; idx++) { + p_single_dis[idx] = float(1.0 / 0.0); + p_single_id[idx] = -1; } - std::transform(ret.begin(), ret.end(), std::back_inserter(ids), - [](const std::pair& e) { return e.second; }); - - MapOffsetToUid(ids.data(), ids.size()); - memcpy(p_dist + i * k, dist.data(), dist_size); - memcpy(p_id + i * k, ids.data(), id_size); } query_end = std::chrono::high_resolution_clock::now(); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h index e26b90144a..43bf8c7a5b 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h @@ -56,7 +56,6 @@ class IndexHNSW : public VecIndex { ClearStatistics() override; private: - bool normalize = false; std::shared_ptr> index_; }; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp index 12932e7cde..9a3f1bb552 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp @@ -19,10 +19,10 @@ std::shared_ptr ConvertToMetadataSet(const DatasetPtr& dataset_ptr) { auto elems = dataset_ptr->Get(meta::ROWS); - auto p_id = (int64_t*)malloc(sizeof(int64_t) * elems); + auto p_id = new int64_t[elems]; for (int64_t i = 0; i < elems; ++i) p_id[i] = i; - auto p_offset = (int64_t*)malloc(sizeof(int64_t) * (elems + 1)); + auto p_offset = new int64_t[elems + 1]; for (int64_t i = 0; i <= elems; ++i) p_offset[i] = i * 8; std::shared_ptr metaset( diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances.cpp b/internal/core/src/index/thirdparty/faiss/utils/distances.cpp index d09778e49f..8da34b9b57 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/internal/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -1098,12 +1098,16 @@ void elkan_L2_sse ( return (i > j) ? data[j + i * (i - 1) / 2] : data[i + j * (j - 1) / 2]; }; -#pragma omp parallel for - for (size_t i = j0 + 1; i < j1; i++) { - const float *y_i = y + i * d; - for (size_t j = j0; j < i; j++) { - const float *y_j = y + j * d; - Y(i, j) = sqrt(fvec_L2sqr(y_i, y_j, d)); +#pragma omp parallel + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + for (size_t i = j0 + 1 + rank; i < j1; i += nt) { + const float *y_i = y + i * d; + for (size_t j = j0; j < i; j++) { + const float *y_j = y + j * d; + Y(i, j) = fvec_L2sqr(y_i, y_j, d); + } } } @@ -1112,18 +1116,22 @@ void elkan_L2_sse ( const float *x_i = x + i * d; int64_t ids_i = j0; - float val_i = sqrt(fvec_L2sqr(x_i, y + j0 * d, d)); - float val_i_2 = val_i * 2; + float val_i = fvec_L2sqr(x_i, y + j0 * d, d); + float val_i_time_4 = val_i * 4; for (size_t j = j0 + 1; j < j1; j++) { - if (val_i_2 <= Y(ids_i, j)) { + if (val_i_time_4 <= Y(ids_i, j)) { continue; } const float *y_j = y + j * d; - float disij = sqrt(fvec_L2sqr(x_i, y_j, d)); + float disij = fvec_L2sqr(x_i, y_j, d / 2); + if (disij >= val_i) { + continue; + } + disij += fvec_L2sqr(x_i + d / 2, y_j + d / 2, d - d / 2); if (disij < val_i) { ids_i = j; val_i = disij; - val_i_2 = val_i * 2; + val_i_time_4 = val_i * 4; } } diff --git a/internal/core/src/index/thirdparty/hnswlib/LICENSE b/internal/core/src/index/thirdparty/hnswlib/LICENSE new file mode 100644 index 0000000000..8dada3edaf --- /dev/null +++ b/internal/core/src/index/thirdparty/hnswlib/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/internal/core/src/index/thirdparty/hnswlib/hnswalg.h b/internal/core/src/index/thirdparty/hnswlib/hnswalg.h index de4435e163..90a81eff09 100644 --- a/internal/core/src/index/thirdparty/hnswlib/hnswalg.h +++ b/internal/core/src/index/thirdparty/hnswlib/hnswalg.h @@ -317,50 +317,54 @@ class HierarchicalNSW : public AlgorithmInterface { return top_candidates; } - void getNeighborsByHeuristic2( + std::vector + getNeighborsByHeuristic2 ( std::priority_queue, std::vector>, CompareByFirst> &top_candidates, const size_t M) { - if (top_candidates.size() < M) { - return; - } - std::priority_queue> queue_closest; - std::vector> return_list; - while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); - top_candidates.pop(); - } + std::vector return_list; - while (queue_closest.size()) { - if (return_list.size() >= M) - break; - std::pair curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); - bool good = true; - for (std::pair second_pair : return_list) { - dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), - dist_func_param_);; - if (curdist < dist_to_query) { - good = false; - break; + if (top_candidates.size() < M) { + return_list.resize(top_candidates.size()); + + for (int i = static_cast(top_candidates.size() - 1); i >= 0; i--) { + return_list[i] = top_candidates.top().second; + top_candidates.pop(); + } + + } else if (M > 0) { + return_list.reserve(M); + + std::vector> queue_closest; + queue_closest.resize(top_candidates.size()); + for (int i = static_cast(top_candidates.size() - 1); i >= 0; i--) { + queue_closest[i] = top_candidates.top(); + top_candidates.pop(); + } + + for (std::pair ¤t_pair: queue_closest) { + bool good = true; + for (tableint id : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(id), + getDataByInternalId(current_pair.second), + dist_func_param_); + if (curdist < current_pair.first) { + good = false; + break; + } + } + if (good) { + return_list.push_back(current_pair.second); + if (return_list.size() >= M) { + break; + } } } - if (good) { - return_list.push_back(curent_pair); - } - - } - for (std::pair curent_pair : return_list) { - - top_candidates.emplace(-curent_pair.first, curent_pair.second); - } + return return_list; } - linklistsizeint *get_linklist0(tableint internal_id) const { return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); }; @@ -373,21 +377,17 @@ class HierarchicalNSW : public AlgorithmInterface { return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); }; - void mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - int level) { + tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level) { size_t Mcurmax = level ? maxM_ : maxM0_; - getNeighborsByHeuristic2(top_candidates, M_); - if (top_candidates.size() > M_) + + std::vector selectedNeighbors(getNeighborsByHeuristic2(top_candidates, M_)); + if (selectedNeighbors.size() > M_) throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - std::vector selectedNeighbors; - selectedNeighbors.reserve(M_); - while (top_candidates.size() > 0) { - selectedNeighbors.push_back(top_candidates.top().second); - top_candidates.pop(); - } + tableint next_closest_entry_point = selectedNeighbors.front(); { linklistsizeint *ll_cur; @@ -451,15 +451,11 @@ class HierarchicalNSW : public AlgorithmInterface { dist_func_param_), data[j]); } - getNeighborsByHeuristic2(candidates, Mcurmax); - - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; + std::vector selected(getNeighborsByHeuristic2(candidates, Mcurmax)); + setListCount(ll_other, static_cast(selected.size())); + for (size_t idx = 0; idx < selected.size(); idx++) { + data[idx] = selected[idx]; } - setListCount(ll_other, indx); // Nearest K: /*int indx = -1; for (int j = 0; j < sz_link_list_other; j++) { @@ -475,6 +471,8 @@ class HierarchicalNSW : public AlgorithmInterface { } } + + return next_closest_entry_point; } std::mutex global; @@ -499,17 +497,18 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector(new_max_elements).swap(link_list_locks_); - // Reallocate base layer - data_level0_memory_ = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; // Reallocate all other layers - linkLists_ = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); - if (linkLists_ == nullptr) + char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + linkLists_ = linkLists_new; - max_elements_=new_max_elements; + max_elements_ = new_max_elements; } @@ -814,9 +813,7 @@ class HierarchicalNSW : public AlgorithmInterface { } std::unique_lock lock_el(link_list_locks_[cur_c]); - int curlevel = getRandomLevel(mult_); - if (level > 0) - curlevel = level; + int curlevel = (level > 0) ? level : getRandomLevel(mult_); element_levels_[cur_c] = curlevel; @@ -881,9 +878,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( currObj, data_point, level); - mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); - - currObj = top_candidates.top().second; + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); } } else { // Do nothing for the first element @@ -956,24 +951,6 @@ class HierarchicalNSW : public AlgorithmInterface { return result; }; - template - std::vector> - searchKnn(const void* query_data, size_t k, Comp comp, const faiss::BitsetView bitset, StatisticsInfo &stats) { - std::vector> result; - if (cur_element_count == 0) return result; - - auto ret = searchKnn(query_data, k, bitset, stats); - - while (!ret.empty()) { - result.push_back(ret.top()); - ret.pop(); - } - - std::sort(result.begin(), result.end(), comp); - - return result; - } - int64_t cal_size() { int64_t ret = 0; ret += sizeof(*this); diff --git a/internal/core/src/index/thirdparty/hnswlib/hnswalg_nm.h b/internal/core/src/index/thirdparty/hnswlib/hnswalg_nm.h deleted file mode 100644 index 1f6c45b9d2..0000000000 --- a/internal/core/src/index/thirdparty/hnswlib/hnswalg_nm.h +++ /dev/null @@ -1,1227 +0,0 @@ -#pragma once - -#include "visited_list_pool.h" -#include "hnswlib_nm.h" -#include -#include -#include -#include - -#include "knowhere/index/vector_index/helpers/FaissIO.h" -#include "faiss/impl/ScalarQuantizer.h" -#include "faiss/impl/ScalarQuantizerCodec.h" - -namespace hnswlib_nm { - - typedef unsigned int tableint; - typedef unsigned int linklistsizeint; - - using QuantizerClass = faiss::QuantizerTemplate; - using DCClassIP = faiss::DCTemplate, 1>; - using DCClassL2 = faiss::DCTemplate, 1>; - - template - class HierarchicalNSW_NM : public AlgorithmInterface { - public: - HierarchicalNSW_NM(SpaceInterface *s) { - } - - HierarchicalNSW_NM(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { - loadIndex(location, s, max_elements); - } - - HierarchicalNSW_NM(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), element_levels_(max_elements) { - // linxj - space = s; - if (auto x = dynamic_cast(s)) { - metric_type_ = 0; - } else if (auto x = dynamic_cast(s)) { - metric_type_ = 1; - } else { - metric_type_ = 100; - } - - max_elements_ = max_elements; - - has_deletions_=false; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - M_ = M; - maxM_ = M_; - maxM0_ = M_ * 2; - ef_construction_ = std::max(ef_construction,M_); - ef_ = 10; - - is_sq8_ = false; - sq_ = nullptr; - - level_generator_.seed(random_seed); - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_; // + sizeof(labeltype); + data_size_;; -// label_offset_ = size_links_level0_; - - data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); - - cur_element_count = 0; - - visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, max_elements); - - - - //initializations for special treatment of the first node - enterpoint_node_ = -1; - maxlevel_ = -1; - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: HierarchicalNSW_NM failed to allocate linklists"); - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - mult_ = 1 / log(1.0 * M_); - revSize_ = 1.0 / mult_; - } - - struct CompareByFirst { - constexpr bool operator()(std::pair const &a, - std::pair const &b) const noexcept { - return a.first < b.first; - } - }; - - ~HierarchicalNSW_NM() { - - free(data_level0_memory_); - for (tableint i = 0; i < cur_element_count; i++) { - if (element_levels_[i] > 0) - free(linkLists_[i]); - } - free(linkLists_); - delete visited_list_pool_; - - if (sq_) delete sq_; - - // linxj: delete - delete space; - } - - // linxj: use for free resource - SpaceInterface *space; - size_t metric_type_; // 0:l2, 1:ip - - size_t max_elements_; - size_t cur_element_count; - size_t size_data_per_element_; - size_t size_links_per_element_; - - size_t M_; - size_t maxM_; - size_t maxM0_; - size_t ef_construction_; - - bool is_sq8_ = false; - faiss::ScalarQuantizer *sq_ = nullptr; - - double mult_, revSize_; - int maxlevel_; - - - VisitedListPool *visited_list_pool_; - std::mutex cur_element_count_guard_; - - std::vector link_list_locks_; - tableint enterpoint_node_; - - - size_t size_links_level0_; - - - char *data_level0_memory_; - char **linkLists_; - std::vector element_levels_; - - size_t data_size_; - - bool has_deletions_; - - - DISTFUNC fstdistfunc_; - void *dist_func_param_; - - std::default_random_engine level_generator_; - - inline char *getDataByInternalId(void *pdata, tableint offset) const { - return ((char*)pdata + offset * data_size_); - } - - void SetSq8(const float *trained) { - if (!trained) - throw std::runtime_error("trained sq8 data cannot be null in SetSq8!"); - if (sq_) delete sq_; - is_sq8_ = true; - sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code - sq_->trained.resize((sq_->d) << 1); - memcpy(sq_->trained.data(), trained, sq_->trained.size() * sizeof(float)); - } - - void sq_train(size_t nb, const float *xb, uint8_t *p_codes) { - if (!p_codes) - throw std::runtime_error("p_codes cannot be null in sq_train!"); - if (!xb) - throw std::runtime_error("base vector cannot be null in sq_train!"); - if (sq_) delete sq_; - is_sq8_ = true; - sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code - sq_->train(nb, xb); - sq_->compute_codes(xb, p_codes, nb); - memcpy(p_codes + *(size_t*)dist_func_param_ * nb, sq_->trained.data(), *(size_t*)dist_func_param_ * sizeof(float) * 2); - } - - int getRandomLevel(double reverse_size) { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -log(distribution(level_generator_)) * reverse_size; - return (int) r; - } - - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayer(tableint ep_id, const void *data_point, int layer, void *pdata) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidateSet; - - dist_t lowerBound; - if (!isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); - top_candidates.emplace(dist, ep_id); - lowerBound = dist; - candidateSet.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidateSet.emplace(-lowerBound, ep_id); - } - visited_array[ep_id] = visited_array_tag; - - while (!candidateSet.empty()) { - std::pair curr_el_pair = candidateSet.top(); - if ((-curr_el_pair.first) > lowerBound) { - break; - } - candidateSet.pop(); - - tableint curNodeNum = curr_el_pair.second; - - std::unique_lock lock(link_list_locks_[curNodeNum]); - - int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); - } else { - data = (int*)get_linklist(curNodeNum, layer); - // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); - } - size_t size = getListCount((linklistsizeint*)data); - tableint *datal = (tableint *) (data + 1); -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(pdata, *datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(pdata, *(datal + 1)), _MM_HINT_T0); -#endif - - for (size_t j = 0; j < size; j++) { - tableint candidate_id = *(datal + j); - // if (candidate_id == 0) continue; -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(pdata, *(datal + j + 1)), _MM_HINT_T0); -#endif - if (visited_array[candidate_id] == visited_array_tag) continue; - visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(pdata, candidate_id)); - - dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { - candidateSet.emplace(-dist1, candidate_id); -#ifdef USE_SSE - _mm_prefetch(getDataByInternalId(pdata, candidateSet.top().second), _MM_HINT_T0); -#endif - - if (!isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist1, candidate_id); - - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); - - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } - } - } - visited_list_pool_->releaseVisitedList(vl); - - return top_candidates; - } - - template - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, const faiss::BitsetView bitset, void *pdata) const { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - faiss::SQDistanceComputer *sqdc = nullptr; - if (is_sq8_) { - if (metric_type_ == 0) { // L2 - sqdc = new DCClassL2(sq_->d, sq_->trained); - } else if (metric_type_ == 1) { // IP - sqdc = new DCClassIP(sq_->d, sq_->trained); - } else { - throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); - } - sqdc->code_size = sq_->code_size; - sqdc->codes = (uint8_t*)pdata; - sqdc->set_query((const float*)data_point); - } - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidate_set; - - dist_t lowerBound; -// if (!has_deletions || !isMarkedDeleted(ep_id)) { - if (!has_deletions || !bitset.test((faiss::ConcurrentBitset::id_type_t)(ep_id))) { - dist_t dist; - if (is_sq8_) { - dist = (*sqdc)(ep_id); - } else { - dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); - } - lowerBound = dist; - top_candidates.emplace(dist, ep_id); - candidate_set.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidate_set.emplace(-lowerBound, ep_id); - } - - visited_array[ep_id] = visited_array_tag; - - while (!candidate_set.empty()) { - - std::pair current_node_pair = candidate_set.top(); - - if ((-current_node_pair.first) > lowerBound) { - break; - } - candidate_set.pop(); - - tableint current_node_id = current_node_pair.second; - int *data = (int *) get_linklist0(current_node_id); - size_t size = getListCount((linklistsizeint*)data); - // bool cur_node_deleted = isMarkedDeleted(current_node_id); - -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); -// _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(pdata, *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (data + 2), _MM_HINT_T0); -#endif - - for (size_t j = 1; j <= size; j++) { - int candidate_id = *(data + j); - // if (candidate_id == 0) continue; -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(pdata, *(data + j + 1)), - _MM_HINT_T0);//////////// -#endif - if (!(visited_array[candidate_id] == visited_array_tag)) { - - visited_array[candidate_id] = visited_array_tag; - - dist_t dist; - if (is_sq8_) { - dist = (*sqdc)(candidate_id); - } else { - char *currObj1 = (getDataByInternalId(pdata, candidate_id)); - dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - } - - if (top_candidates.size() < ef || lowerBound > dist) { - candidate_set.emplace(-dist, candidate_id); -#ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_,/////////// - _MM_HINT_T0);//////////////////////// -#endif - -// if (!has_deletions || !isMarkedDeleted(candidate_id)) - if (!has_deletions || (!bitset.test((faiss::ConcurrentBitset::id_type_t)(candidate_id)))) - top_candidates.emplace(dist, candidate_id); - - if (top_candidates.size() > ef) - top_candidates.pop(); - - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } - } - } - } - - visited_list_pool_->releaseVisitedList(vl); - if (is_sq8_) delete sqdc; - return top_candidates; - } - - void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M, tableint *ret, size_t &ret_len, void *pdata) { - if (top_candidates.size() < M) { - while (top_candidates.size() > 0) { - ret[ret_len ++] = top_candidates.top().second; - top_candidates.pop(); - } - return; - } - std::priority_queue> queue_closest; - std::vector> return_list; - while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); - top_candidates.pop(); - } - - while (queue_closest.size()) { - if (return_list.size() >= M) - break; - std::pair curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); - bool good = true; - for (std::pair second_pair : return_list) { - dist_t curdist = - fstdistfunc_(getDataByInternalId(pdata, second_pair.second), - getDataByInternalId(pdata, curent_pair.second), - dist_func_param_);; - if (curdist < dist_to_query) { - good = false; - break; - } - } - if (good) { - return_list.push_back(curent_pair); - ret[ret_len ++] = curent_pair.second; - } - - - } - -// for (std::pair curent_pair : return_list) { -// -// top_candidates.emplace(-curent_pair.first, curent_pair.second); -// } - } - - - linklistsizeint *get_linklist0(tableint internal_id) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_); - }; - - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_); - }; - - linklistsizeint *get_linklist(tableint internal_id, int level) const { - return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); - }; - - void mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - int level, void *pdata) { - - size_t Mcurmax = level ? maxM_ : maxM0_; -// std::vector selectedNeighbors; -// selectedNeighbors.reserve(M_); - tableint *selectedNeighbors = (tableint*)malloc(sizeof(tableint) * M_); - size_t selectedNeighbors_size = 0; - getNeighborsByHeuristic2(top_candidates, M_, selectedNeighbors, selectedNeighbors_size, pdata); - if (selectedNeighbors_size > M_) - throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - -// while (top_candidates.size() > 0) { -// selectedNeighbors.push_back(top_candidates.top().second); -// top_candidates.pop(); -// } - - { - linklistsizeint *ll_cur; - if (level == 0) - ll_cur = get_linklist0(cur_c); - else - ll_cur = get_linklist(cur_c, level); - - if (*ll_cur) { - throw std::runtime_error("The newly inserted element should have blank link list"); - } - setListCount(ll_cur,(unsigned short)selectedNeighbors_size); - tableint *data = (tableint *) (ll_cur + 1); - - - for (size_t idx = 0; idx < selectedNeighbors_size; idx++) { - if (data[idx]) - throw std::runtime_error("Possible memory corruption"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - - data[idx] = selectedNeighbors[idx]; - - } - } - for (size_t idx = 0; idx < selectedNeighbors_size; idx++) { - - std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - - - linklistsizeint *ll_other; - if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); - else - ll_other = get_linklist(selectedNeighbors[idx], level); - - size_t sz_link_list_other = getListCount(ll_other); - - if (sz_link_list_other > Mcurmax) - throw std::runtime_error("Bad value of sz_link_list_other"); - if (selectedNeighbors[idx] == cur_c) - throw std::runtime_error("Trying to connect an element to itself"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - - tableint *data = (tableint *) (ll_other + 1); - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(pdata, cur_c), getDataByInternalId(pdata, selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue, std::vector>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); - - for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(pdata, data[j]), getDataByInternalId(pdata, selectedNeighbors[idx]), - dist_func_param_), data[j]); - } - - size_t indx = 0; - getNeighborsByHeuristic2(candidates, Mcurmax, data, indx, pdata); - -// while (candidates.size() > 0) { -// data[indx] = candidates.top().second; -// candidates.pop(); -// indx++; -// } - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; - } - } - if (indx >= 0) { - data[indx] = cur_c; - } */ - } - - } - } - - std::mutex global; - size_t ef_; - - void setEf(size_t ef) { - ef_ = ef; - } - - - std::priority_queue> searchKnnInternal(void *query_data, int k, dist_t *pdata) { - std::priority_queue> top_candidates; - if (cur_element_count == 0) return top_candidates; - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); - - for (size_t level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - int *data; - data = (int *) get_linklist(currObj,level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); - - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - - if (has_deletions_) { - std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_, pdata); - top_candidates.swap(top_candidates1); - } - else{ - std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_, pdata); - top_candidates.swap(top_candidates1); - } - - while (top_candidates.size() > k) { - top_candidates.pop(); - } - return top_candidates; - }; - - void resizeIndex(size_t new_max_elements){ - if (new_max_elements(new_max_elements).swap(link_list_locks_); - - // Reallocate base layer - char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); - if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); - memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); - free(data_level0_memory_); - data_level0_memory_=data_level0_memory_new; - - // Reallocate all other layers - char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); - if (linkLists_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); - memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); - free(linkLists_); - linkLists_=linkLists_new; - - max_elements_=new_max_elements; - - } - - void saveIndex(milvus::knowhere::MemoryIOWriter& output) { - // write l2/ip calculator - writeBinaryPOD(output, metric_type_); - writeBinaryPOD(output, data_size_); - writeBinaryPOD(output, *((size_t *) dist_func_param_)); - -// writeBinaryPOD(output, offsetLevel0_); - writeBinaryPOD(output, max_elements_); - writeBinaryPOD(output, cur_element_count); - writeBinaryPOD(output, size_data_per_element_); -// writeBinaryPOD(output, label_offset_); -// writeBinaryPOD(output, offsetData_); - writeBinaryPOD(output, maxlevel_); - writeBinaryPOD(output, enterpoint_node_); - writeBinaryPOD(output, maxM_); - - writeBinaryPOD(output, maxM0_); - writeBinaryPOD(output, M_); - writeBinaryPOD(output, mult_); - writeBinaryPOD(output, ef_construction_); - - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - - for (size_t i = 0; i < cur_element_count; i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - writeBinaryPOD(output, linkListSize); - if (linkListSize) - output.write(linkLists_[i], linkListSize); - } - // output.close(); - } - - void loadIndex(milvus::knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { - // linxj: init with metrictype - size_t dim = 100; - readBinaryPOD(input, metric_type_); - readBinaryPOD(input, data_size_); - readBinaryPOD(input, dim); - if (metric_type_ == 0) { - space = new L2Space(dim); - } else if (metric_type_ == 1) { - space = new InnerProductSpace(dim); - } else { - // throw exception - } - fstdistfunc_ = space->get_dist_func(); - dist_func_param_ = space->get_dist_func_param(); - -// readBinaryPOD(input, offsetLevel0_); - readBinaryPOD(input, max_elements_); - readBinaryPOD(input, cur_element_count); - - size_t max_elements=max_elements_i; - if(max_elements < cur_element_count) - max_elements = max_elements_; - max_elements_ = max_elements; - readBinaryPOD(input, size_data_per_element_); -// readBinaryPOD(input, label_offset_); -// readBinaryPOD(input, offsetData_); - readBinaryPOD(input, maxlevel_); - readBinaryPOD(input, enterpoint_node_); - - readBinaryPOD(input, maxM_); - readBinaryPOD(input, maxM0_); - readBinaryPOD(input, M_); - readBinaryPOD(input, mult_); - readBinaryPOD(input, ef_construction_); - - - // data_size_ = s->get_data_size(); - // fstdistfunc_ = s->get_dist_func(); - // dist_func_param_ = s->get_dist_func_param(); - - // auto pos= input.rp; - - - // /// Optional - check if index is ok: - // - // input.seekg(cur_element_count * size_data_per_element_,input.cur); - // for (size_t i = 0; i < cur_element_count; i++) { - // if(input.tellg() < 0 || input.tellg()>=total_filesize){ - // throw std::runtime_error("Index seems to be corrupted or unsupported"); - // } - // - // unsigned int linkListSize; - // readBinaryPOD(input, linkListSize); - // if (linkListSize != 0) { - // input.seekg(linkListSize,input.cur); - // } - // } - // - // // throw exception if it either corrupted or old index - // if(input.tellg()!=total_filesize) - // throw std::runtime_error("Index seems to be corrupted or unsupported"); - // - // input.clear(); - // - // /// Optional check end - // - // input.seekg(pos,input.beg); - - - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - - - - - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - std::vector(max_elements).swap(link_list_locks_); - - - visited_list_pool_ = new VisitedListPool(1, max_elements); - - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); - element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; - ef_ = 10; - for (size_t i = 0; i < cur_element_count; i++) { -// label_lookup_[getExternalLabel(i)]=i; - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize == 0) { - element_levels_[i] = 0; - - linkLists_[i] = nullptr; - } else { - element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char *) malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); - } - } - - has_deletions_=false; - - for (size_t i = 0; i < cur_element_count; i++) { - if(isMarkedDeleted(i)) - has_deletions_=true; - } - - return; - } - - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - -// writeBinaryPOD(output, offsetLevel0_); - writeBinaryPOD(output, max_elements_); - writeBinaryPOD(output, cur_element_count); - writeBinaryPOD(output, size_data_per_element_); -// writeBinaryPOD(output, label_offset_); -// writeBinaryPOD(output, offsetData_); - writeBinaryPOD(output, maxlevel_); - writeBinaryPOD(output, enterpoint_node_); - writeBinaryPOD(output, maxM_); - - writeBinaryPOD(output, maxM0_); - writeBinaryPOD(output, M_); - writeBinaryPOD(output, mult_); - writeBinaryPOD(output, ef_construction_); - - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - - for (size_t i = 0; i < cur_element_count; i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - writeBinaryPOD(output, linkListSize); - if (linkListSize) - output.write(linkLists_[i], linkListSize); - } - output.close(); - } - - void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { - std::ifstream input(location, std::ios::binary); - - if (!input.is_open()) - throw std::runtime_error("Cannot open file"); - - // get file size: - input.seekg(0,input.end); - std::streampos total_filesize=input.tellg(); - input.seekg(0,input.beg); - -// readBinaryPOD(input, offsetLevel0_); - readBinaryPOD(input, max_elements_); - readBinaryPOD(input, cur_element_count); - - size_t max_elements=max_elements_i; - if(max_elements < cur_element_count) - max_elements = max_elements_; - max_elements_ = max_elements; - readBinaryPOD(input, size_data_per_element_); -// readBinaryPOD(input, label_offset_); -// readBinaryPOD(input, offsetData_); - readBinaryPOD(input, maxlevel_); - readBinaryPOD(input, enterpoint_node_); - - readBinaryPOD(input, maxM_); - readBinaryPOD(input, maxM0_); - readBinaryPOD(input, M_); - readBinaryPOD(input, mult_); - readBinaryPOD(input, ef_construction_); - - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - - auto pos=input.tellg(); - - /// Optional - check if index is ok: - - input.seekg(cur_element_count * size_data_per_element_,input.cur); - for (size_t i = 0; i < cur_element_count; i++) { - if(input.tellg() < 0 || input.tellg()>=total_filesize){ - throw std::runtime_error("Index seems to be corrupted or unsupported"); - } - - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize != 0) { - input.seekg(linkListSize,input.cur); - } - } - - // throw exception if it either corrupted or old index - if(input.tellg()!=total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); - - input.clear(); - - /// Optional check end - - input.seekg(pos,input.beg); - - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - std::vector(max_elements).swap(link_list_locks_); - - visited_list_pool_ = new VisitedListPool(1, max_elements); - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); - element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; - ef_ = 10; - for (size_t i = 0; i < cur_element_count; i++) { -// label_lookup_[getExternalLabel(i)]=i; - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize == 0) { - element_levels_[i] = 0; - linkLists_[i] = nullptr; - } else { - element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char *) malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); - } - } - - has_deletions_=false; - - for (size_t i = 0; i < cur_element_count; i++) { - if(isMarkedDeleted(i)) - has_deletions_=true; - } - - input.close(); - return; - } - - /* - template - std::vector getDataByLabel(tableint internal_id, dist_t *pdata) { - // tableint label_c; - // auto search = label_lookup_.find(label); - // if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - // throw std::runtime_error("Label not found"); - // } - // label_c = search->second; - - char* data_ptrv = getDataByInternalId(pdata, internal_id); - size_t dim = *((size_t *) dist_func_param_); - std::vector data; - data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { - data.push_back(*data_ptr); - data_ptr += 1; - } - return data; - } - */ - - static const unsigned char DELETE_MARK = 0x01; - // static const unsigned char REUSE_MARK = 0x10; - /** - * Marks an element with the given label deleted, does NOT really change the current graph. - * @param label - */ - void markDelete(labeltype label) - { - has_deletions_=true; -// auto search = label_lookup_.find(label); -// if (search == label_lookup_.end()) { -// throw std::runtime_error("Label not found"); -// } -// markDeletedInternal(search->second); - markDeletedInternal(label); - } - - /** - * Uses the first 8 bits of the memory for the linked list to store the mark, - * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. - * @param internalId - */ - void markDeletedInternal(tableint internalId) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; - } - - /** - * Remove the deleted mark of the node. - * @param internalId - */ - void unmarkDeletedInternal(tableint internalId) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur &= ~DELETE_MARK; - } - - /** - * Checks the first 8 bits of the memory to see if the element is marked deleted. - * @param internalId - * @return - */ - bool isMarkedDeleted(tableint internalId) const { - unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; - return *ll_cur & DELETE_MARK; - } - - unsigned short int getListCount(linklistsizeint * ptr) const { - return *((unsigned short int *)ptr); - } - - void setListCount(linklistsizeint * ptr, unsigned short int size) const { - *((unsigned short int*)(ptr))=*((unsigned short int *)&size); - } - - size_t getCurrentElementCount() { - return cur_element_count; - } - - void addPoint(void *data_point, labeltype label, size_t base, size_t offset) { - addPoint(data_point, label,-1, base, offset); - } - - tableint addPoint(void *data_point, labeltype label, int level, size_t base, size_t offset) { - tableint cur_c = 0; - { - std::unique_lock lock(cur_element_count_guard_); - if (cur_element_count >= max_elements_) { - throw std::runtime_error("The number of elements exceeds the specified limit"); - }; - -// cur_c = cur_element_count; - cur_c = tableint(base + offset); - cur_element_count++; - -// auto search = label_lookup_.find(label); -// if (search != label_lookup_.end()) { -// std::unique_lock lock_el(link_list_locks_[search->second]); -// has_deletions_ = true; -// markDeletedInternal(search->second); -// } -// label_lookup_[label] = cur_c; - } - - std::unique_lock lock_el(link_list_locks_[cur_c]); - int curlevel = getRandomLevel(mult_); - if (level > 0) - curlevel = level; - - element_levels_[cur_c] = curlevel; - - // prepose non-concurrent operation - memset(data_level0_memory_ + cur_c * size_data_per_element_, 0, size_data_per_element_); -// setExternalLabel(cur_c, label); -// memcpy(getDataByInternalId(cur_c), data_point, data_size_); - if (curlevel) { - linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); - if (linkLists_[cur_c] == nullptr) - throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); - memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); - } - - - std::unique_lock templock(global); - int maxlevelcopy = maxlevel_; - if (curlevel <= maxlevelcopy) - templock.unlock(); - tableint currObj = enterpoint_node_; - tableint enterpoint_copy = enterpoint_node_; - - if ((signed)currObj != -1) { - - if (curlevel < maxlevelcopy) { - - dist_t curdist = fstdistfunc_(getDataByInternalId(data_point, (tableint)offset), getDataByInternalId(data_point, currObj), dist_func_param_); - for (int level = maxlevelcopy; level > curlevel; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist(currObj,level); - int size = getListCount(data); - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(getDataByInternalId(data_point, tableint(offset)), getDataByInternalId(data_point, cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - } - - bool epDeleted = isMarkedDeleted(enterpoint_copy); - for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { - if (level > maxlevelcopy || level < 0) // possible? - throw std::runtime_error("Level error"); - - std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, getDataByInternalId(data_point, (tableint)offset), level, data_point); - if (epDeleted) { - top_candidates.emplace(fstdistfunc_(getDataByInternalId(data_point, (tableint)offset), getDataByInternalId(data_point, enterpoint_copy), dist_func_param_), enterpoint_copy); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); - } - currObj = top_candidates.top().second; - - mutuallyConnectNewElement(getDataByInternalId(data_point, (tableint)offset), cur_c, top_candidates, level, data_point); - } - } else { - // Do nothing for the first element - enterpoint_node_ = 0; - maxlevel_ = curlevel; - } - - //Releasing lock for the maximum level - if (curlevel > maxlevelcopy) { - enterpoint_node_ = cur_c; - maxlevel_ = curlevel; - } - return cur_c; - }; - - std::priority_queue> - searchKnn_NM(const void *query_data, size_t k, const faiss::BitsetView bitset, dist_t *pdata) const { - std::priority_queue> result; - if (cur_element_count == 0) return result; - - tableint currObj = enterpoint_node_; - dist_t curdist; - faiss::SQDistanceComputer *sqdc = nullptr; - if (is_sq8_) { - if (metric_type_ == 0) { // L2 - sqdc = new DCClassL2(sq_->d, sq_->trained); - } else if (metric_type_ == 1) { // IP - sqdc = new DCClassIP(sq_->d, sq_->trained); - } else { - throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); - } - sqdc->code_size = sq_->code_size; - sqdc->set_query((const float*)query_data); - sqdc->codes = (uint8_t*)pdata; - curdist = (*sqdc)(currObj); - } else { - curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); - } - - for (int level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - - data = (unsigned int *) get_linklist(currObj, level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d; - if (is_sq8_) { - d = (*sqdc)(cand); - } else { - d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); - } - - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (!bitset.empty()) { - std::priority_queue, std::vector>, CompareByFirst> - top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, pdata); - top_candidates.swap(top_candidates1); - } - else{ - std::priority_queue, std::vector>, CompareByFirst> - top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, pdata); - top_candidates.swap(top_candidates1); - } - while (top_candidates.size() > k) { - top_candidates.pop(); - } - while (top_candidates.size() > 0) { - std::pair rez = top_candidates.top(); -// result.push(std::pair(rez.first, getExternalLabel(rez.second))); - result.push(std::pair(rez.first, rez.second)); - top_candidates.pop(); - } - if (is_sq8_) delete sqdc; - return result; - }; - - template - std::vector> - searchKnn_NM(const void* query_data, size_t k, Comp comp, const faiss::BitsetView bitset, dist_t *pdata) { - std::vector> result; - if (cur_element_count == 0) return result; - - auto ret = searchKnn_NM(query_data, k, bitset, pdata); - - while (!ret.empty()) { - result.push_back(ret.top()); - ret.pop(); - } - - std::sort(result.begin(), result.end(), comp); - - return result; - } - - int64_t cal_size() { - int64_t ret = 0; - ret += sizeof(*this); - ret += sizeof(*space); - ret += visited_list_pool_->GetSize(); - ret += link_list_locks_.size() * sizeof(std::mutex); - ret += element_levels_.size() * sizeof(int); - ret += max_elements_ * size_data_per_element_; - ret += max_elements_ * sizeof(void*); - for (auto i = 0; i < max_elements_; ++ i) { - ret += linkLists_[i] ? size_links_per_element_ * element_levels_[i] : 0; - } - return ret; - } - }; - -} diff --git a/internal/core/src/index/thirdparty/hnswlib/hnswlib.h b/internal/core/src/index/thirdparty/hnswlib/hnswlib.h index edb7d7b0d4..b508005e86 100644 --- a/internal/core/src/index/thirdparty/hnswlib/hnswlib.h +++ b/internal/core/src/index/thirdparty/hnswlib/hnswlib.h @@ -92,10 +92,10 @@ namespace hnswlib { class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; - virtual std::priority_queue> searchKnn(const void *, size_t, const faiss::BitsetView bitset, hnswlib::StatisticsInfo &stats) const = 0; - template - std::vector> searchKnn(const void*, size_t, Comp, const faiss::BitsetView bitset, hnswlib::StatisticsInfo &stats) { - } + + virtual std::priority_queue> + searchKnn(const void *, size_t, const faiss::BitsetView bitset, hnswlib::StatisticsInfo &stats) const = 0; + virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } diff --git a/internal/core/src/index/thirdparty/hnswlib/hnswlib_nm.h b/internal/core/src/index/thirdparty/hnswlib/hnswlib_nm.h deleted file mode 100644 index 5aaad55f95..0000000000 --- a/internal/core/src/index/thirdparty/hnswlib/hnswlib_nm.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once -#ifndef NO_MANUAL_VECTORIZATION -#ifdef __SSE__ -#define USE_SSE -#ifdef __AVX__ -#define USE_AVX -#endif -#endif -#endif - -#if defined(USE_AVX) || defined(USE_SSE) -#ifdef _MSC_VER -#include -#include -#else -#include -#endif - -#if defined(__GNUC__) -#define PORTABLE_ALIGN32 __attribute__((aligned(32))) -#else -#define PORTABLE_ALIGN32 __declspec(align(32)) -#endif -#endif - -#include -#include -#include - -#include -#include -#include - -namespace hnswlib_nm { - typedef int64_t labeltype; - - template - class pairGreater { - public: - bool operator()(const T& p1, const T& p2) { - return p1.first > p2.first; - } - }; - - template - static void writeBinaryPOD(std::ostream &out, const T &podRef) { - out.write((char *) &podRef, sizeof(T)); - } - - template - static void readBinaryPOD(std::istream &in, T &podRef) { - in.read((char *) &podRef, sizeof(T)); - } - - template - static void writeBinaryPOD(W &out, const T &podRef) { - out.write((char *) &podRef, sizeof(T)); - } - - template - static void readBinaryPOD(R &in, T &podRef) { - in.read((char *) &podRef, sizeof(T)); - } - - template - using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); - - - template - class SpaceInterface { - public: - //virtual void search(void *); - virtual size_t get_data_size() = 0; - - virtual DISTFUNC get_dist_func() = 0; - - virtual void *get_dist_func_param() = 0; - - virtual ~SpaceInterface() {} - }; - - template - class AlgorithmInterface { - public: - virtual void addPoint(void *datapoint, labeltype label, size_t base, size_t offset)=0; - virtual std::priority_queue> searchKnn_NM(const void *, size_t, const faiss::BitsetView bitset, dist_t *pdata) const = 0; - template - std::vector> searchKnn_NM(const void*, size_t, Comp, const faiss::BitsetView bitset, dist_t *pdata) { - } - virtual void saveIndex(const std::string &location)=0; - virtual ~AlgorithmInterface(){ - } - }; -} - -#include "space_l2.h" -#include "space_ip.h" -#include "bruteforce.h" -#include "hnswalg_nm.h" \ No newline at end of file diff --git a/internal/core/src/index/unittest/test_hnsw.cpp b/internal/core/src/index/unittest/test_hnsw.cpp index 99bd68610c..4fb14f8dbb 100644 --- a/internal/core/src/index/unittest/test_hnsw.cpp +++ b/internal/core/src/index/unittest/test_hnsw.cpp @@ -80,6 +80,20 @@ TEST_P(HNSWTest, HNSW_basic) { auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); ReleaseQueryResult(result); + + // case: k > nb + const int64_t new_rows = 6; + base_dataset->Set(milvus::knowhere::meta::ROWS, new_rows); + index_->Train(base_dataset, conf); + index_->AddWithoutIds(base_dataset, conf); + auto result2 = index_->Query(query_dataset, conf, nullptr); + auto res_ids = result2->Get(milvus::knowhere::meta::IDS); + for (int64_t i = 0; i < nq; i++) { + for (int64_t j = new_rows; j < k; j++) { + ASSERT_EQ(res_ids[i * k + j], -1); + } + } + ReleaseQueryResult(result2); } TEST_P(HNSWTest, HNSW_delete) {