fix: fix chunk iterator merge order (#46461)

issue: #46349 
When using brute-force search, the iterator results from multiple chunks
are merged; at that point, we need to pay attention to how the metric
affects result ranking.

Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
foxspy 2025-12-23 10:33:17 +08:00 committed by GitHub
parent 1a7ca339a5
commit ab03521588
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 8 deletions

View File

@ -114,11 +114,25 @@ struct OffsetDisPair {
};
struct OffsetDisPairComparator {
bool larger_is_closer_ = false;
OffsetDisPairComparator(bool larger_is_closer = false)
: larger_is_closer_(larger_is_closer) {
}
bool
operator()(const std::shared_ptr<OffsetDisPair>& left,
const std::shared_ptr<OffsetDisPair>& right) const {
// For priority_queue: return true if left has lower priority than right
// We want the element with better (closer) distance at the top
if (left->GetOffDis().second != right->GetOffDis().second) {
return left->GetOffDis().second < right->GetOffDis().second;
if (larger_is_closer_) {
// IP/Cosine: larger distance is better, smaller has lower priority
return left->GetOffDis().second < right->GetOffDis().second;
} else {
// L2: smaller distance is better, larger has lower priority
return left->GetOffDis().second > right->GetOffDis().second;
}
}
return left->GetOffDis().first < right->GetOffDis().first;
}
@ -142,8 +156,11 @@ class VectorIterator {
class ChunkMergeIterator : public VectorIterator {
public:
ChunkMergeIterator(int chunk_count,
const std::vector<int64_t>& total_rows_until_chunk = {})
: total_rows_until_chunk_(total_rows_until_chunk) {
const std::vector<int64_t>& total_rows_until_chunk = {},
bool larger_is_closer = false)
: total_rows_until_chunk_(total_rows_until_chunk),
larger_is_closer_(larger_is_closer),
heap_(OffsetDisPairComparator(larger_is_closer)) {
iterators_.reserve(chunk_count);
}
@ -215,6 +232,7 @@ class ChunkMergeIterator : public VectorIterator {
heap_;
bool sealed = false;
std::vector<int64_t> total_rows_until_chunk_;
bool larger_is_closer_ = false;
//currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
//we may need to add mutex to protect the variable sealed
};
@ -239,7 +257,8 @@ struct SearchResult {
int64_t nq,
int chunk_count,
const std::vector<int64_t>& total_rows_until_chunk,
const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators) {
const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators,
bool larger_is_closer = false) {
AssertInfo(kw_iterators.size() == nq * chunk_count,
"kw_iterators count:{} is not equal to nq*chunk_count:{}, "
"wrong state",
@ -251,7 +270,7 @@ struct SearchResult {
vec_iter_idx = vec_iter_idx % nq;
if (vector_iterators.size() < nq) {
auto chunk_merge_iter = std::make_shared<ChunkMergeIterator>(
chunk_count, total_rows_until_chunk);
chunk_count, total_rows_until_chunk, larger_is_closer);
vector_iterators.emplace_back(chunk_merge_iter);
}
const auto& kw_iterator = kw_iterators[i];

View File

@ -54,8 +54,10 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info,
iterators_val =
index.VectorIterators(dataset, search_conf, bitset);
if (iterators_val.has_value()) {
bool larger_is_closer =
PositivelyRelated(search_info.metric_type_);
search_result.AssembleChunkVectorIterators(
nq, 1, {0}, iterators_val.value());
nq, 1, {0}, iterators_val.value(), larger_is_closer);
} else {
std::string operator_type = "";
if (search_info.group_by_field_id_.has_value()) {

View File

@ -13,6 +13,7 @@
#include "common/QueryInfo.h"
#include "common/Tracer.h"
#include "common/Types.h"
#include "common/Utils.h"
#include "SearchOnGrowing.h"
#include <cstddef>
#include "knowhere/comp/index_param.h"
@ -279,8 +280,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
for (int i = 1; i < max_chunk; ++i) {
chunk_rows[i] = i * vec_size_per_chunk;
}
bool larger_is_closer = PositivelyRelated(info.metric_type_);
search_result.AssembleChunkVectorIterators(
num_queries, max_chunk, chunk_rows, final_qr.chunk_iterators());
num_queries,
max_chunk,
chunk_rows,
final_qr.chunk_iterators(),
larger_is_closer);
} else {
if (info.array_offsets_ != nullptr) {
auto [seg_offsets, elem_indicies] =

View File

@ -18,6 +18,7 @@
#include "common/BitsetView.h"
#include "common/QueryInfo.h"
#include "common/Types.h"
#include "common/Utils.h"
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h"
@ -237,10 +238,12 @@ SearchOnSealedColumn(const Schema& schema,
offset += chunk_size;
}
if (milvus::exec::UseVectorIterator(search_info)) {
bool larger_is_closer = PositivelyRelated(search_info.metric_type_);
result.AssembleChunkVectorIterators(num_queries,
num_chunk,
column->GetNumRowsUntilChunk(),
final_qr.chunk_iterators());
final_qr.chunk_iterators(),
larger_is_closer);
} else {
if (search_info.array_offsets_ != nullptr) {
auto [seg_offsets, elem_indicies] =

View File

@ -357,3 +357,68 @@ class TestSearchGroupBy(TestcaseBase):
output_fields=[grpby_field],
check_task=CheckTasks.err_res,
check_items={"err_code": err_code, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("metric", ["L2", "IP", "COSINE"])
def test_search_group_by_flat_index_correctness(self, metric):
"""
target: test search group by with FLAT index returns correct results
method: 1. create a collection with FLAT index
2. insert data with group_by field having multiple values per group
3. search with and without group_by
4. verify group_by search returns the best result (same as normal search top-1)
expected: The top result from group_by search should match the top result from normal search
for the same group
issue: https://github.com/milvus-io/milvus/issues/46349
"""
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
is_all_data_type=True, with_json=False)[0]
# create FLAT index
_index = {"index_type": "FLAT", "metric_type": metric, "params": {}}
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
# insert data with 10 different group values, 100 records per group
for _ in range(10):
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
collection_w.insert(data)
collection_w.flush()
collection_w.load()
nq = 1
limit = 1
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
grpby_field = ct.default_int32_field_name
search_params = {"metric_type": metric, "params": {}}
# normal search to get the best result
normal_res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
search_params, limit,
output_fields=[grpby_field])[0]
# group_by search
groupby_res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
search_params, limit,
group_by_field=grpby_field,
output_fields=[grpby_field])[0]
# verify that the top result from group_by search matches the normal search
# for the same group value, group_by should return the best (closest) result
normal_top_distance = normal_res[0][0].distance
normal_top_group = normal_res[0][0].entity.get(grpby_field)
groupby_top_distance = groupby_res[0][0].distance
groupby_top_group = groupby_res[0][0].entity.get(grpby_field)
log.info(f"Normal search top result: distance={normal_top_distance}, group={normal_top_group}")
log.info(f"GroupBy search top result: distance={groupby_top_distance}, group={groupby_top_group}")
# The group_by result should have the same or better distance as normal search
# because group_by returns the best result per group
if metric == "L2":
# For L2, smaller is better
assert groupby_top_distance <= normal_top_distance + epsilon, \
f"GroupBy search should return result with distance <= normal search for L2 metric"
else:
# For IP/COSINE, larger is better
assert groupby_top_distance >= normal_top_distance - epsilon, \
f"GroupBy search should return result with distance >= normal search for {metric} metric"