milvus/internal/proxy/search_pipeline.go
congqixia a592cfc8b4
enhance: [2.6] extract shard client logic into dedicated package (#45018) (#45031)
Cherry-pick from master
pr: #45018 #45030
Related to #44761

Refactor proxy shard client management by creating a new
internal/proxy/shardclient package. This improves code organization and
modularity by:

- Moving load balancing logic (LookAsideBalancer, RoundRobinBalancer) to
shardclient package
- Extracting shard client manager and related interfaces into separate
package
- Relocating shard leader management and client lifecycle code
- Adding package documentation (README.md, OWNERS)
- Updating proxy code to use the new shardclient package interfaces

This change makes the shard client functionality more maintainable and
better encapsulated, reducing coupling in the proxy layer.

Also consolidates the proxy package's mockery generation to use a
centralized `.mockery.yaml` configuration file, aligning with the
pattern used by other packages like querycoordv2.

Changes
- **Makefile**: Replace multiple individual mockery commands with a
single config-based invocation for `generate-mockery-proxy` target
- **internal/proxy/.mockery.yaml**: Add mockery configuration defining
all mock interfaces for proxy and proxy/shardclient packages
- **Mock files**: Regenerate mocks using the new configuration:
  - `mock_cache.go`: Clean up by removing unused interface methods
  (credential, shard cache, policy methods)
  - `shardclient/mock_lb_balancer.go`: Update type comments
  (nodeInfo → NodeInfo)
  - `shardclient/mock_lb_policy.go`: Update formatting
  - `shardclient/mock_shardclient_manager.go`: Fix parameter naming
  consistency (nodeInfo1 → nodeInfo)
- **task_search_test.go**: Remove obsolete mock expectations for
deprecated cache methods

Benefits
- Centralized mockery configuration for easier maintenance
- Consistent with other packages (querycoordv2, etc.)
- Cleaner mock interfaces by removing unused methods
- Better type consistency in generated mocks

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2025-10-22 16:06:06 +08:00

1127 lines
33 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"bytes"
"context"
"fmt"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type opMsg map[string]any
type operator interface {
run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error)
}
type nodeDef struct {
name string
inputs []string
outputs []string
params map[string]any
opName string
}
type Node struct {
name string
inputs []string
outputs []string
op operator
}
func (n *Node) unpackInputs(msg opMsg) ([]any, error) {
for _, input := range n.inputs {
if _, ok := msg[input]; !ok {
return nil, fmt.Errorf("Node [%s]'s input %s not found", n.name, input)
}
}
inputs := make([]any, len(n.inputs))
for i, input := range n.inputs {
inputs[i] = msg[input]
}
return inputs, nil
}
func (n *Node) packOutputs(outputs []any, srcMsg opMsg) (opMsg, error) {
msg := srcMsg
if len(outputs) != len(n.outputs) {
return nil, fmt.Errorf("Node [%s] output size not match operator output size", n.name)
}
for i, output := range n.outputs {
msg[output] = outputs[i]
}
return msg, nil
}
func (n *Node) Run(ctx context.Context, span trace.Span, msg opMsg) (opMsg, error) {
inputs, err := n.unpackInputs(msg)
if err != nil {
return nil, err
}
ret, err := n.op.run(ctx, span, inputs...)
if err != nil {
return nil, err
}
outputs, err := n.packOutputs(ret, msg)
if err != nil {
return nil, err
}
return outputs, nil
}
const (
searchReduceOp = "search_reduce"
hybridSearchReduceOp = "hybrid_search_reduce"
rerankOp = "rerank"
requeryOp = "requery"
organizeOp = "organize"
filterFieldOp = "filter_field"
lambdaOp = "lambda"
)
var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, error){
searchReduceOp: newSearchReduceOperator,
hybridSearchReduceOp: newHybridSearchReduceOperator,
rerankOp: newRerankOperator,
organizeOp: newOrganizeOperator,
requeryOp: newRequeryOperator,
lambdaOp: newLambdaOperator,
filterFieldOp: newFilterFieldOperator,
}
func NewNode(info *nodeDef, t *searchTask) (*Node, error) {
n := Node{
name: info.name,
inputs: info.inputs,
outputs: info.outputs,
}
op, err := opFactory[info.opName](t, info.params)
if err != nil {
return nil, err
}
n.op = op
return &n, nil
}
type searchReduceOperator struct {
traceCtx context.Context
primaryFieldSchema *schemapb.FieldSchema
nq int64
topK int64
offset int64
collectionID int64
partitionIDs []int64
queryInfos []*planpb.QueryInfo
}
func newSearchReduceOperator(t *searchTask, _ map[string]any) (operator, error) {
pkField, err := t.schema.GetPkField()
if err != nil {
return nil, err
}
return &searchReduceOperator{
traceCtx: t.TraceCtx(),
primaryFieldSchema: pkField,
nq: t.GetNq(),
topK: t.GetTopk(),
offset: t.GetOffset(),
collectionID: t.GetCollectionID(),
partitionIDs: t.GetPartitionIDs(),
queryInfos: t.queryInfos,
}, nil
}
func (op *searchReduceOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
_, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "searchReduceOperator")
defer sp.End()
toReduceResults := inputs[0].([]*internalpb.SearchResults)
metricType := getMetricType(toReduceResults)
result, err := reduceResults(
op.traceCtx, toReduceResults, op.nq, op.topK, op.offset,
metricType, op.primaryFieldSchema.GetDataType(), op.queryInfos[0], false, op.collectionID, op.partitionIDs)
if err != nil {
return nil, err
}
return []any{[]*milvuspb.SearchResults{result}, []string{metricType}}, nil
}
type hybridSearchReduceOperator struct {
traceCtx context.Context
subReqs []*internalpb.SubSearchRequest
primaryFieldSchema *schemapb.FieldSchema
collectionID int64
partitionIDs []int64
queryInfos []*planpb.QueryInfo
}
func newHybridSearchReduceOperator(t *searchTask, _ map[string]any) (operator, error) {
pkField, err := t.schema.GetPkField()
if err != nil {
return nil, err
}
return &hybridSearchReduceOperator{
traceCtx: t.TraceCtx(),
subReqs: t.GetSubReqs(),
primaryFieldSchema: pkField,
collectionID: t.GetCollectionID(),
partitionIDs: t.GetPartitionIDs(),
queryInfos: t.queryInfos,
}, nil
}
func (op *hybridSearchReduceOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
_, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "hybridSearchReduceOperator")
defer sp.End()
toReduceResults := inputs[0].([]*internalpb.SearchResults)
// Collecting the results of a subsearch
// [[shard1, shard2, ...],[shard1, shard2, ...]]
multipleInternalResults := make([][]*internalpb.SearchResults, len(op.subReqs))
for _, searchResult := range toReduceResults {
// if get a non-advanced result, skip all
if !searchResult.GetIsAdvanced() {
continue
}
for _, subResult := range searchResult.GetSubResults() {
// swallow copy
internalResults := &internalpb.SearchResults{
MetricType: subResult.GetMetricType(),
NumQueries: subResult.GetNumQueries(),
TopK: subResult.GetTopK(),
SlicedBlob: subResult.GetSlicedBlob(),
SlicedNumCount: subResult.GetSlicedNumCount(),
SlicedOffset: subResult.GetSlicedOffset(),
IsAdvanced: false,
}
reqIndex := subResult.GetReqIndex()
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
}
}
multipleMilvusResults := make([]*milvuspb.SearchResults, len(op.subReqs))
searchMetrics := []string{}
for index, internalResults := range multipleInternalResults {
subReq := op.subReqs[index]
// Since the metrictype in the request may be empty, it can only be obtained from the result
subMetricType := getMetricType(internalResults)
result, err := reduceResults(
op.traceCtx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType,
op.primaryFieldSchema.GetDataType(), op.queryInfos[index], true, op.collectionID, op.partitionIDs)
if err != nil {
return nil, err
}
searchMetrics = append(searchMetrics, subMetricType)
multipleMilvusResults[index] = result
}
return []any{multipleMilvusResults, searchMetrics}, nil
}
type rerankOperator struct {
nq int64
topK int64
offset int64
roundDecimal int64
groupByFieldId int64
groupSize int64
strictGroupSize bool
groupScorerStr string
functionScore *rerank.FunctionScore
}
func newRerankOperator(t *searchTask, params map[string]any) (operator, error) {
if t.SearchRequest.GetIsAdvanced() {
return &rerankOperator{
nq: t.GetNq(),
topK: t.rankParams.limit,
offset: t.rankParams.offset,
roundDecimal: t.rankParams.roundDecimal,
groupByFieldId: t.rankParams.groupByFieldId,
groupSize: t.rankParams.groupSize,
strictGroupSize: t.rankParams.strictGroupSize,
groupScorerStr: getGroupScorerStr(t.request.GetSearchParams()),
functionScore: t.functionScore,
}, nil
}
return &rerankOperator{
nq: t.SearchRequest.GetNq(),
topK: t.SearchRequest.GetTopk(),
offset: 0, // Search performs Offset in the reduce phase
roundDecimal: t.queryInfos[0].RoundDecimal,
groupByFieldId: t.queryInfos[0].GroupByFieldId,
groupSize: t.queryInfos[0].GroupSize,
strictGroupSize: t.queryInfos[0].StrictGroupSize,
groupScorerStr: getGroupScorerStr(t.request.GetSearchParams()),
functionScore: t.functionScore,
}, nil
}
func (op *rerankOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "rerankOperator")
defer sp.End()
reducedResults := inputs[0].([]*milvuspb.SearchResults)
metrics := inputs[1].([]string)
rankInputs := []*milvuspb.SearchResults{}
rankMetrics := []string{}
for idx, ret := range reducedResults {
rankInputs = append(rankInputs, ret)
rankMetrics = append(rankMetrics, metrics[idx])
}
params := rerank.NewSearchParams(op.nq, op.topK, op.offset, op.roundDecimal, op.groupByFieldId,
op.groupSize, op.strictGroupSize, op.groupScorerStr, rankMetrics)
ret, err := op.functionScore.Process(ctx, params, rankInputs)
if err != nil {
return nil, err
}
return []any{ret}, nil
}
type requeryOperator struct {
traceCtx context.Context
outputFieldNames []string
timestamp uint64
dbName string
collectionName string
notReturnAllMeta bool
partitionNames []string
partitionIDs []int64
primaryFieldSchema *schemapb.FieldSchema
queryChannelsTs map[string]Timestamp
consistencyLevel commonpb.ConsistencyLevel
guaranteeTimestamp uint64
node types.ProxyComponent
}
func newRequeryOperator(t *searchTask, _ map[string]any) (operator, error) {
pkField, err := t.schema.GetPkField()
if err != nil {
return nil, err
}
outputFieldNames := typeutil.NewSet(t.translatedOutputFields...)
if t.SearchRequest.GetIsAdvanced() {
outputFieldNames.Insert(t.functionScore.GetAllInputFieldNames()...)
}
return &requeryOperator{
traceCtx: t.TraceCtx(),
outputFieldNames: outputFieldNames.Collect(),
timestamp: t.BeginTs(),
dbName: t.request.GetDbName(),
collectionName: t.request.GetCollectionName(),
primaryFieldSchema: pkField,
queryChannelsTs: t.queryChannelsTs,
consistencyLevel: t.SearchRequest.GetConsistencyLevel(),
guaranteeTimestamp: t.SearchRequest.GetGuaranteeTimestamp(),
notReturnAllMeta: t.request.GetNotReturnAllMeta(),
partitionNames: t.request.GetPartitionNames(),
partitionIDs: t.SearchRequest.GetPartitionIDs(),
node: t.node,
}, nil
}
func (op *requeryOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
allIDs := inputs[0].(*schemapb.IDs)
storageCostFromLastOp := inputs[1].(segcore.StorageCost)
if typeutil.GetSizeOfIDs(allIDs) == 0 {
return []any{[]*schemapb.FieldData{}, storageCostFromLastOp}, nil
}
queryResult, storageCost, err := op.requery(ctx, span, allIDs, op.outputFieldNames)
if err != nil {
return nil, err
}
storageCost.ScannedRemoteBytes += storageCostFromLastOp.ScannedRemoteBytes
storageCost.ScannedTotalBytes += storageCostFromLastOp.ScannedTotalBytes
return []any{queryResult.GetFieldsData(), storageCost}, nil
}
func (op *requeryOperator) requery(ctx context.Context, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, segcore.StorageCost, error) {
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
Timestamp: op.timestamp,
},
DbName: op.dbName,
CollectionName: op.collectionName,
ConsistencyLevel: op.consistencyLevel,
NotReturnAllMeta: op.notReturnAllMeta,
Expr: "",
OutputFields: outputFields,
PartitionNames: op.partitionNames,
UseDefaultConsistency: false,
GuaranteeTimestamp: op.guaranteeTimestamp,
}
plan := planparserv2.CreateRequeryPlan(op.primaryFieldSchema, ids)
channelsMvcc := make(map[string]Timestamp)
for k, v := range op.queryChannelsTs {
channelsMvcc[k] = v
}
qt := &queryTask{
ctx: op.traceCtx,
Condition: NewTaskCondition(op.traceCtx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
PartitionIDs: op.partitionIDs, // use search partitionIDs
ConsistencyLevel: op.consistencyLevel,
},
request: queryReq,
plan: plan,
mixCoord: op.node.(*Proxy).mixCoord,
lb: op.node.(*Proxy).lbPolicy,
shardclientMgr: op.node.(*Proxy).shardMgr,
channelsMvcc: channelsMvcc,
fastSkip: true,
reQuery: true,
}
queryResult, storageCost, err := op.node.(*Proxy).query(op.traceCtx, qt, span)
if err != nil {
return nil, segcore.StorageCost{}, err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return nil, segcore.StorageCost{}, merr.Error(queryResult.GetStatus())
}
return queryResult, storageCost, nil
}
type organizeOperator struct {
traceCtx context.Context
primaryFieldSchema *schemapb.FieldSchema
collectionID int64
}
func newOrganizeOperator(t *searchTask, _ map[string]any) (operator, error) {
pkField, err := t.schema.GetPkField()
if err != nil {
return nil, err
}
return &organizeOperator{
traceCtx: t.TraceCtx(),
primaryFieldSchema: pkField,
collectionID: t.SearchRequest.GetCollectionID(),
}, nil
}
func (op *organizeOperator) emptyFieldDataAccordingFieldSchema(fieldData *schemapb.FieldData) *schemapb.FieldData {
ret := &schemapb.FieldData{
Type: fieldData.Type,
FieldName: fieldData.FieldName,
FieldId: fieldData.FieldId,
IsDynamic: fieldData.IsDynamic,
ValidData: make([]bool, 0),
}
if fieldData.Type == schemapb.DataType_FloatVector ||
fieldData.Type == schemapb.DataType_BinaryVector ||
fieldData.Type == schemapb.DataType_BFloat16Vector ||
fieldData.Type == schemapb.DataType_Float16Vector ||
fieldData.Type == schemapb.DataType_Int8Vector {
ret.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: fieldData.GetVectors().GetDim(),
},
}
}
return ret
}
func (op *organizeOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
_, sp := otel.Tracer(typeutil.ProxyRole).Start(op.traceCtx, "organizeOperator")
defer sp.End()
fields := inputs[0].([]*schemapb.FieldData)
var idsList []*schemapb.IDs
switch inputs[1].(type) {
case *schemapb.IDs:
idsList = []*schemapb.IDs{inputs[1].(*schemapb.IDs)}
case []*schemapb.IDs:
idsList = inputs[1].([]*schemapb.IDs)
default:
panic(fmt.Sprintf("invalid ids type: %T", inputs[1]))
}
if len(fields) == 0 {
emptyFields := make([][]*schemapb.FieldData, len(idsList))
return []any{emptyFields}, nil
}
pkFieldData, err := typeutil.GetPrimaryFieldData(fields, op.primaryFieldSchema)
if err != nil {
return nil, err
}
offsets := make(map[any]int)
pkItr := typeutil.GetDataIterator(pkFieldData)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := pkItr(i)
offsets[pk] = i
}
allFieldData := make([][]*schemapb.FieldData, len(idsList))
for idx, ids := range idsList {
if typeutil.GetSizeOfIDs(ids) == 0 {
emptyFields := []*schemapb.FieldData{}
for _, field := range fields {
emptyFields = append(emptyFields, op.emptyFieldDataAccordingFieldSchema(field))
}
allFieldData[idx] = emptyFields
continue
}
if fieldData, err := pickFieldData(ids, offsets, fields, op.collectionID); err != nil {
return nil, err
} else {
allFieldData[idx] = fieldData
}
}
return []any{allFieldData}, nil
}
func pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, fields []*schemapb.FieldData, collectionID int64) ([]*schemapb.FieldData, error) {
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
// 3 2 5 4 1 (query ids)
// ||
// || (query)
// \/
// 4 3 5 1 2 (result ids)
// v4 v3 v5 v1 v2 (result vectors)
// ||
// || (reorganize)
// \/
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
fieldsData := make([]*schemapb.FieldData, len(fields))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := pkOffset[id]; !ok {
return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(pkOffset), collectionID))
}
typeutil.AppendFieldData(fieldsData, fields, int64(pkOffset[id]))
}
return fieldsData, nil
}
const (
lambdaParamKey = "lambda"
)
type lambdaOperator struct {
f func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error)
}
func newLambdaOperator(_ *searchTask, params map[string]any) (operator, error) {
return &lambdaOperator{
f: params[lambdaParamKey].(func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error)),
}, nil
}
func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return op.f(ctx, span, inputs...)
}
type filterFieldOperator struct {
outputFieldNames []string
fieldSchemas []*schemapb.FieldSchema
}
func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) {
return &filterFieldOperator{
outputFieldNames: t.translatedOutputFields,
fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
}, nil
}
func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
for _, retField := range result.Results.FieldsData {
for _, fieldSchema := range op.fieldSchemas {
if retField != nil && retField.FieldId == fieldSchema.FieldID {
retField.FieldName = fieldSchema.Name
retField.Type = fieldSchema.DataType
retField.IsDynamic = fieldSchema.IsDynamic
}
}
}
result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(field *schemapb.FieldData, _ int) bool {
return lo.Contains(op.outputFieldNames, field.FieldName)
})
return []any{result}, nil
}
func mergeIDsFunc(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
multipleMilvusResults := inputs[0].([]*milvuspb.SearchResults)
idInt64Type := false
idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) {
if m.GetResults().GetIds().GetIntId() != nil {
idInt64Type = true
}
return m.Results.Ids, true
})
uniqueIDs := &schemapb.IDs{}
if idInt64Type {
idsSet := typeutil.NewSet[int64]()
for _, ids := range idsList {
if data := ids.GetIntId().GetData(); data != nil {
idsSet.Insert(data...)
}
}
uniqueIDs.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: idsSet.Collect(),
},
}
} else {
idsSet := typeutil.NewSet[string]()
for _, ids := range idsList {
if data := ids.GetStrId().GetData(); data != nil {
idsSet.Insert(data...)
}
}
uniqueIDs.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: idsSet.Collect(),
},
}
}
return []any{uniqueIDs}, nil
}
type pipeline struct {
name string
nodes []*Node
}
func newPipeline(pipeDef *pipelineDef, t *searchTask) (*pipeline, error) {
nodes := make([]*Node, len(pipeDef.nodes))
for i, def := range pipeDef.nodes {
node, err := NewNode(def, t)
if err != nil {
return nil, err
}
nodes[i] = node
}
return &pipeline{name: pipeDef.name, nodes: nodes}, nil
}
func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults, storageCost segcore.StorageCost) (*milvuspb.SearchResults, segcore.StorageCost, error) {
log.Ctx(ctx).Debug("SearchPipeline run", zap.String("pipeline", p.String()))
msg := opMsg{}
msg["input"] = toReduceResults
msg["storage_cost"] = storageCost
for _, node := range p.nodes {
var err error
log.Ctx(ctx).Debug("SearchPipeline run node", zap.String("node", node.name))
msg, err = node.Run(ctx, span, msg)
if err != nil {
log.Ctx(ctx).Error("Run node failed: ", zap.String("err", err.Error()))
return nil, storageCost, err
}
}
return msg["output"].(*milvuspb.SearchResults), msg["storage_cost"].(segcore.StorageCost), nil
}
func (p *pipeline) String() string {
buf := bytes.NewBufferString(fmt.Sprintf("SearchPipeline: %s", p.name))
for _, node := range p.nodes {
buf.WriteString(fmt.Sprintf(" %s -> %s", node.name, node.outputs))
}
return buf.String()
}
type pipelineDef struct {
name string
nodes []*nodeDef
}
var searchPipe = &pipelineDef{
name: "search",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
{
name: "pick",
inputs: []string{"reduced"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].([]*milvuspb.SearchResults)[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
var searchWithRequeryPipe = &pipelineDef{
name: "searchWithRequery",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
{
name: "merge",
inputs: []string{"reduced"},
outputs: []string{"unique_ids"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: mergeIDsFunc,
},
},
{
name: "requery",
inputs: []string{"unique_ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
opName: requeryOp,
},
{
name: "gen_ids",
inputs: []string{"reduced"},
outputs: []string{"ids"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{[]*schemapb.IDs{inputs[0].([]*milvuspb.SearchResults)[0].Results.Ids}}, nil
},
},
},
{
name: "organize",
inputs: []string{"fields", "ids"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "pick",
inputs: []string{"reduced", "organized_fields"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].([]*milvuspb.SearchResults)[0]
fields := inputs[1].([][]*schemapb.FieldData)
result.Results.FieldsData = fields[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
var searchWithRerankPipe = &pipelineDef{
name: "searchWithRerank",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
{
name: "rerank",
inputs: []string{"reduced", "metrics"},
outputs: []string{"rank_result"},
opName: rerankOp,
},
{
name: "pick",
inputs: []string{"reduced", "rank_result"},
outputs: []string{"ids", "fields"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{
inputs[0].([]*milvuspb.SearchResults)[0].Results.FieldsData,
[]*schemapb.IDs{inputs[1].(*milvuspb.SearchResults).Results.Ids},
}, nil
},
},
opName: lambdaOp,
},
{
name: "organize",
inputs: []string{"ids", "fields"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "result",
inputs: []string{"rank_result", "organized_fields"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
fields := inputs[1].([][]*schemapb.FieldData)
result.Results.FieldsData = fields[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
var searchWithRerankRequeryPipe = &pipelineDef{
name: "searchWithRerankRequery",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
{
name: "rerank",
inputs: []string{"reduced", "metrics"},
outputs: []string{"rank_result"},
opName: rerankOp,
},
{
name: "pick_ids",
inputs: []string{"rank_result"},
outputs: []string{"ids"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{
inputs[0].(*milvuspb.SearchResults).Results.Ids,
}, nil
},
},
opName: lambdaOp,
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
opName: requeryOp,
},
{
name: "to_ids_list",
inputs: []string{"ids"},
outputs: []string{"ids"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{[]*schemapb.IDs{inputs[0].(*schemapb.IDs)}}, nil
},
},
},
{
name: "organize",
inputs: []string{"fields", "ids"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "result",
inputs: []string{"rank_result", "organized_fields"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
fields := inputs[1].([][]*schemapb.FieldData)
result.Results.FieldsData = fields[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
var hybridSearchPipe = &pipelineDef{
name: "hybridSearchPipe",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
{
name: "rerank",
inputs: []string{"reduced", "metrics"},
outputs: []string{"result"},
opName: rerankOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
name: "hybridSearchWithRequeryAndRerankByDataPipe",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
{
name: "merge_ids",
inputs: []string{"reduced"},
outputs: []string{"ids"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: mergeIDsFunc,
},
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
opName: requeryOp,
},
{
name: "parse_ids",
inputs: []string{"reduced"},
outputs: []string{"id_list"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
multipleMilvusResults := inputs[0].([]*milvuspb.SearchResults)
idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) {
return m.Results.Ids, true
})
return []any{idsList}, nil
},
},
},
{
name: "organize_rank_data",
inputs: []string{"fields", "id_list"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "gen_rank_data",
inputs: []string{"reduced", "organized_fields"},
outputs: []string{"rank_data"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
results := inputs[0].([]*milvuspb.SearchResults)
fields := inputs[1].([][]*schemapb.FieldData)
for i := 0; i < len(results); i++ {
results[i].Results.FieldsData = fields[i]
}
return []any{results}, nil
},
},
},
{
name: "rerank",
inputs: []string{"rank_data", "metrics"},
outputs: []string{"rank_result"},
opName: rerankOp,
},
{
name: "pick_ids",
inputs: []string{"rank_result"},
outputs: []string{"ids"},
opName: lambdaOp,
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{[]*schemapb.IDs{inputs[0].(*milvuspb.SearchResults).Results.Ids}}, nil
},
},
},
{
name: "organize_result",
inputs: []string{"fields", "ids"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "result",
inputs: []string{"rank_result", "organized_fields"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
fields := inputs[1].([][]*schemapb.FieldData)
result.Results.FieldsData = fields[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
var hybridSearchWithRequeryPipe = &pipelineDef{
name: "hybridSearchWithRequeryPipe",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
{
name: "rerank",
inputs: []string{"reduced", "metrics"},
outputs: []string{"rank_result"},
opName: rerankOp,
},
{
name: "pick_ids",
inputs: []string{"rank_result"},
outputs: []string{"ids"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{
inputs[0].(*milvuspb.SearchResults).Results.Ids,
}, nil
},
},
opName: lambdaOp,
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
opName: requeryOp,
},
{
name: "organize",
inputs: []string{"fields", "ids"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "result",
inputs: []string{"rank_result", "organized_fields"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
fields := inputs[1].([][]*schemapb.FieldData)
result.Results.FieldsData = fields[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil {
return newPipeline(searchPipe, t)
}
if !t.SearchRequest.GetIsAdvanced() && t.needRequery && t.functionScore == nil {
return newPipeline(searchWithRequeryPipe, t)
}
if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore != nil {
return newPipeline(searchWithRerankPipe, t)
}
if !t.SearchRequest.GetIsAdvanced() && t.needRequery && t.functionScore != nil {
return newPipeline(searchWithRerankRequeryPipe, t)
}
if t.SearchRequest.GetIsAdvanced() && !t.needRequery {
return newPipeline(hybridSearchPipe, t)
}
if t.SearchRequest.GetIsAdvanced() && t.needRequery {
if len(t.functionScore.GetAllInputFieldIDs()) > 0 {
// When the function score need field data, we need to requery to fetch the field data before rerank.
// The requery will fetch the field data of all search results,
// so there's some memory overhead.
return newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, t)
} else {
// Otherwise, we can rerank and limit the requery size to the limit.
// so the memory overhead is less than the hybridSearchWithRequeryAndRerankByFieldDataPipe.
return newPipeline(hybridSearchWithRequeryPipe, t)
}
}
return nil, fmt.Errorf("Unsupported pipeline")
}