enhance: Add search post pipeline (#43065)

https://github.com/milvus-io/milvus/issues/35856

Signed-off-by: junjiejiangjjj <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-07-21 11:10:52 +08:00 committed by GitHub
parent 21e71f6eb2
commit 77f3a1f213
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1939 additions and 582 deletions

View File

@ -3111,7 +3111,6 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
lb: node.lbPolicy,
enableMaterializedView: node.enableMaterializedView,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
requeryFunc: requeryImpl,
}
log := log.Ctx(ctx).With( // TODO: it might cause some cpu consumption
@ -3354,7 +3353,6 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
node: node,
lb: node.lbPolicy,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
requeryFunc: requeryImpl,
}
log := log.Ctx(ctx).With(

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,754 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"slices"
"testing"
"time"
"github.com/bytedance/mockey"
"github.com/stretchr/testify/suite"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
)
func TestSearchPipeline(t *testing.T) {
suite.Run(t, new(SearchPipelineSuite))
}
type SearchPipelineSuite struct {
suite.Suite
span trace.Span
}
func (s *SearchPipelineSuite) SetupTest() {
_, sp := otel.Tracer("test").Start(context.Background(), "Proxy-Search-PostExecute")
s.span = sp
}
func (s *SearchPipelineSuite) TearDownTest() {
s.span.End()
}
func (s *SearchPipelineSuite) TestSearchReduceOp() {
nq := int64(2)
topk := int64(10)
pk := &schemapb.FieldSchema{
FieldID: 101,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
AutoID: true,
}
data := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, false)
op := searchReduceOperator{
context.Background(),
pk,
nq,
topk,
0,
1,
[]int64{1},
[]*planpb.QueryInfo{{}},
}
_, err := op.run(context.Background(), s.span, []*internalpb.SearchResults{data})
s.NoError(err)
}
func (s *SearchPipelineSuite) TestHybridSearchReduceOp() {
nq := int64(2)
topk := int64(10)
pk := &schemapb.FieldSchema{
FieldID: 101,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
AutoID: true,
}
data1 := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, true)
data1.SubResults[0].ReqIndex = 0
data2 := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, true)
data2.SubResults[0].ReqIndex = 1
subReqs := []*internalpb.SubSearchRequest{
{
Nq: 2,
Topk: 10,
Offset: 0,
},
{
Nq: 2,
Topk: 10,
Offset: 0,
},
}
op := hybridSearchReduceOperator{
context.Background(),
subReqs,
pk,
1,
[]int64{1},
[]*planpb.QueryInfo{{}, {}},
}
_, err := op.run(context.Background(), s.span, []*internalpb.SearchResults{data1, data2})
s.NoError(err)
}
func (s *SearchPipelineSuite) TestRerankOp() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
OutputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
})
s.NoError(err)
nq := int64(2)
topk := int64(10)
offset := int64(0)
reduceOp := searchReduceOperator{
context.Background(),
schema.Fields[0],
nq,
topk,
offset,
1,
[]int64{1},
[]*planpb.QueryInfo{{}},
}
data := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, false)
reduced, err := reduceOp.run(context.Background(), s.span, []*internalpb.SearchResults{data})
s.NoError(err)
op := rerankOperator{
nq: nq,
topK: topk,
offset: offset,
roundDecimal: 10,
functionScore: funcScore,
}
_, err = op.run(context.Background(), s.span, reduced[0], []string{"IP"})
s.NoError(err)
}
func (s *SearchPipelineSuite) TestRequeryOp() {
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
f1.FieldId = 101
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1},
}, nil).Build()
defer mocker.UnPatch()
op := requeryOperator{
traceCtx: context.Background(),
outputFieldNames: []string{"int64"},
}
ids := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2},
},
},
}
_, err := op.run(context.Background(), s.span, ids, []string{"int64"})
s.NoError(err)
}
func (s *SearchPipelineSuite) TestOrganizeOp() {
op := organizeOperator{
traceCtx: context.Background(),
primaryFieldSchema: &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
collectionID: 1,
}
fields := []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "pk",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
},
}, {
Type: schemapb.DataType_Int64,
FieldName: "int64",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
},
},
}
ids := []*schemapb.IDs{
{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 4, 5, 9, 10},
},
},
},
{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{5, 6, 7, 8, 9, 10},
},
},
},
}
ret, err := op.run(context.Background(), s.span, fields, ids)
s.NoError(err)
fmt.Println(ret)
}
func (s *SearchPipelineSuite) TestSearchPipeline() {
collectionName := "test"
task := &searchTask{
ctx: context.Background(),
collectionName: collectionName,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
MetricType: "L2",
Topk: 10,
Nq: 2,
PartitionIDs: []int64{1},
CollectionID: 1,
DbID: 1,
},
schema: &schemaInfo{
CollectionSchema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
},
},
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
},
queryInfos: []*planpb.QueryInfo{{}},
translatedOutputFields: []string{"intField"},
}
pipeline, err := newPipeline(searchPipe, task)
s.NoError(err)
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{
genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false),
})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.NotNil(results.Results.FieldsData)
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
fmt.Println(results)
}
func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() {
collectionName := "test_collection"
task := &searchTask{
ctx: context.Background(),
collectionName: collectionName,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
MetricType: "L2",
Topk: 10,
Nq: 2,
PartitionIDs: []int64{1},
CollectionID: 1,
DbID: 1,
},
schema: &schemaInfo{
CollectionSchema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
},
},
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
},
queryInfos: []*planpb.QueryInfo{{}},
translatedOutputFields: []string{"intField"},
node: nil,
}
// Mock requery operation
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
f1.FieldId = 101
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
f2.FieldId = 100
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2},
}, nil).Build()
defer mocker.UnPatch()
pipeline, err := newPipeline(searchWithRequeryPipe, task)
s.NoError(err)
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{
genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false),
})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.NotNil(results.Results.FieldsData)
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
}
func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"intField"},
OutputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
},
}
funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
})
s.NoError(err)
task := &searchTask{
ctx: context.Background(),
collectionName: "test_collection",
SearchRequest: &internalpb.SearchRequest{
MetricType: "L2",
Topk: 10,
Nq: 2,
PartitionIDs: []int64{1},
CollectionID: 1,
DbID: 1,
},
schema: &schemaInfo{
CollectionSchema: schema,
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
},
queryInfos: []*planpb.QueryInfo{{}},
translatedOutputFields: []string{"intField"},
node: nil,
functionScore: funcScore,
}
pipeline, err := newPipeline(searchWithRerankPipe, task)
s.NoError(err)
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.NotNil(results.Results.FieldsData)
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
}
func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"intField"},
OutputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
},
}
funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
})
s.NoError(err)
task := &searchTask{
ctx: context.Background(),
collectionName: "test_collection",
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
MetricType: "L2",
Topk: 10,
Nq: 2,
PartitionIDs: []int64{1},
CollectionID: 1,
DbID: 1,
},
schema: &schemaInfo{
CollectionSchema: schema,
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
},
queryInfos: []*planpb.QueryInfo{{}},
translatedOutputFields: []string{"intField"},
node: nil,
functionScore: funcScore,
}
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
f1.FieldId = 101
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
f2.FieldId = 100
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2},
}, nil).Build()
defer mocker.UnPatch()
pipeline, err := newPipeline(searchWithRerankRequeryPipe, task)
s.NoError(err)
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.NotNil(results.Results.FieldsData)
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
}
func (s *SearchPipelineSuite) TestHybridSearchPipe() {
task := getHybridSearchTask("test_collection", [][]string{
{"1", "2"},
{"3", "4"},
},
[]string{},
)
pipeline, err := newPipeline(hybridSearchPipe, task)
s.NoError(err)
f1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
f2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{f1, f2})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
}
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
task := getHybridSearchTask("test_collection", [][]string{
{"1", "2"},
{"3", "4"},
},
[]string{"intField"},
)
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
f1.FieldId = 101
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
f2.FieldId = 100
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2},
}, nil).Build()
defer mocker.UnPatch()
pipeline, err := newPipeline(hybridSearchWithRequeryPipe, task)
s.NoError(err)
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2})
s.NoError(err)
s.NotNil(results)
s.NotNil(results.Results)
s.Equal(int64(2), results.Results.NumQueries)
s.Equal(int64(10), results.Results.Topks[0])
s.Equal(int64(10), results.Results.Topks[1])
s.NotNil(results.Results.Ids)
s.NotNil(results.Results.Ids.GetIntId())
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
s.NotNil(results.Results.Scores)
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
s.NotNil(results.Results.FieldsData)
s.Len(results.Results.FieldsData, 1) // One output field
s.Equal("intField", results.Results.FieldsData[0].FieldName)
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
}
func getHybridSearchTask(collName string, data [][]string, outputFields []string) *searchTask {
subReqs := []*milvuspb.SubSearchRequest{}
for _, item := range data {
subReq := &milvuspb.SubSearchRequest{
SearchParams: []*commonpb.KeyValuePair{
{Key: TopKKey, Value: "10"},
},
Nq: int64(len(item)),
}
subReqs = append(subReqs, subReq)
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{},
OutputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "rrf"},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
},
}
funcScore, _ := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
})
task := &searchTask{
ctx: context.Background(),
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
Topk: 10,
Nq: 2,
IsAdvanced: true,
SubReqs: []*internalpb.SubSearchRequest{
{
Topk: 10,
Nq: 2,
},
{
Topk: 10,
Nq: 2,
},
},
},
request: &milvuspb.SearchRequest{
CollectionName: collName,
SubReqs: subReqs,
SearchParams: []*commonpb.KeyValuePair{
{Key: LimitKey, Value: "10"},
},
FunctionScore: &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
},
OutputFields: outputFields,
},
schema: &schemaInfo{
CollectionSchema: schema,
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
},
mixCoord: nil,
tr: timerecord.NewTimeRecorder("test-search"),
rankParams: &rankParams{
limit: 10,
offset: 0,
roundDecimal: 0,
},
queryInfos: []*planpb.QueryInfo{{}, {}},
functionScore: funcScore,
translatedOutputFields: outputFields,
}
return task
}
func (s *SearchPipelineSuite) TestMergeIDsFunc() {
{
ids1 := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 5},
},
},
}
ids2 := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 4, 5, 6},
},
},
}
rets := []*milvuspb.SearchResults{
{
Results: &schemapb.SearchResultData{
Ids: ids1,
},
},
{
Results: &schemapb.SearchResultData{
Ids: ids2,
},
},
}
allIDs, err := mergeIDsFunc(context.Background(), s.span, rets)
s.NoError(err)
sortedIds := allIDs[0].(*schemapb.IDs).GetIntId().GetData()
slices.Sort(sortedIds)
s.Equal(sortedIds, []int64{1, 2, 3, 4, 5, 6})
}
{
ids1 := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "e"},
},
},
}
ids2 := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "c", "d"},
},
},
}
rets := []*milvuspb.SearchResults{
{
Results: &schemapb.SearchResultData{
Ids: ids1,
},
},
}
rets = append(rets, &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: ids2,
},
})
allIDs, err := mergeIDsFunc(context.Background(), s.span, rets)
s.NoError(err)
sortedIds := allIDs[0].(*schemapb.IDs).GetStrId().GetData()
slices.Sort(sortedIds)
s.Equal(sortedIds, []string{"a", "b", "c", "d", "e"})
}
}

View File

@ -5,12 +5,16 @@ import (
"fmt"
"github.com/cockroachdb/errors"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -470,3 +474,105 @@ func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
},
}
}
func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
defer sp.End()
log := log.Ctx(ctx)
// Decode all search results
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
if err != nil {
log.Warn("failed to decode search results", zap.Error(err))
return nil, err
}
if len(validSearchResults) <= 0 {
return fillInEmptyResult(nq), nil
}
// Reduce all search results
log.Debug("proxy search post execute reduce",
zap.Int64("collection", collectionID),
zap.Int64s("partitionIDs", partitionIDs),
zap.Int("number of valid search results", len(validSearchResults)))
var result *milvuspb.SearchResults
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType).
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return nil, err
}
return result, nil
}
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
defer sp.End()
tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0)
for _, partialSearchResult := range searchResults {
if partialSearchResult.SlicedBlob == nil {
continue
}
var partialResultData schemapb.SearchResultData
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
if err != nil {
return nil, err
}
results = append(results, &partialResultData)
}
tr.CtxElapse(ctx, "decodeSearchResults done")
return results, nil
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
}
if data.TopK != topk {
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
}
if len(data.Scores) != pkHitNum {
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
len(data.Scores), pkHitNum)
}
return nil
}
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
var (
subSearchIdx = -1
resultDataIdx int64 = -1
)
maxScore := minFloat32
for i := range cursors {
if cursors[i] >= subSearchResultData[i].Topks[qi] {
continue
}
sIdx := subSearchNqOffset[i][qi] + cursors[i]
sScore := subSearchResultData[i].Scores[sIdx]
// Choose the larger score idx or the smaller pk idx with the same score
if subSearchIdx == -1 || sScore > maxScore {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
} else if sScore == maxScore {
if subSearchIdx == -1 {
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
// by mistake.
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
} else if typeutil.ComparePK(
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
}
}
}
return subSearchIdx, resultDataIdx
}

View File

@ -15,6 +15,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -577,3 +578,11 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se
}
return ret
}
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
metricType := ""
if len(toReduceResults) >= 1 {
metricType = toReduceResults[0].GetMetricType()
}
return metricType
}

View File

@ -10,7 +10,6 @@ import (
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
@ -22,7 +21,6 @@ import (
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
@ -96,9 +94,6 @@ type searchTask struct {
// we always remove pk field from output fields, as search result already contains pk field.
// if the user explicitly set pk field in output fields, we add it back to the result.
userRequestedPkFieldExplicitly bool
// To facilitate writing unit tests
requeryFunc func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error)
}
func (t *searchTask) CanSkipAllocTimestamp() bool {
@ -488,7 +483,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
// used for requery
if t.partitionKeyMode {
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
}
@ -496,57 +490,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
return nil
}
func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
// Collecting the results of a subsearch
// [[shard1, shard2, ...],[shard1, shard2, ...]]
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for _, searchResult := range toReduceResults {
// if get a non-advanced result, skip all
if !searchResult.GetIsAdvanced() {
continue
}
for _, subResult := range searchResult.GetSubResults() {
// swallow copy
internalResults := &internalpb.SearchResults{
MetricType: subResult.GetMetricType(),
NumQueries: subResult.GetNumQueries(),
TopK: subResult.GetTopK(),
SlicedBlob: subResult.GetSlicedBlob(),
SlicedNumCount: subResult.GetSlicedNumCount(),
SlicedOffset: subResult.GetSlicedOffset(),
IsAdvanced: false,
}
reqIndex := subResult.GetReqIndex()
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
}
}
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
searchMetrics := []string{}
for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index]
// Since the metrictype in the request may be empty, it can only be obtained from the result
subMetricType := getMetricType(internalResults)
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
if err != nil {
return err
}
searchMetrics = append(searchMetrics, subMetricType)
multipleMilvusResults[index] = result
}
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil {
return err
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
})
t.fillResult()
return nil
}
func (t *searchTask) fillResult() {
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
resultSizeInsufficient := false
@ -558,111 +501,6 @@ func (t *searchTask) fillResult() {
}
t.resultSizeInsufficient = resultSizeInsufficient
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()
}
func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
uniqueIDs := &schemapb.IDs{}
int64IDs := typeutil.NewSet[int64]()
strIDs := typeutil.NewSet[string]()
for _, ids := range idsList {
if ids == nil {
continue
}
switch ids.GetIdField().(type) {
case *schemapb.IDs_IntId:
int64IDs.Insert(ids.GetIntId().GetData()...)
case *schemapb.IDs_StrId:
strIDs.Insert(ids.GetStrId().GetData()...)
}
}
if int64IDs.Len() > 0 {
uniqueIDs.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: int64IDs.Collect(),
},
}
return uniqueIDs, int64IDs.Len()
}
if strIDs.Len() > 0 {
uniqueIDs.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: strIDs.Collect(),
},
}
return uniqueIDs, strIDs.Len()
}
return nil, 0
}
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error {
var err error
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, groupScorerStr, searchMetrics,
)
return t.functionScore.Process(ctx, params, results)
}
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
// At this time, outputFields and rerank input_fields will be recalled.
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
if t.needRequery {
idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) {
return m.Results.Ids, true
})
allIDs, count := mergeIDs(idsList)
if count == 0 {
t.result = &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: t.Nq,
TopK: t.rankParams.limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
return nil
}
allNames := typeutil.NewSet[string](t.translatedOutputFields...)
allNames.Insert(t.functionScore.GetAllInputFieldNames()...)
queryResult, err := t.requeryFunc(t, span, allIDs, allNames.Collect())
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), idsList)
if err != nil {
return err
}
for i := 0; i < len(multipleMilvusResults); i++ {
multipleMilvusResults[i].Results.FieldsData = fields[i]
}
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
return err
}
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
return err
} else {
t.result.Results.FieldsData = fields[0]
}
} else {
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
return err
}
}
return nil
}
func (t *searchTask) initSearchRequest(ctx context.Context) error {
@ -677,19 +515,13 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
}
if t.request.FunctionScore != nil {
// TODO: When rerank is configured, range search is also supported
if isIterator {
return merr.WrapErrParameterInvalidMsg("Range search do not support rerank")
}
if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil {
log.Warn("Failed to create function score", zap.Error(err))
return err
}
// TODO: When rerank is configured, grouping search is also supported
if !t.functionScore.IsSupportGroup() && queryInfo.GetGroupByFieldId() > 0 {
return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search")
return merr.WrapErrParameterInvalidMsg("Rerank %s does not support grouping search", t.functionScore.RerankName())
}
}
@ -773,57 +605,6 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
return nil
}
func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
metricType := getMetricType(toReduceResults)
result, err := t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
if err != nil {
return err
}
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
{
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, groupScorerStr, []string{metricType})
// rank only returns id and score
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err
}
}
if !t.needRequery {
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})
if err != nil {
return err
}
t.result.Results.FieldsData = fields[0]
}
} else {
t.result = result
}
t.fillResult()
if t.needRequery {
if t.requeryFunc == nil {
t.requeryFunc = requeryImpl
}
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids})
if err != nil {
return err
}
t.result.Results.FieldsData = fields[0]
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
})
return nil
}
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
if err != nil || len(annsFieldName) == 0 {
@ -919,50 +700,6 @@ func (t *searchTask) Execute(ctx context.Context) error {
return nil
}
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
metricType := ""
if len(toReduceResults) >= 1 {
metricType = toReduceResults[0].GetMetricType()
}
return metricType
}
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
defer sp.End()
log := log.Ctx(ctx)
// Decode all search results
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
if err != nil {
log.Warn("failed to decode search results", zap.Error(err))
return nil, err
}
if len(validSearchResults) <= 0 {
return fillInEmptyResult(nq), nil
}
// Reduce all search results
log.Debug("proxy search post execute reduce",
zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int("number of valid search results", len(validSearchResults)))
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return nil, err
}
var result *milvuspb.SearchResults
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()).
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return nil, err
}
return result, nil
}
// find the last bound based on reduced results and metric type
// only support nq == 1, for search iterator v2
func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 {
@ -1017,15 +754,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
t.isTopkReduce = isTopkReduce
t.isRecallEvaluation = isRecallEvaluation
if t.SearchRequest.GetIsAdvanced() {
err = t.advancedPostProcess(ctx, sp, toReduceResults)
} else {
err = t.searchPostProcess(ctx, sp, toReduceResults)
}
// call pipeline
pipeline, err := newBuiltInPipeline(t)
if err != nil {
log.Warn("Faild to create post process pipeline")
return err
}
if t.result, err = pipeline.Run(ctx, sp, toReduceResults); err != nil {
return err
}
t.fillResult()
t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName()
@ -1143,161 +881,6 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
//return int64(sizePerRecord) * nq * topK, nil
}
func requeryImpl(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
Timestamp: t.BeginTs(),
},
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
Expr: "",
OutputFields: outputFields,
PartitionNames: t.request.GetPartitionNames(),
UseDefaultConsistency: false,
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
}
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
if err != nil {
return nil, err
}
plan := planparserv2.CreateRequeryPlan(pkField, ids)
channelsMvcc := make(map[string]Timestamp)
for k, v := range t.queryChannelsTs {
channelsMvcc[k] = v
}
qt := &queryTask{
ctx: t.ctx,
Condition: NewTaskCondition(t.ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
PartitionIDs: t.GetPartitionIDs(), // use search partitionIDs
ConsistencyLevel: t.ConsistencyLevel,
},
request: queryReq,
plan: plan,
mixCoord: t.node.(*Proxy).mixCoord,
lb: t.node.(*Proxy).lbPolicy,
channelsMvcc: channelsMvcc,
fastSkip: true,
reQuery: true,
}
queryResult, err := t.node.(*Proxy).query(t.ctx, qt, span)
if err != nil {
return nil, err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return nil, merr.Error(queryResult.GetStatus())
}
return queryResult, err
}
func isEmpty(ids *schemapb.IDs) bool {
if ids == nil {
return true
}
if ids.GetIntId() != nil && len(ids.GetIntId().Data) != 0 {
return false
}
if ids.GetStrId() != nil && len(ids.GetStrId().Data) != 0 {
return false
}
return true
}
func (t *searchTask) reorganizeRequeryResults(ctx context.Context, fields []*schemapb.FieldData, idsList []*schemapb.IDs) ([][]*schemapb.FieldData, error) {
_, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults")
defer sp.End()
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
if err != nil {
return nil, err
}
pkFieldData, err := typeutil.GetPrimaryFieldData(fields, pkField)
if err != nil {
return nil, err
}
offsets := make(map[any]int)
pkItr := typeutil.GetDataIterator(pkFieldData)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := pkItr(i)
offsets[pk] = i
}
allFieldData := make([][]*schemapb.FieldData, len(idsList))
for idx, ids := range idsList {
if isEmpty(ids) {
emptyFields := []*schemapb.FieldData{}
for _, field := range fields {
emptyFields = append(emptyFields, &schemapb.FieldData{
Type: field.Type,
FieldName: field.FieldName,
FieldId: field.FieldId,
IsDynamic: field.IsDynamic,
})
}
allFieldData[idx] = emptyFields
continue
}
if fieldData, err := t.pickFieldData(ids, offsets, fields); err != nil {
return nil, err
} else {
allFieldData[idx] = fieldData
}
}
return allFieldData, nil
}
// pick field data from query results
func (t *searchTask) pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, fields []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
// 3 2 5 4 1 (query ids)
// ||
// || (query)
// \/
// 4 3 5 1 2 (result ids)
// v4 v3 v5 v1 v2 (result vectors)
// ||
// || (reorganize)
// \/
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
fieldsData := make([]*schemapb.FieldData, len(fields))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := pkOffset[id]; !ok {
return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(pkOffset), t.GetCollectionID()))
}
typeutil.AppendFieldData(fieldsData, fields, int64(pkOffset[id]))
}
return fieldsData, nil
}
func (t *searchTask) fillInFieldInfo() {
for _, retField := range t.result.Results.FieldsData {
for _, schemaField := range t.schema.Fields {
if retField != nil && retField.FieldId == schemaField.FieldID {
retField.FieldName = schemaField.Name
retField.Type = schemaField.DataType
retField.IsDynamic = schemaField.IsDynamic
}
}
}
}
func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.SearchResults, error) {
select {
case <-t.TraceCtx().Done():
@ -1316,77 +899,6 @@ func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.Se
}
}
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
defer sp.End()
tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0)
for _, partialSearchResult := range searchResults {
if partialSearchResult.SlicedBlob == nil {
continue
}
var partialResultData schemapb.SearchResultData
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
if err != nil {
return nil, err
}
results = append(results, &partialResultData)
}
tr.CtxElapse(ctx, "decodeSearchResults done")
return results, nil
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
}
if data.TopK != topk {
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
}
if len(data.Scores) != pkHitNum {
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
len(data.Scores), pkHitNum)
}
return nil
}
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
var (
subSearchIdx = -1
resultDataIdx int64 = -1
)
maxScore := minFloat32
for i := range cursors {
if cursors[i] >= subSearchResultData[i].Topks[qi] {
continue
}
sIdx := subSearchNqOffset[i][qi] + cursors[i]
sScore := subSearchResultData[i].Scores[sIdx]
// Choose the larger score idx or the smaller pk idx with the same score
if subSearchIdx == -1 || sScore > maxScore {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
} else if sScore == maxScore {
if subSearchIdx == -1 {
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
// by mistake.
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
} else if typeutil.ComparePK(
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
}
}
}
return subSearchIdx, resultDataIdx
}
func (t *searchTask) TraceCtx() context.Context {
return t.ctx
}

View File

@ -19,12 +19,12 @@ import (
"context"
"fmt"
"math"
"slices"
"strconv"
"strings"
"testing"
"time"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/samber/lo"
@ -32,7 +32,6 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
@ -394,12 +393,10 @@ func TestSearchTask_PostExecute(t *testing.T) {
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
f3.FieldId = fieldNameId[testInt64Field]
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3},
PrimaryFieldName: testInt64Field,
}, nil
}
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3},
}, nil).Build()
defer mocker.UnPatch()
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
@ -436,12 +433,10 @@ func TestSearchTask_PostExecute(t *testing.T) {
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
f3.FieldId = fieldNameId[testInt64Field]
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3},
PrimaryFieldName: testInt64Field,
}, nil
}
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3},
}, nil).Build()
defer mocker.UnPatch()
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
@ -478,12 +473,11 @@ func TestSearchTask_PostExecute(t *testing.T) {
f3.FieldId = fieldNameId[testInt64Field]
f4 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
f4.FieldId = fieldNameId[testFloatField]
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3, f4},
PrimaryFieldName: testInt64Field,
}, nil
}
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3, f4},
}, nil).Build()
defer mocker.UnPatch()
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
@ -505,53 +499,6 @@ func TestSearchTask_PostExecute(t *testing.T) {
}
}
})
t.Run("Test mergeIDs function", func(t *testing.T) {
{
ids1 := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 5},
},
},
}
ids2 := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 4, 5, 6},
},
},
}
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
assert.Equal(t, count, 6)
sortedIds := allIDs.GetIntId().GetData()
slices.Sort(sortedIds)
assert.Equal(t, sortedIds, []int64{1, 2, 3, 4, 5, 6})
}
{
ids1 := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "e"},
},
},
}
ids2 := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "c", "d"},
},
},
}
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
assert.Equal(t, count, 5)
sortedIds := allIDs.GetStrId().GetData()
slices.Sort(sortedIds)
assert.Equal(t, sortedIds, []string{"a", "b", "c", "d", "e"})
}
})
}
func createCollWithFields(t *testing.T, collName string, rc types.MixCoordClient) (*schemapb.CollectionSchema, map[string]int64) {
@ -3743,9 +3690,10 @@ func TestSearchTask_Requery(t *testing.T) {
tr: timerecord.NewTimeRecorder("search"),
node: node,
translatedOutputFields: outputFields,
requeryFunc: requeryImpl,
}
queryResult, err := qt.requeryFunc(qt, nil, qt.result.Results.Ids, outputFields)
op, err := newRequeryOperator(qt, nil)
assert.NoError(t, err)
queryResult, err := op.(*requeryOperator).requery(ctx, nil, qt.result.Results.Ids, outputFields)
assert.NoError(t, err)
assert.Len(t, queryResult.FieldsData, 2)
for _, field := range qt.result.Results.FieldsData {
@ -3768,14 +3716,13 @@ func TestSearchTask_Requery(t *testing.T) {
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
requeryFunc: requeryImpl,
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
_, err := newRequeryOperator(qt, nil)
t.Logf("err = %s", err)
assert.Error(t, err)
})
@ -3804,13 +3751,14 @@ func TestSearchTask_Requery(t *testing.T) {
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
requeryFunc: requeryImpl,
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
op, err := newRequeryOperator(qt, nil)
assert.NoError(t, err)
_, err = op.(*requeryOperator).requery(ctx, nil, &schemapb.IDs{}, []string{})
t.Logf("err = %s", err)
assert.Error(t, err)
})

View File

@ -34,7 +34,7 @@ const (
// The current version only supports plain text, and cipher text will be supported later.
type Credentials struct {
// key formats:
// {credentialName}.api_key
// {credentialName}.apikey
// {credentialName}.access_key_id
// {credentialName}.secret_access_key
// {credentialName}.credential_json

View File

@ -234,7 +234,7 @@ func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchPa
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
Topks: make([]int64, searchParams.nq),
},
}, nil
}
@ -280,3 +280,10 @@ func (fScore *FunctionScore) IsSupportGroup() bool {
}
return fScore.reranker.IsSupportGroup()
}
func (fScore *FunctionScore) RerankName() string {
if fScore == nil {
return ""
}
return fScore.reranker.GetRankName()
}

View File

@ -322,7 +322,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() {
rankParams := []*commonpb.KeyValuePair{}
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.NoError(err)
s.Equal(f.reranker.GetRankName(), rrfName)
s.Equal(f.RerankName(), rrfName)
}
{
rankParams := []*commonpb.KeyValuePair{

View File

@ -289,8 +289,14 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
self.create_collection(client, collection_name, default_dim)
# 2. hybrid search
# 2. insert
rng = np.random.default_rng(seed=19530)
rows = [
{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
default_vector_field_name+"new": list(rng.random((1, default_dim))[0]),
default_string_field_name: str(i)} for i in range(default_nb)]
self.insert(client, collection_name, rows)
# 2. hybrid search
vectors_to_search = rng.random((1, default_dim))
sub_search1 = AnnSearchRequest(vectors_to_search, "vector", {"level": 1}, 20, expr="id<100")
ranker = WeightedRanker(0.2, 0.8)