mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 22:45:26 +08:00
issue: https://github.com/milvus-io/milvus/issues/42148 For a vector field inside a STRUCT, since a STRUCT can only appear as the element type of an ARRAY field, the vector field in STRUCT is effectively an array of vectors, i.e. an embedding list. Milvus already supports searching embedding lists with metrics whose names start with the prefix MAX_SIM_. This PR allows Milvus to search embeddings inside an embedding list using the same metrics as normal embedding fields. Each embedding in the list is treated as an independent vector and participates in ANN search. Further, since STRUCT may contain scalar fields that are highly related to the embedding field, this PR introduces an element-level filter expression to refine search results. The grammar of the element-level filter is: element_filter(structFieldName, $[subFieldName] == 3) where $[subFieldName] refers to the value of subFieldName in each element of the STRUCT array structFieldName. It can be combined with existing filter expressions, for example: "varcharField == 'aaa' && element_filter(struct_field, $[struct_int] == 3)" A full example: ``` struct_schema = milvus_client.create_struct_field_schema() struct_schema.add_field("struct_str", DataType.VARCHAR, max_length=65535) struct_schema.add_field("struct_int", DataType.INT32) struct_schema.add_field("struct_float_vec", DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM) schema.add_field( "struct_field", datatype=DataType.ARRAY, element_type=DataType.STRUCT, struct_schema=struct_schema, max_capacity=1000, ) ... filter = "varcharField == 'aaa' && element_filter(struct_field, $[struct_int] == 3 && $[struct_str] == 'abc')" res = milvus_client.search( COLLECTION_NAME, data=query_embeddings, limit=10, anns_field="struct_field[struct_float_vec]", filter=filter, output_fields=["struct_field[struct_int]", "varcharField"], ) ``` TODO: 1. When an `element_filter` expression is used, a regular filter expression must also be present. Remove this restriction. 2. Implement `element_filter` expressions in the `query`. --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
642 lines
22 KiB
Go
642 lines
22 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"go.opentelemetry.io/otel"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) {
|
|
if reduceInfo.GetGroupByFieldId() > 0 {
|
|
if reduceInfo.GetIsAdvance() {
|
|
// for hybrid search group by, we cannot reduce result for results from one single search path,
|
|
// because the final score has not been accumulated, also, offset cannot be applied
|
|
return reduceAdvanceGroupBy(ctx,
|
|
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
|
|
}
|
|
return reduceSearchResultDataWithGroupBy(ctx,
|
|
subSearchResultData,
|
|
reduceInfo.GetNq(),
|
|
reduceInfo.GetTopK(),
|
|
reduceInfo.GetMetricType(),
|
|
reduceInfo.GetPkType(),
|
|
reduceInfo.GetOffset(),
|
|
reduceInfo.GetGroupSize())
|
|
}
|
|
return reduceSearchResultDataNoGroupBy(ctx,
|
|
subSearchResultData,
|
|
reduceInfo.GetNq(),
|
|
reduceInfo.GetTopK(),
|
|
reduceInfo.GetMetricType(),
|
|
reduceInfo.GetPkType(),
|
|
reduceInfo.GetOffset())
|
|
}
|
|
|
|
func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
|
nq int64, topK int64,
|
|
) (int64, int, error) {
|
|
var allSearchCount int64
|
|
var hitNum int
|
|
for i, sData := range subSearchResultData {
|
|
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
|
log.Ctx(ctx).Debug("subSearchResultData",
|
|
zap.Int("result No.", i),
|
|
zap.Int64("nq", sData.NumQueries),
|
|
zap.Int64("topk", sData.TopK),
|
|
zap.Int("length of pks", pkLength),
|
|
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
|
allSearchCount += sData.GetAllSearchCount()
|
|
hitNum += pkLength
|
|
if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return allSearchCount, hitNum, err
|
|
}
|
|
}
|
|
return allSearchCount, hitNum, nil
|
|
}
|
|
|
|
func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
|
nq int64, topK int64, pkType schemapb.DataType, metricType string,
|
|
) (*milvuspb.SearchResults, error) {
|
|
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
|
|
// for advance group by, offset is not applied, so just return when there's only one channel
|
|
if len(subSearchResultData) == 1 {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: subSearchResultData[0],
|
|
}, nil
|
|
}
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topK,
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}
|
|
|
|
var limit int64
|
|
if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
} else {
|
|
ret.GetResults().AllSearchCount = allSearchCount
|
|
limit = int64(hitNum)
|
|
// Find the first non-empty FieldsData as template
|
|
for _, result := range subSearchResultData {
|
|
if len(result.GetFieldsData()) > 0 {
|
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
|
return ret, nil
|
|
}
|
|
|
|
var (
|
|
subSearchNum = len(subSearchResultData)
|
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
|
subSearchNqOffset = make([][]int64, subSearchNum)
|
|
)
|
|
for i := 0; i < subSearchNum; i++ {
|
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
|
for j := int64(1); j < nq; j++ {
|
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
|
}
|
|
}
|
|
|
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
|
|
if err != nil {
|
|
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
|
|
}
|
|
// reducing nq * topk results
|
|
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
|
|
dataCount := int64(0)
|
|
for subIdx := 0; subIdx < subSearchNum; subIdx += 1 {
|
|
subData := subSearchResultData[subIdx]
|
|
subPks := subData.GetIds()
|
|
subScores := subData.GetScores()
|
|
subGroupByVals := subData.GetGroupByFieldValue()
|
|
|
|
nqTopK := subData.Topks[nqIdx]
|
|
groupByValIterator := typeutil.GetDataIterator(subGroupByVals)
|
|
|
|
for i := int64(0); i < nqTopK; i++ {
|
|
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
|
|
pk := typeutil.GetPK(subPks, innerIdx)
|
|
score := subScores[innerIdx]
|
|
groupByVal := groupByValIterator(int(innerIdx))
|
|
gpFieldBuilder.Add(groupByVal)
|
|
typeutil.AppendPKs(ret.Results.Ids, pk)
|
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
|
|
|
// Handle ElementIndices if present
|
|
if subData.ElementIndices != nil {
|
|
if ret.Results.ElementIndices == nil {
|
|
ret.Results.ElementIndices = &schemapb.LongArray{
|
|
Data: make([]int64, 0, limit),
|
|
}
|
|
}
|
|
elemIdx := subData.ElementIndices.GetData()[innerIdx]
|
|
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
|
|
}
|
|
|
|
dataCount += 1
|
|
}
|
|
}
|
|
ret.Results.Topks = append(ret.Results.Topks, dataCount)
|
|
}
|
|
|
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
|
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
|
|
if !metric.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
type MilvusPKType interface{}
|
|
|
|
type groupReduceInfo struct {
|
|
subSearchIdx int
|
|
resultIdx int64
|
|
score float32
|
|
id MilvusPKType
|
|
}
|
|
|
|
func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
|
nq int64, topk int64, metricType string,
|
|
pkType schemapb.DataType,
|
|
offset int64,
|
|
groupSize int64,
|
|
) (*milvuspb.SearchResults, error) {
|
|
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
|
defer func() {
|
|
tr.CtxElapse(ctx, "done")
|
|
}()
|
|
|
|
limit := topk - offset
|
|
log.Ctx(ctx).Debug("reduceSearchResultData",
|
|
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
|
zap.Int64("nq", nq),
|
|
zap.Int64("offset", offset),
|
|
zap.Int64("limit", limit),
|
|
zap.String("metricType", metricType))
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topk,
|
|
FieldsData: []*schemapb.FieldData{},
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}
|
|
groupBound := groupSize * limit
|
|
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
|
|
return ret, err
|
|
}
|
|
|
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
} else {
|
|
ret.GetResults().AllSearchCount = allSearchCount
|
|
}
|
|
|
|
// Find the first non-empty FieldsData as template
|
|
for _, result := range subSearchResultData {
|
|
if len(result.GetFieldsData()) > 0 {
|
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
|
|
break
|
|
}
|
|
}
|
|
|
|
var (
|
|
subSearchNum = len(subSearchResultData)
|
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
|
subSearchNqOffset = make([][]int64, subSearchNum)
|
|
totalResCount int64 = 0
|
|
subSearchGroupByValIterator = make([]func(int) any, subSearchNum)
|
|
)
|
|
for i := 0; i < subSearchNum; i++ {
|
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
|
for j := int64(1); j < nq; j++ {
|
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
|
}
|
|
totalResCount += subSearchNqOffset[i][nq-1]
|
|
subSearchGroupByValIterator[i] = typeutil.GetDataIterator(subSearchResultData[i].GetGroupByFieldValue())
|
|
}
|
|
|
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
|
|
if err != nil {
|
|
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
|
|
}
|
|
|
|
var realTopK int64 = -1
|
|
var retSize int64
|
|
|
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
// reducing nq * topk results
|
|
for i := int64(0); i < nq; i++ {
|
|
var (
|
|
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
|
// sum(cursors) == j
|
|
cursors = make([]int64, subSearchNum)
|
|
|
|
j int64
|
|
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
|
skipOffsetMap = make(map[interface{}]bool)
|
|
groupByValList = make([]interface{}, limit)
|
|
groupByValIdx = 0
|
|
)
|
|
|
|
for j = 0; j < groupBound; {
|
|
subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
|
|
if subSearchIdx == -1 {
|
|
break
|
|
}
|
|
subSearchRes := subSearchResultData[subSearchIdx]
|
|
|
|
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
|
score := subSearchRes.GetScores()[resultDataIdx]
|
|
groupByVal := subSearchGroupByValIterator[subSearchIdx](int(resultDataIdx))
|
|
|
|
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
|
skipOffsetMap[groupByVal] = true
|
|
// the first offset's group will be ignored
|
|
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
|
// skip when groupbyMap has been full and found new groupByVal
|
|
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
|
// skip when target group has been full
|
|
} else {
|
|
if len(groupByValMap[groupByVal]) == 0 {
|
|
groupByValList[groupByValIdx] = groupByVal
|
|
groupByValIdx++
|
|
}
|
|
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
|
|
subSearchIdx: subSearchIdx,
|
|
resultIdx: resultDataIdx, id: id, score: score,
|
|
})
|
|
j++
|
|
}
|
|
|
|
cursors[subSearchIdx]++
|
|
}
|
|
|
|
// assemble all eligible values in group
|
|
// values in groupByValList is sorted by the highest score in each group
|
|
for _, groupVal := range groupByValList {
|
|
groupEntities := groupByValMap[groupVal]
|
|
for _, groupEntity := range groupEntities {
|
|
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
|
if len(ret.Results.FieldsData) > 0 {
|
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
|
}
|
|
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
|
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
|
|
|
// Handle ElementIndices if present
|
|
if subResData.ElementIndices != nil {
|
|
if ret.Results.ElementIndices == nil {
|
|
ret.Results.ElementIndices = &schemapb.LongArray{
|
|
Data: make([]int64, 0, limit),
|
|
}
|
|
}
|
|
elemIdx := subResData.ElementIndices.GetData()[groupEntity.resultIdx]
|
|
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
|
|
}
|
|
|
|
gpFieldBuilder.Add(groupVal)
|
|
}
|
|
}
|
|
|
|
if realTopK != -1 && realTopK != j {
|
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
|
}
|
|
realTopK = j
|
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
|
|
|
// limit search result to avoid oom
|
|
if retSize > maxOutputSize {
|
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
|
}
|
|
}
|
|
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
|
if !metric.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
|
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
|
defer func() {
|
|
tr.CtxElapse(ctx, "done")
|
|
}()
|
|
|
|
limit := topk - offset
|
|
log.Ctx(ctx).Debug("reduceSearchResultData",
|
|
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
|
zap.Int64("nq", nq),
|
|
zap.Int64("offset", offset),
|
|
zap.Int64("limit", limit),
|
|
zap.String("metricType", metricType))
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: nq,
|
|
TopK: topk,
|
|
FieldsData: []*schemapb.FieldData{},
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}
|
|
|
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
|
return ret, nil
|
|
}
|
|
|
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
return ret, err
|
|
} else {
|
|
ret.GetResults().AllSearchCount = allSearchCount
|
|
}
|
|
|
|
// Find the first non-empty FieldsData as template
|
|
for _, result := range subSearchResultData {
|
|
if len(result.GetFieldsData()) > 0 {
|
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(result.GetFieldsData(), limit)
|
|
break
|
|
}
|
|
}
|
|
|
|
subSearchNum := len(subSearchResultData)
|
|
if subSearchNum == 1 && offset == 0 {
|
|
// sorting is not needed if there is only one shard and no offset, assigning the result directly.
|
|
// we still need to adjust the scores later.
|
|
ret.Results = subSearchResultData[0]
|
|
// realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator.
|
|
topks := subSearchResultData[0].Topks
|
|
if len(topks) > 0 {
|
|
ret.Results.TopK = topks[len(topks)-1]
|
|
}
|
|
} else {
|
|
var realTopK int64 = -1
|
|
var retSize int64
|
|
|
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
|
subSearchNqOffset := make([][]int64, subSearchNum)
|
|
for i := 0; i < subSearchNum; i++ {
|
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
|
for j := int64(1); j < nq; j++ {
|
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
|
}
|
|
}
|
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
|
// reducing nq * topk results
|
|
for i := int64(0); i < nq; i++ {
|
|
var (
|
|
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
|
// sum(cursors) == j
|
|
cursors = make([]int64, subSearchNum)
|
|
j int64
|
|
)
|
|
|
|
// skip offset results
|
|
for k := int64(0); k < offset; k++ {
|
|
subSearchIdx, _ := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
|
|
if subSearchIdx == -1 {
|
|
break
|
|
}
|
|
|
|
cursors[subSearchIdx]++
|
|
}
|
|
|
|
// keep limit results
|
|
for j = 0; j < limit; j++ {
|
|
// From all the sub-query result sets of the i-th query vector,
|
|
// find the sub-query result set index of the score j-th data,
|
|
// and the index of the data in schemapb.SearchResultData
|
|
subSearchIdx, resultDataIdx := selectHighestScoreIndex(ctx, subSearchResultData, subSearchNqOffset, cursors, i)
|
|
if subSearchIdx == -1 {
|
|
break
|
|
}
|
|
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
|
|
|
if len(ret.Results.FieldsData) > 0 {
|
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
|
}
|
|
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
|
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
|
|
|
// Handle ElementIndices if present
|
|
if subSearchResultData[subSearchIdx].ElementIndices != nil {
|
|
if ret.Results.ElementIndices == nil {
|
|
ret.Results.ElementIndices = &schemapb.LongArray{
|
|
Data: make([]int64, 0, limit),
|
|
}
|
|
}
|
|
elemIdx := subSearchResultData[subSearchIdx].ElementIndices.GetData()[resultDataIdx]
|
|
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
|
|
}
|
|
|
|
cursors[subSearchIdx]++
|
|
}
|
|
if realTopK != -1 && realTopK != j {
|
|
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
|
// return nil, errors.New("the length (topk) between all result of query is different")
|
|
}
|
|
realTopK = j
|
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
|
|
|
// limit search result to avoid oom
|
|
if retSize > maxOutputSize {
|
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
|
}
|
|
}
|
|
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
|
}
|
|
|
|
if !metric.PositivelyRelated(metricType) {
|
|
for k := range ret.Results.Scores {
|
|
ret.Results.Scores[k] *= -1
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func compareKey(keyI interface{}, keyJ interface{}) bool {
|
|
switch keyI.(type) {
|
|
case int64:
|
|
return keyI.(int64) < keyJ.(int64)
|
|
case string:
|
|
return keyI.(string) < keyJ.(string)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
|
|
switch pkType {
|
|
case schemapb.DataType_Int64:
|
|
searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: make([]int64, 0, capacity),
|
|
},
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
|
StrId: &schemapb.StringArray{
|
|
Data: make([]string, 0, capacity),
|
|
},
|
|
}
|
|
default:
|
|
return errors.New("unsupported pk type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success("search result is empty"),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: numQueries,
|
|
Topks: make([]int64, numQueries),
|
|
},
|
|
}
|
|
}
|
|
|
|
func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
|
|
defer sp.End()
|
|
|
|
log := log.Ctx(ctx)
|
|
// Decode all search results
|
|
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
|
|
if err != nil {
|
|
log.Warn("failed to decode search results", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
if len(validSearchResults) <= 0 {
|
|
log.Debug("reduced search results is empty, fill in empty result")
|
|
return fillInEmptyResult(nq), nil
|
|
}
|
|
|
|
// Reduce all search results
|
|
log.Debug("proxy search post execute reduce",
|
|
zap.Int64("collection", collectionID),
|
|
zap.Int64s("partitionIDs", partitionIDs),
|
|
zap.Int("number of valid search results", len(validSearchResults)))
|
|
var result *milvuspb.SearchResults
|
|
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType).
|
|
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
|
|
if err != nil {
|
|
log.Warn("failed to reduce search results", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
|
|
defer sp.End()
|
|
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
|
results := make([]*schemapb.SearchResultData, 0)
|
|
for _, partialSearchResult := range searchResults {
|
|
if partialSearchResult.SlicedBlob == nil {
|
|
continue
|
|
}
|
|
|
|
var partialResultData schemapb.SearchResultData
|
|
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
results = append(results, &partialResultData)
|
|
}
|
|
tr.CtxElapse(ctx, "decodeSearchResults done")
|
|
return results, nil
|
|
}
|
|
|
|
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
|
|
if data.NumQueries != nq {
|
|
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
|
}
|
|
if data.TopK != topk {
|
|
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
|
}
|
|
|
|
if len(data.Scores) != pkHitNum {
|
|
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
|
len(data.Scores), pkHitNum)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
|
|
var (
|
|
subSearchIdx = -1
|
|
resultDataIdx int64 = -1
|
|
)
|
|
maxScore := minFloat32
|
|
for i := range cursors {
|
|
if cursors[i] >= subSearchResultData[i].Topks[qi] {
|
|
continue
|
|
}
|
|
sIdx := subSearchNqOffset[i][qi] + cursors[i]
|
|
sScore := subSearchResultData[i].Scores[sIdx]
|
|
|
|
// Choose the larger score idx or the smaller pk idx with the same score
|
|
if subSearchIdx == -1 || sScore > maxScore {
|
|
subSearchIdx = i
|
|
resultDataIdx = sIdx
|
|
maxScore = sScore
|
|
} else if sScore == maxScore {
|
|
if subSearchIdx == -1 {
|
|
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
|
|
// by mistake.
|
|
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
|
|
} else if typeutil.ComparePK(
|
|
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
|
|
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
|
|
subSearchIdx = i
|
|
resultDataIdx = sIdx
|
|
maxScore = sScore
|
|
}
|
|
}
|
|
}
|
|
return subSearchIdx, resultDataIdx
|
|
}
|