From 5437fcce8ebafd12981b54d72cb6f5f11d663b2f Mon Sep 17 00:00:00 2001 From: foxspy Date: Tue, 25 Jul 2023 10:07:01 +0800 Subject: [PATCH] fix range search (#25880) Signed-off-by: xianliang --- .../core/src/segcore/IndexConfigGenerator.cpp | 5 +++ .../core/src/segcore/IndexConfigGenerator.h | 3 ++ internal/core/unittest/test_growing_index.cpp | 35 +++++++++++++++++-- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/internal/core/src/segcore/IndexConfigGenerator.cpp b/internal/core/src/segcore/IndexConfigGenerator.cpp index b903d9b1bc..37f6f9563b 100644 --- a/internal/core/src/segcore/IndexConfigGenerator.cpp +++ b/internal/core/src/segcore/IndexConfigGenerator.cpp @@ -64,6 +64,11 @@ VecIndexConfig::GetSearchConf(const SearchInfo& searchInfo) { SearchInfo searchParam(searchInfo); searchParam.metric_type_ = metric_type_; searchParam.search_params_ = search_params_; + for (auto& key : maintain_params) { + if (searchInfo.search_params_.contains(key)) { + searchParam.search_params_[key] = searchInfo.search_params_[key]; + } + } return searchParam; } diff --git a/internal/core/src/segcore/IndexConfigGenerator.h b/internal/core/src/segcore/IndexConfigGenerator.h index 7c146ff224..cf31fc63a5 100644 --- a/internal/core/src/segcore/IndexConfigGenerator.h +++ b/internal/core/src/segcore/IndexConfigGenerator.h @@ -33,6 +33,9 @@ class VecIndexConfig { inline static const std::map index_build_ratio = { {knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, 0.1}}; + inline static const std::unordered_set maintain_params = { + "radius", "range_filter"}; + public: VecIndexConfig(const int64_t max_index_row_count, const FieldIndexMeta& index_meta_, diff --git a/internal/core/unittest/test_growing_index.cpp b/internal/core/unittest/test_growing_index.cpp index 9e5fab5a21..7a3a6353b7 100644 --- a/internal/core/unittest/test_growing_index.cpp +++ b/internal/core/unittest/test_growing_index.cpp @@ -75,6 +75,19 @@ TEST(GrowingIndex, Correctness) { query_info->set_search_params(R"({"nprobe": 16})"); auto plan_str = plan_node.SerializeAsString(); + milvus::proto::plan::PlanNode range_query_plan_node; + auto vector_range_querys = range_query_plan_node.mutable_vector_anns(); + vector_range_querys->set_is_binary(false); + vector_range_querys->set_placeholder_tag("$0"); + vector_range_querys->set_field_id(102); + auto range_query_info = vector_range_querys->mutable_query_info(); + range_query_info->set_topk(5); + range_query_info->set_round_decimal(3); + range_query_info->set_metric_type("l2"); + range_query_info->set_search_params( + R"({"nprobe": 10, "radius": 600, "range_filter": 500})"); + auto range_plan_str = range_query_plan_node.SerializeAsString(); + int64_t per_batch = 10000; int64_t n_batch = 20; int64_t top_k = 5; @@ -99,10 +112,11 @@ TEST(GrowingIndex, Correctness) { EXPECT_EQ(filed_data->num_chunk(), 0); } - auto plan = milvus::query::CreateSearchPlanByExpr( - *schema, plan_str.data(), plan_str.size()); auto num_queries = 5; auto ph_group_raw = CreatePlaceholderGroup(num_queries, 128, 1024); + + auto plan = milvus::query::CreateSearchPlanByExpr( + *schema, plan_str.data(), plan_str.size()); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); Timestamp time = 1000000; @@ -111,6 +125,23 @@ TEST(GrowingIndex, Correctness) { EXPECT_EQ(sr->unity_topK_, top_k); EXPECT_EQ(sr->distances_.size(), num_queries * top_k); EXPECT_EQ(sr->seg_offsets_.size(), num_queries * top_k); + + auto range_plan = milvus::query::CreateSearchPlanByExpr( + *schema, range_plan_str.data(), range_plan_str.size()); + auto range_ph_group = ParsePlaceholderGroup( + range_plan.get(), ph_group_raw.SerializeAsString()); + auto range_sr = + segment->Search(range_plan.get(), range_ph_group.get(), time); + ASSERT_EQ(range_sr->total_nq_, num_queries); + EXPECT_EQ(sr->unity_topK_, top_k); + EXPECT_EQ(sr->distances_.size(), num_queries * top_k); + EXPECT_EQ(sr->seg_offsets_.size(), num_queries * top_k); + for (int j = 0; j < range_sr->seg_offsets_.size(); j++) { + if (range_sr->seg_offsets_[j] != -1) { + EXPECT_TRUE(sr->distances_[j] >= 500.0 && + sr->distances_[j] <= 600.0); + } + } } }