diff --git a/core/src/db/engine/ExecutionEngine.h b/core/src/db/engine/ExecutionEngine.h index d2a0886131..cbd98bcd89 100644 --- a/core/src/db/engine/ExecutionEngine.h +++ b/core/src/db/engine/ExecutionEngine.h @@ -17,12 +17,12 @@ #pragma once -#include "utils/Status.h" - #include #include #include +#include "utils/Status.h" + namespace milvus { namespace engine { @@ -39,7 +39,8 @@ enum class EngineType { SPTAG_BKT, FAISS_BIN_IDMAP, FAISS_BIN_IVFFLAT, - MAX_VALUE = FAISS_BIN_IVFFLAT, + HNSW, + MAX_VALUE = HNSW, }; enum class MetricType { diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 3869bfd790..dcc5756e71 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -16,6 +16,11 @@ // under the License. #include "db/engine/ExecutionEngineImpl.h" + +#include +#include +#include + #include "cache/CpuCacheMgr.h" #include "cache/GpuCacheMgr.h" #include "knowhere/common/Config.h" @@ -33,10 +38,6 @@ #include "wrapper/VecImpl.h" #include "wrapper/VecIndex.h" -#include -#include -#include - //#define ON_SEARCH namespace milvus { namespace engine { @@ -196,6 +197,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { index = GetVecIndexFactory(IndexType::SPTAG_BKT_RNT_CPU); break; } + case EngineType::HNSW: { + index = GetVecIndexFactory(IndexType::HNSW); + break; + } case EngineType::FAISS_BIN_IDMAP: { index = GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP); break; diff --git a/core/src/index/knowhere/CMakeLists.txt b/core/src/index/knowhere/CMakeLists.txt index 7d781a1816..646c0fddf0 100644 --- a/core/src/index/knowhere/CMakeLists.txt +++ b/core/src/index/knowhere/CMakeLists.txt @@ -37,6 +37,7 @@ set(index_srcs knowhere/index/vector_index/IndexBinaryIDMAP.cpp knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp knowhere/index/vector_index/IndexNSG.cpp + knowhere/index/vector_index/IndexHNSW.cpp knowhere/index/vector_index/nsg/NSG.cpp knowhere/index/vector_index/nsg/NSGIO.cpp knowhere/index/vector_index/nsg/NSGHelper.cpp diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp new file mode 100644 index 0000000000..2ac1efec0f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "knowhere/adapter/VectorAdapter.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexHNSW.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +#include "hnswlib/hnswalg.h" +#include "hnswlib/space_ip.h" +#include "hnswlib/space_l2.h" + +namespace knowhere { + +BinarySet +IndexHNSW::Serialize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + index_->saveIndex(writer); + auto data = std::make_shared(); + data.reset(writer.data_); + + BinarySet res_set; + res_set.Append("HNSW", data, writer.total); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW::Load(const BinarySet& index_binary) { + try { + auto binary = index_binary.GetByName("HNSW"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + hnswlib::SpaceInterface* space; + index_ = std::make_shared>(space); + index_->loadIndex(reader); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) { + auto search_cfg = std::dynamic_pointer_cast(config); + if (search_cfg != nullptr) { + search_cfg->CheckValid(); // throw exception + } + + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GETTENSOR(dataset) + using P = std::pair; + auto compare = [](P v1, P v2) { return v1.second < v2.second; }; + std::vector> ret = index_->searchKnn(p_data, search_cfg->k, compare); + + std::vector dist(ret.size()); + std::vector ids(ret.size()); + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return e.first; }); + std::transform(ret.begin(), ret.end(), std::back_inserter(ids), + [](const std::pair& e) { return e.second; }); + + auto elems = rows * search_cfg->k; + assert(elems == ret.size()); + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + memcpy(p_dist, dist.data(), dist.size() * sizeof(float)); + memcpy(p_id, ids.data(), ids.size() * sizeof(int64_t)); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); +} + +IndexModelPtr +IndexHNSW::Train(const DatasetPtr& dataset, const Config& config) { + auto build_cfg = std::dynamic_pointer_cast(config); + if (build_cfg != nullptr) { + build_cfg->CheckValid(); // throw exception + } + + GETTENSOR(dataset) + + hnswlib::SpaceInterface* space; + if (config->metric_type == METRICTYPE::L2) { + space = new hnswlib::L2Space(dim); + } else if (config->metric_type == METRICTYPE::IP) { + space = new hnswlib::InnerProductSpace(dim); + } + index_ = std::make_shared>(space, rows, build_cfg->M, build_cfg->ef); + + return nullptr; +} + +void +IndexHNSW::Add(const DatasetPtr& dataset, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + + GETTENSOR(dataset) + auto p_ids = dataset->Get(meta::IDS); + + for (int i = 0; i < 1; i++) { + index_->addPoint((void*)(p_data + dim * i), p_ids[i]); + } +#pragma omp parallel for + for (int i = 1; i < rows; i++) { + index_->addPoint((void*)(p_data + dim * i), p_ids[i]); + } +} + +void +IndexHNSW::Seal() { + // do nothing +} + +int64_t +IndexHNSW::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->cur_element_count; +} + +int64_t +IndexHNSW::Dimension() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return (*(size_t*)index_->dist_func_param_); +} + +} // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h new file mode 100644 index 0000000000..6afc5d7882 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "hnswlib/hnswlib.h" + +#include "knowhere/index/vector_index/VectorIndex.h" + +namespace knowhere { + +class IndexHNSW : public VectorIndex { + public: + BinarySet + Serialize() override; + + void + Load(const BinarySet& index_binary) override; + + DatasetPtr + Search(const DatasetPtr& dataset, const Config& config) override; + + // void + // set_preprocessor(PreprocessorPtr preprocessor) override; + // + // void + // set_index_model(IndexModelPtr model) override; + // + // PreprocessorPtr + // BuildPreprocessor(const DatasetPtr& dataset, const Config& config) override; + + IndexModelPtr + Train(const DatasetPtr& dataset, const Config& config) override; + + void + Add(const DatasetPtr& dataset, const Config& config) override; + + void + Seal() override; + + int64_t + Count() override; + + int64_t + Dimension() override; + + private: + std::mutex mutex_; + std::shared_ptr> index_; +}; + +} // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h index a7f8f349e1..08ddc3ebe9 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h @@ -28,6 +28,12 @@ struct MemoryIOWriter : public faiss::IOWriter { size_t operator()(const void* ptr, size_t size, size_t nitems) override; + + template + size_t + write(T* ptr, size_t size, size_t nitems = 1) { + operator()((const void*)ptr, size, nitems); + } }; struct MemoryIOReader : public faiss::IOReader { @@ -37,6 +43,12 @@ struct MemoryIOReader : public faiss::IOReader { size_t operator()(void* ptr, size_t size, size_t nitems) override; + + template + size_t + read(T* ptr, size_t size, size_t nitems = 1) { + operator()((void*)ptr, size, nitems); + } }; } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h index 01ac9930fa..b9f5dde7c4 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h @@ -68,6 +68,10 @@ constexpr int64_t DEFAULT_BKTNUMBER = INVALID_VALUE; constexpr int64_t DEFAULT_BKTKMEANSK = INVALID_VALUE; constexpr int64_t DEFAULT_BKTLEAFSIZE = INVALID_VALUE; +// HNSW Config +constexpr int64_t DEFAULT_M = INVALID_VALUE; +constexpr int64_t DEFAULT_EF = INVALID_VALUE; + struct IVFCfg : public Cfg { int64_t nlist = DEFAULT_NLIST; int64_t nprobe = DEFAULT_NPROBE; @@ -242,4 +246,12 @@ struct BinIDMAPCfg : public Cfg { } }; +struct HNSWCfg : public Cfg { + int64_t M = DEFAULT_M; + int64_t ef = DEFAULT_EF; + + HNSWCfg() = default; +}; +using HNSWConfig = std::shared_ptr; + } // namespace knowhere diff --git a/core/src/index/thirdparty/hnswlib/bruteforce.h b/core/src/index/thirdparty/hnswlib/bruteforce.h new file mode 100644 index 0000000000..5b1bd655ac --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/bruteforce.h @@ -0,0 +1,170 @@ +#pragma once +#include +#include +#include +#include + +namespace hnswlib { + template + class BruteforceSearch : public AlgorithmInterface { + public: + BruteforceSearch(SpaceInterface *s) { + + } + BruteforceSearch(SpaceInterface *s, const std::string &location) { + loadIndex(location, s); + } + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + ~BruteforceSearch() { + free(data_); + } + + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + void addPoint(const void *datapoint, labeltype label) { + + int idx; + { + std::unique_lock lock(index_lock); + + + + auto search=dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx=search->second; + } + else{ + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); + } + idx=cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; + } + } + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + + + + + }; + + void removePoint(labeltype cur_external) { + size_t cur_c=dict_external_to_internal[cur_external]; + + dict_external_to_internal.erase(cur_external); + + labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label]=cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k) const { + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + for (int i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + + data_size_)))); + } + dist_t lastdist = topResults.top().first; + for (int i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + + data_size_)))); + if (topResults.size() > k) + topResults.pop(); + lastdist = topResults.top().first; + } + + } + return topResults; + }; + + template + std::vector> + searchKnn(const void* query_data, size_t k, Comp comp) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn(query_data, k); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + + output.write(data_, maxelements_ * size_per_element_); + + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s) { + + + std::ifstream input(location, std::ios::binary); + std::streampos position; + + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + + input.read(data_, maxelements_ * size_per_element_); + + input.close(); + + } + + }; +} diff --git a/core/src/index/thirdparty/hnswlib/hnswalg.h b/core/src/index/thirdparty/hnswlib/hnswalg.h new file mode 100644 index 0000000000..427ec3e607 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswalg.h @@ -0,0 +1,1160 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib.h" +#include +#include +#include +#include + +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace hnswlib { + typedef unsigned int tableint; + typedef unsigned int linklistsizeint; + + template + class HierarchicalNSW : public AlgorithmInterface { + public: + + HierarchicalNSW(SpaceInterface *s) { + } + + HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { + loadIndex(location, s, max_elements); + } + + HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : + link_list_locks_(max_elements), element_levels_(max_elements) { + // linxj + space = s; + if (auto x = dynamic_cast(s)) { + metric_type_ = 0; + } else if (auto x = dynamic_cast(s)) { + metric_type_ = 1; + } else { + metric_type_ = 100; + } + + max_elements_ = max_elements; + + has_deletions_=false; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction,M_); + ef_ = 10; + + level_generator_.seed(random_seed); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + + //initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + struct CompareByFirst { + constexpr bool operator()(std::pair const &a, + std::pair const &b) const noexcept { + return a.first < b.first; + } + }; + + ~HierarchicalNSW() { + + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + + // linxj: delete + delete space; + } + + // linxj: use for free resource + SpaceInterface *space; + size_t metric_type_; // 0:l2, 1:ip + + size_t max_elements_; + size_t cur_element_count; + size_t size_data_per_element_; + size_t size_links_per_element_; + + size_t M_; + size_t maxM_; + size_t maxM0_; + size_t ef_construction_; + + double mult_, revSize_; + int maxlevel_; + + + VisitedListPool *visited_list_pool_; + std::mutex cur_element_count_guard_; + + std::vector link_list_locks_; + tableint enterpoint_node_; + + + size_t size_links_level0_; + size_t offsetData_, offsetLevel0_; + + + char *data_level0_memory_; + char **linkLists_; + std::vector element_levels_; + + size_t data_size_; + + bool has_deletions_; + + + size_t label_offset_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } + + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } + + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } + + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); +// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; + if (!has_deletions || !isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lowerBound) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); +// bool cur_node_deleted = isMarkedDeleted(current_node_id); + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0);//////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + + visited_array[candidate_id] = visited_array_tag; + + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_,/////////// + _MM_HINT_T0);//////////////////////// +#endif + + if (!has_deletions || !isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > ef) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_);; + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + } + + + } + + for (std::pair curent_pair : return_list) { + + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + }; + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + }; + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + }; + + void mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> top_candidates, + int level) { + + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + { + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur,selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); + + + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx]) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + + } + } + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + + } + } + + std::mutex global; + size_t ef_; + + void setEf(size_t ef) { + ef_ = ef; + } + + + std::priority_queue> searchKnnInternal(void *query_data, int k) { + std::priority_queue> top_candidates; + if (cur_element_count == 0) return top_candidates; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (size_t level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + int *data; + data = (int *) get_linklist(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + if (has_deletions_) { + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_); + top_candidates.swap(top_candidates1); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + return top_candidates; + }; + + void resizeIndex(size_t new_max_elements){ + if (new_max_elements(new_max_elements).swap(link_list_locks_); + + + // Reallocate base layer + char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); + free(data_level0_memory_); + data_level0_memory_=data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); + free(linkLists_); + linkLists_=linkLists_new; + + max_elements_=new_max_elements; + + } + + void saveIndex(knowhere::MemoryIOWriter& output) { + // write l2/ip calculator + writeBinaryPOD(output, metric_type_); + writeBinaryPOD(output, data_size_); + writeBinaryPOD(output, *((size_t *) dist_func_param_)); + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } +// output.close(); + } + + void loadIndex(knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { + auto totoal_filesize = input.total; + + // linxj: init with metrictype + size_t dim = 100; + readBinaryPOD(input, metric_type_); + readBinaryPOD(input, data_size_); + readBinaryPOD(input, dim); + if (metric_type_ == 0) { + space = new hnswlib::L2Space(dim); + } else if (metric_type_ == 1) { + space = new hnswlib::InnerProductSpace(dim); + } else { + // throw exception + } + fstdistfunc_ = space->get_dist_func(); + dist_func_param_ = space->get_dist_func_param(); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + +// data_size_ = s->get_data_size(); +// fstdistfunc_ = s->get_dist_func(); +// dist_func_param_ = s->get_dist_func_param(); + +// auto pos= input.rp; + + +// /// Optional - check if index is ok: +// +// input.seekg(cur_element_count * size_data_per_element_,input.cur); +// for (size_t i = 0; i < cur_element_count; i++) { +// if(input.tellg() < 0 || input.tellg()>=total_filesize){ +// throw std::runtime_error("Index seems to be corrupted or unsupported"); +// } +// +// unsigned int linkListSize; +// readBinaryPOD(input, linkListSize); +// if (linkListSize != 0) { +// input.seekg(linkListSize,input.cur); +// } +// } +// +// // throw exception if it either corrupted or old index +// if(input.tellg()!=total_filesize) +// throw std::runtime_error("Index seems to be corrupted or unsupported"); +// +// input.clear(); +// +// /// Optional check end +// +// input.seekg(pos,input.beg); + + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + + + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + return; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { + + + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + + // get file size: + input.seekg(0,input.end); + std::streampos total_filesize=input.tellg(); + input.seekg(0,input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos=input.tellg(); + + + /// Optional - check if index is ok: + + input.seekg(cur_element_count * size_data_per_element_,input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if(input.tellg() < 0 || input.tellg()>=total_filesize){ + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize,input.cur); + } + } + + // throw exception if it either corrupted or old index + if(input.tellg()!=total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + + /// Optional check end + + input.seekg(pos,input.beg); + + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + + + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + input.close(); + + return; + } + + template + std::vector getDataByLabel(labeltype label) + { + tableint label_c; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + label_c = search->second; + + char* data_ptrv = getDataByInternalId(label_c); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + + static const unsigned char DELETE_MARK = 0x01; +// static const unsigned char REUSE_MARK = 0x10; + /** + * Marks an element with the given label deleted, does NOT really change the current graph. + * @param label + */ + void markDelete(labeltype label) + { + has_deletions_=true; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + markDeletedInternal(search->second); + } + + /** + * Uses the first 8 bits of the memory for the linked list to store the mark, + * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. + * @param internalId + */ + void markDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + } + + /** + * Remove the deleted mark of the node. + * @param internalId + */ + void unmarkDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur &= ~DELETE_MARK; + } + + /** + * Checks the first 8 bits of the memory to see if the element is marked deleted. + * @param internalId + * @return + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; + return *ll_cur & DELETE_MARK; + } + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + void addPoint(const void *data_point, labeltype label) { + addPoint(data_point, label,-1); + } + + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + std::unique_lock lock(cur_element_count_guard_); + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + }; + + cur_c = cur_element_count; + cur_element_count++; + + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + std::unique_lock lock_el(link_list_locks_[search->second]); + has_deletions_ = true; + markDeletedInternal(search->second); + } + label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + + + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + + if (curlevel < maxlevelcopy) { + + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + + + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj,level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); + + currObj = top_candidates.top().second; + } + + + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + + } + + //Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + }; + + std::priority_queue> + searchKnn(const void *query_data, size_t k) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (has_deletions_) { + std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + currObj, query_data, std::max(ef_, k)); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + currObj, query_data, std::max(ef_, k)); + top_candidates.swap(top_candidates1); + } + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + }; + + template + std::vector> + searchKnn(const void* query_data, size_t k, Comp comp) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn(query_data, k); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + // TODO(linxj): uncomment + std::sort(result.begin(), result.end(), comp); + + return result; + } + + }; + +} diff --git a/core/src/index/thirdparty/hnswlib/hnswlib.h b/core/src/index/thirdparty/hnswlib/hnswlib.h new file mode 100644 index 0000000000..6089a30b96 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswlib.h @@ -0,0 +1,98 @@ +#pragma once +#ifndef NO_MANUAL_VECTORIZATION +#ifdef __SSE__ +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +#else +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#endif +#endif + +#include +#include + +#include + +namespace hnswlib { + typedef int64_t labeltype; + + template + class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } + }; + + template + static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + static void writeBinaryPOD(W &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(R &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + + + template + class SpaceInterface { + public: + //virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + virtual ~SpaceInterface() {} + }; + + template + class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label)=0; + virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; + template + std::vector> searchKnn(const void*, size_t, Comp) { + } + virtual void saveIndex(const std::string &location)=0; + virtual ~AlgorithmInterface(){ + } + }; + + +} + +#include "space_l2.h" +#include "space_ip.h" +#include "bruteforce.h" +#include "hnswalg.h" diff --git a/core/src/index/thirdparty/hnswlib/space_ip.h b/core/src/index/thirdparty/hnswlib/space_ip.h new file mode 100644 index 0000000000..e94674730c --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/space_ip.h @@ -0,0 +1,248 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + + static float + InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; + } + return (1.0f - res); + + } + +#if defined(USE_AVX) + +// Favor using AVX if available. + static float + InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; + return 1.0f - sum; +} + +#elif defined(USE_SSE) + + static float + InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return 1.0f - sum; + } + +#endif + +#if defined(USE_AVX) + + static float + InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return 1.0f - sum; + } + +#elif defined(USE_SSE) + + static float + InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return 1.0f - sum; + } + +#endif + + class InnerProductSpace : public SpaceInterface { + + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProduct; + #if defined(USE_AVX) || defined(USE_SSE) + if (dim % 4 == 0) + fstdistfunc_ = InnerProductSIMD4Ext; + if (dim % 16 == 0) + fstdistfunc_ = InnerProductSIMD16Ext; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~InnerProductSpace() {} + }; + + +} diff --git a/core/src/index/thirdparty/hnswlib/space_l2.h b/core/src/index/thirdparty/hnswlib/space_l2.h new file mode 100644 index 0000000000..4d3ac69ac4 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/space_l2.h @@ -0,0 +1,244 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + + static float + L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { + //return *((float *)pVect2); + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + float t = ((float *) pVect1)[i] - ((float *) pVect2)[i]; + res += t * t; + } + return (res); + + } + +#if defined(USE_AVX) + + // Favor using AVX if available. + static float + L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return (res); +} + +#elif defined(USE_SSE) + + static float + L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + // size_t qty4 = qty >> 2; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + // const float* pEnd2 = pVect1 + (qty4 << 2); + // const float* pEnd3 = pVect1 + qty; + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); + } +#endif + + +#ifdef USE_SSE + static float + L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + + // size_t qty4 = qty >> 2; + size_t qty16 = qty >> 2; + + const float *pEnd1 = pVect1 + (qty16 << 2); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); + } +#endif + + class L2Space : public SpaceInterface { + + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; + #if defined(USE_SSE) || defined(USE_AVX) + if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + /*else{ + throw runtime_error("Data type not supported!"); + }*/ + #endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2Space() {} + }; + + static int + L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + /*for (int i = 0; i < qty; i++) { + int t = int((a)[i]) - int((b)[i]); + res += t*t; + }*/ + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + + + } + + return (res); + + } + + class L2SpaceI : public SpaceInterface { + + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + L2SpaceI(size_t dim) { + fstdistfunc_ = L2SqrI; + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2SpaceI() {} + }; + + +} diff --git a/core/src/index/thirdparty/hnswlib/visited_list_pool.h b/core/src/index/thirdparty/hnswlib/visited_list_pool.h new file mode 100644 index 0000000000..6b0f445878 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/visited_list_pool.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +namespace hnswlib { + typedef unsigned short int vl_type; + + class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + }; + + ~VisitedList() { delete[] mass; } + }; +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + + class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + }; + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + }; + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + }; + }; +} + diff --git a/core/src/wrapper/ConfAdapter.cpp b/core/src/wrapper/ConfAdapter.cpp index fde3a9bfd6..2c7c1cb9a5 100644 --- a/core/src/wrapper/ConfAdapter.cpp +++ b/core/src/wrapper/ConfAdapter.cpp @@ -16,15 +16,16 @@ // under the License. #include "wrapper/ConfAdapter.h" -#include "WrapperException.h" -#include "knowhere/index/vector_index/helpers/IndexParameter.h" -#include "server/Config.h" -#include "utils/Log.h" #include #include #include +#include "WrapperException.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "server/Config.h" +#include "utils/Log.h" + // TODO(lxj): add conf checker namespace milvus { @@ -266,6 +267,17 @@ SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& return conf; } +knowhere::Config +HNSWConfAdapter::Match(const TempMetaConf& metaconf) { + auto conf = std::make_shared(); + conf->d = metaconf.dim; + conf->metric_type = metaconf.metric_type; + + conf->ef = 100; // ef can be auto-configured by using sample data. + conf->M = 16; // A reasonable range of M is from 5 to 48. + return conf; +} + knowhere::Config BinIDMAPConfAdapter::Match(const TempMetaConf& metaconf) { auto conf = std::make_shared(); diff --git a/core/src/wrapper/ConfAdapter.h b/core/src/wrapper/ConfAdapter.h index 616507bc75..bbb031c2ff 100644 --- a/core/src/wrapper/ConfAdapter.h +++ b/core/src/wrapper/ConfAdapter.h @@ -17,11 +17,11 @@ #pragma once +#include + #include "VecIndex.h" #include "knowhere/common/Config.h" -#include - namespace milvus { namespace engine { @@ -124,5 +124,11 @@ class BinIVFConfAdapter : public IVFConfAdapter { Match(const TempMetaConf& metaconf) override; }; +class HNSWConfAdapter : public ConfAdapter { + public: + knowhere::Config + Match(const TempMetaConf& metaconf) override; +}; + } // namespace engine } // namespace milvus diff --git a/core/src/wrapper/ConfAdapterMgr.cpp b/core/src/wrapper/ConfAdapterMgr.cpp index b0eeb9b5c4..3be6b380f4 100644 --- a/core/src/wrapper/ConfAdapterMgr.cpp +++ b/core/src/wrapper/ConfAdapterMgr.cpp @@ -16,6 +16,7 @@ // under the License. #include "wrapper/ConfAdapterMgr.h" + #include "utils/Exception.h" namespace milvus { @@ -61,6 +62,8 @@ AdapterMgr::RegisterAdapter() { REGISTER_CONF_ADAPTER(SPTAGKDTConfAdapter, IndexType::SPTAG_KDT_RNT_CPU, sptag_kdt); REGISTER_CONF_ADAPTER(SPTAGBKTConfAdapter, IndexType::SPTAG_BKT_RNT_CPU, sptag_bkt); + + REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexType::HNSW, hnsw); } } // namespace engine diff --git a/core/src/wrapper/VecIndex.cpp b/core/src/wrapper/VecIndex.cpp index e86e1e30eb..1b8e7bd489 100644 --- a/core/src/wrapper/VecIndex.cpp +++ b/core/src/wrapper/VecIndex.cpp @@ -16,10 +16,12 @@ // under the License. #include "wrapper/VecIndex.h" + #include "VecImpl.h" #include "knowhere/common/Exception.h" #include "knowhere/index/vector_index/IndexBinaryIDMAP.h" #include "knowhere/index/vector_index/IndexBinaryIVF.h" +#include "knowhere/index/vector_index/IndexHNSW.h" #include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" @@ -38,6 +40,7 @@ #ifdef MILVUS_GPU_VERSION #include + #include "knowhere/index/vector_index/IndexGPUIDMAP.h" #include "knowhere/index/vector_index/IndexGPUIVF.h" #include "knowhere/index/vector_index/IndexGPUIVFPQ.h" @@ -99,6 +102,10 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) { index = std::make_shared(); break; } + case IndexType::HNSW: { + index = std::make_shared(); + break; + } #ifdef MILVUS_GPU_VERSION case IndexType::FAISS_IVFFLAT_GPU: { diff --git a/core/src/wrapper/VecIndex.h b/core/src/wrapper/VecIndex.h index 446b13e77e..ed14212710 100644 --- a/core/src/wrapper/VecIndex.h +++ b/core/src/wrapper/VecIndex.h @@ -50,6 +50,7 @@ enum class IndexType { NSG_MIX, FAISS_IVFPQ_MIX, SPTAG_BKT_RNT_CPU, + HNSW, FAISS_BIN_IDMAP = 100, FAISS_BIN_IVFLAT_CPU = 101, };