package proxy import ( "context" "fmt" "github.com/cockroachdb/errors" "go.opentelemetry.io/otel" "go.uber.org/zap" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metric" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/timerecord" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) { if reduceInfo.GetGroupByFieldId() > 0 { if reduceInfo.GetIsAdvance() { // for hybrid search group by, we cannot reduce result for results from one single search path, // because the final score has not been accumulated, also, offset cannot be applied return reduceAdvanceGroupBy(ctx, subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType()) } return reduceSearchResultDataWithGroupBy(ctx, subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetMetricType(), reduceInfo.GetPkType(), reduceInfo.GetOffset(), reduceInfo.GetGroupSize()) } return reduceSearchResultDataNoGroupBy(ctx, subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetMetricType(), reduceInfo.GetPkType(), reduceInfo.GetOffset()) } func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topK int64, ) (int64, int, error) { var allSearchCount int64 var hitNum int for i, sData := range subSearchResultData { pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) log.Ctx(ctx).Debug("subSearchResultData", zap.Int("result No.", i), zap.Int64("nq", sData.NumQueries), zap.Int64("topk", sData.TopK), zap.Int("length of pks", pkLength), zap.Int("length of FieldsData", len(sData.FieldsData))) allSearchCount += sData.GetAllSearchCount() hitNum += pkLength if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil { log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) return allSearchCount, hitNum, err } } return allSearchCount, hitNum, nil } func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topK int64, pkType schemapb.DataType, metricType string, ) (*milvuspb.SearchResults, error) { log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq)) // for advance group by, offset is not applied, so just return when there's only one channel if len(subSearchResultData) == 1 { return &milvuspb.SearchResults{ Status: merr.Success(), Results: subSearchResultData[0], }, nil } ret := &milvuspb.SearchResults{ Status: merr.Success(), Results: &schemapb.SearchResultData{ NumQueries: nq, TopK: topK, Scores: []float32{}, Ids: &schemapb.IDs{}, Topks: []int64{}, }, } var limit int64 if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil { log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) return ret, err } else { ret.GetResults().AllSearchCount = allSearchCount limit = int64(hitNum) // Find the first non-empty FieldsData as template for _, result := range subSearchResultData { if len(result.GetFieldsData()) > 0 { ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit) break } } } if err := setupIdListForSearchResult(ret, pkType, limit); err != nil { return ret, nil } var ( subSearchNum = len(subSearchResultData) // for results of each subSearchResultData, storing the start offset of each query of nq queries subSearchNqOffset = make([][]int64, subSearchNum) ) for i := 0; i < subSearchNum; i++ { subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) for j := int64(1); j < nq; j++ { subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] } } gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit)) if err != nil { return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error()) } // reducing nq * topk results for nqIdx := int64(0); nqIdx < nq; nqIdx++ { dataCount := int64(0) for subIdx := 0; subIdx < subSearchNum; subIdx += 1 { subData := subSearchResultData[subIdx] subPks := subData.GetIds() subScores := subData.GetScores() subGroupByVals := subData.GetGroupByFieldValue() nqTopK := subData.Topks[nqIdx] groupByValIterator := typeutil.GetDataIterator(subGroupByVals) for i := int64(0); i < nqTopK; i++ { innerIdx := subSearchNqOffset[subIdx][nqIdx] + i pk := typeutil.GetPK(subPks, innerIdx) score := subScores[innerIdx] groupByVal := groupByValIterator(int(innerIdx)) gpFieldBuilder.Add(groupByVal) typeutil.AppendPKs(ret.Results.Ids, pk) ret.Results.Scores = append(ret.Results.Scores, score) dataCount += 1 } } ret.Results.Topks = append(ret.Results.Topks, dataCount) } ret.Results.GroupByFieldValue = gpFieldBuilder.Build() ret.Results.TopK = topK // realTopK is the topK of the nq-th query if !metric.PositivelyRelated(metricType) { for k := range ret.Results.Scores { ret.Results.Scores[k] *= -1 } } return ret, nil } type MilvusPKType interface{} type groupReduceInfo struct { subSearchIdx int resultIdx int64 score float32 id MilvusPKType } func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64, groupSize int64, ) (*milvuspb.SearchResults, error) { tr := timerecord.NewTimeRecorder("reduceSearchResultData") defer func() { tr.CtxElapse(ctx, "done") }() limit := topk - offset log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq), zap.Int64("offset", offset), zap.Int64("limit", limit), zap.String("metricType", metricType)) ret := &milvuspb.SearchResults{ Status: merr.Success(), Results: &schemapb.SearchResultData{ NumQueries: nq, TopK: topk, FieldsData: []*schemapb.FieldData{}, Scores: []float32{}, Ids: &schemapb.IDs{}, Topks: []int64{}, }, } groupBound := groupSize * limit if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil { return ret, err } if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil { log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) return ret, err } else { ret.GetResults().AllSearchCount = allSearchCount } // Find the first non-empty FieldsData as template for _, result := range subSearchResultData { if len(result.GetFieldsData()) > 0 { ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit) break } } var ( subSearchNum = len(subSearchResultData) // for results of each subSearchResultData, storing the start offset of each query of nq queries subSearchNqOffset = make([][]int64, subSearchNum) totalResCount int64 = 0 subSearchGroupByValIterator = make([]func(int) any, subSearchNum) ) for i := 0; i < subSearchNum; i++ { subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) for j := int64(1); j < nq; j++ { subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] } totalResCount += subSearchNqOffset[i][nq-1] subSearchGroupByValIterator[i] = typeutil.GetDataIterator(subSearchResultData[i].GetGroupByFieldValue()) } gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit)) if err != nil { return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error()) } var realTopK int64 = -1 var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() // reducing nq * topk results for i := int64(0); i < nq; i++ { var ( // cursor of current data of each subSearch for merging the j-th data of TopK. // sum(cursors) == j cursors = make([]int64, subSearchNum) j int64 groupByValMap = make(map[interface{}][]*groupReduceInfo) skipOffsetMap = make(map[interface{}]bool) groupByValList = make([]interface{}, limit) groupByValIdx = 0 ) for j = 0; j < groupBound; { subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i) if subSearchIdx == -1 { break } subSearchRes := subSearchResultData[subSearchIdx] id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx) score := subSearchRes.GetScores()[resultDataIdx] groupByVal := subSearchGroupByValIterator[subSearchIdx](int(resultDataIdx)) if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] { skipOffsetMap[groupByVal] = true // the first offset's group will be ignored } else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit { // skip when groupbyMap has been full and found new groupByVal } else if int64(len(groupByValMap[groupByVal])) >= groupSize { // skip when target group has been full } else { if len(groupByValMap[groupByVal]) == 0 { groupByValList[groupByValIdx] = groupByVal groupByValIdx++ } groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{ subSearchIdx: subSearchIdx, resultIdx: resultDataIdx, id: id, score: score, }) j++ } cursors[subSearchIdx]++ } // assemble all eligible values in group // values in groupByValList is sorted by the highest score in each group for _, groupVal := range groupByValList { groupEntities := groupByValMap[groupVal] for _, groupEntity := range groupEntities { subResData := subSearchResultData[groupEntity.subSearchIdx] if len(ret.Results.FieldsData) > 0 { retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx) } typeutil.AppendPKs(ret.Results.Ids, groupEntity.id) ret.Results.Scores = append(ret.Results.Scores, groupEntity.score) gpFieldBuilder.Add(groupVal) } } if realTopK != -1 && realTopK != j { log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) } realTopK = j ret.Results.Topks = append(ret.Results.Topks, realTopK) ret.Results.GroupByFieldValue = gpFieldBuilder.Build() // limit search result to avoid oom if retSize > maxOutputSize { return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) } } ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query if !metric.PositivelyRelated(metricType) { for k := range ret.Results.Scores { ret.Results.Scores[k] *= -1 } } return ret, nil } func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) { tr := timerecord.NewTimeRecorder("reduceSearchResultData") defer func() { tr.CtxElapse(ctx, "done") }() limit := topk - offset log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq), zap.Int64("offset", offset), zap.Int64("limit", limit), zap.String("metricType", metricType)) ret := &milvuspb.SearchResults{ Status: merr.Success(), Results: &schemapb.SearchResultData{ NumQueries: nq, TopK: topk, FieldsData: []*schemapb.FieldData{}, Scores: []float32{}, Ids: &schemapb.IDs{}, Topks: []int64{}, }, } if err := setupIdListForSearchResult(ret, pkType, limit); err != nil { return ret, nil } if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil { log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) return ret, err } else { ret.GetResults().AllSearchCount = allSearchCount } // Find the first non-empty FieldsData as template for _, result := range subSearchResultData { if len(result.GetFieldsData()) > 0 { ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit) break } } subSearchNum := len(subSearchResultData) if subSearchNum == 1 && offset == 0 { // sorting is not needed if there is only one shard and no offset, assigning the result directly. // we still need to adjust the scores later. ret.Results = subSearchResultData[0] // realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator. topks := subSearchResultData[0].Topks if len(topks) > 0 { ret.Results.TopK = topks[len(topks)-1] } } else { var realTopK int64 = -1 var retSize int64 // for results of each subSearchResultData, storing the start offset of each query of nq queries subSearchNqOffset := make([][]int64, subSearchNum) for i := 0; i < subSearchNum; i++ { subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries()) for j := int64(1); j < nq; j++ { subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] } } maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() // reducing nq * topk results for i := int64(0); i < nq; i++ { var ( // cursor of current data of each subSearch for merging the j-th data of TopK. // sum(cursors) == j cursors = make([]int64, subSearchNum) j int64 ) // skip offset results for k := int64(0); k < offset; k++ { subSearchIdx, _ := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i) if subSearchIdx == -1 { break } cursors[subSearchIdx]++ } // keep limit results for j = 0; j < limit; j++ { // From all the sub-query result sets of the i-th query vector, // find the sub-query result set index of the score j-th data, // and the index of the data in schemapb.SearchResultData subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i) if subSearchIdx == -1 { break } score := subSearchResultData[subSearchIdx].Scores[resultDataIdx] if len(ret.Results.FieldsData) > 0 { retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) } typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx)) ret.Results.Scores = append(ret.Results.Scores, score) cursors[subSearchIdx]++ } if realTopK != -1 && realTopK != j { log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) // return nil, errors.New("the length (topk) between all result of query is different") } realTopK = j ret.Results.Topks = append(ret.Results.Topks, realTopK) // limit search result to avoid oom if retSize > maxOutputSize { return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) } } ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query } if !metric.PositivelyRelated(metricType) { for k := range ret.Results.Scores { ret.Results.Scores[k] *= -1 } } return ret, nil } func compareKey(keyI interface{}, keyJ interface{}) bool { switch keyI.(type) { case int64: return keyI.(int64) < keyJ.(int64) case string: return keyI.(string) < keyJ.(string) } return false } func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error { switch pkType { case schemapb.DataType_Int64: searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: make([]int64, 0, capacity), }, } case schemapb.DataType_VarChar: searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{ StrId: &schemapb.StringArray{ Data: make([]string, 0, capacity), }, } default: return errors.New("unsupported pk type") } return nil } func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults { return &milvuspb.SearchResults{ Status: merr.Success("search result is empty"), Results: &schemapb.SearchResultData{ NumQueries: numQueries, Topks: make([]int64, numQueries), }, } } func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults") defer sp.End() log := log.Ctx(ctx) // Decode all search results validSearchResults, err := decodeSearchResults(ctx, toReduceResults) if err != nil { log.Warn("failed to decode search results", zap.Error(err)) return nil, err } if len(validSearchResults) <= 0 { return fillInEmptyResult(nq), nil } // Reduce all search results log.Debug("proxy search post execute reduce", zap.Int64("collection", collectionID), zap.Int64s("partitionIDs", partitionIDs), zap.Int("number of valid search results", len(validSearchResults))) var result *milvuspb.SearchResults result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType). WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance)) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) return nil, err } return result, nil } func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults") defer sp.End() tr := timerecord.NewTimeRecorder("decodeSearchResults") results := make([]*schemapb.SearchResultData, 0) for _, partialSearchResult := range searchResults { if partialSearchResult.SlicedBlob == nil { continue } var partialResultData schemapb.SearchResultData err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData) if err != nil { return nil, err } results = append(results, &partialResultData) } tr.CtxElapse(ctx, "decodeSearchResults done") return results, nil } func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error { if data.NumQueries != nq { return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq) } if data.TopK != topk { return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk) } if len(data.Scores) != pkHitNum { return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d", len(data.Scores), pkHitNum) } return nil } func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) { var ( subSearchIdx = -1 resultDataIdx int64 = -1 ) maxScore := minFloat32 for i := range cursors { if cursors[i] >= subSearchResultData[i].Topks[qi] { continue } sIdx := subSearchNqOffset[i][qi] + cursors[i] sScore := subSearchResultData[i].Scores[sIdx] // Choose the larger score idx or the smaller pk idx with the same score if subSearchIdx == -1 || sScore > maxScore { subSearchIdx = i resultDataIdx = sIdx maxScore = sScore } else if sScore == maxScore { if subSearchIdx == -1 { // A bad case happens where Knowhere returns distance/score == +/-maxFloat32 // by mistake. log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore)) } else if typeutil.ComparePK( typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx), typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) { subSearchIdx = i resultDataIdx = sIdx maxScore = sScore } } } return subSearchIdx, resultDataIdx }