mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
fix: fix chunk iterator merge order (#46461)
issue: #46349 When using brute-force search, the iterator results from multiple chunks are merged; at that point, we need to pay attention to how the metric affects result ranking. Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
parent
1a7ca339a5
commit
ab03521588
@ -114,11 +114,25 @@ struct OffsetDisPair {
|
||||
};
|
||||
|
||||
struct OffsetDisPairComparator {
|
||||
bool larger_is_closer_ = false;
|
||||
|
||||
OffsetDisPairComparator(bool larger_is_closer = false)
|
||||
: larger_is_closer_(larger_is_closer) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator()(const std::shared_ptr<OffsetDisPair>& left,
|
||||
const std::shared_ptr<OffsetDisPair>& right) const {
|
||||
// For priority_queue: return true if left has lower priority than right
|
||||
// We want the element with better (closer) distance at the top
|
||||
if (left->GetOffDis().second != right->GetOffDis().second) {
|
||||
return left->GetOffDis().second < right->GetOffDis().second;
|
||||
if (larger_is_closer_) {
|
||||
// IP/Cosine: larger distance is better, smaller has lower priority
|
||||
return left->GetOffDis().second < right->GetOffDis().second;
|
||||
} else {
|
||||
// L2: smaller distance is better, larger has lower priority
|
||||
return left->GetOffDis().second > right->GetOffDis().second;
|
||||
}
|
||||
}
|
||||
return left->GetOffDis().first < right->GetOffDis().first;
|
||||
}
|
||||
@ -142,8 +156,11 @@ class VectorIterator {
|
||||
class ChunkMergeIterator : public VectorIterator {
|
||||
public:
|
||||
ChunkMergeIterator(int chunk_count,
|
||||
const std::vector<int64_t>& total_rows_until_chunk = {})
|
||||
: total_rows_until_chunk_(total_rows_until_chunk) {
|
||||
const std::vector<int64_t>& total_rows_until_chunk = {},
|
||||
bool larger_is_closer = false)
|
||||
: total_rows_until_chunk_(total_rows_until_chunk),
|
||||
larger_is_closer_(larger_is_closer),
|
||||
heap_(OffsetDisPairComparator(larger_is_closer)) {
|
||||
iterators_.reserve(chunk_count);
|
||||
}
|
||||
|
||||
@ -215,6 +232,7 @@ class ChunkMergeIterator : public VectorIterator {
|
||||
heap_;
|
||||
bool sealed = false;
|
||||
std::vector<int64_t> total_rows_until_chunk_;
|
||||
bool larger_is_closer_ = false;
|
||||
//currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
|
||||
//we may need to add mutex to protect the variable sealed
|
||||
};
|
||||
@ -239,7 +257,8 @@ struct SearchResult {
|
||||
int64_t nq,
|
||||
int chunk_count,
|
||||
const std::vector<int64_t>& total_rows_until_chunk,
|
||||
const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators) {
|
||||
const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators,
|
||||
bool larger_is_closer = false) {
|
||||
AssertInfo(kw_iterators.size() == nq * chunk_count,
|
||||
"kw_iterators count:{} is not equal to nq*chunk_count:{}, "
|
||||
"wrong state",
|
||||
@ -251,7 +270,7 @@ struct SearchResult {
|
||||
vec_iter_idx = vec_iter_idx % nq;
|
||||
if (vector_iterators.size() < nq) {
|
||||
auto chunk_merge_iter = std::make_shared<ChunkMergeIterator>(
|
||||
chunk_count, total_rows_until_chunk);
|
||||
chunk_count, total_rows_until_chunk, larger_is_closer);
|
||||
vector_iterators.emplace_back(chunk_merge_iter);
|
||||
}
|
||||
const auto& kw_iterator = kw_iterators[i];
|
||||
|
||||
@ -54,8 +54,10 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info,
|
||||
iterators_val =
|
||||
index.VectorIterators(dataset, search_conf, bitset);
|
||||
if (iterators_val.has_value()) {
|
||||
bool larger_is_closer =
|
||||
PositivelyRelated(search_info.metric_type_);
|
||||
search_result.AssembleChunkVectorIterators(
|
||||
nq, 1, {0}, iterators_val.value());
|
||||
nq, 1, {0}, iterators_val.value(), larger_is_closer);
|
||||
} else {
|
||||
std::string operator_type = "";
|
||||
if (search_info.group_by_field_id_.has_value()) {
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
#include "SearchOnGrowing.h"
|
||||
#include <cstddef>
|
||||
#include "knowhere/comp/index_param.h"
|
||||
@ -279,8 +280,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
for (int i = 1; i < max_chunk; ++i) {
|
||||
chunk_rows[i] = i * vec_size_per_chunk;
|
||||
}
|
||||
bool larger_is_closer = PositivelyRelated(info.metric_type_);
|
||||
search_result.AssembleChunkVectorIterators(
|
||||
num_queries, max_chunk, chunk_rows, final_qr.chunk_iterators());
|
||||
num_queries,
|
||||
max_chunk,
|
||||
chunk_rows,
|
||||
final_qr.chunk_iterators(),
|
||||
larger_is_closer);
|
||||
} else {
|
||||
if (info.array_offsets_ != nullptr) {
|
||||
auto [seg_offsets, elem_indicies] =
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnSealed.h"
|
||||
@ -237,10 +238,12 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
offset += chunk_size;
|
||||
}
|
||||
if (milvus::exec::UseVectorIterator(search_info)) {
|
||||
bool larger_is_closer = PositivelyRelated(search_info.metric_type_);
|
||||
result.AssembleChunkVectorIterators(num_queries,
|
||||
num_chunk,
|
||||
column->GetNumRowsUntilChunk(),
|
||||
final_qr.chunk_iterators());
|
||||
final_qr.chunk_iterators(),
|
||||
larger_is_closer);
|
||||
} else {
|
||||
if (search_info.array_offsets_ != nullptr) {
|
||||
auto [seg_offsets, elem_indicies] =
|
||||
|
||||
@ -357,3 +357,68 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("metric", ["L2", "IP", "COSINE"])
|
||||
def test_search_group_by_flat_index_correctness(self, metric):
|
||||
"""
|
||||
target: test search group by with FLAT index returns correct results
|
||||
method: 1. create a collection with FLAT index
|
||||
2. insert data with group_by field having multiple values per group
|
||||
3. search with and without group_by
|
||||
4. verify group_by search returns the best result (same as normal search top-1)
|
||||
expected: The top result from group_by search should match the top result from normal search
|
||||
for the same group
|
||||
issue: https://github.com/milvus-io/milvus/issues/46349
|
||||
"""
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=False)[0]
|
||||
# create FLAT index
|
||||
_index = {"index_type": "FLAT", "metric_type": metric, "params": {}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
|
||||
# insert data with 10 different group values, 100 records per group
|
||||
for _ in range(10):
|
||||
data = cf.gen_dataframe_all_data_type(nb=100, auto_id=True, with_json=False)
|
||||
collection_w.insert(data)
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.load()
|
||||
|
||||
nq = 1
|
||||
limit = 1
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
grpby_field = ct.default_int32_field_name
|
||||
search_params = {"metric_type": metric, "params": {}}
|
||||
|
||||
# normal search to get the best result
|
||||
normal_res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
|
||||
search_params, limit,
|
||||
output_fields=[grpby_field])[0]
|
||||
|
||||
# group_by search
|
||||
groupby_res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
|
||||
search_params, limit,
|
||||
group_by_field=grpby_field,
|
||||
output_fields=[grpby_field])[0]
|
||||
|
||||
# verify that the top result from group_by search matches the normal search
|
||||
# for the same group value, group_by should return the best (closest) result
|
||||
normal_top_distance = normal_res[0][0].distance
|
||||
normal_top_group = normal_res[0][0].entity.get(grpby_field)
|
||||
groupby_top_distance = groupby_res[0][0].distance
|
||||
groupby_top_group = groupby_res[0][0].entity.get(grpby_field)
|
||||
|
||||
log.info(f"Normal search top result: distance={normal_top_distance}, group={normal_top_group}")
|
||||
log.info(f"GroupBy search top result: distance={groupby_top_distance}, group={groupby_top_group}")
|
||||
|
||||
# The group_by result should have the same or better distance as normal search
|
||||
# because group_by returns the best result per group
|
||||
if metric == "L2":
|
||||
# For L2, smaller is better
|
||||
assert groupby_top_distance <= normal_top_distance + epsilon, \
|
||||
f"GroupBy search should return result with distance <= normal search for L2 metric"
|
||||
else:
|
||||
# For IP/COSINE, larger is better
|
||||
assert groupby_top_distance >= normal_top_distance - epsilon, \
|
||||
f"GroupBy search should return result with distance >= normal search for {metric} metric"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user