mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 09:08:43 +08:00
related: #45418 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
606 lines
21 KiB
Go
606 lines
21 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"go.opentelemetry.io/otel"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) {
|
|
if reduceInfo.GetGroupByFieldId() > 0 {
|
|
if reduceInfo.GetIsAdvance() {
|
|
// for hybrid search group by, we cannot reduce result for results from one single search path,
|
|
// because the final score has not been accumulated, also, offset cannot be applied
|
|
return reduceAdvanceGroupBy(ctx,
|
|
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
|
|
}
|
|
return reduceSearchResultDataWithGroupBy(ctx,
|
|
subSearchResultData,
|
|
reduceInfo.GetNq(),
|
|
reduceInfo.GetTopK(),
|
|
reduceInfo.GetMetricType(),
|
|
reduceInfo.GetPkType(),
|
|
reduceInfo.GetOffset(),
|
|
reduceInfo.GetGroupSize())
|
|
}
|
|
return reduceSearchResultDataNoGroupBy(ctx,
|
|
subSearchResultData,
|
|
reduceInfo.GetNq(),
|
|
reduceInfo.GetTopK(),
|
|
reduceInfo.GetMetricType(),
|
|
reduceInfo.GetPkType(),
|
|
reduceInfo.GetOffset())
|
|
}
|
|
|
|
func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
|
nq int64, topK int64,
|
|
) (int64, int, error) {
|
|
var allSearchCount int64
|
|
var hitNum int
|
|
for i, sData := range subSearchResultData {
|
|
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
|
log.Ctx(ctx).Debug("subSearchResultData",
|
|
zap.Int("result No.", i),
|
|
zap.Int64("nq", sData.NumQueries),
|
|
zap.Int64("topk", sData.TopK),
|
|
zap.Int("length of pks", pkLength),
|
|
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
|
allSearchCount += sData.GetAllSearchCount()
|
|
hitNum += pkLength
|
|
if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return allSearchCount, hitNum, err
|
|
}
|
|
}
|
|
return allSearchCount, hitNum, nil
|
|
}
|
|
|
|
func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
|
nq int64, topK int64, pkType schemapb.DataType, metricType string,
|
|
) (*milvuspb.SearchResults, error) {
|
|
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
|
|
// for advance group by, offset is not applied, so just return when there's only one channel
|
|
if len(subSearchResultData) == 1 {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: subSearchResultData[0],
|
|
}, nil
|
|
}
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topK,
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}
|
|
|
|
var limit int64
|
|
if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
} else {
|
|
ret.GetResults().AllSearchCount = allSearchCount
|
|
limit = int64(hitNum)
|
|
// Find the first non-empty FieldsData as template
|
|
for _, result := range subSearchResultData {
|
|
if len(result.GetFieldsData()) > 0 {
|
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
|
return ret, nil
|
|
}
|
|
|
|
var (
|
|
subSearchNum = len(subSearchResultData)
|
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
|
subSearchNqOffset = make([][]int64, subSearchNum)
|
|
)
|
|
for i := 0; i < subSearchNum; i++ {
|
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
|
for j := int64(1); j < nq; j++ {
|
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
|
}
|
|
}
|
|
|
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
|
|
if err != nil {
|
|
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
|
|
}
|
|
// reducing nq * topk results
|
|
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
|
|
dataCount := int64(0)
|
|
for subIdx := 0; subIdx < subSearchNum; subIdx += 1 {
|
|
subData := subSearchResultData[subIdx]
|
|
subPks := subData.GetIds()
|
|
subScores := subData.GetScores()
|
|
subGroupByVals := subData.GetGroupByFieldValue()
|
|
|
|
nqTopK := subData.Topks[nqIdx]
|
|
groupByValIterator := typeutil.GetDataIterator(subGroupByVals)
|
|
|
|
for i := int64(0); i < nqTopK; i++ {
|
|
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
|
|
pk := typeutil.GetPK(subPks, innerIdx)
|
|
score := subScores[innerIdx]
|
|
groupByVal := groupByValIterator(int(innerIdx))
|
|
gpFieldBuilder.Add(groupByVal)
|
|
typeutil.AppendPKs(ret.Results.Ids, pk)
|
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
|
dataCount += 1
|
|
}
|
|
}
|
|
ret.Results.Topks = append(ret.Results.Topks, dataCount)
|
|
}
|
|
|
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
|
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
|
|
if !metric.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
type MilvusPKType interface{}
|
|
|
|
type groupReduceInfo struct {
|
|
subSearchIdx int
|
|
resultIdx int64
|
|
score float32
|
|
id MilvusPKType
|
|
}
|
|
|
|
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
|
nq int64, topk int64, metricType string,
|
|
pkType schemapb.DataType,
|
|
offset int64,
|
|
groupSize int64,
|
|
) (*milvuspb.SearchResults, error) {
|
|
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
|
defer func() {
|
|
tr.CtxElapse(ctx, "done")
|
|
}()
|
|
|
|
limit := topk - offset
|
|
log.Ctx(ctx).Debug("reduceSearchResultData",
|
|
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
|
zap.Int64("nq", nq),
|
|
zap.Int64("offset", offset),
|
|
zap.Int64("limit", limit),
|
|
zap.String("metricType", metricType))
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topk,
|
|
FieldsData: []*schemapb.FieldData{},
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}
|
|
groupBound := groupSize * limit
|
|
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
|
|
return ret, err
|
|
}
|
|
|
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
} else {
|
|
ret.GetResults().AllSearchCount = allSearchCount
|
|
}
|
|
|
|
// Find the first non-empty FieldsData as template
|
|
for _, result := range subSearchResultData {
|
|
if len(result.GetFieldsData()) > 0 {
|
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
|
|
break
|
|
}
|
|
}
|
|
|
|
var (
|
|
subSearchNum = len(subSearchResultData)
|
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
|
subSearchNqOffset = make([][]int64, subSearchNum)
|
|
totalResCount int64 = 0
|
|
subSearchGroupByValIterator = make([]func(int) any, subSearchNum)
|
|
)
|
|
for i := 0; i < subSearchNum; i++ {
|
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
|
for j := int64(1); j < nq; j++ {
|
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
|
}
|
|
totalResCount += subSearchNqOffset[i][nq-1]
|
|
subSearchGroupByValIterator[i] = typeutil.GetDataIterator(subSearchResultData[i].GetGroupByFieldValue())
|
|
}
|
|
|
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
|
|
if err != nil {
|
|
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
|
|
}
|
|
|
|
var realTopK int64 = -1
|
|
var retSize int64
|
|
|
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
// reducing nq * topk results
|
|
for i := int64(0); i < nq; i++ {
|
|
var (
|
|
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
|
// sum(cursors) == j
|
|
cursors = make([]int64, subSearchNum)
|
|
|
|
j int64
|
|
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
|
skipOffsetMap = make(map[interface{}]bool)
|
|
groupByValList = make([]interface{}, limit)
|
|
groupByValIdx = 0
|
|
)
|
|
|
|
for j = 0; j < groupBound; {
|
|
subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
|
|
if subSearchIdx == -1 {
|
|
break
|
|
}
|
|
subSearchRes := subSearchResultData[subSearchIdx]
|
|
|
|
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
|
score := subSearchRes.GetScores()[resultDataIdx]
|
|
groupByVal := subSearchGroupByValIterator[subSearchIdx](int(resultDataIdx))
|
|
|
|
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
|
skipOffsetMap[groupByVal] = true
|
|
// the first offset's group will be ignored
|
|
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
|
// skip when groupbyMap has been full and found new groupByVal
|
|
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
|
// skip when target group has been full
|
|
} else {
|
|
if len(groupByValMap[groupByVal]) == 0 {
|
|
groupByValList[groupByValIdx] = groupByVal
|
|
groupByValIdx++
|
|
}
|
|
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
|
|
subSearchIdx: subSearchIdx,
|
|
resultIdx: resultDataIdx, id: id, score: score,
|
|
})
|
|
j++
|
|
}
|
|
|
|
cursors[subSearchIdx]++
|
|
}
|
|
|
|
// assemble all eligible values in group
|
|
// values in groupByValList is sorted by the highest score in each group
|
|
for _, groupVal := range groupByValList {
|
|
groupEntities := groupByValMap[groupVal]
|
|
for _, groupEntity := range groupEntities {
|
|
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
|
if len(ret.Results.FieldsData) > 0 {
|
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
|
}
|
|
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
|
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
|
gpFieldBuilder.Add(groupVal)
|
|
}
|
|
}
|
|
|
|
if realTopK != -1 && realTopK != j {
|
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
|
}
|
|
realTopK = j
|
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
|
|
|
// limit search result to avoid oom
|
|
if retSize > maxOutputSize {
|
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
|
}
|
|
}
|
|
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
|
if !metric.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
|
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
|
defer func() {
|
|
tr.CtxElapse(ctx, "done")
|
|
}()
|
|
|
|
limit := topk - offset
|
|
log.Ctx(ctx).Debug("reduceSearchResultData",
|
|
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
|
zap.Int64("nq", nq),
|
|
zap.Int64("offset", offset),
|
|
zap.Int64("limit", limit),
|
|
zap.String("metricType", metricType))
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topk,
|
|
FieldsData: []*schemapb.FieldData{},
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}
|
|
|
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
|
return ret, nil
|
|
}
|
|
|
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
} else {
|
|
ret.GetResults().AllSearchCount = allSearchCount
|
|
}
|
|
|
|
// Find the first non-empty FieldsData as template
|
|
for _, result := range subSearchResultData {
|
|
if len(result.GetFieldsData()) > 0 {
|
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
|
|
break
|
|
}
|
|
}
|
|
|
|
subSearchNum := len(subSearchResultData)
|
|
if subSearchNum == 1 && offset == 0 {
|
|
// sorting is not needed if there is only one shard and no offset, assigning the result directly.
|
|
// we still need to adjust the scores later.
|
|
ret.Results = subSearchResultData[0]
|
|
// realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator.
|
|
topks := subSearchResultData[0].Topks
|
|
if len(topks) > 0 {
|
|
ret.Results.TopK = topks[len(topks)-1]
|
|
}
|
|
} else {
|
|
var realTopK int64 = -1
|
|
var retSize int64
|
|
|
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
|
subSearchNqOffset := make([][]int64, subSearchNum)
|
|
for i := 0; i < subSearchNum; i++ {
|
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
|
for j := int64(1); j < nq; j++ {
|
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
|
}
|
|
}
|
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
// reducing nq * topk results
|
|
for i := int64(0); i < nq; i++ {
|
|
var (
|
|
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
|
// sum(cursors) == j
|
|
cursors = make([]int64, subSearchNum)
|
|
j int64
|
|
)
|
|
|
|
// skip offset results
|
|
for k := int64(0); k < offset; k++ {
|
|
subSearchIdx, _ := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
|
|
if subSearchIdx == -1 {
|
|
break
|
|
}
|
|
|
|
cursors[subSearchIdx]++
|
|
}
|
|
|
|
// keep limit results
|
|
for j = 0; j < limit; j++ {
|
|
// From all the sub-query result sets of the i-th query vector,
|
|
// find the sub-query result set index of the score j-th data,
|
|
// and the index of the data in schemapb.SearchResultData
|
|
subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
|
|
if subSearchIdx == -1 {
|
|
break
|
|
}
|
|
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
|
|
|
if len(ret.Results.FieldsData) > 0 {
|
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
|
}
|
|
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
|
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
|
cursors[subSearchIdx]++
|
|
}
|
|
if realTopK != -1 && realTopK != j {
|
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
|
// return nil, errors.New("the length (topk) between all result of query is different")
|
|
}
|
|
realTopK = j
|
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
|
|
|
// limit search result to avoid oom
|
|
if retSize > maxOutputSize {
|
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
|
}
|
|
}
|
|
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
|
}
|
|
|
|
if !metric.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func compareKey(keyI interface{}, keyJ interface{}) bool {
|
|
switch keyI.(type) {
|
|
case int64:
|
|
return keyI.(int64) < keyJ.(int64)
|
|
case string:
|
|
return keyI.(string) < keyJ.(string)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
|
|
switch pkType {
|
|
case schemapb.DataType_Int64:
|
|
searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: make([]int64, 0, capacity),
|
|
},
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
|
StrId: &schemapb.StringArray{
|
|
Data: make([]string, 0, capacity),
|
|
},
|
|
}
|
|
default:
|
|
return errors.New("unsupported pk type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success("search result is empty"),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: numQueries,
|
|
Topks: make([]int64, numQueries),
|
|
},
|
|
}
|
|
}
|
|
|
|
func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
|
|
defer sp.End()
|
|
|
|
log := log.Ctx(ctx)
|
|
// Decode all search results
|
|
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
|
|
if err != nil {
|
|
log.Warn("failed to decode search results", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
if len(validSearchResults) <= 0 {
|
|
log.Debug("reduced search results is empty, fill in empty result")
|
|
return fillInEmptyResult(nq), nil
|
|
}
|
|
|
|
// Reduce all search results
|
|
log.Debug("proxy search post execute reduce",
|
|
zap.Int64("collection", collectionID),
|
|
zap.Int64s("partitionIDs", partitionIDs),
|
|
zap.Int("number of valid search results", len(validSearchResults)))
|
|
var result *milvuspb.SearchResults
|
|
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType).
|
|
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
|
|
if err != nil {
|
|
log.Warn("failed to reduce search results", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
|
|
defer sp.End()
|
|
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
|
results := make([]*schemapb.SearchResultData, 0)
|
|
for _, partialSearchResult := range searchResults {
|
|
if partialSearchResult.SlicedBlob == nil {
|
|
continue
|
|
}
|
|
|
|
var partialResultData schemapb.SearchResultData
|
|
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
results = append(results, &partialResultData)
|
|
}
|
|
tr.CtxElapse(ctx, "decodeSearchResults done")
|
|
return results, nil
|
|
}
|
|
|
|
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
|
|
if data.NumQueries != nq {
|
|
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
|
}
|
|
if data.TopK != topk {
|
|
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
|
}
|
|
|
|
if len(data.Scores) != pkHitNum {
|
|
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
|
len(data.Scores), pkHitNum)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
|
|
var (
|
|
subSearchIdx = -1
|
|
resultDataIdx int64 = -1
|
|
)
|
|
maxScore := minFloat32
|
|
for i := range cursors {
|
|
if cursors[i] >= subSearchResultData[i].Topks[qi] {
|
|
continue
|
|
}
|
|
sIdx := subSearchNqOffset[i][qi] + cursors[i]
|
|
sScore := subSearchResultData[i].Scores[sIdx]
|
|
|
|
// Choose the larger score idx or the smaller pk idx with the same score
|
|
if subSearchIdx == -1 || sScore > maxScore {
|
|
subSearchIdx = i
|
|
resultDataIdx = sIdx
|
|
maxScore = sScore
|
|
} else if sScore == maxScore {
|
|
if subSearchIdx == -1 {
|
|
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
|
|
// by mistake.
|
|
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
|
|
} else if typeutil.ComparePK(
|
|
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
|
|
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
|
|
subSearchIdx = i
|
|
resultDataIdx = sIdx
|
|
maxScore = sScore
|
|
}
|
|
}
|
|
}
|
|
return subSearchIdx, resultDataIdx
|
|
}
|