From ab03521588ce63e26a9d0f41b6437f3ab9ed527d Mon Sep 17 00:00:00 2001 From: foxspy Date: Tue, 23 Dec 2025 10:33:17 +0800 Subject: [PATCH] fix: fix chunk iterator merge order (#46461) issue: #46349 When using brute-force search, the iterator results from multiple chunks are merged; at that point, we need to pay attention to how the metric affects result ranking. Signed-off-by: xianliang.li --- internal/core/src/common/QueryResult.h | 29 +++++++-- internal/core/src/exec/operator/Utils.h | 4 +- internal/core/src/query/SearchOnGrowing.cpp | 8 ++- internal/core/src/query/SearchOnSealed.cpp | 5 +- .../test_milvus_client_search_group_by.py | 65 +++++++++++++++++++ 5 files changed, 103 insertions(+), 8 deletions(-) diff --git a/internal/core/src/common/QueryResult.h b/internal/core/src/common/QueryResult.h index 024d5ad069..0e49ce6f3f 100644 --- a/internal/core/src/common/QueryResult.h +++ b/internal/core/src/common/QueryResult.h @@ -114,11 +114,25 @@ struct OffsetDisPair { }; struct OffsetDisPairComparator { + bool larger_is_closer_ = false; + + OffsetDisPairComparator(bool larger_is_closer = false) + : larger_is_closer_(larger_is_closer) { + } + bool operator()(const std::shared_ptr& left, const std::shared_ptr& right) const { + // For priority_queue: return true if left has lower priority than right + // We want the element with better (closer) distance at the top if (left->GetOffDis().second != right->GetOffDis().second) { - return left->GetOffDis().second < right->GetOffDis().second; + if (larger_is_closer_) { + // IP/Cosine: larger distance is better, smaller has lower priority + return left->GetOffDis().second < right->GetOffDis().second; + } else { + // L2: smaller distance is better, larger has lower priority + return left->GetOffDis().second > right->GetOffDis().second; + } } return left->GetOffDis().first < right->GetOffDis().first; } @@ -142,8 +156,11 @@ class VectorIterator { class ChunkMergeIterator : public VectorIterator { public: ChunkMergeIterator(int chunk_count, - const std::vector& total_rows_until_chunk = {}) - : total_rows_until_chunk_(total_rows_until_chunk) { + const std::vector& total_rows_until_chunk = {}, + bool larger_is_closer = false) + : total_rows_until_chunk_(total_rows_until_chunk), + larger_is_closer_(larger_is_closer), + heap_(OffsetDisPairComparator(larger_is_closer)) { iterators_.reserve(chunk_count); } @@ -215,6 +232,7 @@ class ChunkMergeIterator : public VectorIterator { heap_; bool sealed = false; std::vector total_rows_until_chunk_; + bool larger_is_closer_ = false; //currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future //we may need to add mutex to protect the variable sealed }; @@ -239,7 +257,8 @@ struct SearchResult { int64_t nq, int chunk_count, const std::vector& total_rows_until_chunk, - const std::vector& kw_iterators) { + const std::vector& kw_iterators, + bool larger_is_closer = false) { AssertInfo(kw_iterators.size() == nq * chunk_count, "kw_iterators count:{} is not equal to nq*chunk_count:{}, " "wrong state", @@ -251,7 +270,7 @@ struct SearchResult { vec_iter_idx = vec_iter_idx % nq; if (vector_iterators.size() < nq) { auto chunk_merge_iter = std::make_shared( - chunk_count, total_rows_until_chunk); + chunk_count, total_rows_until_chunk, larger_is_closer); vector_iterators.emplace_back(chunk_merge_iter); } const auto& kw_iterator = kw_iterators[i]; diff --git a/internal/core/src/exec/operator/Utils.h b/internal/core/src/exec/operator/Utils.h index ab5b957c77..a840e2d180 100644 --- a/internal/core/src/exec/operator/Utils.h +++ b/internal/core/src/exec/operator/Utils.h @@ -54,8 +54,10 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info, iterators_val = index.VectorIterators(dataset, search_conf, bitset); if (iterators_val.has_value()) { + bool larger_is_closer = + PositivelyRelated(search_info.metric_type_); search_result.AssembleChunkVectorIterators( - nq, 1, {0}, iterators_val.value()); + nq, 1, {0}, iterators_val.value(), larger_is_closer); } else { std::string operator_type = ""; if (search_info.group_by_field_id_.has_value()) { diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 11877c8437..b306542c5a 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -13,6 +13,7 @@ #include "common/QueryInfo.h" #include "common/Tracer.h" #include "common/Types.h" +#include "common/Utils.h" #include "SearchOnGrowing.h" #include #include "knowhere/comp/index_param.h" @@ -279,8 +280,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, for (int i = 1; i < max_chunk; ++i) { chunk_rows[i] = i * vec_size_per_chunk; } + bool larger_is_closer = PositivelyRelated(info.metric_type_); search_result.AssembleChunkVectorIterators( - num_queries, max_chunk, chunk_rows, final_qr.chunk_iterators()); + num_queries, + max_chunk, + chunk_rows, + final_qr.chunk_iterators(), + larger_is_closer); } else { if (info.array_offsets_ != nullptr) { auto [seg_offsets, elem_indicies] = diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index ded92fadbe..cfde506887 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -18,6 +18,7 @@ #include "common/BitsetView.h" #include "common/QueryInfo.h" #include "common/Types.h" +#include "common/Utils.h" #include "query/CachedSearchIterator.h" #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" @@ -237,10 +238,12 @@ SearchOnSealedColumn(const Schema& schema, offset += chunk_size; } if (milvus::exec::UseVectorIterator(search_info)) { + bool larger_is_closer = PositivelyRelated(search_info.metric_type_); result.AssembleChunkVectorIterators(num_queries, num_chunk, column->GetNumRowsUntilChunk(), - final_qr.chunk_iterators()); + final_qr.chunk_iterators(), + larger_is_closer); } else { if (search_info.array_offsets_ != nullptr) { auto [seg_offsets, elem_indicies] = diff --git a/tests/python_client/milvus_client_v2/test_milvus_client_search_group_by.py b/tests/python_client/milvus_client_v2/test_milvus_client_search_group_by.py index 2f12484bdf..23e0e0db0e 100644 --- a/tests/python_client/milvus_client_v2/test_milvus_client_search_group_by.py +++ b/tests/python_client/milvus_client_v2/test_milvus_client_search_group_by.py @@ -357,3 +357,68 @@ class TestSearchGroupBy(TestcaseBase): output_fields=[grpby_field], check_task=CheckTasks.err_res, check_items={"err_code": err_code, "err_msg": err_msg}) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("metric", ["L2", "IP", "COSINE"]) + def test_search_group_by_flat_index_correctness(self, metric): + """ + target: test search group by with FLAT index returns correct results + method: 1. create a collection with FLAT index + 2. insert data with group_by field having multiple values per group + 3. search with and without group_by + 4. verify group_by search returns the best result (same as normal search top-1) + expected: The top result from group_by search should match the top result from normal search + for the same group + issue: https://github.com/milvus-io/milvus/issues/46349 + """ + collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False, + is_all_data_type=True, with_json=False)[0] + # create FLAT index + _index = {"index_type": "FLAT", "metric_type": metric, "params": {}} + collection_w.create_index(ct.default_float_vec_field_name, index_params=_index) + + # insert data with 10 different group values, 100 records per group + for _ in range(10): + data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False) + collection_w.insert(data) + + collection_w.flush() + collection_w.load() + + nq = 1 + limit = 1 + search_vectors = cf.gen_vectors(nq, dim=ct.default_dim) + grpby_field = ct.default_int32_field_name + search_params = {"metric_type": metric, "params": {}} + + # normal search to get the best result + normal_res = collection_w.search(search_vectors, ct.default_float_vec_field_name, + search_params, limit, + output_fields=[grpby_field])[0] + + # group_by search + groupby_res = collection_w.search(search_vectors, ct.default_float_vec_field_name, + search_params, limit, + group_by_field=grpby_field, + output_fields=[grpby_field])[0] + + # verify that the top result from group_by search matches the normal search + # for the same group value, group_by should return the best (closest) result + normal_top_distance = normal_res[0][0].distance + normal_top_group = normal_res[0][0].entity.get(grpby_field) + groupby_top_distance = groupby_res[0][0].distance + groupby_top_group = groupby_res[0][0].entity.get(grpby_field) + + log.info(f"Normal search top result: distance={normal_top_distance}, group={normal_top_group}") + log.info(f"GroupBy search top result: distance={groupby_top_distance}, group={groupby_top_group}") + + # The group_by result should have the same or better distance as normal search + # because group_by returns the best result per group + if metric == "L2": + # For L2, smaller is better + assert groupby_top_distance <= normal_top_distance + epsilon, \ + f"GroupBy search should return result with distance <= normal search for L2 metric" + else: + # For IP/COSINE, larger is better + assert groupby_top_distance >= normal_top_distance - epsilon, \ + f"GroupBy search should return result with distance >= normal search for {metric} metric"