mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: rerank before requery if reranker didn't use field data (#44942)
issue: #44918 --------- Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
05df48fbe4
commit
a3a28a4b99
@ -472,7 +472,15 @@ func (op *organizeOperator) run(ctx context.Context, span trace.Span, inputs ...
|
|||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
fields := inputs[0].([]*schemapb.FieldData)
|
fields := inputs[0].([]*schemapb.FieldData)
|
||||||
idsList := inputs[1].([]*schemapb.IDs)
|
var idsList []*schemapb.IDs
|
||||||
|
switch inputs[1].(type) {
|
||||||
|
case *schemapb.IDs:
|
||||||
|
idsList = []*schemapb.IDs{inputs[1].(*schemapb.IDs)}
|
||||||
|
case []*schemapb.IDs:
|
||||||
|
idsList = inputs[1].([]*schemapb.IDs)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("invalid ids type: %T", inputs[1]))
|
||||||
|
}
|
||||||
if len(fields) == 0 {
|
if len(fields) == 0 {
|
||||||
emptyFields := make([][]*schemapb.FieldData, len(idsList))
|
emptyFields := make([][]*schemapb.FieldData, len(idsList))
|
||||||
return []any{emptyFields}, nil
|
return []any{emptyFields}, nil
|
||||||
@ -919,8 +927,8 @@ var hybridSearchPipe = &pipelineDef{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var hybridSearchWithRequeryPipe = &pipelineDef{
|
var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{
|
||||||
name: "hybridSearchWithRequery",
|
name: "hybridSearchWithRequeryAndRerankByDataPipe",
|
||||||
nodes: []*nodeDef{
|
nodes: []*nodeDef{
|
||||||
{
|
{
|
||||||
name: "reduce",
|
name: "reduce",
|
||||||
@ -1026,6 +1034,69 @@ var hybridSearchWithRequeryPipe = &pipelineDef{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var hybridSearchWithRequeryPipe = &pipelineDef{
|
||||||
|
name: "hybridSearchWithRequeryPipe",
|
||||||
|
nodes: []*nodeDef{
|
||||||
|
{
|
||||||
|
name: "reduce",
|
||||||
|
inputs: []string{"input", "storage_cost"},
|
||||||
|
outputs: []string{"reduced", "metrics"},
|
||||||
|
opName: hybridSearchReduceOp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rerank",
|
||||||
|
inputs: []string{"reduced", "metrics"},
|
||||||
|
outputs: []string{"rank_result"},
|
||||||
|
opName: rerankOp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pick_ids",
|
||||||
|
inputs: []string{"rank_result"},
|
||||||
|
outputs: []string{"ids"},
|
||||||
|
params: map[string]any{
|
||||||
|
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||||
|
return []any{
|
||||||
|
inputs[0].(*milvuspb.SearchResults).Results.Ids,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
opName: lambdaOp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "requery",
|
||||||
|
inputs: []string{"ids", "storage_cost"},
|
||||||
|
outputs: []string{"fields", "storage_cost"},
|
||||||
|
opName: requeryOp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "organize",
|
||||||
|
inputs: []string{"fields", "ids"},
|
||||||
|
outputs: []string{"organized_fields"},
|
||||||
|
opName: organizeOp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "result",
|
||||||
|
inputs: []string{"rank_result", "organized_fields"},
|
||||||
|
outputs: []string{"result"},
|
||||||
|
params: map[string]any{
|
||||||
|
lambdaParamKey: func(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||||
|
result := inputs[0].(*milvuspb.SearchResults)
|
||||||
|
fields := inputs[1].([][]*schemapb.FieldData)
|
||||||
|
result.Results.FieldsData = fields[0]
|
||||||
|
return []any{result}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
opName: lambdaOp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter_field",
|
||||||
|
inputs: []string{"result"},
|
||||||
|
outputs: []string{"output"},
|
||||||
|
opName: filterFieldOp,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
|
func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
|
||||||
if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil {
|
if !t.SearchRequest.GetIsAdvanced() && !t.needRequery && t.functionScore == nil {
|
||||||
return newPipeline(searchPipe, t)
|
return newPipeline(searchPipe, t)
|
||||||
@ -1043,7 +1114,16 @@ func newBuiltInPipeline(t *searchTask) (*pipeline, error) {
|
|||||||
return newPipeline(hybridSearchPipe, t)
|
return newPipeline(hybridSearchPipe, t)
|
||||||
}
|
}
|
||||||
if t.SearchRequest.GetIsAdvanced() && t.needRequery {
|
if t.SearchRequest.GetIsAdvanced() && t.needRequery {
|
||||||
return newPipeline(hybridSearchWithRequeryPipe, t)
|
if len(t.functionScore.GetAllInputFieldIDs()) > 0 {
|
||||||
|
// When the function score need field data, we need to requery to fetch the field data before rerank.
|
||||||
|
// The requery will fetch the field data of all search results,
|
||||||
|
// so there's some memory overhead.
|
||||||
|
return newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, t)
|
||||||
|
} else {
|
||||||
|
// Otherwise, we can rerank and limit the requery size to the limit.
|
||||||
|
// so the memory overhead is less than the hybridSearchWithRequeryAndRerankByFieldDataPipe.
|
||||||
|
return newPipeline(hybridSearchWithRequeryPipe, t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("Unsupported pipeline")
|
return nil, fmt.Errorf("Unsupported pipeline")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -643,6 +643,47 @@ func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryAndRerankByDataPipe() {
|
||||||
|
task := getHybridSearchTask("test_collection", [][]string{
|
||||||
|
{"1", "2"},
|
||||||
|
{"3", "4"},
|
||||||
|
},
|
||||||
|
[]string{"intField"},
|
||||||
|
)
|
||||||
|
|
||||||
|
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
|
||||||
|
f1.FieldId = 101
|
||||||
|
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
|
||||||
|
f2.FieldId = 100
|
||||||
|
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||||
|
FieldsData: []*schemapb.FieldData{f1, f2},
|
||||||
|
}, segcore.StorageCost{}, nil).Build()
|
||||||
|
defer mocker.UnPatch()
|
||||||
|
|
||||||
|
pipeline, err := newPipeline(hybridSearchWithRequeryAndRerankByFieldDataPipe, task)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||||
|
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||||
|
results, _, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2}, segcore.StorageCost{})
|
||||||
|
|
||||||
|
s.NoError(err)
|
||||||
|
s.NotNil(results)
|
||||||
|
s.NotNil(results.Results)
|
||||||
|
s.Equal(int64(2), results.Results.NumQueries)
|
||||||
|
s.Equal(int64(10), results.Results.Topks[0])
|
||||||
|
s.Equal(int64(10), results.Results.Topks[1])
|
||||||
|
s.NotNil(results.Results.Ids)
|
||||||
|
s.NotNil(results.Results.Ids.GetIntId())
|
||||||
|
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||||
|
s.NotNil(results.Results.Scores)
|
||||||
|
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||||
|
s.NotNil(results.Results.FieldsData)
|
||||||
|
s.Len(results.Results.FieldsData, 1) // One output field
|
||||||
|
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||||
|
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||||
task := getHybridSearchTask("test_collection", [][]string{
|
task := getHybridSearchTask("test_collection", [][]string{
|
||||||
{"1", "2"},
|
{"1", "2"},
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user