From e480b103bd2669ffd06b836d9ee0e7d4839dc0b2 Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Sun, 8 Sep 2024 17:09:04 +0800 Subject: [PATCH] feat: supporing hybrid search group_by (#35982) related: #35096 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- internal/datacoord/index_meta_test.go | 4 +- internal/datacoord/task_scheduler_test.go | 2 +- internal/datanode/compaction/merge_sort.go | 2 +- .../compaction/priority_queue_test.go | 1 - internal/proto/internal.proto | 10 +- internal/proxy/search_reduce_util.go | 514 +++++++++++++----- internal/proxy/search_reduce_util_test.go | 133 +++++ internal/proxy/search_util.go | 4 +- internal/proxy/task.go | 7 + internal/proxy/task_search.go | 54 +- internal/proxy/task_search_test.go | 219 +++++++- internal/querynodev2/delegator/delegator.go | 22 +- internal/querynodev2/handlers.go | 15 +- internal/querynodev2/segments/result.go | 62 +-- internal/querynodev2/segments/result_test.go | 38 ++ .../querynodev2/segments/search_reduce.go | 55 +- .../segments/search_reduce_test.go | 15 +- internal/querynodev2/services.go | 68 +-- internal/util/reduce/reduce_info.go | 92 ++++ 19 files changed, 983 insertions(+), 334 deletions(-) create mode 100644 internal/proxy/search_reduce_util_test.go create mode 100644 internal/util/reduce/reduce_info.go diff --git a/internal/datacoord/index_meta_test.go b/internal/datacoord/index_meta_test.go index 1a991315ef..cad20c5657 100644 --- a/internal/datacoord/index_meta_test.go +++ b/internal/datacoord/index_meta_test.go @@ -170,7 +170,8 @@ func TestMeta_ScalarAutoIndex(t *testing.T) { { Key: common.IndexTypeKey, Value: "HYBRID", - }}, + }, + }, Timestamp: 0, IsAutoIndex: true, UserIndexParams: userIndexParams, @@ -205,7 +206,6 @@ func TestMeta_ScalarAutoIndex(t *testing.T) { assert.Equal(t, newIndexParams[0].Key, common.IndexTypeKey) assert.Equal(t, newIndexParams[0].Value, "INVERTED") }) - } func TestMeta_CanCreateIndex(t *testing.T) { diff --git a/internal/datacoord/task_scheduler_test.go b/internal/datacoord/task_scheduler_test.go index ba47290a73..ed7051a1e3 100644 --- a/internal/datacoord/task_scheduler_test.go +++ b/internal/datacoord/task_scheduler_test.go @@ -769,7 +769,7 @@ func (s *taskSchedulerSuite) scheduler(handler Handler) { return nil }) catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil) - //catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil) + // catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil) in := mocks.NewMockIndexNodeClient(s.T()) in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil) diff --git a/internal/datanode/compaction/merge_sort.go b/internal/datanode/compaction/merge_sort.go index f9294b3620..450a49197b 100644 --- a/internal/datanode/compaction/merge_sort.go +++ b/internal/datanode/compaction/merge_sort.go @@ -63,7 +63,7 @@ func mergeSortMultipleSegments(ctx context.Context, return nil, err } - //SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID()) + // SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID()) segmentReaders := make([]*SegmentDeserializeReader, len(binlogs)) for i, s := range binlogs { var binlogBatchCount int diff --git a/internal/datanode/compaction/priority_queue_test.go b/internal/datanode/compaction/priority_queue_test.go index d3a73ec062..1bcb4fafa0 100644 --- a/internal/datanode/compaction/priority_queue_test.go +++ b/internal/datanode/compaction/priority_queue_test.go @@ -119,7 +119,6 @@ func (s *PriorityQueueSuite) PriorityQueueMergeSort() { heap.Push(&pq, next) } } - } func TestNewPriorityQueueSuite(t *testing.T) { diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index f01c5256e4..b50ea4004c 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -94,11 +94,8 @@ message SubSearchRequest { int64 topk = 7; int64 offset = 8; string metricType = 9; -} - -message ExtraSearchParam { - int64 group_by_field_id = 1; - int64 group_size = 2; + int64 group_by_field_id = 10; + int64 group_size = 11; } message SearchRequest { @@ -125,7 +122,8 @@ message SearchRequest { bool is_advanced = 20; int64 offset = 21; common.ConsistencyLevel consistency_level = 22; - ExtraSearchParam extra_search_param = 23; + int64 group_by_field_id = 23; + int64 group_size = 24; } message SubSearchResults { diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go index 488f0ef01d..708b0b54ae 100644 --- a/internal/proxy/search_reduce_util.go +++ b/internal/proxy/search_reduce_util.go @@ -11,7 +11,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" @@ -20,54 +20,137 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type reduceSearchResultInfo struct { - subSearchResultData []*schemapb.SearchResultData - nq int64 - topK int64 - metricType string - pkType schemapb.DataType - offset int64 - queryInfo *planpb.QueryInfo -} - -func NewReduceSearchResultInfo( - subSearchResultData []*schemapb.SearchResultData, - nq int64, - topK int64, - metricType string, - pkType schemapb.DataType, - offset int64, - queryInfo *planpb.QueryInfo, -) *reduceSearchResultInfo { - return &reduceSearchResultInfo{ - subSearchResultData: subSearchResultData, - nq: nq, - topK: topK, - metricType: metricType, - pkType: pkType, - offset: offset, - queryInfo: queryInfo, - } -} - -func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) { - if reduceInfo.queryInfo.GroupByFieldId > 0 { +func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) { + if reduceInfo.GetGroupByFieldId() > 0 { + if reduceInfo.GetIsAdvance() { + // for hybrid search group by, we cannot reduce result for results from one single search path, + // because the final score has not been accumulated, also, offset cannot be applied + return reduceAdvanceGroupBY(ctx, + subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType()) + } return reduceSearchResultDataWithGroupBy(ctx, - reduceInfo.subSearchResultData, - reduceInfo.nq, - reduceInfo.topK, - reduceInfo.metricType, - reduceInfo.pkType, - reduceInfo.offset, - reduceInfo.queryInfo.GroupSize) + subSearchResultData, + reduceInfo.GetNq(), + reduceInfo.GetTopK(), + reduceInfo.GetMetricType(), + reduceInfo.GetPkType(), + reduceInfo.GetOffset(), + reduceInfo.GetGroupSize()) } return reduceSearchResultDataNoGroupBy(ctx, - reduceInfo.subSearchResultData, - reduceInfo.nq, - reduceInfo.topK, - reduceInfo.metricType, - reduceInfo.pkType, - reduceInfo.offset) + subSearchResultData, + reduceInfo.GetNq(), + reduceInfo.GetTopK(), + reduceInfo.GetMetricType(), + reduceInfo.GetPkType(), + reduceInfo.GetOffset()) +} + +func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, + nq int64, topK int64, +) (int64, int, error) { + var allSearchCount int64 + var hitNum int + for i, sData := range subSearchResultData { + pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) + log.Ctx(ctx).Debug("subSearchResultData", + zap.Int("result No.", i), + zap.Int64("nq", sData.NumQueries), + zap.Int64("topk", sData.TopK), + zap.Int("length of pks", pkLength), + zap.Int("length of FieldsData", len(sData.FieldsData))) + allSearchCount += sData.GetAllSearchCount() + hitNum += pkLength + if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil { + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) + return allSearchCount, hitNum, err + } + } + return allSearchCount, hitNum, nil +} + +func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, + nq int64, topK int64, pkType schemapb.DataType, metricType string, +) (*milvuspb.SearchResults, error) { + log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq)) + // for advance group by, offset is not applied, so just return when there's only one channel + if len(subSearchResultData) == 1 { + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: subSearchResultData[0], + }, nil + } + + ret := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: nq, + TopK: topK, + Scores: []float32{}, + Ids: &schemapb.IDs{}, + Topks: []int64{}, + }, + } + + var limit int64 + if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil { + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) + return ret, err + } else { + ret.GetResults().AllSearchCount = allSearchCount + limit = int64(hitNum) + ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit) + } + + if err := setupIdListForSearchResult(ret, pkType, limit); err != nil { + return ret, nil + } + + var ( + subSearchNum = len(subSearchResultData) + // for results of each subSearchResultData, storing the start offset of each query of nq queries + subSearchNqOffset = make([][]int64, subSearchNum) + ) + for i := 0; i < subSearchNum; i++ { + subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) + for j := int64(1); j < nq; j++ { + subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] + } + } + // reducing nq * topk results + for nqIdx := int64(0); nqIdx < nq; nqIdx++ { + dataCount := int64(0) + for subIdx := 0; subIdx < subSearchNum; subIdx += 1 { + subData := subSearchResultData[subIdx] + subPks := subData.GetIds() + subScores := subData.GetScores() + subGroupByVals := subData.GetGroupByFieldValue() + + nqTopK := subData.Topks[nqIdx] + for i := int64(0); i < nqTopK; i++ { + innerIdx := subSearchNqOffset[subIdx][nqIdx] + i + pk := typeutil.GetPK(subPks, innerIdx) + score := subScores[innerIdx] + groupByVal := typeutil.GetData(subData.GetGroupByFieldValue(), int(innerIdx)) + typeutil.AppendPKs(ret.Results.Ids, pk) + ret.Results.Scores = append(ret.Results.Scores, score) + if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subGroupByVals.GetType()); err != nil { + log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err)) + return ret, err + } + dataCount += 1 + } + } + ret.Results.Topks = append(ret.Results.Topks, dataCount) + } + + ret.Results.TopK = topK // realTopK is the topK of the nq-th query + if !metric.PositivelyRelated(metricType) { + for k := range ret.Results.Scores { + ret.Results.Scores[k] *= -1 + } + } + return ret, nil } type MilvusPKType interface{} @@ -109,37 +192,16 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData Topks: []int64{}, }, } - - switch pkType { - case schemapb.DataType_Int64: - ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: make([]int64, 0, limit), - }, - } - case schemapb.DataType_VarChar: - ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: make([]string, 0, limit), - }, - } - default: - return nil, errors.New("unsupported pk type") + groupBound := groupSize * limit + if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil { + return ret, nil } - for i, sData := range subSearchResultData { - pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) - log.Ctx(ctx).Debug("subSearchResultData", - zap.Int("result No.", i), - zap.Int64("nq", sData.NumQueries), - zap.Int64("topk", sData.TopK), - zap.Int("length of pks", pkLength), - zap.Int("length of FieldsData", len(sData.FieldsData))) - ret.Results.AllSearchCount += sData.GetAllSearchCount() - if err := checkSearchResultData(sData, nq, topk); err != nil { - log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) - return ret, err - } - // printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) + + if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil { + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) + return ret, err + } else { + ret.GetResults().AllSearchCount = allSearchCount } var ( @@ -163,7 +225,6 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - groupBound := groupSize * limit // reducing nq * topk results for i := int64(0); i < nq; i++ { @@ -298,36 +359,15 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData [] }, } - switch pkType { - case schemapb.DataType_Int64: - ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: make([]int64, 0, limit), - }, - } - case schemapb.DataType_VarChar: - ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: make([]string, 0, limit), - }, - } - default: - return nil, errors.New("unsupported pk type") + if err := setupIdListForSearchResult(ret, pkType, limit); err != nil { + return ret, nil } - for i, sData := range subSearchResultData { - pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) - log.Ctx(ctx).Debug("subSearchResultData", - zap.Int("result No.", i), - zap.Int64("nq", sData.NumQueries), - zap.Int64("topk", sData.TopK), - zap.Int("length of pks", pkLength), - zap.Int("length of FieldsData", len(sData.FieldsData))) - ret.Results.AllSearchCount += sData.GetAllSearchCount() - if err := checkSearchResultData(sData, nq, topk); err != nil { - log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) - return ret, err - } - // printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) + + if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil { + log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) + return ret, err + } else { + ret.GetResults().AllSearchCount = allSearchCount } var ( @@ -428,23 +468,215 @@ func rankSearchResultData(ctx context.Context, params *rankParams, pkType schemapb.DataType, searchResults []*milvuspb.SearchResults, + groupByFieldID int64, + groupSize int64, + groupScorer func(group *Group) error, ) (*milvuspb.SearchResults, error) { - tr := timerecord.NewTimeRecorder("rankSearchResultData") + if groupByFieldID > 0 { + return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize) + } + return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults) +} + +func compareKey(keyI interface{}, keyJ interface{}) bool { + switch keyI.(type) { + case int64: + return keyI.(int64) < keyJ.(int64) + case string: + return keyI.(string) < keyJ.(string) + } + return false +} + +func GetGroupScorer(scorerType string) (func(group *Group) error, error) { + switch scorerType { + case MaxScorer: + return func(group *Group) error { + group.finalScore = group.maxScore + return nil + }, nil + case SumScorer: + return func(group *Group) error { + group.finalScore = group.sumScore + return nil + }, nil + case AvgScorer: + return func(group *Group) error { + if len(group.idList) == 0 { + return merr.WrapErrParameterInvalid(1, len(group.idList), + "input group for score must have at least one id, must be sth wrong within code") + } + group.finalScore = group.sumScore / float32(len(group.idList)) + return nil + }, nil + default: + return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType) + } +} + +type Group struct { + idList []interface{} + scoreList []float32 + groupVal interface{} + maxScore float32 + sumScore float32 + finalScore float32 +} + +func rankSearchResultDataByGroup(ctx context.Context, + nq int64, + params *rankParams, + pkType schemapb.DataType, + searchResults []*milvuspb.SearchResults, + groupScorer func(group *Group) error, + groupSize int64, +) (*milvuspb.SearchResults, error) { + tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup") defer func() { tr.CtxElapse(ctx, "done") }() - - offset := params.offset - limit := params.limit - topk := limit + offset - roundDecimal := params.roundDecimal - log.Ctx(ctx).Debug("rankSearchResultData", + offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal + // in the context of group by, the meaning for offset/limit/top refers to related numbers of group + groupTopK := limit + offset + log.Ctx(ctx).Debug("rankSearchResultDataByGroup", zap.Int("len(searchResults)", len(searchResults)), zap.Int64("nq", nq), zap.Int64("offset", offset), zap.Int64("limit", limit)) - ret := &milvuspb.SearchResults{ + var ret *milvuspb.SearchResults + if ret = initSearchResults(nq, limit); len(searchResults) == 0 { + return ret, nil + } + + totalCount := limit * groupSize + if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil { + return ret, err + } + + type accumulateIDGroupVal struct { + accumulatedScore float32 + groupVal interface{} + } + + accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq) + for i := int64(0); i < nq; i++ { + accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal) + } + groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType() + for _, result := range searchResults { + scores := result.GetResults().GetScores() + start := 0 + // milvus has limits for the value range of nq and limit + // no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe + for i := 0; i < int(nq); i++ { + realTopK := int(result.GetResults().Topks[i]) + for j := start; j < start+realTopK; j++ { + id := typeutil.GetPK(result.GetResults().GetIds(), int64(j)) + groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j) + if accumulatedScores[i][id] != nil { + accumulatedScores[i][id].accumulatedScore += scores[j] + } else { + accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal} + } + } + start += realTopK + } + } + + for i := int64(0); i < nq; i++ { + idSet := accumulatedScores[i] + keys := make([]interface{}, 0) + for key := range idSet { + keys = append(keys, key) + } + + // sort id by score + big := func(i, j int) bool { + scoreItemI := idSet[keys[i]] + scoreItemJ := idSet[keys[j]] + if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore { + return compareKey(keys[i], keys[j]) + } + return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore + } + sort.Slice(keys, big) + + // separate keys into buckets according to groupVal + buckets := make(map[interface{}]*Group) + for _, key := range keys { + scoreItem := idSet[key] + groupVal := scoreItem.groupVal + if buckets[groupVal] == nil { + buckets[groupVal] = &Group{ + idList: make([]interface{}, 0), + scoreList: make([]float32, 0), + groupVal: groupVal, + } + } + if int64(len(buckets[groupVal].idList)) >= groupSize { + // only consider group size results in each group + continue + } + buckets[groupVal].idList = append(buckets[groupVal].idList, key) + buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore) + if scoreItem.accumulatedScore > buckets[groupVal].maxScore { + buckets[groupVal].maxScore = scoreItem.accumulatedScore + } + buckets[groupVal].sumScore += scoreItem.accumulatedScore + } + if int64(len(buckets)) <= offset { + ret.Results.Topks = append(ret.Results.Topks, 0) + continue + } + + groupList := make([]*Group, len(buckets)) + idx := 0 + for _, group := range buckets { + groupScorer(group) + groupList[idx] = group + idx += 1 + } + sort.Slice(groupList, func(i, j int) bool { + if groupList[i].finalScore == groupList[j].finalScore { + if len(groupList[i].idList) == len(groupList[j].idList) { + // if final score and size of group are both equal + // choose the group with smaller first key + // here, it's guaranteed all group having at least one id in the idList + return compareKey(groupList[i].idList[0], groupList[j].idList[0]) + } + // choose the larger group when scores are equal + return len(groupList[i].idList) > len(groupList[j].idList) + } + return groupList[i].finalScore > groupList[j].finalScore + }) + + if int64(len(groupList)) > groupTopK { + groupList = groupList[:groupTopK] + } + returnedRowNum := 0 + for index := int(offset); index < len(groupList); index++ { + group := groupList[index] + for i, score := range group.scoreList { + // idList and scoreList must have same length + typeutil.AppendPKs(ret.Results.Ids, group.idList[i]) + if roundDecimal != -1 { + multiplier := math.Pow(10.0, float64(roundDecimal)) + score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier) + } + ret.Results.Scores = append(ret.Results.Scores, score) + typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType) + } + returnedRowNum += len(group.idList) + } + ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum)) + } + + return ret, nil +} + +func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults { + return &milvuspb.SearchResults{ Status: merr.Success(), Results: &schemapb.SearchResultData{ NumQueries: nq, @@ -455,22 +687,54 @@ func rankSearchResultData(ctx context.Context, Topks: []int64{}, }, } +} +func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error { switch pkType { case schemapb.DataType_Int64: - ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ + searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ - Data: make([]int64, 0), + Data: make([]int64, 0, capacity), }, } case schemapb.DataType_VarChar: - ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ + searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{ StrId: &schemapb.StringArray{ - Data: make([]string, 0), + Data: make([]string, 0, capacity), }, } default: - return nil, errors.New("unsupported pk type") + return errors.New("unsupported pk type") + } + return nil +} + +func rankSearchResultDataByPk(ctx context.Context, + nq int64, + params *rankParams, + pkType schemapb.DataType, + searchResults []*milvuspb.SearchResults, +) (*milvuspb.SearchResults, error) { + tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk") + defer func() { + tr.CtxElapse(ctx, "done") + }() + + offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal + topk := limit + offset + log.Ctx(ctx).Debug("rankSearchResultDataByPk", + zap.Int("len(searchResults)", len(searchResults)), + zap.Int64("nq", nq), + zap.Int64("offset", offset), + zap.Int64("limit", limit)) + + var ret *milvuspb.SearchResults + if ret = initSearchResults(nq, limit); len(searchResults) == 0 { + return ret, nil + } + + if err := setupIdListForSearchResult(ret, pkType, limit); err != nil { + return ret, nil } // []map[id]score @@ -503,20 +767,10 @@ func rankSearchResultData(ctx context.Context, continue } - compareKeys := func(keyI, keyJ interface{}) bool { - switch keyI.(type) { - case int64: - return keyI.(int64) < keyJ.(int64) - case string: - return keyI.(string) < keyJ.(string) - } - return false - } - // sort id by score big := func(i, j int) bool { if idSet[keys[i]] == idSet[keys[j]] { - return compareKeys(keys[i], keys[j]) + return compareKey(keys[i], keys[j]) } return idSet[keys[i]] > idSet[keys[j]] } diff --git a/internal/proxy/search_reduce_util_test.go b/internal/proxy/search_reduce_util_test.go new file mode 100644 index 0000000000..423ac8ca66 --- /dev/null +++ b/internal/proxy/search_reduce_util_test.go @@ -0,0 +1,133 @@ +package proxy + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type SearchReduceUtilTestSuite struct { + suite.Suite +} + +func (struts *SearchReduceUtilTestSuite) TestRankByGroup() { + var searchResultData1 *schemapb.SearchResultData + var searchResultData2 *schemapb.SearchResultData + + { + groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"} + searchResultData1 = &schemapb.SearchResultData{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"7", "5", "4", "2", "3", "6", "1", "9", "8"}, + }, + }, + }, + Topks: []int64{9}, + Scores: []float32{0.6, 0.53, 0.52, 0.43, 0.41, 0.33, 0.30, 0.27, 0.22}, + GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1), + } + } + + { + groupFieldValue := []string{"www", "aaa", "ccc", "www", "www", "ccc", "aaa", "ccc", "aaa"} + searchResultData2 = &schemapb.SearchResultData{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"17", "15", "14", "12", "13", "16", "11", "19", "18"}, + }, + }, + }, + Topks: []int64{9}, + Scores: []float32{0.7, 0.43, 0.32, 0.32, 0.31, 0.31, 0.30, 0.30, 0.30}, + GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1), + } + } + + searchResults := []*milvuspb.SearchResults{ + {Results: searchResultData1}, + {Results: searchResultData2}, + } + + nq := int64(1) + limit := int64(3) + offset := int64(0) + roundDecimal := int64(1) + groupSize := int64(3) + groupByFieldId := int64(101) + rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal} + + { + // test for sum group scorer + scorerType := "sum" + groupScorer, _ := GetGroupScorer(scorerType) + rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) + struts.NoError(err) + struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data) + struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores()) + struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) + } + + { + // test for max group scorer + scorerType := "max" + groupScorer, _ := GetGroupScorer(scorerType) + rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) + struts.NoError(err) + struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data) + struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores()) + struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) + } + + { + // test for avg group scorer + scorerType := "avg" + groupScorer, _ := GetGroupScorer(scorerType) + rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) + struts.NoError(err) + struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data) + struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores()) + struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) + } + + { + // test for offset for ranking group + scorerType := "avg" + groupScorer, _ := GetGroupScorer(scorerType) + rankParams.offset = 2 + rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) + struts.NoError(err) + struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data) + struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores()) + struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) + } + + { + // test for offset exceeding the count of final groups + scorerType := "avg" + groupScorer, _ := GetGroupScorer(scorerType) + rankParams.offset = 4 + rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) + struts.NoError(err) + struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data) + struts.Equal([]float32{}, rankedRes.GetResults().GetScores()) + } + + { + // test for invalid group scorer + scorerType := "xxx" + groupScorer, err := GetGroupScorer(scorerType) + struts.Error(err) + struts.Nil(groupScorer) + } +} + +func TestSearchReduceUtilTestSuite(t *testing.T) { + suite.Run(t, new(SearchReduceUtilTestSuite)) +} diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index be9e3fc4b2..10e8ea5530 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -310,13 +310,15 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro } func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest { + searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams())) + copy(searchParams, req.GetRankParams()) ret := &milvuspb.SearchRequest{ Base: req.GetBase(), DbName: req.GetDbName(), CollectionName: req.GetCollectionName(), PartitionNames: req.GetPartitionNames(), OutputFields: req.GetOutputFields(), - SearchParams: req.GetRankParams(), + SearchParams: searchParams, TravelTimestamp: req.GetTravelTimestamp(), GuaranteeTimestamp: req.GetGuaranteeTimestamp(), Nq: 0, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 2b4fad685a..92828399d5 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -42,6 +42,12 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +const ( + SumScorer string = "sum" + MaxScorer string = "max" + AvgScorer string = "avg" +) + const ( IgnoreGrowingKey = "ignore_growing" ReduceStopForBestKey = "reduce_stop_for_best" @@ -49,6 +55,7 @@ const ( GroupByFieldKey = "group_by_field" GroupSizeKey = "group_size" GroupStrictSize = "group_strict_size" + RankGroupScorer = "rank_group_scorer" AnnsFieldKey = "anns_field" TopKKey = "topk" NQKey = "nq" diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index aff3d06690..b8aa6be104 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/commonpbutil" @@ -76,8 +77,9 @@ type searchTask struct { queryInfos []*planpb.QueryInfo relatedDataSize int64 - reScorers []reScorer - rankParams *rankParams + reScorers []reScorer + rankParams *rankParams + groupScorer func(group *Group) error } func (t *searchTask) CanSkipAllocTimestamp() bool { @@ -339,10 +341,9 @@ func setQueryInfoIfMvEnable(queryInfo *planpb.QueryInfo, t *searchTask, plan *pl func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request") defer sp.End() - t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]() - log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) + // fetch search_growing from search param t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs())) t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs())) @@ -351,9 +352,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { if err != nil { return err } - if queryInfo.GetGroupByFieldId() != -1 { - return errors.New("not support search_group_by operation in the hybrid search") - } + internalSubReq := &internalpb.SubSearchRequest{ Dsl: subReq.GetDsl(), PlaceholderGroup: subReq.GetPlaceholderGroup(), @@ -364,6 +363,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { Topk: queryInfo.GetTopk(), Offset: offset, MetricType: queryInfo.GetMetricType(), + GroupByFieldId: queryInfo.GetGroupByFieldId(), + GroupSize: queryInfo.GetGroupSize(), } // set PartitionIDs for sub search @@ -403,6 +404,11 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. } + if len(t.queryInfos) > 0 { + t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId() + t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize() + } + // used for requery if t.partitionKeyMode { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() @@ -413,6 +419,18 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err)) return err } + + // set up groupScorer for hybridsearch+groupBy + groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams()) + if err != nil { + groupScorerStr = MaxScorer + } + groupScorer, err := GetGroupScorer(groupScorerStr) + if err != nil { + return err + } + t.groupScorer = groupScorer + return nil } @@ -461,7 +479,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { t.SearchRequest.MetricType = queryInfo.GetMetricType() t.queryInfos = append(t.queryInfos, queryInfo) t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 - t.SearchRequest.ExtraSearchParam = &internalpb.ExtraSearchParam{GroupByFieldId: queryInfo.GroupByFieldId, GroupSize: queryInfo.GroupSize} + t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId + t.SearchRequest.GroupSize = queryInfo.GroupSize log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. @@ -554,7 +573,7 @@ func (t *searchTask) Execute(ctx context.Context) error { return nil } -func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo) (*milvuspb.SearchResults, error) { +func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) { metricType := "" if len(toReduceResults) >= 1 { metricType = toReduceResults[0].GetMetricType() @@ -585,8 +604,8 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter return nil, err } var result *milvuspb.SearchResults - result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, nq, topK, - metricType, primaryFieldSchema.DataType, offset, queryInfo)) + result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()). + WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance)) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) return nil, err @@ -647,7 +666,6 @@ func (t *searchTask) PostExecute(ctx context.Context) error { multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults) } } - multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs())) for index, internalResults := range multipleInternalResults { subReq := t.SearchRequest.GetSubReqs()[index] @@ -656,7 +674,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { if len(internalResults) >= 1 { metricType = internalResults[0].GetMetricType() } - result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index]) + result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true) if err != nil { return err } @@ -667,13 +685,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error { t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(), t.rankParams, primaryFieldSchema.GetDataType(), - multipleMilvusResults) + multipleMilvusResults, + t.SearchRequest.GetGroupByFieldId(), + t.SearchRequest.GetGroupSize(), + t.groupScorer) if err != nil { log.Warn("rank search result failed", zap.Error(err)) return err } } else { - t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0]) + t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false) if err != nil { return err } @@ -914,7 +935,7 @@ func decodeSearchResults(ctx context.Context, searchResults []*internalpb.Search return results, nil } -func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error { +func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error { if data.NumQueries != nq { return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq) } @@ -922,7 +943,6 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64 return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk) } - pkHitNum := typeutil.GetSizeOfIDs(data.GetIds()) if len(data.Scores) != pkHitNum { return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d", len(data.Scores), pkHitNum) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 4d53945526..b170b06cdd 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -1247,7 +1248,8 @@ func Test_checkSearchResultData(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk) + pkLength := typeutil.GetSizeOfIDs(test.args.data.GetIds()) + err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk, pkLength) if test.wantErr { assert.Error(t, err) @@ -1522,8 +1524,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { } for _, test := range tests { t.Run(test.description, func(t *testing.T) { - reduced, err := reduceSearchResult(context.TODO(), - NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo)) + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64). + WithOffset(test.offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) assert.NoError(t, err) assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData()) assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks()) @@ -1574,8 +1577,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { } for _, test := range lessThanLimitTests { t.Run(test.description, func(t *testing.T) { - reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk, - metric.L2, schemapb.DataType_Int64, test.offset, queryInfo)) + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(test.offset). + WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) assert.NoError(t, err) assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData()) assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks()) @@ -1603,9 +1607,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { GroupByFieldId: -1, } - reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo( - results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo)) - + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) assert.NoError(t, err) assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData()) assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks()) @@ -1633,9 +1636,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { queryInfo := &planpb.QueryInfo{ GroupByFieldId: -1, } - - reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, - nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo)) + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_VarChar).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) assert.NoError(t, err) assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData()) @@ -1708,8 +1710,8 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) { GroupByFieldId: 1, GroupSize: 1, } - reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2, - schemapb.DataType_Int64, 0, queryInfo)) + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) resultIDs := reduced.GetResults().GetIds().GetIntId().Data resultScores := reduced.GetResults().GetScores() resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData() @@ -1768,8 +1770,8 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) { GroupByFieldId: 1, GroupSize: 1, } - reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2, - schemapb.DataType_Int64, offset, queryInfo)) + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, limit+offset).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) resultIDs := reduced.GetResults().GetIds().GetIntId().Data resultScores := reduced.GetResults().GetScores() resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData() @@ -1842,8 +1844,9 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) { GroupByFieldId: 1, GroupSize: 2, } - reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2, - schemapb.DataType_Int64, 0, queryInfo)) + reduced, err := reduceSearchResult(context.TODO(), results, + reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize())) + resultIDs := reduced.GetResults().GetIds().GetIntId().Data resultScores := reduced.GetResults().GetScores() resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData() @@ -1855,6 +1858,188 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) { } } +func TestTaskSearch_reduceAdvanceSearchGroupBy(t *testing.T) { + groupByField := int64(101) + nq := int64(1) + subSearchResultData := make([]*schemapb.SearchResultData, 0) + topK := int64(3) + { + scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43} + ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37} + tops := []int64{9} + groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"} + groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1) + result1 := &schemapb.SearchResultData{ + Scores: scores, + TopK: topK, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + NumQueries: nq, + Topks: tops, + GroupByFieldValue: groupByVals, + } + subSearchResultData = append(subSearchResultData, result1) + } + { + scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48} + ids := []int64{17, 15, 16, 21, 32, 24, 41, 33, 27} + tops := []int64{9} + groupFieldValue := []string{"xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"} + groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1) + result2 := &schemapb.SearchResultData{ + TopK: topK, + Scores: scores, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + Topks: tops, + NumQueries: nq, + GroupByFieldValue: groupByVals, + } + subSearchResultData = append(subSearchResultData, result2) + } + groupSize := int64(3) + + reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData, + reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true)) + assert.NoError(t, err) + // reduce_advance_groupby will only merge results from different delegator without reducing any result + assert.Equal(t, 18, len(reducedRes.GetResults().Ids.GetIntId().Data)) + assert.Equal(t, 18, len(reducedRes.GetResults().GetScores())) + assert.Equal(t, 18, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)) + assert.Equal(t, topK, reducedRes.GetResults().GetTopK()) + assert.Equal(t, []int64{18}, reducedRes.GetResults().GetTopks()) + + assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37, 17, 15, 16, 21, 32, 24, 41, 33, 27}, reducedRes.GetResults().Ids.GetIntId().Data) + assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43, 0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}, reducedRes.GetResults().GetScores()) + assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa", "xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) +} + +func TestTaskSearch_reduceAdvanceSearchGroupByShortCut(t *testing.T) { + groupByField := int64(101) + nq := int64(1) + subSearchResultData := make([]*schemapb.SearchResultData, 0) + topK := int64(3) + { + scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43} + ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37} + tops := []int64{9} + groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"} + groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1) + result1 := &schemapb.SearchResultData{ + Scores: scores, + TopK: topK, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + NumQueries: nq, + Topks: tops, + GroupByFieldValue: groupByVals, + } + subSearchResultData = append(subSearchResultData, result1) + } + groupSize := int64(3) + + reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData, + reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true)) + + assert.NoError(t, err) + // reduce_advance_groupby will only merge results from different delegator without reducing any result + assert.Equal(t, 9, len(reducedRes.GetResults().Ids.GetIntId().Data)) + assert.Equal(t, 9, len(reducedRes.GetResults().GetScores())) + assert.Equal(t, 9, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)) + assert.Equal(t, topK, reducedRes.GetResults().GetTopK()) + assert.Equal(t, []int64{9}, reducedRes.GetResults().GetTopks()) + + assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}, reducedRes.GetResults().Ids.GetIntId().Data) + assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}, reducedRes.GetResults().GetScores()) + assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) +} + +func TestTaskSearch_reduceAdvanceSearchGroupByMultipleNq(t *testing.T) { + groupByField := int64(101) + nq := int64(2) + subSearchResultData := make([]*schemapb.SearchResultData, 0) + topK := int64(2) + groupSize := int64(2) + { + scores := []float32{0.9, 0.7, 0.65, 0.55, 0.51, 0.5, 0.45, 0.43} + ids := []int64{7, 5, 6, 11, 14, 31, 23, 37} + tops := []int64{4, 4} + groupFieldValue := []string{"ccc", "bbb", "ccc", "bbb", "aaa", "xxx", "xxx", "aaa"} + groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1) + result1 := &schemapb.SearchResultData{ + Scores: scores, + TopK: topK, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + NumQueries: nq, + Topks: tops, + GroupByFieldValue: groupByVals, + } + subSearchResultData = append(subSearchResultData, result1) + } + { + scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51} + ids := []int64{17, 15, 16, 21, 32, 24, 41, 33} + tops := []int64{4, 4} + groupFieldValue := []string{"ddd", "bbb", "ddd", "bbb", "rrr", "sss", "rrr", "sss"} + groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1) + result2 := &schemapb.SearchResultData{ + TopK: topK, + Scores: scores, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + Topks: tops, + NumQueries: nq, + GroupByFieldValue: groupByVals, + } + subSearchResultData = append(subSearchResultData, result2) + } + + reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData, + reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true)) + assert.NoError(t, err) + // reduce_advance_groupby will only merge results from different delegator without reducing any result + assert.Equal(t, 16, len(reducedRes.GetResults().Ids.GetIntId().Data)) + assert.Equal(t, 16, len(reducedRes.GetResults().GetScores())) + assert.Equal(t, 16, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)) + + assert.Equal(t, topK, reducedRes.GetResults().GetTopK()) + assert.Equal(t, []int64{8, 8}, reducedRes.GetResults().GetTopks()) + + assert.Equal(t, []int64{7, 5, 6, 11, 17, 15, 16, 21, 14, 31, 23, 37, 32, 24, 41, 33}, reducedRes.GetResults().Ids.GetIntId().Data) + assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.83, 0.72, 0.72, 0.65, 0.51, 0.5, 0.45, 0.43, 0.63, 0.55, 0.52, 0.51}, reducedRes.GetResults().GetScores()) + assert.Equal(t, []string{"ccc", "bbb", "ccc", "bbb", "ddd", "bbb", "ddd", "bbb", "aaa", "xxx", "xxx", "aaa", "rrr", "sss", "rrr", "sss"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) + + fmt.Println(reducedRes.GetResults().Ids.GetIntId().Data) + fmt.Println(reducedRes.GetResults().GetScores()) + fmt.Println(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) +} + func TestSearchTask_ErrExecute(t *testing.T) { var ( err error diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index e9bbe9877f..b04003ac12 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -332,6 +333,8 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest IgnoreGrowing: req.GetReq().GetIgnoreGrowing(), Username: req.GetReq().GetUsername(), IsAdvanced: false, + GroupByFieldId: subReq.GetGroupByFieldId(), + GroupSize: subReq.GetGroupSize(), } future := conc.Go(func() (*internalpb.SearchResults, error) { searchReq := &querypb.SearchRequest{ @@ -350,14 +353,12 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return nil, err } - return segments.ReduceSearchResults(ctx, + return segments.ReduceSearchOnQueryNode(ctx, results, - segments.NewReduceInfo(searchReq.Req.GetNq(), - searchReq.Req.GetTopk(), - searchReq.Req.GetExtraSearchParam().GetGroupByFieldId(), - searchReq.Req.GetExtraSearchParam().GetGroupSize(), - searchReq.Req.GetMetricType()), - ) + reduce.NewReduceSearchResultInfo(searchReq.GetReq().GetNq(), + searchReq.GetReq().GetTopk()).WithMetricType(searchReq.GetReq().GetMetricType()). + WithGroupByField(searchReq.GetReq().GetGroupByFieldId()). + WithGroupSize(searchReq.GetReq().GetGroupSize())) }) futures[index] = future } @@ -376,12 +377,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest } results[i] = result } - var ret *internalpb.SearchResults - ret, err = segments.MergeToAdvancedResults(ctx, results) - if err != nil { - return nil, err - } - return []*internalpb.SearchResults{ret}, nil + return results, nil } return sd.search(ctx, req, sealed, growing) } diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index bf71bf74d6..73db8509a6 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -384,16 +385,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq req.GetSegmentIDs(), )) - var resp *internalpb.SearchResults - if req.GetReq().GetIsAdvanced() { - resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq()) - } else { - resp, err = segments.ReduceSearchResults(ctx, results, segments.NewReduceInfo(req.Req.GetNq(), - req.Req.GetTopk(), - req.Req.GetExtraSearchParam().GetGroupByFieldId(), - req.Req.GetExtraSearchParam().GetGroupSize(), - req.Req.GetMetricType())) - } + resp, err := segments.ReduceSearchOnQueryNode(ctx, results, + reduce.NewReduceSearchResultInfo(req.GetReq().GetNq(), + req.GetReq().GetTopk()).WithMetricType(req.GetReq().GetMetricType()).WithGroupByField(req.GetReq().GetGroupByFieldId()). + WithGroupSize(req.GetReq().GetGroupByFieldId()).WithAdvance(req.GetReq().GetIsAdvanced())) if err != nil { return nil, err } diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index fc43edb5c5..6549835f8b 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/reduce" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -42,7 +43,14 @@ var _ typeutil.ResultWithID = &internalpb.RetrieveResults{} var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{} -func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *ReduceInfo) (*internalpb.SearchResults, error) { +func ReduceSearchOnQueryNode(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) { + if info.GetIsAdvance() { + return ReduceAdvancedSearchResults(ctx, results) + } + return ReduceSearchResults(ctx, results, info) +} + +func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) { results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool { return result != nil && result.GetSlicedBlob() != nil }) @@ -60,8 +68,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult channelsMvcc[ch] = ts } // shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty - if info.metricType == "" { - info.metricType = r.MetricType + if info.GetMetricType() == "" { + info.SetMetricType(r.MetricType) } } log := log.Ctx(ctx) @@ -86,7 +94,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult log.Warn("shard leader reduce errors", zap.Error(err)) return nil, err } - searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.nq, info.topK, info.metricType) + searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.GetNq(), info.GetTopK(), info.GetMetricType()) if err != nil { log.Warn("shard leader encode search result errors", zap.Error(err)) return nil, err @@ -115,7 +123,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult return searchResults, nil } -func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64) (*internalpb.SearchResults, error) { +func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) { _, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults") defer sp.End() @@ -129,53 +137,14 @@ func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.Sear IsAdvanced: true, } - for _, result := range results { - relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize() - for ch, ts := range result.GetChannelsMvcc() { - channelsMvcc[ch] = ts - } - if !result.GetIsAdvanced() { - continue - } - // we just append here, no need to split subResult and reduce - // defer this reduce to proxy - searchResults.SubResults = append(searchResults.SubResults, result.GetSubResults()...) - searchResults.NumQueries = result.GetNumQueries() - } - searchResults.ChannelsMvcc = channelsMvcc - requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) { - if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() { - return result.GetCostAggregation(), true - } - - if result.GetBase().GetSourceID() == paramtable.GetNodeID() { - return result.GetCostAggregation(), true - } - - return nil, false - }) - searchResults.CostAggregation = mergeRequestCost(requestCosts) - if searchResults.CostAggregation == nil { - searchResults.CostAggregation = &internalpb.CostAggregation{} - } - searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize - return searchResults, nil -} - -func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) { - searchResults := &internalpb.SearchResults{ - IsAdvanced: true, - } - - channelsMvcc := make(map[string]uint64) - relatedDataSize := int64(0) for index, result := range results { relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize() for ch, ts := range result.GetChannelsMvcc() { channelsMvcc[ch] = ts } + searchResults.NumQueries = result.GetNumQueries() // we just append here, no need to split subResult and reduce - // defer this reduce to proxy + // defer this reduction to proxy subResult := &internalpb.SubSearchResults{ MetricType: result.GetMetricType(), NumQueries: result.GetNumQueries(), @@ -185,7 +154,6 @@ func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchRes SlicedOffset: result.GetSlicedOffset(), ReqIndex: int64(index), } - searchResults.NumQueries = result.GetNumQueries() searchResults.SubResults = append(searchResults.SubResults, subResult) } searchResults.ChannelsMvcc = channelsMvcc diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 4f4d4fead7..2d9a2fa993 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -29,7 +29,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -886,6 +888,42 @@ func (suite *ResultSuite) TestSort() { }, result.FieldsData[9].GetScalars().GetArrayData().GetData()) } +func (suite *ResultSuite) TestReduceSearchOnQueryNode() { + results := make([]*internalpb.SearchResults, 0) + metricType := metric.IP + nq := int64(1) + topK := int64(1) + mockBlob := []byte{65, 66, 67, 65, 66, 67} + { + subRes1 := &internalpb.SearchResults{ + MetricType: metricType, + NumQueries: nq, + TopK: topK, + SlicedBlob: mockBlob, + } + results = append(results, subRes1) + } + { + subRes2 := &internalpb.SearchResults{ + MetricType: metricType, + NumQueries: nq, + TopK: topK, + SlicedBlob: mockBlob, + } + results = append(results, subRes2) + } + reducedRes, err := ReduceSearchOnQueryNode(context.Background(), results, reduce.NewReduceSearchResultInfo(nq, topK). + WithMetricType(metricType).WithPkType(schemapb.DataType_Int8).WithAdvance(true)) + suite.NoError(err) + suite.Equal(2, len(reducedRes.GetSubResults())) + + subRes1 := reducedRes.GetSubResults()[0] + suite.Equal(metricType, subRes1.GetMetricType()) + suite.Equal(nq, subRes1.GetNumQueries()) + suite.Equal(topK, subRes1.GetTopK()) + suite.Equal(mockBlob, subRes1.GetSlicedBlob()) +} + func TestResult_MergeRequestCost(t *testing.T) { costs := []*internalpb.CostAggregation{ { diff --git a/internal/querynodev2/segments/search_reduce.go b/internal/querynodev2/segments/search_reduce.go index 14dff5fc7c..a8997115ab 100644 --- a/internal/querynodev2/segments/search_reduce.go +++ b/internal/querynodev2/segments/search_reduce.go @@ -8,39 +8,28 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type ReduceInfo struct { - nq int64 - topK int64 - groupByFieldID int64 - groupSize int64 - metricType string -} - -func NewReduceInfo(nq int64, topK int64, groupByFieldID int64, groupSize int64, metric string) *ReduceInfo { - return &ReduceInfo{nq, topK, groupByFieldID, groupSize, metric} -} - type SearchReduce interface { - ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) + ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) } type SearchCommonReduce struct{} -func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) { +func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) { ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData") defer sp.End() log := log.Ctx(ctx) if len(searchResultData) == 0 { return &schemapb.SearchResultData{ - NumQueries: info.nq, - TopK: info.topK, + NumQueries: info.GetNq(), + TopK: info.GetTopK(), FieldsData: make([]*schemapb.FieldData, 0), Scores: make([]float32, 0), Ids: &schemapb.IDs{}, @@ -48,8 +37,8 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc }, nil } ret := &schemapb.SearchResultData{ - NumQueries: info.nq, - TopK: info.topK, + NumQueries: info.GetNq(), + TopK: info.GetTopK(), FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)), Scores: make([]float32, 0), Ids: &schemapb.IDs{}, @@ -59,7 +48,7 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc resultOffsets := make([][]int64, len(searchResultData)) for i := 0; i < len(searchResultData); i++ { resultOffsets[i] = make([]int64, len(searchResultData[i].Topks)) - for j := int64(1); j < info.nq; j++ { + for j := int64(1); j < info.GetNq(); j++ { resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1] } ret.AllSearchCount += searchResultData[i].GetAllSearchCount() @@ -68,11 +57,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc var skipDupCnt int64 var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - for i := int64(0); i < info.nq; i++ { + for i := int64(0); i < info.GetNq(); i++ { offsets := make([]int64, len(searchResultData)) idSet := make(map[interface{}]struct{}) var j int64 - for j = 0; j < info.topK; { + for j = 0; j < info.GetTopK(); { sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i) if sel == -1 { break @@ -113,15 +102,15 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc type SearchGroupByReduce struct{} -func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) { +func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) { ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData") defer sp.End() log := log.Ctx(ctx) if len(searchResultData) == 0 { return &schemapb.SearchResultData{ - NumQueries: info.nq, - TopK: info.topK, + NumQueries: info.GetNq(), + TopK: info.GetTopK(), FieldsData: make([]*schemapb.FieldData, 0), Scores: make([]float32, 0), Ids: &schemapb.IDs{}, @@ -129,8 +118,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear }, nil } ret := &schemapb.SearchResultData{ - NumQueries: info.nq, - TopK: info.topK, + NumQueries: info.GetNq(), + TopK: info.GetTopK(), FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)), Scores: make([]float32, 0), Ids: &schemapb.IDs{}, @@ -140,7 +129,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear resultOffsets := make([][]int64, len(searchResultData)) for i := 0; i < len(searchResultData); i++ { resultOffsets[i] = make([]int64, len(searchResultData[i].Topks)) - for j := int64(1); j < info.nq; j++ { + for j := int64(1); j < info.GetNq(); j++ { resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1] } ret.AllSearchCount += searchResultData[i].GetAllSearchCount() @@ -149,13 +138,13 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear var filteredCount int64 var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - groupSize := info.groupSize + groupSize := info.GetGroupSize() if groupSize <= 0 { groupSize = 1 } - groupBound := info.topK * groupSize + groupBound := info.GetTopK() * groupSize - for i := int64(0); i < info.nq; i++ { + for i := int64(0); i < info.GetNq(); i++ { offsets := make([]int64, len(searchResultData)) idSet := make(map[interface{}]struct{}) @@ -178,7 +167,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear } groupCount := groupByValueMap[groupByVal] - if groupCount == 0 && int64(len(groupByValueMap)) >= info.topK { + if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() { // exceed the limit for group count, filter this entity filteredCount++ } else if groupCount >= groupSize { @@ -219,8 +208,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear return ret, nil } -func InitSearchReducer(info *ReduceInfo) SearchReduce { - if info.groupByFieldID > 0 { +func InitSearchReducer(info *reduce.ResultInfo) SearchReduce { + if info.GetGroupByFieldId() > 0 { return &SearchGroupByReduce{} } return &SearchCommonReduce{} diff --git a/internal/querynodev2/segments/search_reduce_test.go b/internal/querynodev2/segments/search_reduce_test.go index 05eb4a058f..22ff091a0e 100644 --- a/internal/querynodev2/segments/search_reduce_test.go +++ b/internal/querynodev2/segments/search_reduce_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -28,7 +29,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - reduceInfo := &ReduceInfo{nq: nq, topK: topk} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) @@ -47,7 +48,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - reduceInfo := &ReduceInfo{nq: nq, topK: topk} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) @@ -96,7 +97,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) @@ -140,7 +141,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) @@ -184,7 +185,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) @@ -228,7 +229,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) - reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) @@ -239,7 +240,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { suite.Run("reduce_group_by_empty_input", func() { dataArray := make([]*schemapb.SearchResultData, 0) - reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3} + reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101) searchReduce := InitSearchReducer(reduceInfo) res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo) suite.Nil(err) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 464b49da5e..3a87f575d2 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -753,69 +753,41 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( return resp, nil } - toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels())) - runningGp, runningCtx := errgroup.WithContext(ctx) - - for i, ch := range req.GetDmlChannels() { - ch := ch - req := &querypb.SearchRequest{ - Req: req.Req, - DmlChannels: []string{ch}, - SegmentIDs: req.SegmentIDs, - Scope: req.Scope, - TotalChannelNum: req.TotalChannelNum, - } - - i := i - runningGp.Go(func() error { - ret, err := node.searchChannel(runningCtx, req, ch) - if err != nil { - return err - } - if err := merr.Error(ret.GetStatus()); err != nil { - return err - } - toReduceResults[i] = ret - return nil - }) + if len(req.GetDmlChannels()) != 1 { + err := merr.WrapErrParameterInvalid(1, len(req.GetDmlChannels()), "count of channel to be searched should only be 1, wrong code") + resp.Status = merr.Status(err) + log.Warn("got wrong number of channels to be searched", zap.Error(err)) + return resp, nil } - if err := runningGp.Wait(); err != nil { + + ch := req.GetDmlChannels()[0] + channelReq := &querypb.SearchRequest{ + Req: req.Req, + DmlChannels: []string{ch}, + SegmentIDs: req.SegmentIDs, + Scope: req.Scope, + TotalChannelNum: req.TotalChannelNum, + } + ret, err := node.searchChannel(ctx, channelReq, ch) + if err != nil { resp.Status = merr.Status(err) return resp, nil } tr.RecordSpan() - var result *internalpb.SearchResults - var err2 error - if req.GetReq().GetIsAdvanced() { - result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq()) - } else { - result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, segments.NewReduceInfo(req.Req.GetNq(), - req.Req.GetTopk(), - req.Req.GetExtraSearchParam().GetGroupByFieldId(), - req.Req.GetExtraSearchParam().GetGroupSize(), - req.Req.GetMetricType())) - } - - if err2 != nil { - log.Warn("failed to reduce search results", zap.Error(err2)) - resp.Status = merr.Status(err2) - return resp, nil - } - result.Status = merr.Success() + ret.Status = merr.Success() reduceLatency := tr.RecordSpan() metrics.QueryNodeReduceLatency. WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce). Observe(float64(reduceLatency.Milliseconds())) - metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel). Add(float64(proto.Size(req))) - if result.GetCostAggregation() != nil { - result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds() + if ret.GetCostAggregation() != nil { + ret.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds() } - return result, nil + return ret, nil } // only used for delegator query segments from worker diff --git a/internal/util/reduce/reduce_info.go b/internal/util/reduce/reduce_info.go new file mode 100644 index 0000000000..91de9f2df2 --- /dev/null +++ b/internal/util/reduce/reduce_info.go @@ -0,0 +1,92 @@ +package reduce + +import ( + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type ResultInfo struct { + nq int64 + topK int64 + metricType string + pkType schemapb.DataType + offset int64 + groupByFieldId int64 + groupSize int64 + isAdvance bool +} + +func NewReduceSearchResultInfo( + nq int64, + topK int64, +) *ResultInfo { + return &ResultInfo{ + nq: nq, + topK: topK, + } +} + +func (r *ResultInfo) WithMetricType(metricType string) *ResultInfo { + r.metricType = metricType + return r +} + +func (r *ResultInfo) WithPkType(pkType schemapb.DataType) *ResultInfo { + r.pkType = pkType + return r +} + +func (r *ResultInfo) WithOffset(offset int64) *ResultInfo { + r.offset = offset + return r +} + +func (r *ResultInfo) WithGroupByField(groupByField int64) *ResultInfo { + r.groupByFieldId = groupByField + return r +} + +func (r *ResultInfo) WithGroupSize(groupSize int64) *ResultInfo { + r.groupSize = groupSize + return r +} + +func (r *ResultInfo) WithAdvance(advance bool) *ResultInfo { + r.isAdvance = advance + return r +} + +func (r *ResultInfo) GetNq() int64 { + return r.nq +} + +func (r *ResultInfo) GetTopK() int64 { + return r.topK +} + +func (r *ResultInfo) GetMetricType() string { + return r.metricType +} + +func (r *ResultInfo) GetPkType() schemapb.DataType { + return r.pkType +} + +func (r *ResultInfo) GetOffset() int64 { + return r.offset +} + +func (r *ResultInfo) GetGroupByFieldId() int64 { + return r.groupByFieldId +} + +func (r *ResultInfo) GetGroupSize() int64 { + return r.groupSize +} + +func (r *ResultInfo) GetIsAdvance() bool { + return r.isAdvance +} + +func (r *ResultInfo) SetMetricType(metricType string) { + r.metricType = metricType +}