diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 66a9e604eb..34a7006596 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -3111,7 +3111,6 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, lb: node.lbPolicy, enableMaterializedView: node.enableMaterializedView, mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(), - requeryFunc: requeryImpl, } log := log.Ctx(ctx).With( // TODO: it might cause some cpu consumption @@ -3354,7 +3353,6 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea node: node, lb: node.lbPolicy, mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(), - requeryFunc: requeryImpl, } log := log.Ctx(ctx).With( diff --git a/internal/proxy/search_pipeline.go b/internal/proxy/search_pipeline.go new file mode 100644 index 0000000000..752648c29d --- /dev/null +++ b/internal/proxy/search_pipeline.go @@ -0,0 +1,1017 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "bytes" + "context" + "fmt" + + "github.com/samber/lo" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/function/rerank" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +type opMsg map[string]any + +type operator interface { + run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) +} + +type nodeDef struct { + name string + inputs []string + outputs []string + params map[string]any + opName string +} + +type Node struct { + name string + inputs []string + outputs []string + + op operator +} + +func (n *Node) unpackInputs(msg opMsg) ([]any, error) { + for _, input := range n.inputs { + if _, ok := msg[input]; !ok { + return nil, fmt.Errorf("Node [%s]'s input %s not found", n.name, input) + } + } + inputs := make([]any, len(n.inputs)) + for i, input := range n.inputs { + inputs[i] = msg[input] + } + return inputs, nil +} + +func (n *Node) packOutputs(outputs []any, srcMsg opMsg) (opMsg, error) { + msg := srcMsg + if len(outputs) != len(n.outputs) { + return nil, fmt.Errorf("Node [%s] output size not match operator output size", n.name) + } + for i, output := range n.outputs { + msg[output] = outputs[i] + } + return msg, nil +} + +func (n *Node) Run(ctx context.Context, span trace.Span, msg opMsg) (opMsg, error) { + inputs, err := n.unpackInputs(msg) + if err != nil { + return nil, err + } + ret, err := n.op.run(ctx, span, inputs...) + if err != nil { + return nil, err + } + outputs, err := n.packOutputs(ret, msg) + if err != nil { + return nil, err + } + return outputs, nil +} + +const ( + searchReduceOp = "search_reduce" + hybridSearchReduceOp = "hybrid_search_reduce" + rerankOp = "rerank" + requeryOp = "requery" + organizeOp = "organize" + filterFieldOp = "filter_field" + lambdaOp = "lambda" +) + +var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, error){ + searchReduceOp: newSearchReduceOperator, + hybridSearchReduceOp: newHybridSearchReduceOperator, + rerankOp: newRerankOperator, + organizeOp: newOrganizeOperator, + requeryOp: newRequeryOperator, + lambdaOp: newLambdaOperator, + filterFieldOp: newFilterFieldOperator, +} + +func NewNode(info *nodeDef, t *searchTask) (*Node, error) { + n := Node{ + name: info.name, + inputs: info.inputs, + outputs: info.outputs, + } + op, err := opFactory[info.opName](t, info.params) + if err != nil { + return nil, err + } + n.op = op + return &n, nil +} + +type searchReduceOperator struct { + traceCtx context.Context + primaryFieldSchema *schemapb.FieldSchema + nq int64 + topK int64 + offset int64 + collectionID int64 + partitionIDs []int64 + queryInfos []*planpb.QueryInfo +} + +func newSearchReduceOperator(t *searchTask, _ map[string]any) (operator, error) { + pkField, err := t.schema.GetPkField() + if err != nil { + return nil, err + } + return &searchReduceOperator{ + traceCtx: t.TraceCtx(), + primaryFieldSchema: pkField, + nq: t.GetNq(), + topK: t.GetTopk(), + offset: t.GetOffset(), + collectionID: t.GetCollectionID(), + partitionIDs: t.GetPartitionIDs(), + queryInfos: t.queryInfos, + }, nil +} + +func (op *searchReduceOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + _, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "searchReduceOperator") + defer sp.End() + toReduceResults := inputs[0].([]*internalpb.SearchResults) + metricType := getMetricType(toReduceResults) + result, err := reduceResults( + op.traceCtx, toReduceResults, op.nq, op.topK, op.offset, + metricType, op.primaryFieldSchema.GetDataType(), op.queryInfos[0], false, op.collectionID, op.partitionIDs) + if err != nil { + return nil, err + } + return []any{[]*milvuspb.SearchResults{result}, []string{metricType}}, nil +} + +type hybridSearchReduceOperator struct { + traceCtx context.Context + subReqs []*internalpb.SubSearchRequest + primaryFieldSchema *schemapb.FieldSchema + collectionID int64 + partitionIDs []int64 + queryInfos []*planpb.QueryInfo +} + +func newHybridSearchReduceOperator(t *searchTask, _ map[string]any) (operator, error) { + pkField, err := t.schema.GetPkField() + if err != nil { + return nil, err + } + return &hybridSearchReduceOperator{ + traceCtx: t.TraceCtx(), + subReqs: t.GetSubReqs(), + primaryFieldSchema: pkField, + collectionID: t.GetCollectionID(), + partitionIDs: t.GetPartitionIDs(), + queryInfos: t.queryInfos, + }, nil +} + +func (op *hybridSearchReduceOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + _, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "hybridSearchReduceOperator") + defer sp.End() + toReduceResults := inputs[0].([]*internalpb.SearchResults) + // Collecting the results of a subsearch + // [[shard1, shard2, ...],[shard1, shard2, ...]] + multipleInternalResults := make([][]*internalpb.SearchResults, len(op.subReqs)) + for _, searchResult := range toReduceResults { + // if get a non-advanced result, skip all + if !searchResult.GetIsAdvanced() { + continue + } + for _, subResult := range searchResult.GetSubResults() { + // swallow copy + internalResults := &internalpb.SearchResults{ + MetricType: subResult.GetMetricType(), + NumQueries: subResult.GetNumQueries(), + TopK: subResult.GetTopK(), + SlicedBlob: subResult.GetSlicedBlob(), + SlicedNumCount: subResult.GetSlicedNumCount(), + SlicedOffset: subResult.GetSlicedOffset(), + IsAdvanced: false, + } + reqIndex := subResult.GetReqIndex() + multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults) + } + } + + multipleMilvusResults := make([]*milvuspb.SearchResults, len(op.subReqs)) + searchMetrics := []string{} + for index, internalResults := range multipleInternalResults { + subReq := op.subReqs[index] + // Since the metrictype in the request may be empty, it can only be obtained from the result + subMetricType := getMetricType(internalResults) + result, err := reduceResults( + op.traceCtx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, + op.primaryFieldSchema.GetDataType(), op.queryInfos[index], true, op.collectionID, op.partitionIDs) + if err != nil { + return nil, err + } + + searchMetrics = append(searchMetrics, subMetricType) + multipleMilvusResults[index] = result + } + return []any{multipleMilvusResults, searchMetrics}, nil +} + +type rerankOperator struct { + nq int64 + topK int64 + offset int64 + roundDecimal int64 + groupByFieldId int64 + groupSize int64 + strictGroupSize bool + groupScorerStr string + + functionScore *rerank.FunctionScore +} + +func newRerankOperator(t *searchTask, _ map[string]any) (operator, error) { + if t.SearchRequest.GetIsAdvanced() { + return &rerankOperator{ + nq: t.GetNq(), + topK: t.rankParams.limit, + offset: t.rankParams.offset, + roundDecimal: t.rankParams.roundDecimal, + groupByFieldId: t.rankParams.groupByFieldId, + groupSize: t.rankParams.groupSize, + strictGroupSize: t.rankParams.strictGroupSize, + groupScorerStr: getGroupScorerStr(t.request.GetSearchParams()), + functionScore: t.functionScore, + }, nil + } + return &rerankOperator{ + nq: t.SearchRequest.GetNq(), + topK: t.SearchRequest.GetTopk(), + offset: 0, // Search performs Offset in the reduce phase + roundDecimal: t.queryInfos[0].RoundDecimal, + groupByFieldId: t.queryInfos[0].GroupByFieldId, + groupSize: t.queryInfos[0].GroupSize, + strictGroupSize: t.queryInfos[0].StrictGroupSize, + groupScorerStr: getGroupScorerStr(t.request.GetSearchParams()), + functionScore: t.functionScore, + }, nil +} + +func (op *rerankOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "rerankOperator") + defer sp.End() + + reducedResults := inputs[0].([]*milvuspb.SearchResults) + metrics := inputs[1].([]string) + rankInputs := []*milvuspb.SearchResults{} + rankMetrics := []string{} + for idx, ret := range reducedResults { + if typeutil.GetSizeOfIDs(ret.Results.Ids) == 0 { + continue + } + rankInputs = append(rankInputs, ret) + rankMetrics = append(rankMetrics, metrics[idx]) + } + params := rerank.NewSearchParams(op.nq, op.topK, op.offset, op.roundDecimal, op.groupByFieldId, + op.groupSize, op.strictGroupSize, op.groupScorerStr, rankMetrics) + ret, err := op.functionScore.Process(ctx, params, rankInputs) + if err != nil { + return nil, err + } + return []any{ret}, nil +} + +type requeryOperator struct { + traceCtx context.Context + outputFieldNames []string + + timestamp uint64 + dbName string + collectionName string + notReturnAllMeta bool + partitionNames []string + partitionIDs []int64 + primaryFieldSchema *schemapb.FieldSchema + queryChannelsTs map[string]Timestamp + consistencyLevel commonpb.ConsistencyLevel + guaranteeTimestamp uint64 + + node types.ProxyComponent +} + +func newRequeryOperator(t *searchTask, _ map[string]any) (operator, error) { + pkField, err := t.schema.GetPkField() + if err != nil { + return nil, err + } + outputFieldNames := typeutil.NewSet[string](t.translatedOutputFields...) + if t.SearchRequest.GetIsAdvanced() { + outputFieldNames.Insert(t.functionScore.GetAllInputFieldNames()...) + } + return &requeryOperator{ + traceCtx: t.TraceCtx(), + outputFieldNames: outputFieldNames.Collect(), + timestamp: t.BeginTs(), + dbName: t.request.GetDbName(), + collectionName: t.request.GetCollectionName(), + primaryFieldSchema: pkField, + queryChannelsTs: t.queryChannelsTs, + consistencyLevel: t.SearchRequest.GetConsistencyLevel(), + guaranteeTimestamp: t.SearchRequest.GetGuaranteeTimestamp(), + notReturnAllMeta: t.request.GetNotReturnAllMeta(), + partitionNames: t.request.GetPartitionNames(), + partitionIDs: t.SearchRequest.GetPartitionIDs(), + node: t.node, + }, nil +} + +func (op *requeryOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + allIDs := inputs[0].(*schemapb.IDs) + if typeutil.GetSizeOfIDs(allIDs) == 0 { + return []any{[]*schemapb.FieldData{}}, nil + } + + queryResult, err := op.requery(ctx, span, allIDs, op.outputFieldNames) + if err != nil { + return nil, err + } + return []any{queryResult.GetFieldsData()}, nil +} + +func (op *requeryOperator) requery(ctx context.Context, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { + queryReq := &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + Timestamp: op.timestamp, + }, + DbName: op.dbName, + CollectionName: op.collectionName, + ConsistencyLevel: op.consistencyLevel, + NotReturnAllMeta: op.notReturnAllMeta, + Expr: "", + OutputFields: outputFields, + PartitionNames: op.partitionNames, + UseDefaultConsistency: false, + GuaranteeTimestamp: op.guaranteeTimestamp, + } + plan := planparserv2.CreateRequeryPlan(op.primaryFieldSchema, ids) + channelsMvcc := make(map[string]Timestamp) + for k, v := range op.queryChannelsTs { + channelsMvcc[k] = v + } + qt := &queryTask{ + ctx: op.traceCtx, + Condition: NewTaskCondition(op.traceCtx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + ReqID: paramtable.GetNodeID(), + PartitionIDs: op.partitionIDs, // use search partitionIDs + ConsistencyLevel: op.consistencyLevel, + }, + request: queryReq, + plan: plan, + mixCoord: op.node.(*Proxy).mixCoord, + lb: op.node.(*Proxy).lbPolicy, + channelsMvcc: channelsMvcc, + fastSkip: true, + reQuery: true, + } + queryResult, err := op.node.(*Proxy).query(op.traceCtx, qt, span) + if err != nil { + return nil, err + } + + if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return nil, merr.Error(queryResult.GetStatus()) + } + return queryResult, nil +} + +type organizeOperator struct { + traceCtx context.Context + primaryFieldSchema *schemapb.FieldSchema + collectionID int64 +} + +func newOrganizeOperator(t *searchTask, _ map[string]any) (operator, error) { + pkField, err := t.schema.GetPkField() + if err != nil { + return nil, err + } + return &organizeOperator{ + traceCtx: t.TraceCtx(), + primaryFieldSchema: pkField, + collectionID: t.SearchRequest.GetCollectionID(), + }, nil +} + +func (op *organizeOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + _, sp := otel.Tracer(typeutil.ProxyRole).Start(op.traceCtx, "organizeOperator") + defer sp.End() + + fields := inputs[0].([]*schemapb.FieldData) + idsList := inputs[1].([]*schemapb.IDs) + if len(fields) == 0 { + emptyFields := make([][]*schemapb.FieldData, len(idsList)) + return []any{emptyFields}, nil + } + pkFieldData, err := typeutil.GetPrimaryFieldData(fields, op.primaryFieldSchema) + if err != nil { + return nil, err + } + offsets := make(map[any]int) + pkItr := typeutil.GetDataIterator(pkFieldData) + for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ { + pk := pkItr(i) + offsets[pk] = i + } + + allFieldData := make([][]*schemapb.FieldData, len(idsList)) + for idx, ids := range idsList { + if typeutil.GetSizeOfIDs(ids) == 0 { + emptyFields := []*schemapb.FieldData{} + for _, field := range fields { + emptyFields = append(emptyFields, &schemapb.FieldData{ + Type: field.Type, + FieldName: field.FieldName, + FieldId: field.FieldId, + IsDynamic: field.IsDynamic, + }) + } + allFieldData[idx] = emptyFields + continue + } + if fieldData, err := pickFieldData(ids, offsets, fields, op.collectionID); err != nil { + return nil, err + } else { + allFieldData[idx] = fieldData + } + } + return []any{allFieldData}, nil +} + +func pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, fields []*schemapb.FieldData, collectionID int64) ([]*schemapb.FieldData, error) { + // Reorganize Results. The order of query result ids will be altered and differ from queried ids. + // We should reorganize query results to keep the order of original queried ids. For example: + // =========================================== + // 3 2 5 4 1 (query ids) + // || + // || (query) + // \/ + // 4 3 5 1 2 (result ids) + // v4 v3 v5 v1 v2 (result vectors) + // || + // || (reorganize) + // \/ + // 3 2 5 4 1 (result ids) + // v3 v2 v5 v4 v1 (result vectors) + // =========================================== + fieldsData := make([]*schemapb.FieldData, len(fields)) + for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ { + id := typeutil.GetPK(ids, int64(i)) + if _, ok := pkOffset[id]; !ok { + return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d", + id, typeutil.GetSizeOfIDs(ids), len(pkOffset), collectionID)) + } + typeutil.AppendFieldData(fieldsData, fields, int64(pkOffset[id])) + } + + return fieldsData, nil +} + +const ( + lambdaParamKey = "lambda" +) + +type lambdaOperator struct { + f func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) +} + +func newLambdaOperator(_ *searchTask, params map[string]any) (operator, error) { + return &lambdaOperator{ + f: params[lambdaParamKey].(func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error)), + }, nil +} + +func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return op.f(ctx, span, inputs...) +} + +type filterFieldOperator struct { + outputFieldNames []string + schema *schemaInfo +} + +func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) { + return &filterFieldOperator{ + outputFieldNames: t.translatedOutputFields, + schema: t.schema, + }, nil +} + +func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].(*milvuspb.SearchResults) + for _, retField := range result.Results.FieldsData { + for _, schemaField := range op.schema.Fields { + if retField != nil && retField.FieldId == schemaField.FieldID { + retField.FieldName = schemaField.Name + retField.Type = schemaField.DataType + retField.IsDynamic = schemaField.IsDynamic + } + } + } + result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(field *schemapb.FieldData, _ int) bool { + return lo.Contains(op.outputFieldNames, field.FieldName) + }) + return []any{result}, nil +} + +func mergeIDsFunc(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + multipleMilvusResults := inputs[0].([]*milvuspb.SearchResults) + idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) { + return m.Results.Ids, true + }) + uniqueIDs := &schemapb.IDs{} + switch idsList[0].GetIdField().(type) { + case *schemapb.IDs_IntId: + idsSet := typeutil.NewSet[int64]() + for _, ids := range idsList { + if data := ids.GetIntId().GetData(); data != nil { + idsSet.Insert(data...) + } + } + uniqueIDs.IdField = &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: idsSet.Collect(), + }, + } + case *schemapb.IDs_StrId: + idsSet := typeutil.NewSet[string]() + for _, ids := range idsList { + if data := ids.GetStrId().GetData(); data != nil { + idsSet.Insert(data...) + } + } + uniqueIDs.IdField = &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: idsSet.Collect(), + }, + } + } + return []any{uniqueIDs}, nil +} + +type pipeline struct { + name string + nodes []*Node +} + +func newPipeline(pipeDef *pipelineDef, t *searchTask) (*pipeline, error) { + nodes := make([]*Node, len(pipeDef.nodes)) + for i, def := range pipeDef.nodes { + node, err := NewNode(def, t) + if err != nil { + return nil, err + } + nodes[i] = node + } + return &pipeline{name: pipeDef.name, nodes: nodes}, nil +} + +func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) (*milvuspb.SearchResults, error) { + log.Ctx(ctx).Debug("SearchPipeline run", zap.String("pipeline", p.String())) + msg := opMsg{} + msg["input"] = toReduceResults + for _, node := range p.nodes { + var err error + log.Ctx(ctx).Debug("SearchPipeline run node", zap.String("node", node.name)) + msg, err = node.Run(ctx, span, msg) + if err != nil { + log.Ctx(ctx).Error("Run node failed: ", zap.String("err", err.Error())) + return nil, err + } + } + return msg["output"].(*milvuspb.SearchResults), nil +} + +func (p *pipeline) String() string { + buf := bytes.NewBufferString(fmt.Sprintf("SearchPipeline: %s", p.name)) + for _, node := range p.nodes { + buf.WriteString(fmt.Sprintf(" %s -> %s", node.name, node.outputs)) + } + return buf.String() +} + +type pipelineDef struct { + name string + nodes []*nodeDef +} + +var searchPipe = &pipelineDef{ + name: "search", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input"}, + outputs: []string{"reduced", "metrics"}, + opName: searchReduceOp, + }, + { + name: "pick", + inputs: []string{"reduced"}, + outputs: []string{"result"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].([]*milvuspb.SearchResults)[0] + return []any{result}, nil + }, + }, + opName: lambdaOp, + }, + { + name: "filter_field", + inputs: []string{"result"}, + outputs: []string{"output"}, + opName: filterFieldOp, + }, + }, +} + +var searchWithRequeryPipe = &pipelineDef{ + name: "searchWithRequery", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input"}, + outputs: []string{"reduced", "metrics"}, + opName: searchReduceOp, + }, + { + name: "merge", + inputs: []string{"reduced"}, + outputs: []string{"unique_ids"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: mergeIDsFunc, + }, + }, + { + name: "requery", + inputs: []string{"unique_ids"}, + outputs: []string{"fields"}, + opName: requeryOp, + }, + { + name: "gen_ids", + inputs: []string{"reduced"}, + outputs: []string{"ids"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return []any{[]*schemapb.IDs{inputs[0].([]*milvuspb.SearchResults)[0].Results.Ids}}, nil + }, + }, + }, + { + name: "organize", + inputs: []string{"fields", "ids"}, + outputs: []string{"organized_fields"}, + opName: organizeOp, + }, + { + name: "pick", + inputs: []string{"reduced", "organized_fields"}, + outputs: []string{"result"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].([]*milvuspb.SearchResults)[0] + fields := inputs[1].([][]*schemapb.FieldData) + result.Results.FieldsData = fields[0] + return []any{result}, nil + }, + }, + opName: lambdaOp, + }, + { + name: "filter_field", + inputs: []string{"result"}, + outputs: []string{"output"}, + opName: filterFieldOp, + }, + }, +} + +var searchWithRerankPipe = &pipelineDef{ + name: "searchWithRerank", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input"}, + outputs: []string{"reduced", "metrics"}, + opName: searchReduceOp, + }, + { + name: "rerank", + inputs: []string{"reduced", "metrics"}, + outputs: []string{"rank_result"}, + opName: rerankOp, + }, + { + name: "pick", + inputs: []string{"reduced", "rank_result"}, + outputs: []string{"ids", "fields"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return []any{ + inputs[0].([]*milvuspb.SearchResults)[0].Results.FieldsData, + []*schemapb.IDs{inputs[1].(*milvuspb.SearchResults).Results.Ids}, + }, nil + }, + }, + opName: lambdaOp, + }, + { + name: "organize", + inputs: []string{"ids", "fields"}, + outputs: []string{"organized_fields"}, + opName: organizeOp, + }, + { + name: "result", + inputs: []string{"rank_result", "organized_fields"}, + outputs: []string{"result"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].(*milvuspb.SearchResults) + fields := inputs[1].([][]*schemapb.FieldData) + result.Results.FieldsData = fields[0] + return []any{result}, nil + }, + }, + opName: lambdaOp, + }, + { + name: "filter_field", + inputs: []string{"result"}, + outputs: []string{"output"}, + opName: filterFieldOp, + }, + }, +} + +var searchWithRerankRequeryPipe = &pipelineDef{ + name: "searchWithRerankRequery", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input"}, + outputs: []string{"reduced", "metrics"}, + opName: searchReduceOp, + }, + { + name: "rerank", + inputs: []string{"reduced", "metrics"}, + outputs: []string{"rank_result"}, + opName: rerankOp, + }, + { + name: "pick_ids", + inputs: []string{"rank_result"}, + outputs: []string{"ids"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return []any{ + inputs[0].(*milvuspb.SearchResults).Results.Ids, + }, nil + }, + }, + opName: lambdaOp, + }, + { + name: "requery", + inputs: []string{"ids"}, + outputs: []string{"fields"}, + opName: requeryOp, + }, + { + name: "to_ids_list", + inputs: []string{"ids"}, + outputs: []string{"ids"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return []any{[]*schemapb.IDs{inputs[0].(*schemapb.IDs)}}, nil + }, + }, + }, + { + name: "organize", + inputs: []string{"fields", "ids"}, + outputs: []string{"organized_fields"}, + opName: organizeOp, + }, + { + name: "result", + inputs: []string{"rank_result", "organized_fields"}, + outputs: []string{"result"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].(*milvuspb.SearchResults) + fields := inputs[1].([][]*schemapb.FieldData) + result.Results.FieldsData = fields[0] + return []any{result}, nil + }, + }, + opName: lambdaOp, + }, + { + name: "filter_field", + inputs: []string{"result"}, + outputs: []string{"output"}, + opName: filterFieldOp, + }, + }, +} + +var hybridSearchPipe = &pipelineDef{ + name: "hybridSearchPipe", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input"}, + outputs: []string{"reduced", "metrics"}, + opName: hybridSearchReduceOp, + }, + { + name: "rerank", + inputs: []string{"reduced", "metrics"}, + outputs: []string{"output"}, + opName: rerankOp, + }, + }, +} + +var hybridSearchWithRequeryPipe = &pipelineDef{ + name: "hybridSearchWithRequery", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input"}, + outputs: []string{"reduced", "metrics"}, + opName: hybridSearchReduceOp, + }, + { + name: "merge_ids", + inputs: []string{"reduced"}, + outputs: []string{"ids"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: mergeIDsFunc, + }, + }, + { + name: "requery", + inputs: []string{"ids"}, + outputs: []string{"fields"}, + opName: requeryOp, + }, + { + name: "parse_ids", + inputs: []string{"reduced"}, + outputs: []string{"id_list"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + multipleMilvusResults := inputs[0].([]*milvuspb.SearchResults) + idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) { + return m.Results.Ids, true + }) + return []any{idsList}, nil + }, + }, + }, + { + name: "organize_rank_data", + inputs: []string{"fields", "id_list"}, + outputs: []string{"organized_fields"}, + opName: organizeOp, + }, + { + name: "gen_rank_data", + inputs: []string{"reduced", "organized_fields"}, + outputs: []string{"rank_data"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + results := inputs[0].([]*milvuspb.SearchResults) + fields := inputs[1].([][]*schemapb.FieldData) + for i := 0; i < len(results); i++ { + results[i].Results.FieldsData = fields[i] + } + return []any{results}, nil + }, + }, + }, + { + name: "rerank", + inputs: []string{"rank_data", "metrics"}, + outputs: []string{"rank_result"}, + opName: rerankOp, + }, + { + name: "pick_ids", + inputs: []string{"rank_result"}, + outputs: []string{"ids"}, + opName: lambdaOp, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return []any{[]*schemapb.IDs{inputs[0].(*milvuspb.SearchResults).Results.Ids}}, nil + }, + }, + }, + { + name: "organize_result", + inputs: []string{"fields", "ids"}, + outputs: []string{"organized_fields"}, + opName: organizeOp, + }, + { + name: "result", + inputs: []string{"rank_result", "organized_fields"}, + outputs: []string{"result"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].(*milvuspb.SearchResults) + fields := inputs[1].([][]*schemapb.FieldData) + result.Results.FieldsData = fields[0] + return []any{result}, nil + }, + }, + opName: lambdaOp, + }, + { + name: "filter_field", + inputs: []string{"result"}, + outputs: []string{"output"}, + opName: filterFieldOp, + }, + }, +} + +func newBuiltInPipeline(t *searchTask) (*pipeline, error) { + if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil { + return newPipeline(searchPipe, t) + } + if !t.SearchRequest.GetIsAdvanced() && t.needRequery && t.functionScore == nil { + return newPipeline(searchWithRequeryPipe, t) + } + if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore != nil { + return newPipeline(searchWithRerankPipe, t) + } + if !t.SearchRequest.GetIsAdvanced() && t.needRequery && t.functionScore != nil { + return newPipeline(searchWithRerankRequeryPipe, t) + } + if t.SearchRequest.GetIsAdvanced() && !t.needRequery { + return newPipeline(hybridSearchPipe, t) + } + if t.SearchRequest.GetIsAdvanced() && t.needRequery { + return newPipeline(hybridSearchWithRequeryPipe, t) + } + return nil, fmt.Errorf("Unsupported pipeline") +} diff --git a/internal/proxy/search_pipeline_test.go b/internal/proxy/search_pipeline_test.go new file mode 100644 index 0000000000..38831a3c30 --- /dev/null +++ b/internal/proxy/search_pipeline_test.go @@ -0,0 +1,754 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + "slices" + "testing" + "time" + + "github.com/bytedance/mockey" + "github.com/stretchr/testify/suite" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function/rerank" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/testutils" + "github.com/milvus-io/milvus/pkg/v2/util/timerecord" +) + +func TestSearchPipeline(t *testing.T) { + suite.Run(t, new(SearchPipelineSuite)) +} + +type SearchPipelineSuite struct { + suite.Suite + span trace.Span +} + +func (s *SearchPipelineSuite) SetupTest() { + _, sp := otel.Tracer("test").Start(context.Background(), "Proxy-Search-PostExecute") + s.span = sp +} + +func (s *SearchPipelineSuite) TearDownTest() { + s.span.End() +} + +func (s *SearchPipelineSuite) TestSearchReduceOp() { + nq := int64(2) + topk := int64(10) + pk := &schemapb.FieldSchema{ + FieldID: 101, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + AutoID: true, + } + data := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, false) + op := searchReduceOperator{ + context.Background(), + pk, + nq, + topk, + 0, + 1, + []int64{1}, + []*planpb.QueryInfo{{}}, + } + _, err := op.run(context.Background(), s.span, []*internalpb.SearchResults{data}) + s.NoError(err) +} + +func (s *SearchPipelineSuite) TestHybridSearchReduceOp() { + nq := int64(2) + topk := int64(10) + pk := &schemapb.FieldSchema{ + FieldID: 101, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + AutoID: true, + } + data1 := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, true) + data1.SubResults[0].ReqIndex = 0 + data2 := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, true) + data2.SubResults[0].ReqIndex = 1 + + subReqs := []*internalpb.SubSearchRequest{ + { + Nq: 2, + Topk: 10, + Offset: 0, + }, + { + Nq: 2, + Topk: 10, + Offset: 0, + }, + } + + op := hybridSearchReduceOperator{ + context.Background(), + subReqs, + pk, + 1, + []int64{1}, + []*planpb.QueryInfo{{}, {}}, + } + _, err := op.run(context.Background(), s.span, []*internalpb.SearchResults{data1, data2}) + s.NoError(err) +} + +func (s *SearchPipelineSuite) TestRerankOp() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"ts"}, + OutputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: "reranker", Value: "decay"}, + {Key: "origin", Value: "4"}, + {Key: "scale", Value: "4"}, + {Key: "offset", Value: "4"}, + {Key: "decay", Value: "0.5"}, + {Key: "function", Value: "gauss"}, + }, + } + funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{functionSchema}, + }) + s.NoError(err) + + nq := int64(2) + topk := int64(10) + offset := int64(0) + + reduceOp := searchReduceOperator{ + context.Background(), + schema.Fields[0], + nq, + topk, + offset, + 1, + []int64{1}, + []*planpb.QueryInfo{{}}, + } + + data := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, false) + reduced, err := reduceOp.run(context.Background(), s.span, []*internalpb.SearchResults{data}) + s.NoError(err) + + op := rerankOperator{ + nq: nq, + topK: topk, + offset: offset, + roundDecimal: 10, + functionScore: funcScore, + } + + _, err = op.run(context.Background(), s.span, reduced[0], []string{"IP"}) + s.NoError(err) +} + +func (s *SearchPipelineSuite) TestRequeryOp() { + f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20) + f1.FieldId = 101 + + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1}, + }, nil).Build() + defer mocker.UnPatch() + + op := requeryOperator{ + traceCtx: context.Background(), + outputFieldNames: []string{"int64"}, + } + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2}, + }, + }, + } + _, err := op.run(context.Background(), s.span, ids, []string{"int64"}) + s.NoError(err) +} + +func (s *SearchPipelineSuite) TestOrganizeOp() { + op := organizeOperator{ + traceCtx: context.Background(), + primaryFieldSchema: &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + collectionID: 1, + } + fields := []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "pk", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + }, + }, + }, + }, { + Type: schemapb.DataType_Int64, + FieldName: "int64", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + }, + }, + }, + }, + } + + ids := []*schemapb.IDs{ + { + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 4, 5, 9, 10}, + }, + }, + }, + { + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{5, 6, 7, 8, 9, 10}, + }, + }, + }, + } + ret, err := op.run(context.Background(), s.span, fields, ids) + s.NoError(err) + fmt.Println(ret) +} + +func (s *SearchPipelineSuite) TestSearchPipeline() { + collectionName := "test" + task := &searchTask{ + ctx: context.Background(), + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + MetricType: "L2", + Topk: 10, + Nq: 2, + PartitionIDs: []int64{1}, + CollectionID: 1, + DbID: 1, + }, + schema: &schemaInfo{ + CollectionSchema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64}, + }, + }, + pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + }, + queryInfos: []*planpb.QueryInfo{{}}, + translatedOutputFields: []string{"intField"}, + } + pipeline, err := newPipeline(searchPipe, task) + s.NoError(err) + results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{ + genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false), + }) + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk + s.NotNil(results.Results.FieldsData) + s.Len(results.Results.FieldsData, 1) // One output field + s.Equal("intField", results.Results.FieldsData[0].FieldName) + s.Equal(int64(101), results.Results.FieldsData[0].FieldId) + fmt.Println(results) +} + +func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() { + collectionName := "test_collection" + task := &searchTask{ + ctx: context.Background(), + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + MetricType: "L2", + Topk: 10, + Nq: 2, + PartitionIDs: []int64{1}, + CollectionID: 1, + DbID: 1, + }, + schema: &schemaInfo{ + CollectionSchema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64}, + }, + }, + pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + }, + queryInfos: []*planpb.QueryInfo{{}}, + translatedOutputFields: []string{"intField"}, + node: nil, + } + + // Mock requery operation + f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20) + f1.FieldId = 101 + f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20) + f2.FieldId = 100 + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2}, + }, nil).Build() + defer mocker.UnPatch() + + pipeline, err := newPipeline(searchWithRequeryPipe, task) + s.NoError(err) + results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{ + genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false), + }) + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk + s.NotNil(results.Results.FieldsData) + s.Len(results.Results.FieldsData, 1) // One output field + s.Equal("intField", results.Results.FieldsData[0].FieldName) + s.Equal(int64(101), results.Results.FieldsData[0].FieldId) +} + +func (s *SearchPipelineSuite) TestSearchWithRerankPipe() { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"intField"}, + OutputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: "reranker", Value: "decay"}, + {Key: "origin", Value: "4"}, + {Key: "scale", Value: "4"}, + {Key: "offset", Value: "4"}, + {Key: "decay", Value: "0.5"}, + {Key: "function", Value: "gauss"}, + }, + } + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64}, + }, + } + funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{functionSchema}, + }) + s.NoError(err) + + task := &searchTask{ + ctx: context.Background(), + collectionName: "test_collection", + SearchRequest: &internalpb.SearchRequest{ + MetricType: "L2", + Topk: 10, + Nq: 2, + PartitionIDs: []int64{1}, + CollectionID: 1, + DbID: 1, + }, + schema: &schemaInfo{ + CollectionSchema: schema, + pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + }, + queryInfos: []*planpb.QueryInfo{{}}, + translatedOutputFields: []string{"intField"}, + node: nil, + functionScore: funcScore, + } + + pipeline, err := newPipeline(searchWithRerankPipe, task) + s.NoError(err) + + searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false) + results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}) + + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk + s.NotNil(results.Results.FieldsData) + s.Len(results.Results.FieldsData, 1) // One output field + s.Equal("intField", results.Results.FieldsData[0].FieldName) + s.Equal(int64(101), results.Results.FieldsData[0].FieldId) +} + +func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"intField"}, + OutputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: "reranker", Value: "decay"}, + {Key: "origin", Value: "4"}, + {Key: "scale", Value: "4"}, + {Key: "offset", Value: "4"}, + {Key: "decay", Value: "0.5"}, + {Key: "function", Value: "gauss"}, + }, + } + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64}, + }, + } + funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{functionSchema}, + }) + s.NoError(err) + + task := &searchTask{ + ctx: context.Background(), + collectionName: "test_collection", + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + MetricType: "L2", + Topk: 10, + Nq: 2, + PartitionIDs: []int64{1}, + CollectionID: 1, + DbID: 1, + }, + schema: &schemaInfo{ + CollectionSchema: schema, + pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + }, + queryInfos: []*planpb.QueryInfo{{}}, + translatedOutputFields: []string{"intField"}, + node: nil, + functionScore: funcScore, + } + f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20) + f1.FieldId = 101 + f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20) + f2.FieldId = 100 + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2}, + }, nil).Build() + defer mocker.UnPatch() + + pipeline, err := newPipeline(searchWithRerankRequeryPipe, task) + s.NoError(err) + + searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false) + results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}) + + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk + s.NotNil(results.Results.FieldsData) + s.Len(results.Results.FieldsData, 1) // One output field + s.Equal("intField", results.Results.FieldsData[0].FieldName) + s.Equal(int64(101), results.Results.FieldsData[0].FieldId) +} + +func (s *SearchPipelineSuite) TestHybridSearchPipe() { + task := getHybridSearchTask("test_collection", [][]string{ + {"1", "2"}, + {"3", "4"}, + }, + []string{}, + ) + + pipeline, err := newPipeline(hybridSearchPipe, task) + s.NoError(err) + + f1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) + f2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) + results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{f1, f2}) + + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk +} + +func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { + task := getHybridSearchTask("test_collection", [][]string{ + {"1", "2"}, + {"3", "4"}, + }, + []string{"intField"}, + ) + + f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20) + f1.FieldId = 101 + f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20) + f2.FieldId = 100 + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2}, + }, nil).Build() + defer mocker.UnPatch() + + pipeline, err := newPipeline(hybridSearchWithRequeryPipe, task) + s.NoError(err) + + d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) + d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) + results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}) + + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk + s.NotNil(results.Results.FieldsData) + s.Len(results.Results.FieldsData, 1) // One output field + s.Equal("intField", results.Results.FieldsData[0].FieldName) + s.Equal(int64(101), results.Results.FieldsData[0].FieldId) +} + +func getHybridSearchTask(collName string, data [][]string, outputFields []string) *searchTask { + subReqs := []*milvuspb.SubSearchRequest{} + for _, item := range data { + subReq := &milvuspb.SubSearchRequest{ + SearchParams: []*commonpb.KeyValuePair{ + {Key: TopKKey, Value: "10"}, + }, + Nq: int64(len(item)), + } + subReqs = append(subReqs, subReq) + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{}, + OutputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: "reranker", Value: "rrf"}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64}, + }, + } + funcScore, _ := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{functionSchema}, + }) + task := &searchTask{ + ctx: context.Background(), + collectionName: collName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + Topk: 10, + Nq: 2, + IsAdvanced: true, + SubReqs: []*internalpb.SubSearchRequest{ + { + Topk: 10, + Nq: 2, + }, + { + Topk: 10, + Nq: 2, + }, + }, + }, + request: &milvuspb.SearchRequest{ + CollectionName: collName, + SubReqs: subReqs, + SearchParams: []*commonpb.KeyValuePair{ + {Key: LimitKey, Value: "10"}, + }, + FunctionScore: &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{functionSchema}, + }, + OutputFields: outputFields, + }, + schema: &schemaInfo{ + CollectionSchema: schema, + pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + }, + mixCoord: nil, + tr: timerecord.NewTimeRecorder("test-search"), + rankParams: &rankParams{ + limit: 10, + offset: 0, + roundDecimal: 0, + }, + queryInfos: []*planpb.QueryInfo{{}, {}}, + functionScore: funcScore, + translatedOutputFields: outputFields, + } + return task +} + +func (s *SearchPipelineSuite) TestMergeIDsFunc() { + { + ids1 := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 5}, + }, + }, + } + + ids2 := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 4, 5, 6}, + }, + }, + } + rets := []*milvuspb.SearchResults{ + { + Results: &schemapb.SearchResultData{ + Ids: ids1, + }, + }, + { + Results: &schemapb.SearchResultData{ + Ids: ids2, + }, + }, + } + allIDs, err := mergeIDsFunc(context.Background(), s.span, rets) + s.NoError(err) + sortedIds := allIDs[0].(*schemapb.IDs).GetIntId().GetData() + slices.Sort(sortedIds) + s.Equal(sortedIds, []int64{1, 2, 3, 4, 5, 6}) + } + { + ids1 := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "b", "e"}, + }, + }, + } + + ids2 := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "b", "c", "d"}, + }, + }, + } + rets := []*milvuspb.SearchResults{ + { + Results: &schemapb.SearchResultData{ + Ids: ids1, + }, + }, + } + rets = append(rets, &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + Ids: ids2, + }, + }) + allIDs, err := mergeIDsFunc(context.Background(), s.span, rets) + s.NoError(err) + sortedIds := allIDs[0].(*schemapb.IDs).GetStrId().GetData() + slices.Sort(sortedIds) + s.Equal(sortedIds, []string{"a", "b", "c", "d", "e"}) + } +} diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go index 13fd658ce9..71c556f8ee 100644 --- a/internal/proxy/search_reduce_util.go +++ b/internal/proxy/search_reduce_util.go @@ -5,12 +5,16 @@ import ( "fmt" "github.com/cockroachdb/errors" + "go.opentelemetry.io/otel" "go.uber.org/zap" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metric" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" @@ -470,3 +474,105 @@ func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults { }, } } + +func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults") + defer sp.End() + + log := log.Ctx(ctx) + // Decode all search results + validSearchResults, err := decodeSearchResults(ctx, toReduceResults) + if err != nil { + log.Warn("failed to decode search results", zap.Error(err)) + return nil, err + } + + if len(validSearchResults) <= 0 { + return fillInEmptyResult(nq), nil + } + + // Reduce all search results + log.Debug("proxy search post execute reduce", + zap.Int64("collection", collectionID), + zap.Int64s("partitionIDs", partitionIDs), + zap.Int("number of valid search results", len(validSearchResults))) + var result *milvuspb.SearchResults + result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType). + WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance)) + if err != nil { + log.Warn("failed to reduce search results", zap.Error(err)) + return nil, err + } + return result, nil +} + +func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults") + defer sp.End() + tr := timerecord.NewTimeRecorder("decodeSearchResults") + results := make([]*schemapb.SearchResultData, 0) + for _, partialSearchResult := range searchResults { + if partialSearchResult.SlicedBlob == nil { + continue + } + + var partialResultData schemapb.SearchResultData + err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData) + if err != nil { + return nil, err + } + results = append(results, &partialResultData) + } + tr.CtxElapse(ctx, "decodeSearchResults done") + return results, nil +} + +func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error { + if data.NumQueries != nq { + return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq) + } + if data.TopK != topk { + return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk) + } + + if len(data.Scores) != pkHitNum { + return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d", + len(data.Scores), pkHitNum) + } + return nil +} + +func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) { + var ( + subSearchIdx = -1 + resultDataIdx int64 = -1 + ) + maxScore := minFloat32 + for i := range cursors { + if cursors[i] >= subSearchResultData[i].Topks[qi] { + continue + } + sIdx := subSearchNqOffset[i][qi] + cursors[i] + sScore := subSearchResultData[i].Scores[sIdx] + + // Choose the larger score idx or the smaller pk idx with the same score + if subSearchIdx == -1 || sScore > maxScore { + subSearchIdx = i + resultDataIdx = sIdx + maxScore = sScore + } else if sScore == maxScore { + if subSearchIdx == -1 { + // A bad case happens where Knowhere returns distance/score == +/-maxFloat32 + // by mistake. + log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore)) + } else if typeutil.ComparePK( + typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx), + typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) { + subSearchIdx = i + resultDataIdx = sIdx + maxScore = sScore + } + } + } + return subSearchIdx, resultDataIdx +} diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 706e372dcb..97bd656e72 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/common" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -577,3 +578,11 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se } return ret } + +func getMetricType(toReduceResults []*internalpb.SearchResults) string { + metricType := "" + if len(toReduceResults) >= 1 { + metricType = toReduceResults[0].GetMetricType() + } + return metricType +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index f756faa6d8..3182898b81 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -10,7 +10,6 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "google.golang.org/protobuf/proto" @@ -22,7 +21,6 @@ import ( "github.com/milvus-io/milvus/internal/util/exprutil" "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/function/rerank" - "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" @@ -96,9 +94,6 @@ type searchTask struct { // we always remove pk field from output fields, as search result already contains pk field. // if the user explicitly set pk field in output fields, we add it back to the result. userRequestedPkFieldExplicitly bool - - // To facilitate writing unit tests - requeryFunc func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) } func (t *searchTask) CanSkipAllocTimestamp() bool { @@ -488,7 +483,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId() t.SearchRequest.GroupSize = t.rankParams.GetGroupSize() - // used for requery if t.partitionKeyMode { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() } @@ -496,57 +490,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { return nil } -func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error { - // Collecting the results of a subsearch - // [[shard1, shard2, ...],[shard1, shard2, ...]] - multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs())) - for _, searchResult := range toReduceResults { - // if get a non-advanced result, skip all - if !searchResult.GetIsAdvanced() { - continue - } - for _, subResult := range searchResult.GetSubResults() { - // swallow copy - internalResults := &internalpb.SearchResults{ - MetricType: subResult.GetMetricType(), - NumQueries: subResult.GetNumQueries(), - TopK: subResult.GetTopK(), - SlicedBlob: subResult.GetSlicedBlob(), - SlicedNumCount: subResult.GetSlicedNumCount(), - SlicedOffset: subResult.GetSlicedOffset(), - IsAdvanced: false, - } - reqIndex := subResult.GetReqIndex() - multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults) - } - } - - multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs())) - searchMetrics := []string{} - for index, internalResults := range multipleInternalResults { - subReq := t.SearchRequest.GetSubReqs()[index] - // Since the metrictype in the request may be empty, it can only be obtained from the result - subMetricType := getMetricType(internalResults) - result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true) - if err != nil { - return err - } - - searchMetrics = append(searchMetrics, subMetricType) - multipleMilvusResults[index] = result - } - - if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil { - return err - } - - t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool { - return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName()) - }) - t.fillResult() - return nil -} - func (t *searchTask) fillResult() { limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset() resultSizeInsufficient := false @@ -558,111 +501,6 @@ func (t *searchTask) fillResult() { } t.resultSizeInsufficient = resultSizeInsufficient t.result.CollectionName = t.collectionName - t.fillInFieldInfo() -} - -func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) { - uniqueIDs := &schemapb.IDs{} - int64IDs := typeutil.NewSet[int64]() - strIDs := typeutil.NewSet[string]() - - for _, ids := range idsList { - if ids == nil { - continue - } - switch ids.GetIdField().(type) { - case *schemapb.IDs_IntId: - int64IDs.Insert(ids.GetIntId().GetData()...) - case *schemapb.IDs_StrId: - strIDs.Insert(ids.GetStrId().GetData()...) - } - } - - if int64IDs.Len() > 0 { - uniqueIDs.IdField = &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: int64IDs.Collect(), - }, - } - return uniqueIDs, int64IDs.Len() - } - - if strIDs.Len() > 0 { - uniqueIDs.IdField = &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: strIDs.Collect(), - }, - } - return uniqueIDs, strIDs.Len() - } - - return nil, 0 -} - -func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error { - var err error - processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) { - ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf") - defer sp.End() - groupScorerStr := getGroupScorerStr(t.request.GetSearchParams()) - params := rerank.NewSearchParams( - t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal, - t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, groupScorerStr, searchMetrics, - ) - return t.functionScore.Process(ctx, params, results) - } - - // The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery. - // At this time, outputFields and rerank input_fields will be recalled. - // If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step - if t.needRequery { - idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) { - return m.Results.Ids, true - }) - allIDs, count := mergeIDs(idsList) - if count == 0 { - t.result = &milvuspb.SearchResults{ - Status: merr.Success(), - Results: &schemapb.SearchResultData{ - NumQueries: t.Nq, - TopK: t.rankParams.limit, - FieldsData: make([]*schemapb.FieldData, 0), - Scores: []float32{}, - Ids: &schemapb.IDs{}, - Topks: []int64{}, - }, - } - return nil - } - allNames := typeutil.NewSet[string](t.translatedOutputFields...) - allNames.Insert(t.functionScore.GetAllInputFieldNames()...) - queryResult, err := t.requeryFunc(t, span, allIDs, allNames.Collect()) - if err != nil { - log.Warn("failed to requery", zap.Error(err)) - return err - } - fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), idsList) - if err != nil { - return err - } - for i := 0; i < len(multipleMilvusResults); i++ { - multipleMilvusResults[i].Results.FieldsData = fields[i] - } - - if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil { - return err - } - if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil { - return err - } else { - t.result.Results.FieldsData = fields[0] - } - } else { - if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil { - return err - } - } - return nil } func (t *searchTask) initSearchRequest(ctx context.Context) error { @@ -677,19 +515,13 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { } if t.request.FunctionScore != nil { - // TODO: When rerank is configured, range search is also supported - if isIterator { - return merr.WrapErrParameterInvalidMsg("Range search do not support rerank") - } - if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil { log.Warn("Failed to create function score", zap.Error(err)) return err } - // TODO: When rerank is configured, grouping search is also supported if !t.functionScore.IsSupportGroup() && queryInfo.GetGroupByFieldId() > 0 { - return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search") + return merr.WrapErrParameterInvalidMsg("Rerank %s does not support grouping search", t.functionScore.RerankName()) } } @@ -773,57 +605,6 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { return nil } -func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error { - metricType := getMetricType(toReduceResults) - result, err := t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false) - if err != nil { - return err - } - - if t.functionScore != nil && len(result.Results.FieldsData) != 0 { - { - ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf") - defer sp.End() - groupScorerStr := getGroupScorerStr(t.request.GetSearchParams()) - params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), - t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, groupScorerStr, []string{metricType}) - // rank only returns id and score - if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil { - return err - } - } - if !t.needRequery { - fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids}) - if err != nil { - return err - } - t.result.Results.FieldsData = fields[0] - } - } else { - t.result = result - } - t.fillResult() - if t.needRequery { - if t.requeryFunc == nil { - t.requeryFunc = requeryImpl - } - queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields) - if err != nil { - log.Warn("failed to requery", zap.Error(err)) - return err - } - fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}) - if err != nil { - return err - } - t.result.Results.FieldsData = fields[0] - } - t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool { - return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName()) - }) - return nil -} - func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) { annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params) if err != nil || len(annsFieldName) == 0 { @@ -919,50 +700,6 @@ func (t *searchTask) Execute(ctx context.Context) error { return nil } -func getMetricType(toReduceResults []*internalpb.SearchResults) string { - metricType := "" - if len(toReduceResults) >= 1 { - metricType = toReduceResults[0].GetMetricType() - } - return metricType -} - -func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) { - ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults") - defer sp.End() - - log := log.Ctx(ctx) - // Decode all search results - validSearchResults, err := decodeSearchResults(ctx, toReduceResults) - if err != nil { - log.Warn("failed to decode search results", zap.Error(err)) - return nil, err - } - - if len(validSearchResults) <= 0 { - return fillInEmptyResult(nq), nil - } - - // Reduce all search results - log.Debug("proxy search post execute reduce", - zap.Int64("collection", t.GetCollectionID()), - zap.Int64s("partitionIDs", t.GetPartitionIDs()), - zap.Int("number of valid search results", len(validSearchResults))) - primaryFieldSchema, err := t.schema.GetPkField() - if err != nil { - log.Warn("failed to get primary field schema", zap.Error(err)) - return nil, err - } - var result *milvuspb.SearchResults - result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()). - WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance)) - if err != nil { - log.Warn("failed to reduce search results", zap.Error(err)) - return nil, err - } - return result, nil -} - // find the last bound based on reduced results and metric type // only support nq == 1, for search iterator v2 func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 { @@ -1017,15 +754,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error { t.isTopkReduce = isTopkReduce t.isRecallEvaluation = isRecallEvaluation - if t.SearchRequest.GetIsAdvanced() { - err = t.advancedPostProcess(ctx, sp, toReduceResults) - } else { - err = t.searchPostProcess(ctx, sp, toReduceResults) - } - + // call pipeline + pipeline, err := newBuiltInPipeline(t) if err != nil { + log.Warn("Faild to create post process pipeline") return err } + if t.result, err = pipeline.Run(ctx, sp, toReduceResults); err != nil { + return err + } + t.fillResult() t.result.Results.OutputFields = t.userOutputFields t.result.CollectionName = t.request.GetCollectionName() @@ -1143,161 +881,6 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) { //return int64(sizePerRecord) * nq * topK, nil } -func requeryImpl(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { - queryReq := &milvuspb.QueryRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Retrieve, - Timestamp: t.BeginTs(), - }, - DbName: t.request.GetDbName(), - CollectionName: t.request.GetCollectionName(), - ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(), - NotReturnAllMeta: t.request.GetNotReturnAllMeta(), - Expr: "", - OutputFields: outputFields, - PartitionNames: t.request.GetPartitionNames(), - UseDefaultConsistency: false, - GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp, - } - pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema) - if err != nil { - return nil, err - } - - plan := planparserv2.CreateRequeryPlan(pkField, ids) - channelsMvcc := make(map[string]Timestamp) - for k, v := range t.queryChannelsTs { - channelsMvcc[k] = v - } - qt := &queryTask{ - ctx: t.ctx, - Condition: NewTaskCondition(t.ctx), - RetrieveRequest: &internalpb.RetrieveRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - ReqID: paramtable.GetNodeID(), - PartitionIDs: t.GetPartitionIDs(), // use search partitionIDs - ConsistencyLevel: t.ConsistencyLevel, - }, - request: queryReq, - plan: plan, - mixCoord: t.node.(*Proxy).mixCoord, - lb: t.node.(*Proxy).lbPolicy, - channelsMvcc: channelsMvcc, - fastSkip: true, - reQuery: true, - } - queryResult, err := t.node.(*Proxy).query(t.ctx, qt, span) - if err != nil { - return nil, err - } - if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return nil, merr.Error(queryResult.GetStatus()) - } - return queryResult, err -} - -func isEmpty(ids *schemapb.IDs) bool { - if ids == nil { - return true - } - if ids.GetIntId() != nil && len(ids.GetIntId().Data) != 0 { - return false - } - - if ids.GetStrId() != nil && len(ids.GetStrId().Data) != 0 { - return false - } - return true -} - -func (t *searchTask) reorganizeRequeryResults(ctx context.Context, fields []*schemapb.FieldData, idsList []*schemapb.IDs) ([][]*schemapb.FieldData, error) { - _, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults") - defer sp.End() - - pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema) - if err != nil { - return nil, err - } - pkFieldData, err := typeutil.GetPrimaryFieldData(fields, pkField) - if err != nil { - return nil, err - } - offsets := make(map[any]int) - pkItr := typeutil.GetDataIterator(pkFieldData) - for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ { - pk := pkItr(i) - offsets[pk] = i - } - - allFieldData := make([][]*schemapb.FieldData, len(idsList)) - for idx, ids := range idsList { - if isEmpty(ids) { - emptyFields := []*schemapb.FieldData{} - for _, field := range fields { - emptyFields = append(emptyFields, &schemapb.FieldData{ - Type: field.Type, - FieldName: field.FieldName, - FieldId: field.FieldId, - IsDynamic: field.IsDynamic, - }) - } - allFieldData[idx] = emptyFields - continue - } - if fieldData, err := t.pickFieldData(ids, offsets, fields); err != nil { - return nil, err - } else { - allFieldData[idx] = fieldData - } - } - return allFieldData, nil -} - -// pick field data from query results -func (t *searchTask) pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, fields []*schemapb.FieldData) ([]*schemapb.FieldData, error) { - // Reorganize Results. The order of query result ids will be altered and differ from queried ids. - // We should reorganize query results to keep the order of original queried ids. For example: - // =========================================== - // 3 2 5 4 1 (query ids) - // || - // || (query) - // \/ - // 4 3 5 1 2 (result ids) - // v4 v3 v5 v1 v2 (result vectors) - // || - // || (reorganize) - // \/ - // 3 2 5 4 1 (result ids) - // v3 v2 v5 v4 v1 (result vectors) - // =========================================== - fieldsData := make([]*schemapb.FieldData, len(fields)) - for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ { - id := typeutil.GetPK(ids, int64(i)) - if _, ok := pkOffset[id]; !ok { - return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d", - id, typeutil.GetSizeOfIDs(ids), len(pkOffset), t.GetCollectionID())) - } - typeutil.AppendFieldData(fieldsData, fields, int64(pkOffset[id])) - } - - return fieldsData, nil -} - -func (t *searchTask) fillInFieldInfo() { - for _, retField := range t.result.Results.FieldsData { - for _, schemaField := range t.schema.Fields { - if retField != nil && retField.FieldId == schemaField.FieldID { - retField.FieldName = schemaField.Name - retField.Type = schemaField.DataType - retField.IsDynamic = schemaField.IsDynamic - } - } - } -} - func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.SearchResults, error) { select { case <-t.TraceCtx().Done(): @@ -1316,77 +899,6 @@ func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.Se } } -func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { - ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults") - defer sp.End() - tr := timerecord.NewTimeRecorder("decodeSearchResults") - results := make([]*schemapb.SearchResultData, 0) - for _, partialSearchResult := range searchResults { - if partialSearchResult.SlicedBlob == nil { - continue - } - - var partialResultData schemapb.SearchResultData - err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData) - if err != nil { - return nil, err - } - results = append(results, &partialResultData) - } - tr.CtxElapse(ctx, "decodeSearchResults done") - return results, nil -} - -func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error { - if data.NumQueries != nq { - return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq) - } - if data.TopK != topk { - return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk) - } - - if len(data.Scores) != pkHitNum { - return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d", - len(data.Scores), pkHitNum) - } - return nil -} - -func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) { - var ( - subSearchIdx = -1 - resultDataIdx int64 = -1 - ) - maxScore := minFloat32 - for i := range cursors { - if cursors[i] >= subSearchResultData[i].Topks[qi] { - continue - } - sIdx := subSearchNqOffset[i][qi] + cursors[i] - sScore := subSearchResultData[i].Scores[sIdx] - - // Choose the larger score idx or the smaller pk idx with the same score - if subSearchIdx == -1 || sScore > maxScore { - subSearchIdx = i - resultDataIdx = sIdx - maxScore = sScore - } else if sScore == maxScore { - if subSearchIdx == -1 { - // A bad case happens where Knowhere returns distance/score == +/-maxFloat32 - // by mistake. - log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore)) - } else if typeutil.ComparePK( - typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx), - typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) { - subSearchIdx = i - resultDataIdx = sIdx - maxScore = sScore - } - } - } - return subSearchIdx, resultDataIdx -} - func (t *searchTask) TraceCtx() context.Context { return t.ctx } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 3b694e4d15..ca705fad87 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -19,12 +19,12 @@ import ( "context" "fmt" "math" - "slices" "strconv" "strings" "testing" "time" + "github.com/bytedance/mockey" "github.com/cockroachdb/errors" "github.com/google/uuid" "github.com/samber/lo" @@ -32,7 +32,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -394,12 +393,10 @@ func TestSearchTask_PostExecute(t *testing.T) { f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20) f3.FieldId = fieldNameId[testInt64Field] - qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { - return &milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{f1, f2, f3}, - PrimaryFieldName: testInt64Field, - }, nil - } + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2, f3}, + }, nil).Build() + defer mocker.UnPatch() err := qt.PostExecute(context.TODO()) assert.NoError(t, err) @@ -436,12 +433,10 @@ func TestSearchTask_PostExecute(t *testing.T) { f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20) f3.FieldId = fieldNameId[testInt64Field] - qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { - return &milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{f1, f2, f3}, - PrimaryFieldName: testInt64Field, - }, nil - } + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2, f3}, + }, nil).Build() + defer mocker.UnPatch() err := qt.PostExecute(context.TODO()) assert.NoError(t, err) @@ -478,12 +473,11 @@ func TestSearchTask_PostExecute(t *testing.T) { f3.FieldId = fieldNameId[testInt64Field] f4 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20) f4.FieldId = fieldNameId[testFloatField] - qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { - return &milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{f1, f2, f3, f4}, - PrimaryFieldName: testInt64Field, - }, nil - } + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2, f3, f4}, + }, nil).Build() + defer mocker.UnPatch() + err := qt.PostExecute(context.TODO()) assert.NoError(t, err) assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks) @@ -505,53 +499,6 @@ func TestSearchTask_PostExecute(t *testing.T) { } } }) - - t.Run("Test mergeIDs function", func(t *testing.T) { - { - ids1 := &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{1, 2, 3, 5}, - }, - }, - } - - ids2 := &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{1, 2, 4, 5, 6}, - }, - }, - } - allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2}) - assert.Equal(t, count, 6) - sortedIds := allIDs.GetIntId().GetData() - slices.Sort(sortedIds) - assert.Equal(t, sortedIds, []int64{1, 2, 3, 4, 5, 6}) - } - { - ids1 := &schemapb.IDs{ - IdField: &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: []string{"a", "b", "e"}, - }, - }, - } - - ids2 := &schemapb.IDs{ - IdField: &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: []string{"a", "b", "c", "d"}, - }, - }, - } - allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2}) - assert.Equal(t, count, 5) - sortedIds := allIDs.GetStrId().GetData() - slices.Sort(sortedIds) - assert.Equal(t, sortedIds, []string{"a", "b", "c", "d", "e"}) - } - }) } func createCollWithFields(t *testing.T, collName string, rc types.MixCoordClient) (*schemapb.CollectionSchema, map[string]int64) { @@ -3743,9 +3690,10 @@ func TestSearchTask_Requery(t *testing.T) { tr: timerecord.NewTimeRecorder("search"), node: node, translatedOutputFields: outputFields, - requeryFunc: requeryImpl, } - queryResult, err := qt.requeryFunc(qt, nil, qt.result.Results.Ids, outputFields) + op, err := newRequeryOperator(qt, nil) + assert.NoError(t, err) + queryResult, err := op.(*requeryOperator).requery(ctx, nil, qt.result.Results.Ids, outputFields) assert.NoError(t, err) assert.Len(t, queryResult.FieldsData, 2) for _, field := range qt.result.Results.FieldsData { @@ -3768,14 +3716,13 @@ func TestSearchTask_Requery(t *testing.T) { SourceID: paramtable.GetNodeID(), }, }, - request: &milvuspb.SearchRequest{}, - schema: schema, - tr: timerecord.NewTimeRecorder("search"), - node: node, - requeryFunc: requeryImpl, + request: &milvuspb.SearchRequest{}, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, } - _, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{}) + _, err := newRequeryOperator(qt, nil) t.Logf("err = %s", err) assert.Error(t, err) }) @@ -3804,13 +3751,14 @@ func TestSearchTask_Requery(t *testing.T) { request: &milvuspb.SearchRequest{ CollectionName: collectionName, }, - schema: schema, - tr: timerecord.NewTimeRecorder("search"), - node: node, - requeryFunc: requeryImpl, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, } - _, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{}) + op, err := newRequeryOperator(qt, nil) + assert.NoError(t, err) + _, err = op.(*requeryOperator).requery(ctx, nil, &schemapb.IDs{}, []string{}) t.Logf("err = %s", err) assert.Error(t, err) }) diff --git a/internal/util/credentials/credentials.go b/internal/util/credentials/credentials.go index 4166586c64..8f92f221b0 100644 --- a/internal/util/credentials/credentials.go +++ b/internal/util/credentials/credentials.go @@ -34,7 +34,7 @@ const ( // The current version only supports plain text, and cipher text will be supported later. type Credentials struct { // key formats: - // {credentialName}.api_key + // {credentialName}.apikey // {credentialName}.access_key_id // {credentialName}.secret_access_key // {credentialName}.credential_json diff --git a/internal/util/function/rerank/function_score.go b/internal/util/function/rerank/function_score.go index 174034f173..9296fa89c6 100644 --- a/internal/util/function/rerank/function_score.go +++ b/internal/util/function/rerank/function_score.go @@ -234,7 +234,7 @@ func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchPa FieldsData: make([]*schemapb.FieldData, 0), Scores: []float32{}, Ids: &schemapb.IDs{}, - Topks: []int64{}, + Topks: make([]int64, searchParams.nq), }, }, nil } @@ -280,3 +280,10 @@ func (fScore *FunctionScore) IsSupportGroup() bool { } return fScore.reranker.IsSupportGroup() } + +func (fScore *FunctionScore) RerankName() string { + if fScore == nil { + return "" + } + return fScore.reranker.GetRankName() +} diff --git a/internal/util/function/rerank/function_score_test.go b/internal/util/function/rerank/function_score_test.go index 09cb7b2d01..5a01c5b9c6 100644 --- a/internal/util/function/rerank/function_score_test.go +++ b/internal/util/function/rerank/function_score_test.go @@ -322,7 +322,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() { rankParams := []*commonpb.KeyValuePair{} f, err := NewFunctionScoreWithlegacy(schema, rankParams) s.NoError(err) - s.Equal(f.reranker.GetRankName(), rrfName) + s.Equal(f.RerankName(), rrfName) } { rankParams := []*commonpb.KeyValuePair{ diff --git a/tests/python_client/milvus_client/test_milvus_client_hybrid_search.py b/tests/python_client/milvus_client/test_milvus_client_hybrid_search.py index ddd1ffbe51..1b705fc8c4 100644 --- a/tests/python_client/milvus_client/test_milvus_client_hybrid_search.py +++ b/tests/python_client/milvus_client/test_milvus_client_hybrid_search.py @@ -289,8 +289,14 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base): collection_name = cf.gen_unique_str(prefix) # 1. create collection self.create_collection(client, collection_name, default_dim) - # 2. hybrid search + # 2. insert rng = np.random.default_rng(seed=19530) + rows = [ + {default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_vector_field_name+"new": list(rng.random((1, default_dim))[0]), + default_string_field_name: str(i)} for i in range(default_nb)] + self.insert(client, collection_name, rows) + # 2. hybrid search vectors_to_search = rng.random((1, default_dim)) sub_search1 = AnnSearchRequest(vectors_to_search, "vector", {"level": 1}, 20, expr="id<100") ranker = WeightedRanker(0.2, 0.8)