diff --git a/CHANGELOG.md b/CHANGELOG.md index a20f36ec11..2511c98150 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Please mark all change in change log and use the issue from GitHub ## Feature ## Improvement +- \#3213 Allow users to specify a distance type at runtime for Flat index ## Task diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index cbc7aa9194..fee486c532 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -1158,7 +1158,10 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvu } milvus::json conf = extra_params; + if (conf.contains(knowhere::Metric::TYPE)) + MappingMetricType(conf[knowhere::Metric::TYPE], conf); conf[knowhere::meta::TOPK] = k; + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type()); if (!adapter->CheckSearch(conf, index_->index_type(), index_->index_mode())) { LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Illegal search params", "search", 0); @@ -1197,6 +1200,8 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil } milvus::json conf = extra_params; + if (conf.contains(knowhere::Metric::TYPE)) + MappingMetricType(conf[knowhere::Metric::TYPE], conf); conf[knowhere::meta::TOPK] = k; auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_->index_type()); if (!adapter->CheckSearch(conf, index_->index_type(), index_->index_mode())) { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index 056a955abe..5a3ebb923a 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -53,8 +53,8 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = (int64_t*)malloc(p_id_size); auto p_dist = (float*)malloc(p_dist_size); - QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config()); - + // QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config()); + QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config); auto ret_ds = std::make_shared(); if (index_->metric_type == faiss::METRIC_Hamming) { auto pf_dist = (float*)malloc(p_dist_size); @@ -214,7 +214,13 @@ void BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config) { int32_t* pdistances = (int32_t*)distances; + + auto flat_index = dynamic_cast(index_.get())->index; + auto default_type = flat_index->metric_type; + if (config.contains(Metric::TYPE)) + flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_); + flat_index->metric_type = default_type; } } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp index 612de45547..cf09eeaccb 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp @@ -105,8 +105,8 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = (int64_t*)malloc(p_id_size); auto p_dist = (float*)malloc(p_dist_size); - QueryImpl(rows, (float*)p_data, k, p_dist, p_id, Config()); - + // QueryImpl(rows, (float*)p_data, k, p_dist, p_id, Config()); + QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config); auto ret_ds = std::make_shared(); ret_ds->Set(meta::IDS, p_id); ret_ds->Set(meta::DISTANCE, p_dist); @@ -221,7 +221,12 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { void IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + auto flat_index = dynamic_cast(index_.get())->index; + auto default_type = flat_index->metric_type; + if (config.contains(Metric::TYPE)) + flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); index_->search(n, (float*)data, k, distances, labels, bitset_); + flat_index->metric_type = default_type; } } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp index 5bcc288c32..f91afa19a6 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp @@ -17,6 +17,7 @@ #include #endif #include +#include #include "knowhere/common/Exception.h" #include "knowhere/index/vector_index/IndexIDMAP.h" @@ -105,7 +106,13 @@ GPUIDMAP::GetRawIds() { void GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { ResScope rs(res_, gpu_id_); + + auto flat_index = dynamic_cast(index_.get())->index; + auto default_type = flat_index->metric_type; + if (config.contains(Metric::TYPE)) + flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); index_->search(n, (float*)data, k, distances, labels, bitset_); + flat_index->metric_type = default_type; } void diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index e9dd032f1e..521c81f179 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -22,6 +22,7 @@ #include "db/Utils.h" #include "db/engine/EngineFactory.h" +#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h" #include "metrics/Metrics.h" #include "scheduler/SchedInst.h" #include "scheduler/job/SearchJob.h" @@ -122,6 +123,9 @@ XSearchTask::XSearchTask(const std::shared_ptr& context, Segmen milvus::json json_params; if (!file_->index_params_.empty()) { json_params = milvus::json::parse(file_->index_params_); + if (json_params.contains(knowhere::Metric::TYPE) && + (engine_type == EngineType::FAISS_BIN_IDMAP || engine_type == EngineType::FAISS_IDMAP)) + ascending_reduce = json_params[knowhere::Metric::TYPE] != static_cast(MetricType::IP); } // if (auto job = job_.lock()) { // auto search_job = std::static_pointer_cast(job);