From d0eeea4b443f5353421de69a22d83dc1fa4dee23 Mon Sep 17 00:00:00 2001 From: MrPresent-Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:47:00 +0800 Subject: [PATCH] fix: reduce incorrectly for group-by with offset(#30828) (#30882) related: #30828 Signed-off-by: MrPresent-Han --- internal/proxy/search_reduce_util.go | 381 +++++++++++++++++++++++++++ internal/proxy/search_util.go | 1 + internal/proxy/task_search.go | 171 +----------- internal/proxy/task_search_test.go | 92 ++++++- 4 files changed, 469 insertions(+), 176 deletions(-) create mode 100644 internal/proxy/search_reduce_util.go diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go new file mode 100644 index 0000000000..44a59f795e --- /dev/null +++ b/internal/proxy/search_reduce_util.go @@ -0,0 +1,381 @@ +package proxy + +import ( + "context" + "fmt" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type reduceSearchResultInfo struct { + subSearchResultData []*schemapb.SearchResultData + nq int64 + topK int64 + metricType string + pkType schemapb.DataType + offset int64 + queryInfo *planpb.QueryInfo +} + +func NewReduceSearchResultInfo( + subSearchResultData []*schemapb.SearchResultData, + nq int64, + topK int64, + metricType string, + pkType schemapb.DataType, + offset int64, + queryInfo *planpb.QueryInfo, +) *reduceSearchResultInfo { + return &reduceSearchResultInfo{ + subSearchResultData: subSearchResultData, + nq: nq, + topK: topK, + metricType: metricType, + pkType: pkType, + offset: offset, + queryInfo: queryInfo, + } +} + +func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) { + if reduceInfo.queryInfo.GroupByFieldId > 0 { + return reduceSearchResultDataWithGroupBy(ctx, + reduceInfo.subSearchResultData, + reduceInfo.nq, + reduceInfo.topK, + reduceInfo.metricType, + reduceInfo.pkType, + reduceInfo.offset) + } + return reduceSearchResultDataNoGroupBy(ctx, + reduceInfo.subSearchResultData, + reduceInfo.nq, + reduceInfo.topK, + reduceInfo.metricType, + reduceInfo.pkType, + reduceInfo.offset) +} + +func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) { + tr := timerecord.NewTimeRecorder("reduceSearchResultData") + defer func() { + tr.CtxElapse(ctx, "done") + }() + + limit := topk - offset + log.Ctx(ctx).Debug("reduceSearchResultData", + zap.Int("len(subSearchResultData)", len(subSearchResultData)), + zap.Int64("nq", nq), + zap.Int64("offset", offset), + zap.Int64("limit", limit), + zap.String("metricType", metricType)) + + ret := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: nq, + TopK: topk, + FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit), + Scores: []float32{}, + Ids: &schemapb.IDs{}, + Topks: []int64{}, + }, + } + + switch pkType { + case schemapb.DataType_Int64: + ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: make([]int64, 0, limit), + }, + } + case schemapb.DataType_VarChar: + ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: make([]string, 0, limit), + }, + } + default: + return nil, errors.New("unsupported pk type") + } + for i, sData := range subSearchResultData { + pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) + log.Ctx(ctx).Debug("subSearchResultData", + zap.Int("result No.", i), + zap.Int64("nq", sData.NumQueries), + zap.Int64("topk", sData.TopK), + zap.Int("length of pks", pkLength), + zap.Int("length of FieldsData", len(sData.FieldsData))) + if err := checkSearchResultData(sData, nq, topk); err != nil { + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) + return ret, err + } + // printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) + } + + var ( + subSearchNum = len(subSearchResultData) + // for results of each subSearchResultData, storing the start offset of each query of nq queries + subSearchNqOffset = make([][]int64, subSearchNum) + ) + for i := 0; i < subSearchNum; i++ { + subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) + for j := int64(1); j < nq; j++ { + subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] + } + } + + var ( + skipDupCnt int64 + realTopK int64 = -1 + ) + + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + + // reducing nq * topk results + for i := int64(0); i < nq; i++ { + var ( + // cursor of current data of each subSearch for merging the j-th data of TopK. + // sum(cursors) == j + cursors = make([]int64, subSearchNum) + + j int64 + idSet = make(map[interface{}]struct{}) + groupByValSet = make(map[interface{}]struct{}) + ) + + // keep limit results + for j = 0; j < limit; { + // From all the sub-query result sets of the i-th query vector, + // find the sub-query result set index of the score j-th data, + // and the index of the data in schemapb.SearchResultData + subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) + if subSearchIdx == -1 { + break + } + subSearchRes := subSearchResultData[subSearchIdx] + + id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx) + score := subSearchRes.Scores[resultDataIdx] + groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx)) + if groupByVal == nil { + return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," + + "there must be sth wrong on queryNode side") + } + + // remove duplicates + if _, ok := idSet[id]; !ok { + _, groupByValExist := groupByValSet[groupByVal] + if !groupByValExist { + groupByValSet[groupByVal] = struct{}{} + if int64(len(groupByValSet)) <= offset { + continue + // skip offset groups + } + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) + typeutil.AppendPKs(ret.Results.Ids, id) + ret.Results.Scores = append(ret.Results.Scores, score) + idSet[id] = struct{}{} + if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil { + log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err)) + return ret, err + } + j++ + } else { + // skip entity with same groupby + skipDupCnt++ + } + } else { + // skip entity with same id + skipDupCnt++ + } + cursors[subSearchIdx]++ + } + if realTopK != -1 && realTopK != j { + log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) + // return nil, errors.New("the length (topk) between all result of query is different") + } + realTopK = j + ret.Results.Topks = append(ret.Results.Topks, realTopK) + + // limit search result to avoid oom + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) + } + } + log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) + + if skipDupCnt > 0 { + log.Ctx(ctx).Info("skip duplicated search result", zap.Int64("count", skipDupCnt)) + } + + ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query + if !metric.PositivelyRelated(metricType) { + for k := range ret.Results.Scores { + ret.Results.Scores[k] *= -1 + } + } + return ret, nil +} + +func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) { + tr := timerecord.NewTimeRecorder("reduceSearchResultData") + defer func() { + tr.CtxElapse(ctx, "done") + }() + + limit := topk - offset + log.Ctx(ctx).Debug("reduceSearchResultData", + zap.Int("len(subSearchResultData)", len(subSearchResultData)), + zap.Int64("nq", nq), + zap.Int64("offset", offset), + zap.Int64("limit", limit), + zap.String("metricType", metricType)) + + ret := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: nq, + TopK: topk, + FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit), + Scores: []float32{}, + Ids: &schemapb.IDs{}, + Topks: []int64{}, + }, + } + + switch pkType { + case schemapb.DataType_Int64: + ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: make([]int64, 0, limit), + }, + } + case schemapb.DataType_VarChar: + ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: make([]string, 0, limit), + }, + } + default: + return nil, errors.New("unsupported pk type") + } + for i, sData := range subSearchResultData { + pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) + log.Ctx(ctx).Debug("subSearchResultData", + zap.Int("result No.", i), + zap.Int64("nq", sData.NumQueries), + zap.Int64("topk", sData.TopK), + zap.Int("length of pks", pkLength), + zap.Int("length of FieldsData", len(sData.FieldsData))) + if err := checkSearchResultData(sData, nq, topk); err != nil { + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) + return ret, err + } + // printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) + } + + var ( + subSearchNum = len(subSearchResultData) + // for results of each subSearchResultData, storing the start offset of each query of nq queries + subSearchNqOffset = make([][]int64, subSearchNum) + ) + for i := 0; i < subSearchNum; i++ { + subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) + for j := int64(1); j < nq; j++ { + subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] + } + } + + var ( + skipDupCnt int64 + realTopK int64 = -1 + ) + + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + + // reducing nq * topk results + for i := int64(0); i < nq; i++ { + var ( + // cursor of current data of each subSearch for merging the j-th data of TopK. + // sum(cursors) == j + cursors = make([]int64, subSearchNum) + + j int64 + idSet = make(map[interface{}]struct{}) + ) + + // skip offset results + for k := int64(0); k < offset; k++ { + subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) + if subSearchIdx == -1 { + break + } + + cursors[subSearchIdx]++ + } + + // keep limit results + for j = 0; j < limit; { + // From all the sub-query result sets of the i-th query vector, + // find the sub-query result set index of the score j-th data, + // and the index of the data in schemapb.SearchResultData + subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) + if subSearchIdx == -1 { + break + } + id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx) + score := subSearchResultData[subSearchIdx].Scores[resultDataIdx] + + // remove duplicatessds + if _, ok := idSet[id]; !ok { + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) + typeutil.AppendPKs(ret.Results.Ids, id) + ret.Results.Scores = append(ret.Results.Scores, score) + idSet[id] = struct{}{} + j++ + } else { + // skip entity with same id + skipDupCnt++ + } + cursors[subSearchIdx]++ + } + if realTopK != -1 && realTopK != j { + log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) + // return nil, errors.New("the length (topk) between all result of query is different") + } + realTopK = j + ret.Results.Topks = append(ret.Results.Topks, realTopK) + + // limit search result to avoid oom + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) + } + } + log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) + + if skipDupCnt > 0 { + log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt)) + } + + ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query + if !metric.PositivelyRelated(metricType) { + for k := range ret.Results.Scores { + ret.Results.Scores[k] *= -1 + } + } + return ret, nil +} diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 69362e9faa..c6bb7cc64a 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -118,6 +118,7 @@ func initSearchRequest(ctx context.Context, t *searchTask) error { t.SearchRequest.Topk = queryInfo.GetTopk() t.SearchRequest.MetricType = queryInfo.GetMetricType() + t.queryInfo = queryInfo t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk) diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 68791f4c57..909e87d1e3 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -70,6 +70,7 @@ type searchTask struct { node types.ProxyComponent lb LBPolicy queryChannelsTs map[string]Timestamp + queryInfo *planpb.QueryInfo } func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) { @@ -443,7 +444,8 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return err } - t.result, err = reduceSearchResultData(ctx, validSearchResults, Nq, Topk, MetricType, primaryFieldSchema.DataType, t.offset) + t.result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, Nq, Topk, + MetricType, primaryFieldSchema.DataType, t.offset, t.queryInfo)) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) return err @@ -751,173 +753,6 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s return subSearchIdx, resultDataIdx } -func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) { - tr := timerecord.NewTimeRecorder("reduceSearchResultData") - defer func() { - tr.CtxElapse(ctx, "done") - }() - - limit := topk - offset - log.Ctx(ctx).Debug("reduceSearchResultData", - zap.Int("len(subSearchResultData)", len(subSearchResultData)), - zap.Int64("nq", nq), - zap.Int64("offset", offset), - zap.Int64("limit", limit), - zap.String("metricType", metricType)) - - ret := &milvuspb.SearchResults{ - Status: merr.Success(), - Results: &schemapb.SearchResultData{ - NumQueries: nq, - TopK: topk, - FieldsData: typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit), - Scores: []float32{}, - Ids: &schemapb.IDs{}, - Topks: []int64{}, - }, - } - - switch pkType { - case schemapb.DataType_Int64: - ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: make([]int64, 0, limit), - }, - } - case schemapb.DataType_VarChar: - ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: make([]string, 0, limit), - }, - } - default: - return nil, errors.New("unsupported pk type") - } - for i, sData := range subSearchResultData { - pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) - log.Ctx(ctx).Debug("subSearchResultData", - zap.Int("result No.", i), - zap.Int64("nq", sData.NumQueries), - zap.Int64("topk", sData.TopK), - zap.Int("length of pks", pkLength), - zap.Int("length of FieldsData", len(sData.FieldsData))) - if err := checkSearchResultData(sData, nq, topk); err != nil { - log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) - return ret, err - } - // printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) - } - - var ( - subSearchNum = len(subSearchResultData) - // for results of each subSearchResultData, storing the start offset of each query of nq queries - subSearchNqOffset = make([][]int64, subSearchNum) - ) - for i := 0; i < subSearchNum; i++ { - subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) - for j := int64(1); j < nq; j++ { - subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] - } - } - - var ( - skipDupCnt int64 - realTopK int64 = -1 - ) - - var retSize int64 - maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - - // reducing nq * topk results - for i := int64(0); i < nq; i++ { - var ( - // cursor of current data of each subSearch for merging the j-th data of TopK. - // sum(cursors) == j - cursors = make([]int64, subSearchNum) - - j int64 - idSet = make(map[interface{}]struct{}) - groupByValSet = make(map[interface{}]struct{}) - ) - - // skip offset results - for k := int64(0); k < offset; k++ { - subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) - if subSearchIdx == -1 { - break - } - - cursors[subSearchIdx]++ - } - - // keep limit results - for j = 0; j < limit; { - // From all the sub-query result sets of the i-th query vector, - // find the sub-query result set index of the score j-th data, - // and the index of the data in schemapb.SearchResultData - subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i) - if subSearchIdx == -1 { - break - } - subSearchRes := subSearchResultData[subSearchIdx] - - id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx) - score := subSearchRes.Scores[resultDataIdx] - groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx)) - - // remove duplicates - if _, ok := idSet[id]; !ok { - groupByValExist := false - if groupByVal != nil { - _, groupByValExist = groupByValSet[groupByVal] - } - if !groupByValExist { - retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) - typeutil.AppendPKs(ret.Results.Ids, id) - ret.Results.Scores = append(ret.Results.Scores, score) - idSet[id] = struct{}{} - if groupByVal != nil { - groupByValSet[groupByVal] = struct{}{} - if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil { - log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err)) - return ret, err - } - } - j++ - } - } else { - // skip entity with same id - skipDupCnt++ - } - cursors[subSearchIdx]++ - } - if realTopK != -1 && realTopK != j { - log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) - // return nil, errors.New("the length (topk) between all result of query is different") - } - realTopK = j - ret.Results.Topks = append(ret.Results.Topks, realTopK) - - // limit search result to avoid oom - if retSize > maxOutputSize { - return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) - } - } - log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) - - if skipDupCnt > 0 { - log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt)) - } - - ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query - if !metric.PositivelyRelated(metricType) { - for k := range ret.Results.Scores { - ret.Results.Scores[k] *= -1 - } - } - return ret, nil -} - type rangeSearchParams struct { radius float64 rangeFilter float64 diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index f3c671b704..8bd6dd19b8 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" @@ -1526,9 +1527,13 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { results = append(results, r) } + queryInfo := &planpb.QueryInfo{ + GroupByFieldId: -1, + } for _, test := range tests { t.Run(test.description, func(t *testing.T) { - reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset) + reduced, err := reduceSearchResult(context.TODO(), + NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo)) assert.NoError(t, err) assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData()) assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks()) @@ -1577,10 +1582,10 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { []int64{}, }, } - for _, test := range lessThanLimitTests { t.Run(test.description, func(t *testing.T) { - reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset) + reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk, + metric.L2, schemapb.DataType_Int64, test.offset, queryInfo)) assert.NoError(t, err) assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData()) assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks()) @@ -1604,7 +1609,12 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { results = append(results, r) } - reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, 0) + queryInfo := &planpb.QueryInfo{ + GroupByFieldId: -1, + } + + reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo( + results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo)) assert.NoError(t, err) assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData()) @@ -1630,8 +1640,12 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { results = append(results, r) } + queryInfo := &planpb.QueryInfo{ + GroupByFieldId: -1, + } - reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_VarChar, 0) + reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, + nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo)) assert.NoError(t, err) assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData()) @@ -1700,8 +1714,11 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) { } results = append(results, result) } - - reduced, err := reduceSearchResultData(context.TODO(), results, nq, topK, metric.L2, schemapb.DataType_Int64, 0) + queryInfo := &planpb.QueryInfo{ + GroupByFieldId: 1, + } + reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2, + schemapb.DataType_Int64, 0, queryInfo)) resultIDs := reduced.GetResults().GetIds().GetIntId().Data resultScores := reduced.GetResults().GetScores() resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData() @@ -1713,6 +1730,63 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) { } } +func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) { + var ( + nq int64 = 1 + limit int64 = 5 + offset int64 = 5 + ) + ids := [][]int64{ + {1, 3, 5, 7, 9}, + {2, 4, 6, 8, 10}, + } + scores := [][]float32{ + {10, 8, 6, 4, 2}, + {9, 7, 5, 3, 1}, + } + groupByValuesArr := [][]int64{ + {1, 3, 5, 7, 9}, + {2, 4, 6, 8, 10}, + } + expectedIDs := []int64{6, 7, 8, 9, 10} + expectedScores := []float32{-5, -4, -3, -2, -1} + expectedGroupByValues := []int64{6, 7, 8, 9, 10} + + var results []*schemapb.SearchResultData + for j := range ids { + result := getSearchResultData(nq, limit+offset) + result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}} + result.Scores = scores[j] + result.Topks = []int64{limit} + result.GroupByFieldValue = &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: groupByValuesArr[j], + }, + }, + }, + }, + } + results = append(results, result) + } + + queryInfo := &planpb.QueryInfo{ + GroupByFieldId: 1, + } + reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2, + schemapb.DataType_Int64, offset, queryInfo)) + resultIDs := reduced.GetResults().GetIds().GetIntId().Data + resultScores := reduced.GetResults().GetScores() + resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData() + assert.EqualValues(t, expectedIDs, resultIDs) + assert.EqualValues(t, expectedScores, resultScores) + assert.EqualValues(t, expectedGroupByValues, resultGroupByValues) + assert.NoError(t, err) +} + func TestSearchTask_ErrExecute(t *testing.T) { var ( err error @@ -2367,7 +2441,9 @@ func TestSearchTask_Requery(t *testing.T) { qt.resultBuf.Insert(&internalpb.SearchResults{ SlicedBlob: bytes, }) - + qt.queryInfo = &planpb.QueryInfo{ + GroupByFieldId: -1, + } err = qt.PostExecute(ctx) t.Logf("err = %s", err) assert.Error(t, err)