From 2e198e27837b067570d0018ebafcb42e713e2377 Mon Sep 17 00:00:00 2001 From: dragondriver Date: Wed, 30 Dec 2020 19:37:45 +0800 Subject: [PATCH] Add support for idmap, nsg, sq8 Signed-off-by: dragondriver --- .../core/src/indexbuilder/IndexWrapper.cpp | 67 +++++++++++++++---- internal/core/src/indexbuilder/IndexWrapper.h | 3 + internal/core/src/indexbuilder/utils.h | 39 +++++++++-- internal/core/unittest/test_index_wrapper.cpp | 57 ++++++++++++---- 4 files changed, 137 insertions(+), 29 deletions(-) diff --git a/internal/core/src/indexbuilder/IndexWrapper.cpp b/internal/core/src/indexbuilder/IndexWrapper.cpp index ef96d7f5b1..7938db060c 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.cpp +++ b/internal/core/src/indexbuilder/IndexWrapper.cpp @@ -105,6 +105,32 @@ IndexWrapper::parse() { config_[milvus::knowhere::IndexParams::m] = std::stoi(m); } + /************************** NSG Parameter **************************/ + if (!config_.contains(milvus::knowhere::IndexParams::knng)) { + } else { + auto knng = config_[milvus::knowhere::IndexParams::knng].get(); + config_[milvus::knowhere::IndexParams::knng] = std::stoi(knng); + } + + if (!config_.contains(milvus::knowhere::IndexParams::search_length)) { + } else { + auto search_length = config_[milvus::knowhere::IndexParams::search_length].get(); + config_[milvus::knowhere::IndexParams::search_length] = std::stoi(search_length); + } + + if (!config_.contains(milvus::knowhere::IndexParams::out_degree)) { + } else { + auto out_degree = config_[milvus::knowhere::IndexParams::out_degree].get(); + config_[milvus::knowhere::IndexParams::out_degree] = std::stoi(out_degree); + } + + if (!config_.contains(milvus::knowhere::IndexParams::candidate)) { + } else { + auto candidate = config_[milvus::knowhere::IndexParams::candidate].get(); + config_[milvus::knowhere::IndexParams::candidate] = std::stoi(candidate); + } + + /************************** Serialize *******************************/ if (!config_.contains(milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) { config_[milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE] = 4; } else { @@ -132,12 +158,37 @@ IndexWrapper::dim() { void IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) { auto index_type = get_index_type(); - // if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { - // PanicInfo(std::string(index_type) + " doesn't support build without ids yet!"); + if (is_in_need_id_list(index_type)) { + PanicInfo(std::string(index_type) + " doesn't support build without ids yet!"); + } + // if (is_in_need_build_all_list(index_type)) { + // index_->BuildAll(dataset, config_); + // } else { + // index_->Train(dataset, config_); + // index_->AddWithoutIds(dataset, config_); // } - index_->Train(dataset, config_); - index_->AddWithoutIds(dataset, config_); + index_->BuildAll(dataset, config_); + if (is_in_nm_list(index_type)) { + StoreRawData(dataset); + } +} + +void +IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) { + Assert(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end()); + // index_->Train(dataset, config_); + // index_->Add(dataset, config_); + index_->BuildAll(dataset, config_); + + if (is_in_nm_list(get_index_type())) { + StoreRawData(dataset); + } +} + +void +IndexWrapper::StoreRawData(const knowhere::DatasetPtr& dataset) { + auto index_type = get_index_type(); if (is_in_nm_list(index_type)) { auto tensor = dataset->Get(milvus::knowhere::meta::TENSOR); auto row_num = dataset->Get(milvus::knowhere::meta::ROWS); @@ -153,14 +204,6 @@ IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) { } } -void -IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) { - Assert(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end()); - // index_->Train(dataset, config_); - // index_->Add(dataset, config_); - index_->BuildAll(dataset, config_); -} - /* * brief Return serialized binary set */ diff --git a/internal/core/src/indexbuilder/IndexWrapper.h b/internal/core/src/indexbuilder/IndexWrapper.h index 38630aa70b..562ef77f1d 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.h +++ b/internal/core/src/indexbuilder/IndexWrapper.h @@ -49,6 +49,9 @@ class IndexWrapper { std::optional get_config_by_name(std::string name); + void + StoreRawData(const knowhere::DatasetPtr& dataset); + public: void BuildWithIds(const knowhere::DatasetPtr& dataset); diff --git a/internal/core/src/indexbuilder/utils.h b/internal/core/src/indexbuilder/utils.h index 06389faf21..7e40e6283a 100644 --- a/internal/core/src/indexbuilder/utils.h +++ b/internal/core/src/indexbuilder/utils.h @@ -36,16 +36,47 @@ BIN_List() { return ret; } +std::vector +Need_ID_List() { + static std::vector ret{ + // milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, + // milvus::knowhere::IndexEnum::INDEX_NSG + }; + + return ret; +} + +std::vector +Need_BuildAll_list() { + static std::vector ret{milvus::knowhere::IndexEnum::INDEX_NSG}; + return ret; +} + +template +bool +is_in_list(const T& t, std::function()> list_func) { + auto l = list_func(); + return std::find(l.begin(), l.end(), t) != l.end(); +} + bool is_in_bin_list(const milvus::knowhere::IndexType& index_type) { - auto bin_list = BIN_List(); - return std::find(bin_list.begin(), bin_list.end(), index_type) != bin_list.end(); + return is_in_list(index_type, BIN_List); } bool is_in_nm_list(const milvus::knowhere::IndexType& index_type) { - auto nm_list = NM_List(); - return std::find(nm_list.begin(), nm_list.end(), index_type) != nm_list.end(); + return is_in_list(index_type, NM_List); +} + +bool +is_in_need_build_all_list(const milvus::knowhere::IndexType& index_type) { + return is_in_list(index_type, Need_BuildAll_list); +} + +bool +is_in_need_id_list(const milvus::knowhere::IndexType& index_type) { + return is_in_list(index_type, Need_ID_List); } } // namespace indexbuilder diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp index 9064619759..50aed377d3 100644 --- a/internal/core/unittest/test_index_wrapper.cpp +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -38,7 +38,14 @@ int DEVICEID = 0; namespace { auto generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type) { - if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { + if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IDMAP) { + return milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + // {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::Metric::TYPE, metric_type}, + {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, + }; + } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, // {milvus::knowhere::meta::TOPK, K}, @@ -55,7 +62,20 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::IndexParams::nlist, 100}, // {milvus::knowhere::IndexParams::nprobe, 4}, - {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::Metric::TYPE, metric_type}, + {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, +#ifdef MILVUS_GPU_VERSION + {milvus::knowhere::meta::DEVICEID, DEVICEID}, +#endif + }; + } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) { + return milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + // {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::IndexParams::nlist, 100}, + // {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, metric_type}, {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, #ifdef MILVUS_GPU_VERSION {milvus::knowhere::meta::DEVICEID, DEVICEID}, @@ -78,6 +98,15 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::Metric::TYPE, metric_type}, }; + } else if (index_type == milvus::knowhere::IndexEnum::INDEX_NSG) { + return milvus::knowhere::Config{{milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::IndexParams::nlist, 163}, + {milvus::knowhere::IndexParams::nprobe, 8}, + {milvus::knowhere::IndexParams::knng, 20}, + {milvus::knowhere::IndexParams::search_length, 40}, + {milvus::knowhere::IndexParams::out_degree, 30}, + {milvus::knowhere::IndexParams::candidate, 100}, + {milvus::knowhere::Metric::TYPE, metric_type}}; } return milvus::knowhere::Config(); } @@ -142,11 +171,6 @@ class IndexWrapperTest : public ::testing::TestWithParam { if (!is_binary) { xb_data = dataset.get_col(0); xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data()); - } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { - xb_bin_data = dataset.get_col(0); - ids.resize(NB); - std::iota(ids.begin(), ids.end(), 0); - xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_bin_data.data()); } else { xb_bin_data = dataset.get_col(0); xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_bin_data.data()); @@ -335,11 +359,14 @@ TEST(BinIdMapWrapper, Build) { INSTANTIATE_TEST_CASE_P( IndexTypeParameters, IndexWrapperTest, - ::testing::Values( - std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::Metric::L2), - std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, milvus::knowhere::Metric::L2), - std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, milvus::knowhere::Metric::JACCARD), - std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, milvus::knowhere::Metric::JACCARD))); + ::testing::Values(std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_IDMAP, milvus::knowhere::Metric::L2), + std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::Metric::L2), + std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, milvus::knowhere::Metric::L2), + std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, milvus::knowhere::Metric::L2), + std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, + milvus::knowhere::Metric::JACCARD), + std::pair(milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, milvus::knowhere::Metric::JACCARD), + std::pair(milvus::knowhere::IndexEnum::INDEX_NSG, milvus::knowhere::Metric::L2))); TEST_P(IndexWrapperTest, Constructor) { auto index = @@ -357,7 +384,11 @@ TEST_P(IndexWrapperTest, BuildWithoutIds) { auto index = std::make_unique(type_params_str.c_str(), index_params_str.c_str()); - ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset)); + if (milvus::indexbuilder::is_in_need_id_list(index_type)) { + ASSERT_ANY_THROW(index->BuildWithoutIds(xb_dataset)); + } else { + ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset)); + } } TEST_P(IndexWrapperTest, Codec) {