mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 09:08:43 +08:00
fix: Sum AllSearchCount from multiple search results (#45914)
https://github.com/milvus-io/milvus/issues/45842 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
d5bd17315c
commit
dff62c5423
@ -110,11 +110,17 @@ const (
|
||||
rerankOp = "rerank"
|
||||
requeryOp = "requery"
|
||||
organizeOp = "organize"
|
||||
filterFieldOp = "filter_field"
|
||||
endOp = "end"
|
||||
lambdaOp = "lambda"
|
||||
highlightOp = "highlight"
|
||||
)
|
||||
|
||||
const (
|
||||
pipelineOutput = "output"
|
||||
pipelineInput = "input"
|
||||
pipelineStorageCost = "storage_cost"
|
||||
)
|
||||
|
||||
var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, error){
|
||||
searchReduceOp: newSearchReduceOperator,
|
||||
hybridSearchReduceOp: newHybridSearchReduceOperator,
|
||||
@ -122,7 +128,7 @@ var opFactory = map[string]func(t *searchTask, params map[string]any) (operator,
|
||||
organizeOp: newOrganizeOperator,
|
||||
requeryOp: newRequeryOperator,
|
||||
lambdaOp: newLambdaOperator,
|
||||
filterFieldOp: newFilterFieldOperator,
|
||||
endOp: newEndOperator,
|
||||
highlightOp: newHighlightOperator,
|
||||
}
|
||||
|
||||
@ -565,19 +571,19 @@ func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...an
|
||||
return op.f(ctx, span, inputs...)
|
||||
}
|
||||
|
||||
type filterFieldOperator struct {
|
||||
type endOperator struct {
|
||||
outputFieldNames []string
|
||||
fieldSchemas []*schemapb.FieldSchema
|
||||
}
|
||||
|
||||
func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) {
|
||||
return &filterFieldOperator{
|
||||
func newEndOperator(t *searchTask, _ map[string]any) (operator, error) {
|
||||
return &endOperator{
|
||||
outputFieldNames: t.translatedOutputFields,
|
||||
fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||
func (op *endOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||
result := inputs[0].(*milvuspb.SearchResults)
|
||||
for _, retField := range result.Results.FieldsData {
|
||||
for _, fieldSchema := range op.fieldSchemas {
|
||||
@ -591,6 +597,8 @@ func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs
|
||||
result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(field *schemapb.FieldData, _ int) bool {
|
||||
return lo.Contains(op.outputFieldNames, field.FieldName)
|
||||
})
|
||||
allSearchCount := aggregatedAllSearchCount(inputs[1].([]*milvuspb.SearchResults))
|
||||
result.GetResults().AllSearchCount = allSearchCount
|
||||
return []any{result}, nil
|
||||
}
|
||||
|
||||
@ -668,8 +676,8 @@ func (p *pipeline) AddNodes(t *searchTask, nodes ...*nodeDef) error {
|
||||
func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults, storageCost segcore.StorageCost) (*milvuspb.SearchResults, segcore.StorageCost, error) {
|
||||
log.Ctx(ctx).Debug("SearchPipeline run", zap.String("pipeline", p.String()))
|
||||
msg := opMsg{}
|
||||
msg["input"] = toReduceResults
|
||||
msg["storage_cost"] = storageCost
|
||||
msg[pipelineInput] = toReduceResults
|
||||
msg[pipelineStorageCost] = storageCost
|
||||
for _, node := range p.nodes {
|
||||
var err error
|
||||
log.Ctx(ctx).Debug("SearchPipeline run node", zap.String("node", node.name))
|
||||
@ -679,7 +687,7 @@ func (p *pipeline) Run(ctx context.Context, span trace.Span, toReduceResults []*
|
||||
return nil, storageCost, err
|
||||
}
|
||||
}
|
||||
return msg["output"].(*milvuspb.SearchResults), msg["storage_cost"].(segcore.StorageCost), nil
|
||||
return msg[pipelineOutput].(*milvuspb.SearchResults), msg[pipelineStorageCost].(segcore.StorageCost), nil
|
||||
}
|
||||
|
||||
func (p *pipeline) String() string {
|
||||
@ -695,17 +703,17 @@ type pipelineDef struct {
|
||||
nodes []*nodeDef
|
||||
}
|
||||
|
||||
var filterFieldNode = &nodeDef{
|
||||
var endNode = &nodeDef{
|
||||
name: "filter_field",
|
||||
inputs: []string{"result"},
|
||||
outputs: []string{"output"},
|
||||
opName: filterFieldOp,
|
||||
inputs: []string{"result", "reduced"},
|
||||
outputs: []string{pipelineOutput},
|
||||
opName: endOp,
|
||||
}
|
||||
|
||||
var highlightNode = &nodeDef{
|
||||
name: "highlight",
|
||||
inputs: []string{"result"},
|
||||
outputs: []string{"output"},
|
||||
outputs: []string{pipelineOutput},
|
||||
opName: highlightOp,
|
||||
}
|
||||
|
||||
@ -714,7 +722,7 @@ var searchPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: searchReduceOp,
|
||||
},
|
||||
@ -738,7 +746,7 @@ var searchWithRequeryPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: searchReduceOp,
|
||||
},
|
||||
@ -753,8 +761,8 @@ var searchWithRequeryPipe = &pipelineDef{
|
||||
},
|
||||
{
|
||||
name: "requery",
|
||||
inputs: []string{"unique_ids", "storage_cost"},
|
||||
outputs: []string{"fields", "storage_cost"},
|
||||
inputs: []string{"unique_ids", pipelineStorageCost},
|
||||
outputs: []string{"fields", pipelineStorageCost},
|
||||
opName: requeryOp,
|
||||
},
|
||||
{
|
||||
@ -796,7 +804,7 @@ var searchWithRerankPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: searchReduceOp,
|
||||
},
|
||||
@ -848,7 +856,7 @@ var searchWithRerankRequeryPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: searchReduceOp,
|
||||
},
|
||||
@ -873,8 +881,8 @@ var searchWithRerankRequeryPipe = &pipelineDef{
|
||||
},
|
||||
{
|
||||
name: "requery",
|
||||
inputs: []string{"ids", "storage_cost"},
|
||||
outputs: []string{"fields", "storage_cost"},
|
||||
inputs: []string{"ids", pipelineStorageCost},
|
||||
outputs: []string{"fields", pipelineStorageCost},
|
||||
opName: requeryOp,
|
||||
},
|
||||
{
|
||||
@ -916,7 +924,7 @@ var hybridSearchPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: hybridSearchReduceOp,
|
||||
},
|
||||
@ -934,7 +942,7 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: hybridSearchReduceOp,
|
||||
},
|
||||
@ -949,8 +957,8 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
|
||||
},
|
||||
{
|
||||
name: "requery",
|
||||
inputs: []string{"ids", "storage_cost"},
|
||||
outputs: []string{"fields", "storage_cost"},
|
||||
inputs: []string{"ids", pipelineStorageCost},
|
||||
outputs: []string{"fields", pipelineStorageCost},
|
||||
opName: requeryOp,
|
||||
},
|
||||
{
|
||||
@ -1035,7 +1043,7 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
inputs: []string{pipelineInput, pipelineStorageCost},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: hybridSearchReduceOp,
|
||||
},
|
||||
@ -1060,8 +1068,8 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||
},
|
||||
{
|
||||
name: "requery",
|
||||
inputs: []string{"ids", "storage_cost"},
|
||||
outputs: []string{"fields", "storage_cost"},
|
||||
inputs: []string{"ids", pipelineStorageCost},
|
||||
outputs: []string{"fields", pipelineStorageCost},
|
||||
opName: requeryOp,
|
||||
},
|
||||
{
|
||||
@ -1086,9 +1094,9 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||
},
|
||||
{
|
||||
name: "filter_field",
|
||||
inputs: []string{"result"},
|
||||
outputs: []string{"output"},
|
||||
opName: filterFieldOp,
|
||||
inputs: []string{"result", "reduced"},
|
||||
outputs: []string{pipelineOutput},
|
||||
opName: endOp,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -1131,15 +1139,25 @@ func newSearchPipeline(t *searchTask) (*pipeline, error) {
|
||||
}
|
||||
|
||||
if t.highlighter != nil {
|
||||
err := p.AddNodes(t, highlightNode, filterFieldNode)
|
||||
err := p.AddNodes(t, highlightNode, endNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
err := p.AddNodes(t, filterFieldNode)
|
||||
err := p.AddNodes(t, endNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func aggregatedAllSearchCount(searchResults []*milvuspb.SearchResults) int64 {
|
||||
allSearchCount := int64(0)
|
||||
for _, sr := range searchResults {
|
||||
if sr != nil && sr.GetResults() != nil {
|
||||
allSearchCount += sr.GetResults().GetAllSearchCount()
|
||||
}
|
||||
}
|
||||
return allSearchCount
|
||||
}
|
||||
|
||||
@ -394,7 +394,7 @@ func (s *SearchPipelineSuite) TestSearchPipeline() {
|
||||
|
||||
pipeline, err := newPipeline(searchPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
pipeline.AddNodes(task, endNode)
|
||||
|
||||
sr := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
|
||||
results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{sr}, segcore.StorageCost{ScannedRemoteBytes: 100, ScannedTotalBytes: 250})
|
||||
@ -415,6 +415,7 @@ func (s *SearchPipelineSuite) TestSearchPipeline() {
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
s.Equal(int64(100), storageCost.ScannedRemoteBytes)
|
||||
s.Equal(int64(250), storageCost.ScannedTotalBytes)
|
||||
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
|
||||
fmt.Println(results)
|
||||
}
|
||||
|
||||
@ -462,7 +463,7 @@ func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() {
|
||||
|
||||
pipeline, err := newPipeline(searchWithRequeryPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
pipeline.AddNodes(task, endNode)
|
||||
|
||||
results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{
|
||||
genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false),
|
||||
@ -484,6 +485,7 @@ func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() {
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
s.Equal(int64(200), storageCost.ScannedRemoteBytes)
|
||||
s.Equal(int64(400), storageCost.ScannedTotalBytes)
|
||||
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
|
||||
@ -535,7 +537,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
|
||||
|
||||
pipeline, err := newPipeline(searchWithRerankPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
pipeline.AddNodes(task, endNode)
|
||||
|
||||
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
|
||||
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}, segcore.StorageCost{})
|
||||
@ -555,6 +557,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
|
||||
@ -619,7 +622,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
|
||||
|
||||
pipeline, err := newPipeline(searchWithRerankRequeryPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
pipeline.AddNodes(task, endNode)
|
||||
|
||||
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
|
||||
results, storageCost, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults}, segcore.StorageCost{})
|
||||
@ -641,6 +644,7 @@ func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
s.Equal(int64(0), storageCost.ScannedRemoteBytes)
|
||||
s.Equal(int64(0), storageCost.ScannedTotalBytes)
|
||||
s.Equal(int64(2*10), results.GetResults().AllSearchCount)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchPipe() {
|
||||
@ -653,7 +657,7 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() {
|
||||
|
||||
pipeline, err := newPipeline(hybridSearchPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
pipeline.AddNodes(task, endNode)
|
||||
|
||||
f1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
f2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
@ -672,6 +676,7 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() {
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.Equal(int64(900), storageCost.ScannedRemoteBytes)
|
||||
s.Equal(int64(2000), storageCost.ScannedTotalBytes)
|
||||
s.Equal(int64(2*2*10), results.GetResults().AllSearchCount)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
|
||||
@ -699,7 +704,7 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
|
||||
translatedOutputFields: []string{"intField", "floatField", "structArrayField", "structVectorField"},
|
||||
}
|
||||
|
||||
op, err := newFilterFieldOperator(task, nil)
|
||||
op, err := newEndOperator(task, nil)
|
||||
s.NoError(err)
|
||||
|
||||
// Create mock search results with fields including struct array fields
|
||||
@ -714,7 +719,7 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
|
||||
},
|
||||
}
|
||||
|
||||
results, err := op.run(context.Background(), s.span, searchResults)
|
||||
results, err := op.run(context.Background(), s.span, searchResults, []*milvuspb.SearchResults{{Results: &schemapb.SearchResultData{AllSearchCount: 0}}})
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
|
||||
@ -766,8 +771,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
|
||||
|
||||
pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
|
||||
pipeline.AddNodes(task, endNode)
|
||||
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{})
|
||||
@ -787,6 +791,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
s.Equal(int64(2*2*10), results.GetResults().AllSearchCount)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||
@ -808,7 +813,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||
|
||||
pipeline, err := newPipeline(hybridSearchWithRequeryPipe, task)
|
||||
s.NoError(err)
|
||||
pipeline.AddNodes(task, filterFieldNode)
|
||||
pipeline.AddNodes(task, endNode)
|
||||
|
||||
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
@ -829,6 +834,7 @@ func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
s.Equal(int64(2*2*10), results.GetResults().AllSearchCount)
|
||||
}
|
||||
|
||||
func getHybridSearchTask(collName string, data [][]string, outputFields []string) *searchTask {
|
||||
|
||||
@ -4760,6 +4760,7 @@ func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fiel
|
||||
testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk)),
|
||||
testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, int(nq*topk)),
|
||||
},
|
||||
AllSearchCount: nq * topk,
|
||||
}
|
||||
resultData.FieldsData[0].FieldId = fieldId
|
||||
sliceBlob, _ := proto.Marshal(resultData)
|
||||
|
||||
@ -676,7 +676,7 @@ func TestHybridSearchTextEmbeddingBM25(t *testing.T) {
|
||||
}
|
||||
|
||||
// create collection
|
||||
err := mc.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema))
|
||||
err := mc.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// insert test data with diverse content
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user