mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: rerank before requery if reranker didn't use field data (#44942)
issue: #44918 --------- Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
05df48fbe4
commit
a3a28a4b99
@ -472,7 +472,15 @@ func (op *organizeOperator) run(ctx context.Context, span trace.Span, inputs ...
|
||||
defer sp.End()
|
||||
|
||||
fields := inputs[0].([]*schemapb.FieldData)
|
||||
idsList := inputs[1].([]*schemapb.IDs)
|
||||
var idsList []*schemapb.IDs
|
||||
switch inputs[1].(type) {
|
||||
case *schemapb.IDs:
|
||||
idsList = []*schemapb.IDs{inputs[1].(*schemapb.IDs)}
|
||||
case []*schemapb.IDs:
|
||||
idsList = inputs[1].([]*schemapb.IDs)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid ids type: %T", inputs[1]))
|
||||
}
|
||||
if len(fields) == 0 {
|
||||
emptyFields := make([][]*schemapb.FieldData, len(idsList))
|
||||
return []any{emptyFields}, nil
|
||||
@ -919,8 +927,8 @@ var hybridSearchPipe = &pipelineDef{
|
||||
},
|
||||
}
|
||||
|
||||
var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||
name: "hybridSearchWithRequery",
|
||||
var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
|
||||
name: "hybridSearchWithRequeryAndRerankByDataPipe",
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
@ -1026,6 +1034,69 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||
},
|
||||
}
|
||||
|
||||
var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||
name: "hybridSearchWithRequeryPipe",
|
||||
nodes: []*nodeDef{
|
||||
{
|
||||
name: "reduce",
|
||||
inputs: []string{"input", "storage_cost"},
|
||||
outputs: []string{"reduced", "metrics"},
|
||||
opName: hybridSearchReduceOp,
|
||||
},
|
||||
{
|
||||
name: "rerank",
|
||||
inputs: []string{"reduced", "metrics"},
|
||||
outputs: []string{"rank_result"},
|
||||
opName: rerankOp,
|
||||
},
|
||||
{
|
||||
name: "pick_ids",
|
||||
inputs: []string{"rank_result"},
|
||||
outputs: []string{"ids"},
|
||||
params: map[string]any{
|
||||
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||
return []any{
|
||||
inputs[0].(*milvuspb.SearchResults).Results.Ids,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
opName: lambdaOp,
|
||||
},
|
||||
{
|
||||
name: "requery",
|
||||
inputs: []string{"ids", "storage_cost"},
|
||||
outputs: []string{"fields", "storage_cost"},
|
||||
opName: requeryOp,
|
||||
},
|
||||
{
|
||||
name: "organize",
|
||||
inputs: []string{"fields", "ids"},
|
||||
outputs: []string{"organized_fields"},
|
||||
opName: organizeOp,
|
||||
},
|
||||
{
|
||||
name: "result",
|
||||
inputs: []string{"rank_result", "organized_fields"},
|
||||
outputs: []string{"result"},
|
||||
params: map[string]any{
|
||||
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||
result := inputs[0].(*milvuspb.SearchResults)
|
||||
fields := inputs[1].([][]*schemapb.FieldData)
|
||||
result.Results.FieldsData = fields[0]
|
||||
return []any{result}, nil
|
||||
},
|
||||
},
|
||||
opName: lambdaOp,
|
||||
},
|
||||
{
|
||||
name: "filter_field",
|
||||
inputs: []string{"result"},
|
||||
outputs: []string{"output"},
|
||||
opName: filterFieldOp,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
|
||||
if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil {
|
||||
return newPipeline(searchPipe, t)
|
||||
@ -1043,7 +1114,16 @@ func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
|
||||
return newPipeline(hybridSearchPipe, t)
|
||||
}
|
||||
if t.SearchRequest.GetIsAdvanced() && t.needRequery {
|
||||
return newPipeline(hybridSearchWithRequeryPipe, t)
|
||||
if len(t.functionScore.GetAllInputFieldIDs()) > 0 {
|
||||
// When the function score need field data, we need to requery to fetch the field data before rerank.
|
||||
// The requery will fetch the field data of all search results,
|
||||
// so there's some memory overhead.
|
||||
return newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, t)
|
||||
} else {
|
||||
// Otherwise, we can rerank and limit the requery size to the limit.
|
||||
// so the memory overhead is less than the hybridSearchWithRequeryAndRerankByFieldDataPipe.
|
||||
return newPipeline(hybridSearchWithRequeryPipe, t)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Unsupported pipeline")
|
||||
}
|
||||
|
||||
@ -643,6 +643,47 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
|
||||
task := getHybridSearchTask("test_collection", [][]string{
|
||||
{"1", "2"},
|
||||
{"3", "4"},
|
||||
},
|
||||
[]string{"intField"},
|
||||
)
|
||||
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
|
||||
f1.FieldId = 101
|
||||
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
|
||||
f2.FieldId = 100
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2},
|
||||
}, segcore.StorageCost{}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task)
|
||||
s.NoError(err)
|
||||
|
||||
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{})
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.FieldsData)
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||
task := getHybridSearchTask("test_collection", [][]string{
|
||||
{"1", "2"},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user