mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: Add search post pipeline (#43065)
https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjiejiangjjj <junjie.jiang@zilliz.com>
This commit is contained in:
parent
21e71f6eb2
commit
77f3a1f213
@ -3111,7 +3111,6 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
lb: node.lbPolicy,
|
||||
enableMaterializedView: node.enableMaterializedView,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With( // TODO: it might cause some cpu consumption
|
||||
@ -3354,7 +3353,6 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
|
||||
1017
internal/proxy/search_pipeline.go
Normal file
1017
internal/proxy/search_pipeline.go
Normal file
File diff suppressed because it is too large
Load Diff
754
internal/proxy/search_pipeline_test.go
Normal file
754
internal/proxy/search_pipeline_test.go
Normal file
@ -0,0 +1,754 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
||||
)
|
||||
|
||||
func TestSearchPipeline(t *testing.T) {
|
||||
suite.Run(t, new(SearchPipelineSuite))
|
||||
}
|
||||
|
||||
type SearchPipelineSuite struct {
|
||||
suite.Suite
|
||||
span trace.Span
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) SetupTest() {
|
||||
_, sp := otel.Tracer("test").Start(context.Background(), "Proxy-Search-PostExecute")
|
||||
s.span = sp
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TearDownTest() {
|
||||
s.span.End()
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchReduceOp() {
|
||||
nq := int64(2)
|
||||
topk := int64(10)
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "pk",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
}
|
||||
data := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, false)
|
||||
op := searchReduceOperator{
|
||||
context.Background(),
|
||||
pk,
|
||||
nq,
|
||||
topk,
|
||||
0,
|
||||
1,
|
||||
[]int64{1},
|
||||
[]*planpb.QueryInfo{{}},
|
||||
}
|
||||
_, err := op.run(context.Background(), s.span, []*internalpb.SearchResults{data})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchReduceOp() {
|
||||
nq := int64(2)
|
||||
topk := int64(10)
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "pk",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
AutoID: true,
|
||||
}
|
||||
data1 := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, true)
|
||||
data1.SubResults[0].ReqIndex = 0
|
||||
data2 := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, true)
|
||||
data2.SubResults[0].ReqIndex = 1
|
||||
|
||||
subReqs := []*internalpb.SubSearchRequest{
|
||||
{
|
||||
Nq: 2,
|
||||
Topk: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
{
|
||||
Nq: 2,
|
||||
Topk: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
}
|
||||
|
||||
op := hybridSearchReduceOperator{
|
||||
context.Background(),
|
||||
subReqs,
|
||||
pk,
|
||||
1,
|
||||
[]int64{1},
|
||||
[]*planpb.QueryInfo{{}, {}},
|
||||
}
|
||||
_, err := op.run(context.Background(), s.span, []*internalpb.SearchResults{data1, data2})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestRerankOp() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
OutputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
nq := int64(2)
|
||||
topk := int64(10)
|
||||
offset := int64(0)
|
||||
|
||||
reduceOp := searchReduceOperator{
|
||||
context.Background(),
|
||||
schema.Fields[0],
|
||||
nq,
|
||||
topk,
|
||||
offset,
|
||||
1,
|
||||
[]int64{1},
|
||||
[]*planpb.QueryInfo{{}},
|
||||
}
|
||||
|
||||
data := genTestSearchResultData(nq, topk, schemapb.DataType_Int64, "intField", 102, false)
|
||||
reduced, err := reduceOp.run(context.Background(), s.span, []*internalpb.SearchResults{data})
|
||||
s.NoError(err)
|
||||
|
||||
op := rerankOperator{
|
||||
nq: nq,
|
||||
topK: topk,
|
||||
offset: offset,
|
||||
roundDecimal: 10,
|
||||
functionScore: funcScore,
|
||||
}
|
||||
|
||||
_, err = op.run(context.Background(), s.span, reduced[0], []string{"IP"})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestRequeryOp() {
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
|
||||
f1.FieldId = 101
|
||||
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
op := requeryOperator{
|
||||
traceCtx: context.Background(),
|
||||
outputFieldNames: []string{"int64"},
|
||||
}
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := op.run(context.Background(), s.span, ids, []string{"int64"})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestOrganizeOp() {
|
||||
op := organizeOperator{
|
||||
traceCtx: context.Background(),
|
||||
primaryFieldSchema: &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
collectionID: 1,
|
||||
}
|
||||
fields := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "pk",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "int64",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := []*schemapb.IDs{
|
||||
{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 4, 5, 9, 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{5, 6, 7, 8, 9, 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ret, err := op.run(context.Background(), s.span, fields, ids)
|
||||
s.NoError(err)
|
||||
fmt.Println(ret)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchPipeline() {
|
||||
collectionName := "test"
|
||||
task := &searchTask{
|
||||
ctx: context.Background(),
|
||||
collectionName: collectionName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
MetricType: "L2",
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
PartitionIDs: []int64{1},
|
||||
CollectionID: 1,
|
||||
DbID: 1,
|
||||
},
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
},
|
||||
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
},
|
||||
queryInfos: []*planpb.QueryInfo{{}},
|
||||
translatedOutputFields: []string{"intField"},
|
||||
}
|
||||
pipeline, err := newPipeline(searchPipe, task)
|
||||
s.NoError(err)
|
||||
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{
|
||||
genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.FieldsData)
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
fmt.Println(results)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchPipelineWithRequery() {
|
||||
collectionName := "test_collection"
|
||||
task := &searchTask{
|
||||
ctx: context.Background(),
|
||||
collectionName: collectionName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
MetricType: "L2",
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
PartitionIDs: []int64{1},
|
||||
CollectionID: 1,
|
||||
DbID: 1,
|
||||
},
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
},
|
||||
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
},
|
||||
queryInfos: []*planpb.QueryInfo{{}},
|
||||
translatedOutputFields: []string{"intField"},
|
||||
node: nil,
|
||||
}
|
||||
|
||||
// Mock requery operation
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
|
||||
f1.FieldId = 101
|
||||
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
|
||||
f2.FieldId = 100
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
pipeline, err := newPipeline(searchWithRequeryPipe, task)
|
||||
s.NoError(err)
|
||||
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{
|
||||
genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.FieldsData)
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchWithRerankPipe() {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"intField"},
|
||||
OutputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
task := &searchTask{
|
||||
ctx: context.Background(),
|
||||
collectionName: "test_collection",
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
MetricType: "L2",
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
PartitionIDs: []int64{1},
|
||||
CollectionID: 1,
|
||||
DbID: 1,
|
||||
},
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: schema,
|
||||
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
},
|
||||
queryInfos: []*planpb.QueryInfo{{}},
|
||||
translatedOutputFields: []string{"intField"},
|
||||
node: nil,
|
||||
functionScore: funcScore,
|
||||
}
|
||||
|
||||
pipeline, err := newPipeline(searchWithRerankPipe, task)
|
||||
s.NoError(err)
|
||||
|
||||
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
|
||||
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults})
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.FieldsData)
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchWithRerankRequeryPipe() {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"intField"},
|
||||
OutputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
funcScore, err := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
task := &searchTask{
|
||||
ctx: context.Background(),
|
||||
collectionName: "test_collection",
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
MetricType: "L2",
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
PartitionIDs: []int64{1},
|
||||
CollectionID: 1,
|
||||
DbID: 1,
|
||||
},
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: schema,
|
||||
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
},
|
||||
queryInfos: []*planpb.QueryInfo{{}},
|
||||
translatedOutputFields: []string{"intField"},
|
||||
node: nil,
|
||||
functionScore: funcScore,
|
||||
}
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
|
||||
f1.FieldId = 101
|
||||
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
|
||||
f2.FieldId = 100
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
pipeline, err := newPipeline(searchWithRerankRequeryPipe, task)
|
||||
s.NoError(err)
|
||||
|
||||
searchResults := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, false)
|
||||
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{searchResults})
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.FieldsData)
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchPipe() {
|
||||
task := getHybridSearchTask("test_collection", [][]string{
|
||||
{"1", "2"},
|
||||
{"3", "4"},
|
||||
},
|
||||
[]string{},
|
||||
)
|
||||
|
||||
pipeline, err := newPipeline(hybridSearchPipe, task)
|
||||
s.NoError(err)
|
||||
|
||||
f1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
f2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{f1, f2})
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||
task := getHybridSearchTask("test_collection", [][]string{
|
||||
{"1", "2"},
|
||||
{"3", "4"},
|
||||
},
|
||||
[]string{"intField"},
|
||||
)
|
||||
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "intField", 20)
|
||||
f1.FieldId = 101
|
||||
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, "int64", 20)
|
||||
f2.FieldId = 100
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
pipeline, err := newPipeline(hybridSearchWithRequeryPipe, task)
|
||||
s.NoError(err)
|
||||
|
||||
d1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
d2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, "intField", 101, true)
|
||||
results, err := pipeline.Run(context.Background(), s.span, []*internalpb.SearchResults{d1, d2})
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.NotNil(results.Results)
|
||||
s.Equal(int64(2), results.Results.NumQueries)
|
||||
s.Equal(int64(10), results.Results.Topks[0])
|
||||
s.Equal(int64(10), results.Results.Topks[1])
|
||||
s.NotNil(results.Results.Ids)
|
||||
s.NotNil(results.Results.Ids.GetIntId())
|
||||
s.Len(results.Results.Ids.GetIntId().Data, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.Scores)
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
s.NotNil(results.Results.FieldsData)
|
||||
s.Len(results.Results.FieldsData, 1) // One output field
|
||||
s.Equal("intField", results.Results.FieldsData[0].FieldName)
|
||||
s.Equal(int64(101), results.Results.FieldsData[0].FieldId)
|
||||
}
|
||||
|
||||
func getHybridSearchTask(collName string, data [][]string, outputFields []string) *searchTask {
|
||||
subReqs := []*milvuspb.SubSearchRequest{}
|
||||
for _, item := range data {
|
||||
subReq := &milvuspb.SubSearchRequest{
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
Nq: int64(len(item)),
|
||||
}
|
||||
subReqs = append(subReqs, subReq)
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{},
|
||||
OutputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "rrf"},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
funcScore, _ := rerank.NewFunctionScore(schema, &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
})
|
||||
task := &searchTask{
|
||||
ctx: context.Background(),
|
||||
collectionName: collName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
IsAdvanced: true,
|
||||
SubReqs: []*internalpb.SubSearchRequest{
|
||||
{
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
},
|
||||
{
|
||||
Topk: 10,
|
||||
Nq: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
SubReqs: subReqs,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: LimitKey, Value: "10"},
|
||||
},
|
||||
FunctionScore: &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
},
|
||||
OutputFields: outputFields,
|
||||
},
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: schema,
|
||||
pkField: &schemapb.FieldSchema{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
},
|
||||
mixCoord: nil,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
rankParams: &rankParams{
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
roundDecimal: 0,
|
||||
},
|
||||
queryInfos: []*planpb.QueryInfo{{}, {}},
|
||||
functionScore: funcScore,
|
||||
translatedOutputFields: outputFields,
|
||||
}
|
||||
return task
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestMergeIDsFunc() {
|
||||
{
|
||||
ids1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 4, 5, 6},
|
||||
},
|
||||
},
|
||||
}
|
||||
rets := []*milvuspb.SearchResults{
|
||||
{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: ids1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: ids2,
|
||||
},
|
||||
},
|
||||
}
|
||||
allIDs, err := mergeIDsFunc(context.Background(), s.span, rets)
|
||||
s.NoError(err)
|
||||
sortedIds := allIDs[0].(*schemapb.IDs).GetIntId().GetData()
|
||||
slices.Sort(sortedIds)
|
||||
s.Equal(sortedIds, []int64{1, 2, 3, 4, 5, 6})
|
||||
}
|
||||
{
|
||||
ids1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "e"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c", "d"},
|
||||
},
|
||||
},
|
||||
}
|
||||
rets := []*milvuspb.SearchResults{
|
||||
{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: ids1,
|
||||
},
|
||||
},
|
||||
}
|
||||
rets = append(rets, &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: ids2,
|
||||
},
|
||||
})
|
||||
allIDs, err := mergeIDsFunc(context.Background(), s.span, rets)
|
||||
s.NoError(err)
|
||||
sortedIds := allIDs[0].(*schemapb.IDs).GetStrId().GetData()
|
||||
slices.Sort(sortedIds)
|
||||
s.Equal(sortedIds, []string{"a", "b", "c", "d", "e"})
|
||||
}
|
||||
}
|
||||
@ -5,12 +5,16 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
@ -470,3 +474,105 @@ func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK, offset int64, metricType string, pkType schemapb.DataType, queryInfo *planpb.QueryInfo, isAdvance bool, collectionID int64, partitionIDs []int64) (*milvuspb.SearchResults, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
// Decode all search results
|
||||
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
|
||||
if err != nil {
|
||||
log.Warn("failed to decode search results", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(validSearchResults) <= 0 {
|
||||
return fillInEmptyResult(nq), nil
|
||||
}
|
||||
|
||||
// Reduce all search results
|
||||
log.Debug("proxy search post execute reduce",
|
||||
zap.Int64("collection", collectionID),
|
||||
zap.Int64s("partitionIDs", partitionIDs),
|
||||
zap.Int("number of valid search results", len(validSearchResults)))
|
||||
var result *milvuspb.SearchResults
|
||||
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(pkType).
|
||||
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
|
||||
defer sp.End()
|
||||
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
||||
results := make([]*schemapb.SearchResultData, 0)
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.SlicedBlob == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var partialResultData schemapb.SearchResultData
|
||||
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, &partialResultData)
|
||||
}
|
||||
tr.CtxElapse(ctx, "decodeSearchResults done")
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
|
||||
if data.NumQueries != nq {
|
||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||
}
|
||||
if data.TopK != topk {
|
||||
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
||||
}
|
||||
|
||||
if len(data.Scores) != pkHitNum {
|
||||
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
||||
len(data.Scores), pkHitNum)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
|
||||
var (
|
||||
subSearchIdx = -1
|
||||
resultDataIdx int64 = -1
|
||||
)
|
||||
maxScore := minFloat32
|
||||
for i := range cursors {
|
||||
if cursors[i] >= subSearchResultData[i].Topks[qi] {
|
||||
continue
|
||||
}
|
||||
sIdx := subSearchNqOffset[i][qi] + cursors[i]
|
||||
sScore := subSearchResultData[i].Scores[sIdx]
|
||||
|
||||
// Choose the larger score idx or the smaller pk idx with the same score
|
||||
if subSearchIdx == -1 || sScore > maxScore {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
} else if sScore == maxScore {
|
||||
if subSearchIdx == -1 {
|
||||
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
|
||||
// by mistake.
|
||||
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
|
||||
} else if typeutil.ComparePK(
|
||||
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
|
||||
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
}
|
||||
}
|
||||
}
|
||||
return subSearchIdx, resultDataIdx
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
@ -577,3 +578,11 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
|
||||
metricType := ""
|
||||
if len(toReduceResults) >= 1 {
|
||||
metricType = toReduceResults[0].GetMetricType()
|
||||
}
|
||||
return metricType
|
||||
}
|
||||
|
||||
@ -10,7 +10,6 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
@ -22,7 +21,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
@ -96,9 +94,6 @@ type searchTask struct {
|
||||
// we always remove pk field from output fields, as search result already contains pk field.
|
||||
// if the user explicitly set pk field in output fields, we add it back to the result.
|
||||
userRequestedPkFieldExplicitly bool
|
||||
|
||||
// To facilitate writing unit tests
|
||||
requeryFunc func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error)
|
||||
}
|
||||
|
||||
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
@ -488,7 +483,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
|
||||
t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
|
||||
|
||||
// used for requery
|
||||
if t.partitionKeyMode {
|
||||
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
|
||||
}
|
||||
@ -496,57 +490,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
|
||||
// Collecting the results of a subsearch
|
||||
// [[shard1, shard2, ...],[shard1, shard2, ...]]
|
||||
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for _, searchResult := range toReduceResults {
|
||||
// if get a non-advanced result, skip all
|
||||
if !searchResult.GetIsAdvanced() {
|
||||
continue
|
||||
}
|
||||
for _, subResult := range searchResult.GetSubResults() {
|
||||
// swallow copy
|
||||
internalResults := &internalpb.SearchResults{
|
||||
MetricType: subResult.GetMetricType(),
|
||||
NumQueries: subResult.GetNumQueries(),
|
||||
TopK: subResult.GetTopK(),
|
||||
SlicedBlob: subResult.GetSlicedBlob(),
|
||||
SlicedNumCount: subResult.GetSlicedNumCount(),
|
||||
SlicedOffset: subResult.GetSlicedOffset(),
|
||||
IsAdvanced: false,
|
||||
}
|
||||
reqIndex := subResult.GetReqIndex()
|
||||
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
||||
}
|
||||
}
|
||||
|
||||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
searchMetrics := []string{}
|
||||
for index, internalResults := range multipleInternalResults {
|
||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||
// Since the metrictype in the request may be empty, it can only be obtained from the result
|
||||
subMetricType := getMetricType(internalResults)
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
searchMetrics = append(searchMetrics, subMetricType)
|
||||
multipleMilvusResults[index] = result
|
||||
}
|
||||
|
||||
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
|
||||
})
|
||||
t.fillResult()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillResult() {
|
||||
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
|
||||
resultSizeInsufficient := false
|
||||
@ -558,111 +501,6 @@ func (t *searchTask) fillResult() {
|
||||
}
|
||||
t.resultSizeInsufficient = resultSizeInsufficient
|
||||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
}
|
||||
|
||||
func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
|
||||
uniqueIDs := &schemapb.IDs{}
|
||||
int64IDs := typeutil.NewSet[int64]()
|
||||
strIDs := typeutil.NewSet[string]()
|
||||
|
||||
for _, ids := range idsList {
|
||||
if ids == nil {
|
||||
continue
|
||||
}
|
||||
switch ids.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
int64IDs.Insert(ids.GetIntId().GetData()...)
|
||||
case *schemapb.IDs_StrId:
|
||||
strIDs.Insert(ids.GetStrId().GetData()...)
|
||||
}
|
||||
}
|
||||
|
||||
if int64IDs.Len() > 0 {
|
||||
uniqueIDs.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: int64IDs.Collect(),
|
||||
},
|
||||
}
|
||||
return uniqueIDs, int64IDs.Len()
|
||||
}
|
||||
|
||||
if strIDs.Len() > 0 {
|
||||
uniqueIDs.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: strIDs.Collect(),
|
||||
},
|
||||
}
|
||||
return uniqueIDs, strIDs.Len()
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error {
|
||||
var err error
|
||||
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
|
||||
defer sp.End()
|
||||
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, groupScorerStr, searchMetrics,
|
||||
)
|
||||
return t.functionScore.Process(ctx, params, results)
|
||||
}
|
||||
|
||||
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
|
||||
// At this time, outputFields and rerank input_fields will be recalled.
|
||||
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
|
||||
if t.needRequery {
|
||||
idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) {
|
||||
return m.Results.Ids, true
|
||||
})
|
||||
allIDs, count := mergeIDs(idsList)
|
||||
if count == 0 {
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: t.Nq,
|
||||
TopK: t.rankParams.limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
allNames := typeutil.NewSet[string](t.translatedOutputFields...)
|
||||
allNames.Insert(t.functionScore.GetAllInputFieldNames()...)
|
||||
queryResult, err := t.requeryFunc(t, span, allIDs, allNames.Collect())
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), idsList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < len(multipleMilvusResults); i++ {
|
||||
multipleMilvusResults[i].Results.FieldsData = fields[i]
|
||||
}
|
||||
|
||||
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
} else {
|
||||
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
@ -677,19 +515,13 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if t.request.FunctionScore != nil {
|
||||
// TODO: When rerank is configured, range search is also supported
|
||||
if isIterator {
|
||||
return merr.WrapErrParameterInvalidMsg("Range search do not support rerank")
|
||||
}
|
||||
|
||||
if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil {
|
||||
log.Warn("Failed to create function score", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: When rerank is configured, grouping search is also supported
|
||||
if !t.functionScore.IsSupportGroup() && queryInfo.GetGroupByFieldId() > 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search")
|
||||
return merr.WrapErrParameterInvalidMsg("Rerank %s does not support grouping search", t.functionScore.RerankName())
|
||||
}
|
||||
}
|
||||
|
||||
@ -773,57 +605,6 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
|
||||
metricType := getMetricType(toReduceResults)
|
||||
result, err := t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
|
||||
{
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
|
||||
defer sp.End()
|
||||
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
|
||||
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
|
||||
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, groupScorerStr, []string{metricType})
|
||||
// rank only returns id and score
|
||||
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !t.needRequery {
|
||||
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
} else {
|
||||
t.result = result
|
||||
}
|
||||
t.fillResult()
|
||||
if t.needRequery {
|
||||
if t.requeryFunc == nil {
|
||||
t.requeryFunc = requeryImpl
|
||||
}
|
||||
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
|
||||
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
|
||||
if err != nil || len(annsFieldName) == 0 {
|
||||
@ -919,50 +700,6 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
|
||||
metricType := ""
|
||||
if len(toReduceResults) >= 1 {
|
||||
metricType = toReduceResults[0].GetMetricType()
|
||||
}
|
||||
return metricType
|
||||
}
|
||||
|
||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
// Decode all search results
|
||||
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
|
||||
if err != nil {
|
||||
log.Warn("failed to decode search results", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(validSearchResults) <= 0 {
|
||||
return fillInEmptyResult(nq), nil
|
||||
}
|
||||
|
||||
// Reduce all search results
|
||||
log.Debug("proxy search post execute reduce",
|
||||
zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
|
||||
zap.Int("number of valid search results", len(validSearchResults)))
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
var result *milvuspb.SearchResults
|
||||
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()).
|
||||
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// find the last bound based on reduced results and metric type
|
||||
// only support nq == 1, for search iterator v2
|
||||
func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 {
|
||||
@ -1017,15 +754,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
t.isTopkReduce = isTopkReduce
|
||||
t.isRecallEvaluation = isRecallEvaluation
|
||||
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
err = t.advancedPostProcess(ctx, sp, toReduceResults)
|
||||
} else {
|
||||
err = t.searchPostProcess(ctx, sp, toReduceResults)
|
||||
}
|
||||
|
||||
// call pipeline
|
||||
pipeline, err := newBuiltInPipeline(t)
|
||||
if err != nil {
|
||||
log.Warn("Faild to create post process pipeline")
|
||||
return err
|
||||
}
|
||||
if t.result, err = pipeline.Run(ctx, sp, toReduceResults); err != nil {
|
||||
return err
|
||||
}
|
||||
t.fillResult()
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
|
||||
@ -1143,161 +881,6 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
|
||||
//return int64(sizePerRecord) * nq * topK, nil
|
||||
}
|
||||
|
||||
func requeryImpl(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
Timestamp: t.BeginTs(),
|
||||
},
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
|
||||
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
|
||||
Expr: "",
|
||||
OutputFields: outputFields,
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
UseDefaultConsistency: false,
|
||||
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
|
||||
}
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
||||
channelsMvcc := make(map[string]Timestamp)
|
||||
for k, v := range t.queryChannelsTs {
|
||||
channelsMvcc[k] = v
|
||||
}
|
||||
qt := &queryTask{
|
||||
ctx: t.ctx,
|
||||
Condition: NewTaskCondition(t.ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
PartitionIDs: t.GetPartitionIDs(), // use search partitionIDs
|
||||
ConsistencyLevel: t.ConsistencyLevel,
|
||||
},
|
||||
request: queryReq,
|
||||
plan: plan,
|
||||
mixCoord: t.node.(*Proxy).mixCoord,
|
||||
lb: t.node.(*Proxy).lbPolicy,
|
||||
channelsMvcc: channelsMvcc,
|
||||
fastSkip: true,
|
||||
reQuery: true,
|
||||
}
|
||||
queryResult, err := t.node.(*Proxy).query(t.ctx, qt, span)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return nil, merr.Error(queryResult.GetStatus())
|
||||
}
|
||||
return queryResult, err
|
||||
}
|
||||
|
||||
func isEmpty(ids *schemapb.IDs) bool {
|
||||
if ids == nil {
|
||||
return true
|
||||
}
|
||||
if ids.GetIntId() != nil && len(ids.GetIntId().Data) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if ids.GetStrId() != nil && len(ids.GetStrId().Data) != 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *searchTask) reorganizeRequeryResults(ctx context.Context, fields []*schemapb.FieldData, idsList []*schemapb.IDs) ([][]*schemapb.FieldData, error) {
|
||||
_, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults")
|
||||
defer sp.End()
|
||||
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(fields, pkField)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
pkItr := typeutil.GetDataIterator(pkFieldData)
|
||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||
pk := pkItr(i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
allFieldData := make([][]*schemapb.FieldData, len(idsList))
|
||||
for idx, ids := range idsList {
|
||||
if isEmpty(ids) {
|
||||
emptyFields := []*schemapb.FieldData{}
|
||||
for _, field := range fields {
|
||||
emptyFields = append(emptyFields, &schemapb.FieldData{
|
||||
Type: field.Type,
|
||||
FieldName: field.FieldName,
|
||||
FieldId: field.FieldId,
|
||||
IsDynamic: field.IsDynamic,
|
||||
})
|
||||
}
|
||||
allFieldData[idx] = emptyFields
|
||||
continue
|
||||
}
|
||||
if fieldData, err := t.pickFieldData(ids, offsets, fields); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
allFieldData[idx] = fieldData
|
||||
}
|
||||
}
|
||||
return allFieldData, nil
|
||||
}
|
||||
|
||||
// pick field data from query results
|
||||
func (t *searchTask) pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, fields []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
|
||||
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||
// ===========================================
|
||||
// 3 2 5 4 1 (query ids)
|
||||
// ||
|
||||
// || (query)
|
||||
// \/
|
||||
// 4 3 5 1 2 (result ids)
|
||||
// v4 v3 v5 v1 v2 (result vectors)
|
||||
// ||
|
||||
// || (reorganize)
|
||||
// \/
|
||||
// 3 2 5 4 1 (result ids)
|
||||
// v3 v2 v5 v4 v1 (result vectors)
|
||||
// ===========================================
|
||||
fieldsData := make([]*schemapb.FieldData, len(fields))
|
||||
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||
id := typeutil.GetPK(ids, int64(i))
|
||||
if _, ok := pkOffset[id]; !ok {
|
||||
return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(pkOffset), t.GetCollectionID()))
|
||||
}
|
||||
typeutil.AppendFieldData(fieldsData, fields, int64(pkOffset[id]))
|
||||
}
|
||||
|
||||
return fieldsData, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInFieldInfo() {
|
||||
for _, retField := range t.result.Results.FieldsData {
|
||||
for _, schemaField := range t.schema.Fields {
|
||||
if retField != nil && retField.FieldId == schemaField.FieldID {
|
||||
retField.FieldName = schemaField.Name
|
||||
retField.Type = schemaField.DataType
|
||||
retField.IsDynamic = schemaField.IsDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.SearchResults, error) {
|
||||
select {
|
||||
case <-t.TraceCtx().Done():
|
||||
@ -1316,77 +899,6 @@ func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.Se
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "decodeSearchResults")
|
||||
defer sp.End()
|
||||
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
||||
results := make([]*schemapb.SearchResultData, 0)
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.SlicedBlob == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var partialResultData schemapb.SearchResultData
|
||||
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, &partialResultData)
|
||||
}
|
||||
tr.CtxElapse(ctx, "decodeSearchResults done")
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
|
||||
if data.NumQueries != nq {
|
||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||
}
|
||||
if data.TopK != topk {
|
||||
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
||||
}
|
||||
|
||||
if len(data.Scores) != pkHitNum {
|
||||
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
||||
len(data.Scores), pkHitNum)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectHighestScoreIndex(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
|
||||
var (
|
||||
subSearchIdx = -1
|
||||
resultDataIdx int64 = -1
|
||||
)
|
||||
maxScore := minFloat32
|
||||
for i := range cursors {
|
||||
if cursors[i] >= subSearchResultData[i].Topks[qi] {
|
||||
continue
|
||||
}
|
||||
sIdx := subSearchNqOffset[i][qi] + cursors[i]
|
||||
sScore := subSearchResultData[i].Scores[sIdx]
|
||||
|
||||
// Choose the larger score idx or the smaller pk idx with the same score
|
||||
if subSearchIdx == -1 || sScore > maxScore {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
} else if sScore == maxScore {
|
||||
if subSearchIdx == -1 {
|
||||
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
|
||||
// by mistake.
|
||||
log.Ctx(ctx).Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
|
||||
} else if typeutil.ComparePK(
|
||||
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
|
||||
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
}
|
||||
}
|
||||
}
|
||||
return subSearchIdx, resultDataIdx
|
||||
}
|
||||
|
||||
func (t *searchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
@ -19,12 +19,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/samber/lo"
|
||||
@ -32,7 +32,6 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
@ -394,12 +393,10 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
||||
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
|
||||
f3.FieldId = fieldNameId[testInt64Field]
|
||||
|
||||
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3},
|
||||
PrimaryFieldName: testInt64Field,
|
||||
}, nil
|
||||
}
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
@ -436,12 +433,10 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
||||
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
|
||||
f3.FieldId = fieldNameId[testInt64Field]
|
||||
|
||||
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3},
|
||||
PrimaryFieldName: testInt64Field,
|
||||
}, nil
|
||||
}
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
@ -478,12 +473,11 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
||||
f3.FieldId = fieldNameId[testInt64Field]
|
||||
f4 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
|
||||
f4.FieldId = fieldNameId[testFloatField]
|
||||
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3, f4},
|
||||
PrimaryFieldName: testInt64Field,
|
||||
}, nil
|
||||
}
|
||||
mocker := mockey.Mock((*requeryOperator).requery).Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3, f4},
|
||||
}, nil).Build()
|
||||
defer mocker.UnPatch()
|
||||
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
|
||||
@ -505,53 +499,6 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test mergeIDs function", func(t *testing.T) {
|
||||
{
|
||||
ids1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 4, 5, 6},
|
||||
},
|
||||
},
|
||||
}
|
||||
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
|
||||
assert.Equal(t, count, 6)
|
||||
sortedIds := allIDs.GetIntId().GetData()
|
||||
slices.Sort(sortedIds)
|
||||
assert.Equal(t, sortedIds, []int64{1, 2, 3, 4, 5, 6})
|
||||
}
|
||||
{
|
||||
ids1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "e"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c", "d"},
|
||||
},
|
||||
},
|
||||
}
|
||||
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
|
||||
assert.Equal(t, count, 5)
|
||||
sortedIds := allIDs.GetStrId().GetData()
|
||||
slices.Sort(sortedIds)
|
||||
assert.Equal(t, sortedIds, []string{"a", "b", "c", "d", "e"})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createCollWithFields(t *testing.T, collName string, rc types.MixCoordClient) (*schemapb.CollectionSchema, map[string]int64) {
|
||||
@ -3743,9 +3690,10 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
translatedOutputFields: outputFields,
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
queryResult, err := qt.requeryFunc(qt, nil, qt.result.Results.Ids, outputFields)
|
||||
op, err := newRequeryOperator(qt, nil)
|
||||
assert.NoError(t, err)
|
||||
queryResult, err := op.(*requeryOperator).requery(ctx, nil, qt.result.Results.Ids, outputFields)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, queryResult.FieldsData, 2)
|
||||
for _, field := range qt.result.Results.FieldsData {
|
||||
@ -3768,14 +3716,13 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
requeryFunc: requeryImpl,
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
|
||||
_, err := newRequeryOperator(qt, nil)
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
@ -3804,13 +3751,14 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
requeryFunc: requeryImpl,
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
|
||||
op, err := newRequeryOperator(qt, nil)
|
||||
assert.NoError(t, err)
|
||||
_, err = op.(*requeryOperator).requery(ctx, nil, &schemapb.IDs{}, []string{})
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@ -34,7 +34,7 @@ const (
|
||||
// The current version only supports plain text, and cipher text will be supported later.
|
||||
type Credentials struct {
|
||||
// key formats:
|
||||
// {credentialName}.api_key
|
||||
// {credentialName}.apikey
|
||||
// {credentialName}.access_key_id
|
||||
// {credentialName}.secret_access_key
|
||||
// {credentialName}.credential_json
|
||||
|
||||
@ -234,7 +234,7 @@ func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchPa
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
Topks: make([]int64, searchParams.nq),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@ -280,3 +280,10 @@ func (fScore *FunctionScore) IsSupportGroup() bool {
|
||||
}
|
||||
return fScore.reranker.IsSupportGroup()
|
||||
}
|
||||
|
||||
func (fScore *FunctionScore) RerankName() string {
|
||||
if fScore == nil {
|
||||
return ""
|
||||
}
|
||||
return fScore.reranker.GetRankName()
|
||||
}
|
||||
|
||||
@ -322,7 +322,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() {
|
||||
rankParams := []*commonpb.KeyValuePair{}
|
||||
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
s.Equal(f.reranker.GetRankName(), rrfName)
|
||||
s.Equal(f.RerankName(), rrfName)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
|
||||
@ -289,8 +289,14 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim)
|
||||
# 2. hybrid search
|
||||
# 2. insert
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [
|
||||
{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_vector_field_name+"new": list(rng.random((1, default_dim))[0]),
|
||||
default_string_field_name: str(i)} for i in range(default_nb)]
|
||||
self.insert(client, collection_name, rows)
|
||||
# 2. hybrid search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
sub_search1 = AnnSearchRequest(vectors_to_search, "vector", {"level": 1}, 20, expr="id<100")
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user