From 186f36c79413e31acfed18166d458bc1c0f1d603 Mon Sep 17 00:00:00 2001 From: "shengjun.li" Date: Fri, 7 Aug 2020 11:56:01 +0800 Subject: [PATCH] fix metric type (#3158) * fix metric type Signed-off-by: shengjun.li * fix config Signed-off-by: shengjun.li * fix query Signed-off-by: shengjun.li --- .../knowhere/index/vector_index/IndexBinaryIDMAP.cpp | 10 ++++++++-- .../knowhere/index/vector_index/IndexIDMAP.cpp | 9 +++++++-- core/src/segment/SegmentReader.cpp | 10 +++++++--- core/unittest/db/test_db.cpp | 3 ++- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index 5e3531a51b..5a3ad0a5a7 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = (int64_t*)malloc(p_id_size); auto p_dist = (float*)malloc(p_dist_size); - QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config()); + QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config); auto ret_ds = std::make_shared(); if (index_->metric_type == faiss::METRIC_Hamming) { @@ -142,9 +142,12 @@ BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { void BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) { + // users will assign the metric type when querying + // so we let Tanimoto be the default type + faiss::MetricType metric_type = faiss::METRIC_Tanimoto; + const char* desc = "BFlat"; int64_t dim = config[meta::DIM].get(); - faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); auto index = faiss::index_binary_factory(dim, desc, metric_type); index_.reset(index); } @@ -213,6 +216,9 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) void BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + // assign the metric type + index_->metric_type = GetMetricType(config[Metric::TYPE].get()); + int32_t* pdistances = (int32_t*)distances; index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_); } diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp index 01af862115..0bf5178d06 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp @@ -54,9 +54,12 @@ IDMAP::Load(const BinarySet& binary_set) { void IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) { + // users will assign the metric type when querying + // so we let L2 be the default type + faiss::MetricType metric_type = faiss::METRIC_L2; + const char* desc = "IDMap,Flat"; int64_t dim = config[meta::DIM].get(); - faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); auto index = faiss::index_factory(dim, desc, metric_type); index_.reset(index); } @@ -105,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = (int64_t*)malloc(p_id_size); auto p_dist = (float*)malloc(p_dist_size); - QueryImpl(rows, (float*)p_data, k, p_dist, p_id, Config()); + QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config); auto ret_ds = std::make_shared(); ret_ds->Set(meta::IDS, p_id); @@ -221,6 +224,8 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { void IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + // assign the metric type + index_->metric_type = GetMetricType(config[Metric::TYPE].get()); index_->search(n, (float*)data, k, distances, labels, bitset_); } diff --git a/core/src/segment/SegmentReader.cpp b/core/src/segment/SegmentReader.cpp index a8542af181..aaaae5c496 100644 --- a/core/src/segment/SegmentReader.cpp +++ b/core/src/segment/SegmentReader.cpp @@ -303,10 +303,14 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex // construct IDMAP index knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance(); - index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP, - knowhere::IndexMode::MODE_CPU); + if (field->GetFtype() == engine::DataType::VECTOR_FLOAT) { + index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP, + knowhere::IndexMode::MODE_CPU); + } else { + index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, + knowhere::IndexMode::MODE_CPU); + } milvus::json conf{{knowhere::meta::DIM, dimension}}; - conf[engine::PARAM_INDEX_METRIC_TYPE] = knowhere::Metric::L2; index_ptr->Train(knowhere::DatasetPtr(), conf); index_ptr->AddWithoutIds(dataset, conf); index_ptr->SetUids(uids); diff --git a/core/unittest/db/test_db.cpp b/core/unittest/db/test_db.cpp index bdad2b2c76..dec13f96a8 100644 --- a/core/unittest/db/test_db.cpp +++ b/core/unittest/db/test_db.cpp @@ -199,7 +199,8 @@ BuildQueryPtr(const std::string& collection_name, int64_t n, int64_t topk, std:: vector_record.float_data[COLLECTION_DIM * i] += i / 2000.; } vector_query->query_vector = vector_record; - vector_query->extra_params = {{"metric_type", "L2"}, {"nprobe", 1024}}; + vector_query->metric_type = "L2"; + vector_query->extra_params = {{"nprobe", 1024}}; query_ptr->root = general_query; query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));