fix: Sum AllSearchCount from multiple search results (#45914)

https://github.com/milvus-io/milvus/issues/45842

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-12-01 14:33:10 +08:00 committed by GitHub
parent d5bd17315c
commit dff62c5423
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 45 deletions

View File

@ -110,11 +110,17 @@ const (
rerankOp = "rerank"
requeryOp = "requery"
organizeOp = "organize"
filterFieldOp = "filter_field"
endOp = "end"
lambdaOp = "lambda"
highlightOp = "highlight"
)
const (
pipelineOutput = "output"
pipelineInput = "input"
pipelineStorageCost = "storage_cost"
)
var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, error){
searchReduceOp: newSearchReduceOperator,
hybridSearchReduceOp: newHybridSearchReduceOperator,
@ -122,7 +128,7 @@ var opFactory = map[string]func(t *searchTask, params map[string]any) (operator,
organizeOp: newOrganizeOperator,
requeryOp: newRequeryOperator,
lambdaOp: newLambdaOperator,
filterFieldOp: newFilterFieldOperator,
endOp: newEndOperator,
highlightOp: newHighlightOperator,
}
@ -565,19 +571,19 @@ func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...an
return op.f(ctx, span, inputs...)
}
type filterFieldOperator struct {
type endOperator struct {
outputFieldNames []string
fieldSchemas []*schemapb.FieldSchema
}
func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) {
return &filterFieldOperator{
func newEndOperator(t *searchTask, _ map[string]any) (operator, error) {
return &endOperator{
outputFieldNames: t.translatedOutputFields,
fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
}, nil
}
func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
func (op *endOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
for _, retField := range result.Results.FieldsData {
for _, fieldSchema := range op.fieldSchemas {
@ -591,6 +597,8 @@ func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs
result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(field *schemapb.FieldData, _ int) bool {
return lo.Contains(op.outputFieldNames, field.FieldName)
})
allSearchCount := aggregatedAllSearchCount(inputs[1].([]*milvuspb.SearchResults))
result.GetResults().AllSearchCount = allSearchCount
return []any{result}, nil
}
@ -668,8 +676,8 @@ func (p *pipeline) AddNodes(t *searchTask, nodes ...*nodeDef) error {
func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults, storageCost segcore.StorageCost) (*milvuspb.SearchResults, segcore.StorageCost, error) {
log.Ctx(ctx).Debug("SearchPipeline run", zap.String("pipeline", p.String()))
msg := opMsg{}
msg["input"] = toReduceResults
msg["storage_cost"] = storageCost
msg[pipelineInput] = toReduceResults
msg[pipelineStorageCost] = storageCost
for _, node := range p.nodes {
var err error
log.Ctx(ctx).Debug("SearchPipeline run node", zap.String("node", node.name))
@ -679,7 +687,7 @@ func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*
return nil, storageCost, err
}
}
return msg["output"].(*milvuspb.SearchResults), msg["storage_cost"].(segcore.StorageCost), nil
return msg[pipelineOutput].(*milvuspb.SearchResults), msg[pipelineStorageCost].(segcore.StorageCost), nil
}
func (p *pipeline) String() string {
@ -695,17 +703,17 @@ type pipelineDef struct {
nodes []*nodeDef
}
var filterFieldNode = &nodeDef{
var endNode = &nodeDef{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
inputs: []string{"result", "reduced"},
outputs: []string{pipelineOutput},
opName: endOp,
}
var highlightNode = &nodeDef{
name: "highlight",
inputs: []string{"result"},
outputs: []string{"output"},
outputs: []string{pipelineOutput},
opName: highlightOp,
}
@ -714,7 +722,7 @@ var searchPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
@ -738,7 +746,7 @@ var searchWithRequeryPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
@ -753,8 +761,8 @@ var searchWithRequeryPipe = &pipelineDef{
},
{
name: "requery",
inputs: []string{"unique_ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
inputs: []string{"unique_ids", pipelineStorageCost},
outputs: []string{"fields", pipelineStorageCost},
opName: requeryOp,
},
{
@ -796,7 +804,7 @@ var searchWithRerankPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
@ -848,7 +856,7 @@ var searchWithRerankRequeryPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: searchReduceOp,
},
@ -873,8 +881,8 @@ var searchWithRerankRequeryPipe = &pipelineDef{
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
inputs: []string{"ids", pipelineStorageCost},
outputs: []string{"fields", pipelineStorageCost},
opName: requeryOp,
},
{
@ -916,7 +924,7 @@ var hybridSearchPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
@ -934,7 +942,7 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
@ -949,8 +957,8 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
inputs: []string{"ids", pipelineStorageCost},
outputs: []string{"fields", pipelineStorageCost},
opName: requeryOp,
},
{
@ -1035,7 +1043,7 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
inputs: []string{pipelineInput, pipelineStorageCost},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
@ -1060,8 +1068,8 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
inputs: []string{"ids", pipelineStorageCost},
outputs: []string{"fields", pipelineStorageCost},
opName: requeryOp,
},
{
@ -1086,9 +1094,9 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
inputs: []string{"result", "reduced"},
outputs: []string{pipelineOutput},
opName: endOp,
},
},
}
@ -1131,15 +1139,25 @@ func newSearchPipeline(t *searchTask) (*pipeline, error) {
}
if t.highlighter != nil {
err := p.AddNodes(t, highlightNode, filterFieldNode)
err := p.AddNodes(t, highlightNode, endNode)
if err != nil {
return nil, err
}
} else {
err := p.AddNodes(t, filterFieldNode)
err := p.AddNodes(t, endNode)
if err != nil {
return nil, err
}
}
return p, nil
}
func aggregatedAllSearchCount(searchResults []*milvuspb.SearchResults) int64 {
allSearchCount := int64(0)
for _, sr := range searchResults {
if sr != nil && sr.GetResults() != nil {
allSearchCount += sr.GetResults().GetAllSearchCount()
}
}
return allSearchCount
}

View File

@ -394,7 +394,7 @@ func (s *SearchPipelineSuite) TestSearchPipeline() {
pipeline, err := newPipeline(searchPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
sr := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{sr}, segcore.StorageCost{ScannedRemoteBytes: 100, ScannedTotalBytes: 250})
@ -415,6 +415,7 @@ func (s *SearchPipelineSuite) TestSearchPipeline() {
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
s.Equal(int64(100), storageCost.ScannedRemoteBytes)
s.Equal(int64(250), storageCost.ScannedTotalBytes)
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
fmt.Println(results)
}
@ -462,7 +463,7 @@ func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() {
pipeline, err := newPipeline(searchWithRequeryPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{
genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false),
@ -484,6 +485,7 @@ func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() {
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
s.Equal(int64(200), storageCost.ScannedRemoteBytes)
s.Equal(int64(400), storageCost.ScannedTotalBytes)
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
}
func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
@ -535,7 +537,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
pipeline, err := newPipeline(searchWithRerankPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}, segcore.StorageCost{})
@ -555,6 +557,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
}
func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
@ -619,7 +622,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
pipeline, err := newPipeline(searchWithRerankRequeryPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}, segcore.StorageCost{})
@ -641,6 +644,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
s.Equal(int64(0), storageCost.ScannedRemoteBytes)
s.Equal(int64(0), storageCost.ScannedTotalBytes)
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
}
func (s *SearchPipelineSuite) TestHybridSearchPipe() {
@ -653,7 +657,7 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() {
pipeline, err := newPipeline(hybridSearchPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
f1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
f2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
@ -672,6 +676,7 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() {
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.Equal(int64(900), storageCost.ScannedRemoteBytes)
s.Equal(int64(2000), storageCost.ScannedTotalBytes)
s.Equal(int64(2*2*10), results.GetResults().AllSearchCount)
}
func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
@ -699,7 +704,7 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
translatedOutputFields: []string{"intField", "floatField", "structArrayField", "structVectorField"},
}
op, err := newFilterFieldOperator(task, nil)
op, err := newEndOperator(task, nil)
s.NoError(err)
// Create mock search results with fields including struct array fields
@ -714,7 +719,7 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
},
}
results, err := op.run(context.Background(), s.span, searchResults)
results, err := op.run(context.Background(), s.span, searchResults, []*milvuspb.SearchResults{{Results: &schemapb.SearchResultData{AllSearchCount: 0}}})
s.NoError(err)
s.NotNil(results)
@ -766,8 +771,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{})
@ -787,6 +791,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
s.Equal(int64(2*2*10), results.GetResults().AllSearchCount)
}
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
@ -808,7 +813,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
pipeline, err := newPipeline(hybridSearchWithRequeryPipe, task)
s.NoError(err)
pipeline.AddNodes(task, filterFieldNode)
pipeline.AddNodes(task, endNode)
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
@ -829,6 +834,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
s.Equal(int64(2*2*10), results.GetResults().AllSearchCount)
}
func getHybridSearchTask(collName string, data [][]string, outputFields []string) *searchTask {

View File

@ -4760,6 +4760,7 @@ func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fiel
testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk)),
testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, int(nq*topk)),
},
AllSearchCount: nq * topk,
}
resultData.FieldsData[0].FieldId = fieldId
sliceBlob, _ := proto.Marshal(resultData)

View File

@ -676,7 +676,7 @@ func TestHybridSearchTextEmbeddingBM25(t *testing.T) {
}
// create collection
err := mc.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema))
err := mc.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// insert test data with diverse content