From a3a28a4b993193aada96c23de962c5c4d7515c49 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Mon, 20 Oct 2025 14:26:02 +0800 Subject: [PATCH] fix: rerank before requery if reranker didn't use field data (#44942) issue: #44918 --------- Signed-off-by: chyezh --- internal/proxy/search_pipeline.go | 88 ++++++++++++++++++++++++-- internal/proxy/search_pipeline_test.go | 41 ++++++++++++ 2 files changed, 125 insertions(+), 4 deletions(-) diff --git a/internal/proxy/search_pipeline.go b/internal/proxy/search_pipeline.go index 2921ed07d2..6c1d07903f 100644 --- a/internal/proxy/search_pipeline.go +++ b/internal/proxy/search_pipeline.go @@ -472,7 +472,15 @@ func (op *organizeOperator) run(ctx context.Context, span trace.Span, inputs ... defer sp.End() fields := inputs[0].([]*schemapb.FieldData) - idsList := inputs[1].([]*schemapb.IDs) + var idsList []*schemapb.IDs + switch inputs[1].(type) { + case *schemapb.IDs: + idsList = []*schemapb.IDs{inputs[1].(*schemapb.IDs)} + case []*schemapb.IDs: + idsList = inputs[1].([]*schemapb.IDs) + default: + panic(fmt.Sprintf("invalid ids type: %T", inputs[1])) + } if len(fields) == 0 { emptyFields := make([][]*schemapb.FieldData, len(idsList)) return []any{emptyFields}, nil @@ -919,8 +927,8 @@ var hybridSearchPipe = &pipelineDef{ }, } -var hybridSearchWithRequeryPipe = &pipelineDef{ - name: "hybridSearchWithRequery", +var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{ + name: "hybridSearchWithRequeryAndRerankByDataPipe", nodes: []*nodeDef{ { name: "reduce", @@ -1026,6 +1034,69 @@ var hybridSearchWithRequeryPipe = &pipelineDef{ }, } +var hybridSearchWithRequeryPipe = &pipelineDef{ + name: "hybridSearchWithRequeryPipe", + nodes: []*nodeDef{ + { + name: "reduce", + inputs: []string{"input", "storage_cost"}, + outputs: []string{"reduced", "metrics"}, + opName: hybridSearchReduceOp, + }, + { + name: "rerank", + inputs: []string{"reduced", "metrics"}, + outputs: []string{"rank_result"}, + opName: rerankOp, + }, + { + name: "pick_ids", + inputs: []string{"rank_result"}, + outputs: []string{"ids"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + return []any{ + inputs[0].(*milvuspb.SearchResults).Results.Ids, + }, nil + }, + }, + opName: lambdaOp, + }, + { + name: "requery", + inputs: []string{"ids", "storage_cost"}, + outputs: []string{"fields", "storage_cost"}, + opName: requeryOp, + }, + { + name: "organize", + inputs: []string{"fields", "ids"}, + outputs: []string{"organized_fields"}, + opName: organizeOp, + }, + { + name: "result", + inputs: []string{"rank_result", "organized_fields"}, + outputs: []string{"result"}, + params: map[string]any{ + lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].(*milvuspb.SearchResults) + fields := inputs[1].([][]*schemapb.FieldData) + result.Results.FieldsData = fields[0] + return []any{result}, nil + }, + }, + opName: lambdaOp, + }, + { + name: "filter_field", + inputs: []string{"result"}, + outputs: []string{"output"}, + opName: filterFieldOp, + }, + }, +} + func newBuiltInPipeline(t *searchTask) (*pipeline, error) { if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil { return newPipeline(searchPipe, t) @@ -1043,7 +1114,16 @@ func newBuiltInPipeline(t *searchTask) (*pipeline, error) { return newPipeline(hybridSearchPipe, t) } if t.SearchRequest.GetIsAdvanced() && t.needRequery { - return newPipeline(hybridSearchWithRequeryPipe, t) + if len(t.functionScore.GetAllInputFieldIDs()) > 0 { + // When the function score need field data, we need to requery to fetch the field data before rerank. + // The requery will fetch the field data of all search results, + // so there's some memory overhead. + return newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, t) + } else { + // Otherwise, we can rerank and limit the requery size to the limit. + // so the memory overhead is less than the hybridSearchWithRequeryAndRerankByFieldDataPipe. + return newPipeline(hybridSearchWithRequeryPipe, t) + } } return nil, fmt.Errorf("Unsupported pipeline") } diff --git a/internal/proxy/search_pipeline_test.go b/internal/proxy/search_pipeline_test.go index b2c04ec576..05b9a71621 100644 --- a/internal/proxy/search_pipeline_test.go +++ b/internal/proxy/search_pipeline_test.go @@ -643,6 +643,47 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() { } } +func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() { + task := getHybridSearchTask("test_collection", [][]string{ + {"1", "2"}, + {"3", "4"}, + }, + []string{"intField"}, + ) + + f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20) + f1.FieldId = 101 + f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20) + f2.FieldId = 100 + mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{f1, f2}, + }, segcore.StorageCost{}, nil).Build() + defer mocker.UnPatch() + + pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task) + s.NoError(err) + + d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) + d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true) + results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{}) + + s.NoError(err) + s.NotNil(results) + s.NotNil(results.Results) + s.Equal(int64(2), results.Results.NumQueries) + s.Equal(int64(10), results.Results.Topks[0]) + s.Equal(int64(10), results.Results.Topks[1]) + s.NotNil(results.Results.Ids) + s.NotNil(results.Results.Ids.GetIntId()) + s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk + s.NotNil(results.Results.Scores) + s.Len(results.Results.Scores, 20) // 2 queries * 10 topk + s.NotNil(results.Results.FieldsData) + s.Len(results.Results.FieldsData, 1) // One output field + s.Equal("intField", results.Results.FieldsData[0].FieldName) + s.Equal(int64(101), results.Results.FieldsData[0].FieldId) +} + func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { task := getHybridSearchTask("test_collection", [][]string{ {"1", "2"},