diff --git a/internal/proxy/search_pipeline.go b/internal/proxy/search_pipeline.go index 0a7af3a4be..f834ab97bb 100644 --- a/internal/proxy/search_pipeline.go +++ b/internal/proxy/search_pipeline.go @@ -110,11 +110,17 @@ const ( rerankOp = "rerank" requeryOp = "requery" organizeOp = "organize" - filterFieldOp = "filter_field" + endOp = "end" lambdaOp = "lambda" highlightOp = "highlight" ) +const ( + pipelineOutput = "output" + pipelineInput = "input" + pipelineStorageCost = "storage_cost" +) + var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, error){ searchReduceOp: newSearchReduceOperator, hybridSearchReduceOp: newHybridSearchReduceOperator, @@ -122,7 +128,7 @@ var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, organizeOp: newOrganizeOperator, requeryOp: newRequeryOperator, lambdaOp: newLambdaOperator, - filterFieldOp: newFilterFieldOperator, + endOp: newEndOperator, highlightOp: newHighlightOperator, } @@ -565,19 +571,19 @@ func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...an return op.f(ctx, span, inputs...) } -type filterFieldOperator struct { +type endOperator struct { outputFieldNames []string fieldSchemas []*schemapb.FieldSchema } -func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) { - return &filterFieldOperator{ +func newEndOperator(t *searchTask, _ map[string]any) (operator, error) { + return &endOperator{ outputFieldNames: t.translatedOutputFields, fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema), }, nil } -func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { +func (op *endOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { result := inputs[0].(*milvuspb.SearchResults) for _, retField := range result.Results.FieldsData { for _, fieldSchema := range op.fieldSchemas { @@ -591,6 +597,8 @@ func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(field *schemapb.FieldData, _ int) bool { return lo.Contains(op.outputFieldNames, field.FieldName) }) + allSearchCount := aggregatedAllSearchCount(inputs[1].([]*milvuspb.SearchResults)) + result.GetResults().AllSearchCount = allSearchCount return []any{result}, nil } @@ -668,8 +676,8 @@ func (p *pipeline) AddNodes(t *searchTask, nodes ...*nodeDef) error { func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults, storageCost segcore.StorageCost) (*milvuspb.SearchResults, segcore.StorageCost, error) { log.Ctx(ctx).Debug("SearchPipeline run", zap.String("pipeline", p.String())) msg := opMsg{} - msg["input"] = toReduceResults - msg["storage_cost"] = storageCost + msg[pipelineInput] = toReduceResults + msg[pipelineStorageCost] = storageCost for _, node := range p.nodes { var err error log.Ctx(ctx).Debug("SearchPipeline run node", zap.String("node", node.name)) @@ -679,7 +687,7 @@ func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []* return nil, storageCost, err } } - return msg["output"].(*milvuspb.SearchResults), msg["storage_cost"].(segcore.StorageCost), nil + return msg[pipelineOutput].(*milvuspb.SearchResults), msg[pipelineStorageCost].(segcore.StorageCost), nil } func (p *pipeline) String() string { @@ -695,17 +703,17 @@ type pipelineDef struct { nodes []*nodeDef } -var filterFieldNode = &nodeDef{ +var endNode = &nodeDef{ name: "filter_field", - inputs: []string{"result"}, - outputs: []string{"output"}, - opName: filterFieldOp, + inputs: []string{"result", "reduced"}, + outputs: []string{pipelineOutput}, + opName: endOp, } var highlightNode = &nodeDef{ name: "highlight", inputs: []string{"result"}, - outputs: []string{"output"}, + outputs: []string{pipelineOutput}, opName: highlightOp, } @@ -714,7 +722,7 @@ var searchPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: searchReduceOp, }, @@ -738,7 +746,7 @@ var searchWithRequeryPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: searchReduceOp, }, @@ -753,8 +761,8 @@ var searchWithRequeryPipe = &pipelineDef{ }, { name: "requery", - inputs: []string{"unique_ids", "storage_cost"}, - outputs: []string{"fields", "storage_cost"}, + inputs: []string{"unique_ids", pipelineStorageCost}, + outputs: []string{"fields", pipelineStorageCost}, opName: requeryOp, }, { @@ -796,7 +804,7 @@ var searchWithRerankPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: searchReduceOp, }, @@ -848,7 +856,7 @@ var searchWithRerankRequeryPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: searchReduceOp, }, @@ -873,8 +881,8 @@ var searchWithRerankRequeryPipe = &pipelineDef{ }, { name: "requery", - inputs: []string{"ids", "storage_cost"}, - outputs: []string{"fields", "storage_cost"}, + inputs: []string{"ids", pipelineStorageCost}, + outputs: []string{"fields", pipelineStorageCost}, opName: requeryOp, }, { @@ -916,7 +924,7 @@ var hybridSearchPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: hybridSearchReduceOp, }, @@ -934,7 +942,7 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: hybridSearchReduceOp, }, @@ -949,8 +957,8 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{ }, { name: "requery", - inputs: []string{"ids", "storage_cost"}, - outputs: []string{"fields", "storage_cost"}, + inputs: []string{"ids", pipelineStorageCost}, + outputs: []string{"fields", pipelineStorageCost}, opName: requeryOp, }, { @@ -1035,7 +1043,7 @@ var hybridSearchWithRequeryPipe = &pipelineDef{ nodes: []*nodeDef{ { name: "reduce", - inputs: []string{"input", "storage_cost"}, + inputs: []string{pipelineInput, pipelineStorageCost}, outputs: []string{"reduced", "metrics"}, opName: hybridSearchReduceOp, }, @@ -1060,8 +1068,8 @@ var hybridSearchWithRequeryPipe = &pipelineDef{ }, { name: "requery", - inputs: []string{"ids", "storage_cost"}, - outputs: []string{"fields", "storage_cost"}, + inputs: []string{"ids", pipelineStorageCost}, + outputs: []string{"fields", pipelineStorageCost}, opName: requeryOp, }, { @@ -1086,9 +1094,9 @@ var hybridSearchWithRequeryPipe = &pipelineDef{ }, { name: "filter_field", - inputs: []string{"result"}, - outputs: []string{"output"}, - opName: filterFieldOp, + inputs: []string{"result", "reduced"}, + outputs: []string{pipelineOutput}, + opName: endOp, }, }, } @@ -1131,15 +1139,25 @@ func newSearchPipeline(t *searchTask) (*pipeline, error) { } if t.highlighter != nil { - err := p.AddNodes(t, highlightNode, filterFieldNode) + err := p.AddNodes(t, highlightNode, endNode) if err != nil { return nil, err } } else { - err := p.AddNodes(t, filterFieldNode) + err := p.AddNodes(t, endNode) if err != nil { return nil, err } } return p, nil } + +func aggregatedAllSearchCount(searchResults []*milvuspb.SearchResults) int64 { + allSearchCount := int64(0) + for _, sr := range searchResults { + if sr != nil && sr.GetResults() != nil { + allSearchCount += sr.GetResults().GetAllSearchCount() + } + } + return allSearchCount +} diff --git a/internal/proxy/search_pipeline_test.go b/internal/proxy/search_pipeline_test.go index d175b7bdce..bd618fcdbe 100644 --- a/internal/proxy/search_pipeline_test.go +++ b/internal/proxy/search_pipeline_test.go @@ -394,7 +394,7 @@ func (s *SearchPipelineSuite) TestSearchPipeline() { pipeline, err := newPipeline(searchPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) + pipeline.AddNodes(task, endNode) sr := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false) results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{sr}, segcore.StorageCost{ScannedRemoteBytes: 100, ScannedTotalBytes: 250}) @@ -415,6 +415,7 @@ func (s *SearchPipelineSuite) TestSearchPipeline() { s.Equal(int64(101), results.Results.FieldsData[0].FieldId) s.Equal(int64(100), storageCost.ScannedRemoteBytes) s.Equal(int64(250), storageCost.ScannedTotalBytes) + s.Equal(int64(2*10), results.GetResults().AllSearchCount) fmt.Println(results) } @@ -462,7 +463,7 @@ func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() { pipeline, err := newPipeline(searchWithRequeryPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) + pipeline.AddNodes(task, endNode) results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{ genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false), @@ -484,6 +485,7 @@ func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() { s.Equal(int64(101), results.Results.FieldsData[0].FieldId) s.Equal(int64(200), storageCost.ScannedRemoteBytes) s.Equal(int64(400), storageCost.ScannedTotalBytes) + s.Equal(int64(2*10), results.GetResults().AllSearchCount) } func (s *SearchPipelineSuite) TestSearchWithRerankPipe() { @@ -535,7 +537,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankPipe() { pipeline, err := newPipeline(searchWithRerankPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) + pipeline.AddNodes(task, endNode) searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false) results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}, segcore.StorageCost{}) @@ -555,6 +557,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankPipe() { s.Len(results.Results.FieldsData, 1) // One output field s.Equal("intField", results.Results.FieldsData[0].FieldName) s.Equal(int64(101), results.Results.FieldsData[0].FieldId) + s.Equal(int64(2*10), results.GetResults().AllSearchCount) } func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() { @@ -619,7 +622,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() { pipeline, err := newPipeline(searchWithRerankRequeryPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) + pipeline.AddNodes(task, endNode) searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false) results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}, segcore.StorageCost{}) @@ -641,6 +644,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() { s.Equal(int64(101), results.Results.FieldsData[0].FieldId) s.Equal(int64(0), storageCost.ScannedRemoteBytes) s.Equal(int64(0), storageCost.ScannedTotalBytes) + s.Equal(int64(2*10), results.GetResults().AllSearchCount) } func (s *SearchPipelineSuite) TestHybridSearchPipe() { @@ -653,7 +657,7 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() { pipeline, err := newPipeline(hybridSearchPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) + pipeline.AddNodes(task, endNode) f1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) f2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) @@ -672,6 +676,7 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() { s.Len(results.Results.Scores, 20) // 2 queries * 10 topk s.Equal(int64(900), storageCost.ScannedRemoteBytes) s.Equal(int64(2000), storageCost.ScannedTotalBytes) + s.Equal(int64(2*2*10), results.GetResults().AllSearchCount) } func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() { @@ -699,7 +704,7 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() { translatedOutputFields: []string{"intField", "floatField", "structArrayField", "structVectorField"}, } - op, err := newFilterFieldOperator(task, nil) + op, err := newEndOperator(task, nil) s.NoError(err) // Create mock search results with fields including struct array fields @@ -714,7 +719,7 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() { }, } - results, err := op.run(context.Background(), s.span, searchResults) + results, err := op.run(context.Background(), s.span, searchResults, []*milvuspb.SearchResults{{Results: &schemapb.SearchResultData{AllSearchCount: 0}}}) s.NoError(err) s.NotNil(results) @@ -766,8 +771,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() { pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) - + pipeline.AddNodes(task, endNode) d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{}) @@ -787,6 +791,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() { s.Len(results.Results.FieldsData, 1) // One output field s.Equal("intField", results.Results.FieldsData[0].FieldName) s.Equal(int64(101), results.Results.FieldsData[0].FieldId) + s.Equal(int64(2*2*10), results.GetResults().AllSearchCount) } func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { @@ -808,7 +813,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { pipeline, err := newPipeline(hybridSearchWithRequeryPipe, task) s.NoError(err) - pipeline.AddNodes(task, filterFieldNode) + pipeline.AddNodes(task, endNode) d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) @@ -829,6 +834,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { s.Len(results.Results.FieldsData, 1) // One output field s.Equal("intField", results.Results.FieldsData[0].FieldName) s.Equal(int64(101), results.Results.FieldsData[0].FieldId) + s.Equal(int64(2*2*10), results.GetResults().AllSearchCount) } func getHybridSearchTask(collName string, data [][]string, outputFields []string) *searchTask { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index eab119b922..9f6579baa0 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -4760,6 +4760,7 @@ func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fiel testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk)), testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, int(nq*topk)), }, + AllSearchCount: nq * topk, } resultData.FieldsData[0].FieldId = fieldId sliceBlob, _ := proto.Marshal(resultData) diff --git a/tests/go_client/testcases/text_embedding_test.go b/tests/go_client/testcases/text_embedding_test.go index fb10760d94..d4512f0a31 100644 --- a/tests/go_client/testcases/text_embedding_test.go +++ b/tests/go_client/testcases/text_embedding_test.go @@ -676,7 +676,7 @@ func TestHybridSearchTextEmbeddingBM25(t *testing.T) { } // create collection - err := mc.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema)) + err := mc.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) // insert test data with diverse content