fix: rerank before requery if reranker didn't use field data (#44942)

issue: #44918

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-10-20 14:26:02 +08:00 committed by GitHub
parent 05df48fbe4
commit a3a28a4b99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 125 additions and 4 deletions

View File

@ -472,7 +472,15 @@ func (op *organizeOperator) run(ctx context.Context, span trace.Span, inputs ...
defer sp.End() defer sp.End()
fields := inputs[0].([]*schemapb.FieldData) fields := inputs[0].([]*schemapb.FieldData)
idsList := inputs[1].([]*schemapb.IDs) var idsList []*schemapb.IDs
switch inputs[1].(type) {
case *schemapb.IDs:
idsList = []*schemapb.IDs{inputs[1].(*schemapb.IDs)}
case []*schemapb.IDs:
idsList = inputs[1].([]*schemapb.IDs)
default:
panic(fmt.Sprintf("invalid ids type: %T", inputs[1]))
}
if len(fields) == 0 { if len(fields) == 0 {
emptyFields := make([][]*schemapb.FieldData, len(idsList)) emptyFields := make([][]*schemapb.FieldData, len(idsList))
return []any{emptyFields}, nil return []any{emptyFields}, nil
@ -919,8 +927,8 @@ var hybridSearchPipe = &pipelineDef{
}, },
} }
var hybridSearchWithRequeryPipe = &pipelineDef{ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
name: "hybridSearchWithRequery", name: "hybridSearchWithRequeryAndRerankByDataPipe",
nodes: []*nodeDef{ nodes: []*nodeDef{
{ {
name: "reduce", name: "reduce",
@ -1026,6 +1034,69 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
}, },
} }
var hybridSearchWithRequeryPipe = &pipelineDef{
name: "hybridSearchWithRequeryPipe",
nodes: []*nodeDef{
{
name: "reduce",
inputs: []string{"input", "storage_cost"},
outputs: []string{"reduced", "metrics"},
opName: hybridSearchReduceOp,
},
{
name: "rerank",
inputs: []string{"reduced", "metrics"},
outputs: []string{"rank_result"},
opName: rerankOp,
},
{
name: "pick_ids",
inputs: []string{"rank_result"},
outputs: []string{"ids"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
return []any{
inputs[0].(*milvuspb.SearchResults).Results.Ids,
}, nil
},
},
opName: lambdaOp,
},
{
name: "requery",
inputs: []string{"ids", "storage_cost"},
outputs: []string{"fields", "storage_cost"},
opName: requeryOp,
},
{
name: "organize",
inputs: []string{"fields", "ids"},
outputs: []string{"organized_fields"},
opName: organizeOp,
},
{
name: "result",
inputs: []string{"rank_result", "organized_fields"},
outputs: []string{"result"},
params: map[string]any{
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
fields := inputs[1].([][]*schemapb.FieldData)
result.Results.FieldsData = fields[0]
return []any{result}, nil
},
},
opName: lambdaOp,
},
{
name: "filter_field",
inputs: []string{"result"},
outputs: []string{"output"},
opName: filterFieldOp,
},
},
}
func newBuiltInPipeline(t *searchTask) (*pipeline, error) { func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil { if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil {
return newPipeline(searchPipe, t) return newPipeline(searchPipe, t)
@ -1043,7 +1114,16 @@ func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
return newPipeline(hybridSearchPipe, t) return newPipeline(hybridSearchPipe, t)
} }
if t.SearchRequest.GetIsAdvanced() && t.needRequery { if t.SearchRequest.GetIsAdvanced() && t.needRequery {
return newPipeline(hybridSearchWithRequeryPipe, t) if len(t.functionScore.GetAllInputFieldIDs()) > 0 {
// When the function score need field data, we need to requery to fetch the field data before rerank.
// The requery will fetch the field data of all search results,
// so there's some memory overhead.
return newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, t)
} else {
// Otherwise, we can rerank and limit the requery size to the limit.
// so the memory overhead is less than the hybridSearchWithRequeryAndRerankByFieldDataPipe.
return newPipeline(hybridSearchWithRequeryPipe, t)
}
} }
return nil, fmt.Errorf("Unsupported pipeline") return nil, fmt.Errorf("Unsupported pipeline")
} }

View File

@ -643,6 +643,47 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
} }
} }
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
task := getHybridSearchTask("test_collection", [][]string{
{"1", "2"},
{"3", "4"},
},
[]string{"intField"},
)
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
f1.FieldId = 101
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
f2.FieldId = 100
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2},
}, segcore.StorageCost{}, nil).Build()
defer mocker.UnPatch()
pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task)
s.NoError(err)
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.NotNil(results.Results.FieldsData)
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
}
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
task := getHybridSearchTask("test_collection", [][]string{ task := getHybridSearchTask("test_collection", [][]string{
{"1", "2"}, {"1", "2"},