diff --git a/client/milvusclient/read_option_test.go b/client/milvusclient/read_option_test.go index 0ce8af0db4..b4697dc697 100644 --- a/client/milvusclient/read_option_test.go +++ b/client/milvusclient/read_option_test.go @@ -52,7 +52,15 @@ func (s *SearchOptionSuite) TestBasic() { topK := rand.Intn(100) + 1 opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})}) - opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true) + rerankerFunction := entity.NewFunction().WithName("time_decay").WithInputFields("timestamp").WithType(entity.FunctionTypeRerank). + WithParam("reranker", "decay"). + WithParam("function", "gauss"). + WithParam("origin", 1754995249). + WithParam("scale", 7*24*60*60). + WithParam("offset", 24*60*60). + WithParam("decay", 0.5) + + opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true).WithFunctionReranker(rerankerFunction) req, err := opt.Request() s.Require().NoError(err) @@ -73,6 +81,9 @@ func (s *SearchOptionSuite) TestBasic() { s.Require().True(ok) s.Equal("true", spStrictGroupSize) + functionScore := req.GetFunctionScore() + s.Len(functionScore.GetFunctions(), 1) + opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}}) _, err = opt.Request() s.Error(err) diff --git a/client/milvusclient/read_options.go b/client/milvusclient/read_options.go index 64681c0926..d99b97216f 100644 --- a/client/milvusclient/read_options.go +++ b/client/milvusclient/read_options.go @@ -79,6 +79,8 @@ type AnnRequest struct { topK int offset int templateParams map[string]any + + functionRerankers []*entity.Function } func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *AnnRequest { @@ -144,6 +146,13 @@ func (r *AnnRequest) searchRequest() (*milvuspb.SearchRequest, error) { request.ExprTemplateValues[key] = tmplVal } + if len(r.functionRerankers) > 0 { + request.FunctionScore = &schemapb.FunctionScore{} + for _, fr := range r.functionRerankers { + request.FunctionScore.Functions = append(request.FunctionScore.Functions, fr.ProtoMessage()) + } + } + return request, nil } @@ -277,6 +286,11 @@ func (r *AnnRequest) WithIgnoreGrowing(ignoreGrowing bool) *AnnRequest { return r } +func (r *AnnRequest) WithFunctionReranker(fr *entity.Function) *AnnRequest { + r.functionRerankers = append(r.functionRerankers, fr) + return r +} + func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) { request, err := opt.annRequest.searchRequest() if err != nil { @@ -358,6 +372,11 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption { return opt } +func (opt *searchOption) WithFunctionReranker(fr *entity.Function) *searchOption { + opt.annRequest.WithFunctionReranker(fr) + return opt +} + func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption { return &searchOption{ annRequest: NewAnnRequest("", limit, vectors...), @@ -430,9 +449,10 @@ type hybridSearchOption struct { useDefaultConsistency bool consistencyLevel entity.ConsistencyLevel - limit int - offset int - reranker Reranker + limit int + offset int + reranker Reranker + functionRerankers []*entity.Function } func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption { @@ -461,6 +481,11 @@ func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOpti return opt } +func (opt *hybridSearchOption) WithFunctionRerankers(functionReranker *entity.Function) *hybridSearchOption { + opt.functionRerankers = append(opt.functionRerankers, functionReranker) + return opt +} + func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption { opt.offset = offset return opt @@ -485,7 +510,7 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)}) } - return &milvuspb.HybridSearchRequest{ + r := &milvuspb.HybridSearchRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, Requests: requests, @@ -493,7 +518,16 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel), OutputFields: opt.outputFields, RankParams: params, - }, nil + } + + if len(opt.functionRerankers) > 0 { + r.FunctionScore = &schemapb.FunctionScore{} + for _, fr := range opt.functionRerankers { + r.FunctionScore.Functions = append(r.FunctionScore.Functions, fr.ProtoMessage()) + } + } + + return r, nil } func NewHybridSearchOption(collectionName string, limit int, annRequests ...*AnnRequest) *hybridSearchOption {