diff --git a/internal/core/src/common/QueryResult.h b/internal/core/src/common/QueryResult.h index 9fd2d13d77..4cb7fef00e 100644 --- a/internal/core/src/common/QueryResult.h +++ b/internal/core/src/common/QueryResult.h @@ -228,6 +228,7 @@ struct RetrieveResult { void* segment_; std::vector result_offsets_; std::vector field_data_; + bool has_more_result = true; }; using RetrieveResultPtr = std::shared_ptr; diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index e5a7a8f1c7..d9e8a6c125 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -291,8 +291,10 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { false_filtered_out = true; segment->timestamp_filter(bitset_holder, timestamp_); } - retrieve_result.result_offsets_ = + auto results_pair = segment->find_first(node.limit_, bitset_holder, false_filtered_out); + retrieve_result.result_offsets_ = std::move(results_pair.first); + retrieve_result.has_more_result = results_pair.second; retrieve_result_opt_ = std::move(retrieve_result); } diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index 7e85a64c23..7da03c1828 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -60,7 +60,7 @@ class OffsetMap { using OffsetType = int64_t; // TODO: in fact, we can retrieve the pk here. Not sure which way is more efficient. - virtual std::vector + virtual std::pair, bool> find_first(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const = 0; @@ -109,7 +109,7 @@ class OffsetOrderedMap : public OffsetMap { return map_.empty(); } - std::vector + std::pair, bool> find_first(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const override { @@ -131,7 +131,7 @@ class OffsetOrderedMap : public OffsetMap { } private: - std::vector + std::pair, bool> find_first_by_index(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const { @@ -144,8 +144,8 @@ class OffsetOrderedMap : public OffsetMap { limit = std::min(limit, cnt); std::vector seg_offsets; seg_offsets.reserve(limit); - for (auto it = map_.begin(); hit_num < limit && it != map_.end(); - it++) { + auto it = map_.begin(); + for (; hit_num < limit && it != map_.end(); it++) { for (auto seg_offset : it->second) { if (seg_offset >= size) { // Frequently concurrent insert/query will cause this case. @@ -161,7 +161,7 @@ class OffsetOrderedMap : public OffsetMap { } } } - return seg_offsets; + return {seg_offsets, it != map_.end()}; } private: @@ -226,7 +226,7 @@ class OffsetOrderedArray : public OffsetMap { return array_.empty(); } - std::vector + std::pair, bool> find_first(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const override { @@ -248,7 +248,7 @@ class OffsetOrderedArray : public OffsetMap { } private: - std::vector + std::pair, bool> find_first_by_index(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const { @@ -261,11 +261,11 @@ class OffsetOrderedArray : public OffsetMap { limit = std::min(limit, cnt); std::vector seg_offsets; seg_offsets.reserve(limit); - for (auto it = array_.begin(); hit_num < limit && it != array_.end(); - it++) { + auto it = array_.begin(); + for (; hit_num < limit && it != array_.end(); it++) { auto seg_offset = it->second; if (seg_offset >= size) { - // In fact, this case won't happend on sealed segments. + // In fact, this case won't happen on sealed segments. continue; } @@ -274,7 +274,7 @@ class OffsetOrderedArray : public OffsetMap { hit_num++; } } - return seg_offsets; + return {seg_offsets, it != array_.end()}; } void diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 1cc308216b..06f9048d5a 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -268,7 +268,7 @@ class SegmentGrowingImpl : public SegmentGrowing { return true; } - std::vector + std::pair, bool> find_first(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const override { diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 3d79fc0b35..0ad695ff59 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -91,6 +91,7 @@ SegmentInternalInterface::Retrieve(tracer::TraceContext* trace_ctx, query::ExecPlanNodeVisitor visitor(*this, timestamp); auto retrieve_results = visitor.get_retrieve_result(*plan->plan_node_); retrieve_results.segment_ = (void*)this; + results->set_has_more_result(retrieve_results.has_more_result); auto result_rows = retrieve_results.result_offsets_.size(); int64_t output_data_size = 0; @@ -120,7 +121,6 @@ SegmentInternalInterface::Retrieve(tracer::TraceContext* trace_ctx, retrieve_results.result_offsets_.size(), ignore_non_pk, true); - return results; } diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index 2715e387c7..6a2dbf1485 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -290,7 +290,7 @@ class SegmentInternalInterface : public SegmentInterface { * @param false_filtered_out * @return All candidates offsets. */ - virtual std::vector + virtual std::pair, bool> find_first(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const = 0; diff --git a/internal/core/src/segcore/SegmentSealedImpl.h b/internal/core/src/segcore/SegmentSealedImpl.h index 21306616e8..b7e8b89e2c 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.h +++ b/internal/core/src/segcore/SegmentSealedImpl.h @@ -133,7 +133,7 @@ class SegmentSealedImpl : public SegmentSealed { const IdArray* pks, const Timestamp* timestamps) override; - std::vector + std::pair, bool> find_first(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const override { diff --git a/internal/core/unittest/test_offset_ordered_array.cpp b/internal/core/unittest/test_offset_ordered_array.cpp index ec371c6114..1eb2e272b0 100644 --- a/internal/core/unittest/test_offset_ordered_array.cpp +++ b/internal/core/unittest/test_offset_ordered_array.cpp @@ -65,8 +65,6 @@ using TypeOfPks = testing::Types; TYPED_TEST_SUITE_P(TypedOffsetOrderedArrayTest); TYPED_TEST_P(TypedOffsetOrderedArrayTest, find_first) { - std::vector offsets; - // not sealed. ASSERT_ANY_THROW(this->map_.find_first(Unlimited, {}, true)); @@ -81,40 +79,62 @@ TYPED_TEST_P(TypedOffsetOrderedArrayTest, find_first) { this->seal(); // all is satisfied. - BitsetType all(num); - all.set(); - offsets = this->map_.find_first(num / 2, all, true); - ASSERT_EQ(num / 2, offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + { + BitsetType all(num); + all.set(); + { + auto [offsets, has_more_res] = + this->map_.find_first(num / 2, all, true); + ASSERT_EQ(num / 2, offsets.size()); + ASSERT_TRUE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } + } + { + auto [offsets, has_more_res] = + this->map_.find_first(Unlimited, all, true); + ASSERT_EQ(num, offsets.size()); + ASSERT_FALSE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } + } } - offsets = this->map_.find_first(Unlimited, all, true); - ASSERT_EQ(num, offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + { + // corner case, segment offset exceeds the size of bitset. + BitsetType all_minus_1(num - 1); + all_minus_1.set(); + { + auto [offsets, has_more_res] = + this->map_.find_first(num / 2, all_minus_1, true); + ASSERT_EQ(num / 2, offsets.size()); + ASSERT_TRUE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } + } + { + auto [offsets, has_more_res] = + this->map_.find_first(Unlimited, all_minus_1, true); + ASSERT_EQ(all_minus_1.size(), offsets.size()); + ASSERT_FALSE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } + } } - - // corner case, segment offset exceeds the size of bitset. - BitsetType all_minus_1(num - 1); - all_minus_1.set(); - offsets = this->map_.find_first(num / 2, all_minus_1, true); - ASSERT_EQ(num / 2, offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + { + // none is satisfied. + BitsetType none(num); + none.reset(); + auto result_pair = this->map_.find_first(num / 2, none, true); + ASSERT_EQ(0, result_pair.first.size()); + ASSERT_TRUE(result_pair.second); + result_pair = this->map_.find_first(NoLimit, none, true); + ASSERT_EQ(0, result_pair.first.size()); + ASSERT_TRUE(result_pair.second); } - offsets = this->map_.find_first(Unlimited, all_minus_1, true); - ASSERT_EQ(all_minus_1.size(), offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); - } - - // none is satisfied. - BitsetType none(num); - none.reset(); - offsets = this->map_.find_first(num / 2, none, true); - ASSERT_EQ(0, offsets.size()); - offsets = this->map_.find_first(NoLimit, none, true); - ASSERT_EQ(0, offsets.size()); } REGISTER_TYPED_TEST_SUITE_P(TypedOffsetOrderedArrayTest, find_first); diff --git a/internal/core/unittest/test_offset_ordered_map.cpp b/internal/core/unittest/test_offset_ordered_map.cpp index be16aed9e0..36f4bafc83 100644 --- a/internal/core/unittest/test_offset_ordered_map.cpp +++ b/internal/core/unittest/test_offset_ordered_map.cpp @@ -60,12 +60,13 @@ using TypeOfPks = testing::Types; TYPED_TEST_SUITE_P(TypedOffsetOrderedMapTest); TYPED_TEST_P(TypedOffsetOrderedMapTest, find_first) { - std::vector offsets; - // no data. - offsets = this->map_.find_first(Unlimited, {}, true); - ASSERT_EQ(0, offsets.size()); - + { + auto [offsets, has_more_res] = + this->map_.find_first(Unlimited, {}, true); + ASSERT_EQ(0, offsets.size()); + ASSERT_FALSE(has_more_res); + } // insert 10 entities. int num = 10; auto data = this->random_generate(num); @@ -76,38 +77,63 @@ TYPED_TEST_P(TypedOffsetOrderedMapTest, find_first) { // all is satisfied. BitsetType all(num); all.set(); - offsets = this->map_.find_first(num / 2, all, true); - ASSERT_EQ(num / 2, offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + + { + auto [offsets, has_more_res] = + this->map_.find_first(num / 2, all, true); + ASSERT_EQ(num / 2, offsets.size()); + ASSERT_TRUE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } } - offsets = this->map_.find_first(Unlimited, all, true); - ASSERT_EQ(num, offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + { + auto [offsets, has_more_res] = + this->map_.find_first(Unlimited, all, true); + ASSERT_EQ(num, offsets.size()); + ASSERT_FALSE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } } // corner case, segment offset exceeds the size of bitset. BitsetType all_minus_1(num - 1); all_minus_1.set(); - offsets = this->map_.find_first(num / 2, all_minus_1, true); - ASSERT_EQ(num / 2, offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + { + auto [offsets, has_more_res] = + this->map_.find_first(num / 2, all_minus_1, true); + ASSERT_EQ(num / 2, offsets.size()); + ASSERT_TRUE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } } - offsets = this->map_.find_first(Unlimited, all_minus_1, true); - ASSERT_EQ(all_minus_1.size(), offsets.size()); - for (int i = 1; i < offsets.size(); i++) { - ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + { + auto [offsets, has_more_res] = + this->map_.find_first(Unlimited, all_minus_1, true); + ASSERT_EQ(all_minus_1.size(), offsets.size()); + ASSERT_FALSE(has_more_res); + for (int i = 1; i < offsets.size(); i++) { + ASSERT_TRUE(data[offsets[i - 1]] <= data[offsets[i]]); + } } // none is satisfied. BitsetType none(num); none.reset(); - offsets = this->map_.find_first(num / 2, none, true); - ASSERT_EQ(0, offsets.size()); - offsets = this->map_.find_first(NoLimit, none, true); - ASSERT_EQ(0, offsets.size()); + { + auto [offsets, has_more_res] = + this->map_.find_first(num / 2, none, true); + ASSERT_TRUE(has_more_res); + ASSERT_EQ(0, offsets.size()); + } + { + auto [offsets, has_more_res] = + this->map_.find_first(NoLimit, none, true); + ASSERT_TRUE(has_more_res); + ASSERT_EQ(0, offsets.size()); + } } REGISTER_TYPED_TEST_SUITE_P(TypedOffsetOrderedMapTest, find_first); diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 6715af58d9..980cf35769 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -198,6 +198,7 @@ message RetrieveResults { // query request cost CostAggregation costAggregation = 13; int64 all_retrieve_count = 14; + bool has_more_result = 15; } message LoadIndex { diff --git a/internal/proto/segcore.proto b/internal/proto/segcore.proto index ea7697f48c..aaf502bc1e 100644 --- a/internal/proto/segcore.proto +++ b/internal/proto/segcore.proto @@ -10,6 +10,7 @@ message RetrieveResults { repeated int64 offset = 2; repeated schema.FieldData fields_data = 3; int64 all_retrieve_count = 4; + bool has_more_result = 5; } message LoadFieldMeta { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 18abb6ed05..212015b440 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -607,9 +607,9 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re idSet := make(map[interface{}]struct{}) cursors := make([]int64, len(validRetrieveResults)) - retrieveLimit := typeutil.Unlimited if queryParams != nil && queryParams.limit != typeutil.Unlimited { - retrieveLimit = queryParams.limit + queryParams.offset + // reduceStopForBest will try to get as many results as possible + // so loopEnd in this case will be set to the sum of all results' size if !queryParams.reduceStopForBest { loopEnd = int(queryParams.limit) } @@ -618,7 +618,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re // handle offset if queryParams != nil && queryParams.offset > 0 { for i := int64(0); i < queryParams.offset; i++ { - sel, drainOneResult := typeutil.SelectMinPK(retrieveLimit, validRetrieveResults, cursors) + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { return ret, nil } @@ -626,16 +626,11 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re } } - reduceStopForBest := false - if queryParams != nil { - reduceStopForBest = queryParams.reduceStopForBest - } - var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { - sel, drainOneResult := typeutil.SelectMinPK(retrieveLimit, validRetrieveResults, cursors) - if sel == -1 || (reduceStopForBest && drainOneResult) { + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { break } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 5112b53ac2..9b62b9ece5 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -479,8 +479,7 @@ func TestTaskQuery_functions(t *testing.T) { }, FieldsData: fieldDataArray2, } - - result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}, nil) + result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}, &queryParams{limit: 2}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -488,7 +487,7 @@ func TestTaskQuery_functions(t *testing.T) { }) t.Run("test nil results", func(t *testing.T) { - ret, err := reduceRetrieveResults(context.Background(), nil, nil) + ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{}) assert.NoError(t, err) assert.Empty(t, ret.GetFieldsData()) }) @@ -594,6 +593,8 @@ func TestTaskQuery_functions(t *testing.T) { }) t.Run("test stop reduce for best for limit", func(t *testing.T) { + r1.HasMoreResult = true + r2.HasMoreResult = false result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: 2, reduceStopForBest: true}) @@ -605,6 +606,8 @@ func TestTaskQuery_functions(t *testing.T) { }) t.Run("test stop reduce for best for limit and offset", func(t *testing.T) { + r1.HasMoreResult = true + r2.HasMoreResult = true result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: 1, offset: 1, reduceStopForBest: true}) @@ -614,6 +617,8 @@ func TestTaskQuery_functions(t *testing.T) { }) t.Run("test stop reduce for best for limit and offset", func(t *testing.T) { + r1.HasMoreResult = false + r2.HasMoreResult = true result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: 2, offset: 1, reduceStopForBest: true}) @@ -625,6 +630,8 @@ func TestTaskQuery_functions(t *testing.T) { }) t.Run("test stop reduce for best for unlimited set", func(t *testing.T) { + r1.HasMoreResult = false + r2.HasMoreResult = false result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: typeutil.Unlimited, reduceStopForBest: true}) @@ -635,7 +642,7 @@ func TestTaskQuery_functions(t *testing.T) { assert.InDeltaSlice(t, resultFloat[0:(len)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) }) - t.Run("test stop reduce for best for unlimited set amd pffset", func(t *testing.T) { + t.Run("test stop reduce for best for unlimited set amd offset", func(t *testing.T) { result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: typeutil.Unlimited, offset: 3, reduceStopForBest: true}) diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 34a001e6e6..0ac61d81c9 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -401,6 +401,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna validRetrieveResults := []*internalpb.RetrieveResults{} relatedDataSize := int64(0) + hasMoreResult := false for _, r := range retrieveResults { ret.AllRetrieveCount += r.GetAllRetrieveCount() relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize() @@ -410,7 +411,9 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna } validRetrieveResults = append(validRetrieveResults, r) loopEnd += size + hasMoreResult = hasMoreResult || r.GetHasMoreResult() } + ret.HasMoreResult = hasMoreResult if len(validRetrieveResults) == 0 { return ret, nil @@ -427,7 +430,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { - sel, drainOneResult := typeutil.SelectMinPK(param.limit, validRetrieveResults, cursors) + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) if sel == -1 || (param.mergeStopForBest && drainOneResult) { break } @@ -515,6 +518,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore validSegments := make([]Segment, 0, len(segments)) selectedOffsets := make([][]int64, 0, len(retrieveResults)) selectedIndexes := make([][]int64, 0, len(retrieveResults)) + hasMoreResult := false for i, r := range retrieveResults { size := typeutil.GetSizeOfIDs(r.GetIds()) ret.AllRetrieveCount += r.GetAllRetrieveCount() @@ -529,7 +533,9 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore selectedOffsets = append(selectedOffsets, make([]int64, 0, len(r.GetOffset()))) selectedIndexes = append(selectedIndexes, make([]int64, 0, len(r.GetOffset()))) loopEnd += size + hasMoreResult = r.GetHasMoreResult() || hasMoreResult } + ret.HasMoreResult = hasMoreResult if len(validRetrieveResults) == 0 { return ret, nil @@ -549,7 +555,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ { - sel, drainOneResult := typeutil.SelectMinPK(param.limit, validRetrieveResults, cursors) + sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) if sel == -1 || (param.mergeStopForBest && drainOneResult) { break } diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 79e75007d6..6fcaf41965 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -513,29 +513,46 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { FieldsData: fieldDataArray2, } suite.Run("merge stop finite limited", func() { + result1.HasMoreResult = true + result2.HasMoreResult = true result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, NewMergeParam(3, make([]int64, 0), nil, true)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) + // has more result both, stop reduce when draining one result + // here, we can only get best result from 0 to 4 without 6, because result1 has more results suite.Equal([]int64{0, 1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) - // here, we can only get best result from 0 to 4 without 6, because we can never know whether there is - // one potential 5 in following result1 suite.Equal([]int64{11, 22, 11, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44}, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) }) suite.Run("merge stop unlimited", func() { + result1.HasMoreResult = false + result2.HasMoreResult = false result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) + // as result1 and result2 don't have better results neither + // we can reduce all available result into the reduced result suite.Equal([]int64{0, 1, 2, 3, 4, 6}, result.GetIds().GetIntId().GetData()) - // here, we can only get best result from 0 to 4 without 6, because we can never know whether there is - // one potential 5 in following result1 suite.Equal([]int64{11, 22, 11, 22, 33, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 11, 22, 33, 44}, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) }) + suite.Run("merge stop one limited", func() { + result1.HasMoreResult = true + result2.HasMoreResult = false + result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) + suite.NoError(err) + suite.Equal(2, len(result.GetFieldsData())) + // as result1 may have better results, stop reducing when draining it + suite.Equal([]int64{0, 1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) + suite.Equal([]int64{11, 22, 11, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44}, + result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) }) suite.Run("test stop internal merge for best", func() { @@ -559,6 +576,8 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { }, FieldsData: fieldDataArray2, } + result1.HasMoreResult = true + result2.HasMoreResult = false result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, NewMergeParam(3, make([]int64, 0), nil, true)) suite.NoError(err) @@ -590,11 +609,24 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { }, FieldsData: fieldDataArray2, } - result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) - suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) - suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData()) + suite.Run("test drain one result without more results", func() { + result1.HasMoreResult = false + result2.HasMoreResult = false + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, + NewMergeParam(3, make([]int64, 0), nil, true)) + suite.NoError(err) + suite.Equal(2, len(result.GetFieldsData())) + suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData()) + }) + suite.Run("test drain one result with more results", func() { + result1.HasMoreResult = false + result2.HasMoreResult = true + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, + NewMergeParam(3, make([]int64, 0), nil, true)) + suite.NoError(err) + suite.Equal(2, len(result.GetFieldsData())) + suite.Equal([]int64{0, 2}, result.GetIds().GetIntId().GetData()) + }) }) } diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index 831d782d34..d4b0ec5c80 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -160,6 +160,7 @@ func (t *QueryTask) Execute() error { TotalRelatedDataSize: relatedDataSize, }, AllRetrieveCount: reducedResult.GetAllRetrieveCount(), + HasMoreResult: reducedResult.HasMoreResult, } return nil } diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index fd29f632f7..dfa35f2109 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -1323,10 +1323,11 @@ func ComparePK(pkA, pkB interface{}) bool { type ResultWithID interface { GetIds() *schemapb.IDs + GetHasMoreResult() bool } // SelectMinPK select the index of the minPK in results T of the cursors. -func SelectMinPK[T ResultWithID](limit int64, results []T, cursors []int64) (int, bool) { +func SelectMinPK[T ResultWithID](results []T, cursors []int64) (int, bool) { var ( sel = -1 drainResult = false @@ -1336,8 +1337,9 @@ func SelectMinPK[T ResultWithID](limit int64, results []T, cursors []int64) (int minStrPK string ) for i, cursor := range cursors { - // if result size < limit, this means we should ignore the result from this segment - if int(cursor) >= GetSizeOfIDs(results[i].GetIds()) && (GetSizeOfIDs(results[i].GetIds()) == int(limit)) { + // if cursor has run out of all results from one result and this result has more matched results + // in this case we have tell reduce to stop because better results may be retrieved in the following iteration + if int(cursor) >= GetSizeOfIDs(results[i].GetIds()) && (results[i].GetHasMoreResult()) { drainResult = true continue }