milvus/internal/proxy/search_reduce_util.go
marcelo-cjl 3b599441fd
feat: Add nullable vector support for proxy and querynode (#46305)
related: #45993 

This commit extends nullable vector support to the proxy layer,
querynode,
and adds comprehensive validation, search reduce, and field data
handling
    for nullable vectors with sparse storage.
    
    Proxy layer changes:
- Update validate_util.go checkAligned() with getExpectedVectorRows()
helper
      to validate nullable vector field alignment using valid data count
- Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for
      nullable vector validation with proper row count expectations
- Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical
      index translation during search reduce operations
- Update search_reduce_util.go reduceSearchResultData to use
idxComputers
      for correct field data indexing with nullable vectors
- Update task.go, task_query.go, task_upsert.go for nullable vector
handling
    - Update msg_pack.go with nullable vector field data processing
    
    QueryNode layer changes:
    - Update segments/result.go for nullable vector result handling
- Update segments/search_reduce.go with nullable vector offset
translation
    
    Storage and index changes:
- Update data_codec.go and utils.go for nullable vector serialization
- Update indexcgowrapper/dataset.go and index.go for nullable vector
indexing
    
    Utility changes:
- Add FieldDataIdxComputer struct with Compute() method for efficient
      logical-to-physical index mapping across multiple field data
- Update EstimateEntitySize() and AppendFieldData() with fieldIdxs
parameter
    - Update funcutil.go with nullable vector support functions

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full support for nullable vector fields (float, binary, float16,
bfloat16, int8, sparse) across ingest, storage, indexing, search and
retrieval; logical↔physical offset mapping preserves row semantics.
  * Client: compaction control and compaction-state APIs.

* **Bug Fixes**
* Improved validation for adding vector fields (nullable + dimension
checks) and corrected search/query behavior for nullable vectors.

* **Chores**
  * Persisted validity maps with indexes and on-disk formats.

* **Tests**
  * Extensive new and updated end-to-end nullable-vector tests.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
2025-12-24 10:13:19 +08:00

656 lines
22 KiB
Go

package proxy
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) {
if reduceInfo.GetGroupByFieldId() > 0 {
if reduceInfo.GetIsAdvance() {
// for hybrid search group by, we cannot reduce result for results from one single search path,
// because the final score has not been accumulated, also, offset cannot be applied
return reduceAdvanceGroupBy(ctx,
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
}
return reduceSearchResultDataWithGroupBy(ctx,
subSearchResultData,
reduceInfo.GetNq(),
reduceInfo.GetTopK(),
reduceInfo.GetMetricType(),
reduceInfo.GetPkType(),
reduceInfo.GetOffset(),
reduceInfo.GetGroupSize())
}
return reduceSearchResultDataNoGroupBy(ctx,
subSearchResultData,
reduceInfo.GetNq(),
reduceInfo.GetTopK(),
reduceInfo.GetMetricType(),
reduceInfo.GetPkType(),
reduceInfo.GetOffset())
}
func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
nq int64, topK int64,
) (int64, int, error) {
var allSearchCount int64
var hitNum int
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
allSearchCount += sData.GetAllSearchCount()
hitNum += pkLength
if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return allSearchCount, hitNum, err
}
}
return allSearchCount, hitNum, nil
}
func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
nq int64, topK int64, pkType schemapb.DataType, metricType string,
) (*milvuspb.SearchResults, error) {
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
// for advance group by, offset is not applied, so just return when there's only one channel
if len(subSearchResultData) == 1 {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: subSearchResultData[0],
}, nil
}
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topK,
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
var limit int64
if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
} else {
ret.GetResults().AllSearchCount = allSearchCount
limit = int64(hitNum)
// Find the first non-empty FieldsData as template
for _, result := range subSearchResultData {
if len(result.GetFieldsData()) > 0 {
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
break
}
}
}
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
return ret, nil
}
var (
subSearchNum = len(subSearchResultData)
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset = make([][]int64, subSearchNum)
)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
if err != nil {
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
}
// reducing nq * topk results
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
dataCount := int64(0)
for subIdx := 0; subIdx < subSearchNum; subIdx += 1 {
subData := subSearchResultData[subIdx]
subPks := subData.GetIds()
subScores := subData.GetScores()
subGroupByVals := subData.GetGroupByFieldValue()
nqTopK := subData.Topks[nqIdx]
groupByValIterator := typeutil.GetDataIterator(subGroupByVals)
for i := int64(0); i < nqTopK; i++ {
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
pk := typeutil.GetPK(subPks, innerIdx)
score := subScores[innerIdx]
groupByVal := groupByValIterator(int(innerIdx))
gpFieldBuilder.Add(groupByVal)
typeutil.AppendPKs(ret.Results.Ids, pk)
ret.Results.Scores = append(ret.Results.Scores, score)
// Handle ElementIndices if present
if subData.ElementIndices != nil {
if ret.Results.ElementIndices == nil {
ret.Results.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0, limit),
}
}
elemIdx := subData.ElementIndices.GetData()[innerIdx]
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
}
dataCount += 1
}
}
ret.Results.Topks = append(ret.Results.Topks, dataCount)
}
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
type MilvusPKType interface{}
type groupReduceInfo struct {
subSearchIdx int
resultIdx int64
score float32
id MilvusPKType
}
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
nq int64, topk int64, metricType string,
pkType schemapb.DataType,
offset int64,
groupSize int64,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
limit := topk - offset
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit),
zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: []*schemapb.FieldData{},
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
groupBound := groupSize * limit
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
return ret, err
}
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
} else {
ret.GetResults().AllSearchCount = allSearchCount
}
// Find the first non-empty FieldsData as template
for _, result := range subSearchResultData {
if len(result.GetFieldsData()) > 0 {
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
break
}
}
var (
subSearchNum = len(subSearchResultData)
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset = make([][]int64, subSearchNum)
totalResCount int64 = 0
subSearchGroupByValIterator = make([]func(int) any, subSearchNum)
)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
totalResCount += subSearchNqOffset[i][nq-1]
subSearchGroupByValIterator[i] = typeutil.GetDataIterator(subSearchResultData[i].GetGroupByFieldValue())
}
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
if err != nil {
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
}
idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum)
for i, srd := range subSearchResultData {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
}
var realTopK int64 = -1
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)
j int64
groupByValMap = make(map[interface{}][]*groupReduceInfo)
skipOffsetMap = make(map[interface{}]bool)
groupByValList = make([]interface{}, limit)
groupByValIdx = 0
)
for j = 0; j < groupBound; {
subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.GetScores()[resultDataIdx]
groupByVal := subSearchGroupByValIterator[subSearchIdx](int(resultDataIdx))
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
skipOffsetMap[groupByVal] = true
// the first offset's group will be ignored
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
// skip when groupbyMap has been full and found new groupByVal
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
// skip when target group has been full
} else {
if len(groupByValMap[groupByVal]) == 0 {
groupByValList[groupByValIdx] = groupByVal
groupByValIdx++
}
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
subSearchIdx: subSearchIdx,
resultIdx: resultDataIdx, id: id, score: score,
})
j++
}
cursors[subSearchIdx]++
}
// assemble all eligible values in group
// values in groupByValList is sorted by the highest score in each group
for _, groupVal := range groupByValList {
groupEntities := groupByValMap[groupVal]
for _, groupEntity := range groupEntities {
subResData := subSearchResultData[groupEntity.subSearchIdx]
if len(ret.Results.FieldsData) > 0 {
fieldIdxs := idxComputers[groupEntity.subSearchIdx].Compute(groupEntity.resultIdx)
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx, fieldIdxs...)
}
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
// Handle ElementIndices if present
if subResData.ElementIndices != nil {
if ret.Results.ElementIndices == nil {
ret.Results.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0, limit),
}
}
elemIdx := subResData.ElementIndices.GetData()[groupEntity.resultIdx]
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
}
gpFieldBuilder.Add(groupVal)
}
}
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
limit := topk - offset
log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit),
zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: []*schemapb.FieldData{},
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
return ret, nil
}
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
} else {
ret.GetResults().AllSearchCount = allSearchCount
}
// Find the first non-empty FieldsData as template
for _, result := range subSearchResultData {
if len(result.GetFieldsData()) > 0 {
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
break
}
}
subSearchNum := len(subSearchResultData)
if subSearchNum == 1 && offset == 0 {
// sorting is not needed if there is only one shard and no offset, assigning the result directly.
// we still need to adjust the scores later.
ret.Results = subSearchResultData[0]
// realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator.
topks := subSearchResultData[0].Topks
if len(topks) > 0 {
ret.Results.TopK = topks[len(topks)-1]
}
} else {
var realTopK int64 = -1
var retSize int64
// for results of each subSearchResultData, storing the start offset of each query of nq queries
subSearchNqOffset := make([][]int64, subSearchNum)
for i := 0; i < subSearchNum; i++ {
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
for j := int64(1); j < nq; j++ {
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum)
for i, srd := range subSearchResultData {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
}
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
// cursor of current data of each subSearch for merging the j-th data of TopK.
// sum(cursors) == j
cursors = make([]int64, subSearchNum)
j int64
)
// skip offset results
for k := int64(0); k < offset; k++ {
subSearchIdx, _ := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
cursors[subSearchIdx]++
}
// keep limit results
for j = 0; j < limit; j++ {
// From all the sub-query result sets of the i-th query vector,
// find the sub-query result set index of the score j-th data,
// and the index of the data in schemapb.SearchResultData
subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
if subSearchIdx == -1 {
break
}
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
if len(ret.Results.FieldsData) > 0 {
fieldsData := subSearchResultData[subSearchIdx].FieldsData
fieldIdxs := idxComputers[subSearchIdx].Compute(resultDataIdx)
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, fieldsData, resultDataIdx, fieldIdxs...)
}
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
ret.Results.Scores = append(ret.Results.Scores, score)
// Handle ElementIndices if present
if subSearchResultData[subSearchIdx].ElementIndices != nil {
if ret.Results.ElementIndices == nil {
ret.Results.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0, limit),
}
}
elemIdx := subSearchResultData[subSearchIdx].ElementIndices.GetData()[resultDataIdx]
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
}
cursors[subSearchIdx]++
}
if realTopK != -1 && realTopK != j {
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
// limit search result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
}
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
}
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}
}
return ret, nil
}
func compareKey(keyI interface{}, keyJ interface{}) bool {
switch keyI.(type) {
case int64:
return keyI.(int64) < keyJ.(int64)
case string:
return keyI.(string) < keyJ.(string)
}
return false
}
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
switch pkType {
case schemapb.DataType_Int64:
searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0, capacity),
},
}
case schemapb.DataType_VarChar:
searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0, capacity),
},
}
default:
return errors.New("unsupported pk type")
}
return nil
}
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
return &milvuspb.SearchResults{
Status: merr.Success("search result is empty"),
Results: &schemapb.SearchResultData{
NumQueries: numQueries,
Topks: make([]int64, numQueries),
},
}
}
func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
defer sp.End()
log := log.Ctx(ctx)
// Decode all search results
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
if err != nil {
log.Warn("failed to decode search results", zap.Error(err))
return nil, err
}
if len(validSearchResults) <= 0 {
log.Debug("reduced search results is empty, fill in empty result")
return fillInEmptyResult(nq), nil
}
// Reduce all search results
log.Debug("proxy search post execute reduce",
zap.Int64("collection", collectionID),
zap.Int64s("partitionIDs", partitionIDs),
zap.Int("number of valid search results", len(validSearchResults)))
var result *milvuspb.SearchResults
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType).
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return nil, err
}
return result, nil
}
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
defer sp.End()
tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0)
for _, partialSearchResult := range searchResults {
if partialSearchResult.SlicedBlob == nil {
continue
}
var partialResultData schemapb.SearchResultData
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
if err != nil {
return nil, err
}
results = append(results, &partialResultData)
}
tr.CtxElapse(ctx, "decodeSearchResults done")
return results, nil
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
}
if data.TopK != topk {
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
}
if len(data.Scores) != pkHitNum {
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
len(data.Scores), pkHitNum)
}
return nil
}
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
var (
subSearchIdx = -1
resultDataIdx int64 = -1
)
maxScore := minFloat32
for i := range cursors {
if cursors[i] >= subSearchResultData[i].Topks[qi] {
continue
}
sIdx := subSearchNqOffset[i][qi] + cursors[i]
sScore := subSearchResultData[i].Scores[sIdx]
// Choose the larger score idx or the smaller pk idx with the same score
if subSearchIdx == -1 || sScore > maxScore {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
} else if sScore == maxScore {
if subSearchIdx == -1 {
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
// by mistake.
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
} else if typeutil.ComparePK(
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
}
}
}
return subSearchIdx, resultDataIdx
}