enhance: [GoSDK] Support function reranker (#43845)

Related to #35856

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2025-08-14 14:57:44 +08:00 committed by GitHub
parent 3e9e830074
commit 1b87e864ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 6 deletions

View File

@ -52,7 +52,15 @@ func (s *SearchOptionSuite) TestBasic() {
topK := rand.Intn(100) + 1
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true)
rerankerFunction := entity.NewFunction().WithName("time_decay").WithInputFields("timestamp").WithType(entity.FunctionTypeRerank).
WithParam("reranker", "decay").
WithParam("function", "gauss").
WithParam("origin", 1754995249).
WithParam("scale", 7*24*60*60).
WithParam("offset", 24*60*60).
WithParam("decay", 0.5)
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true).WithFunctionReranker(rerankerFunction)
req, err := opt.Request()
s.Require().NoError(err)
@ -73,6 +81,9 @@ func (s *SearchOptionSuite) TestBasic() {
s.Require().True(ok)
s.Equal("true", spStrictGroupSize)
functionScore := req.GetFunctionScore()
s.Len(functionScore.GetFunctions(), 1)
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
_, err = opt.Request()
s.Error(err)

View File

@ -79,6 +79,8 @@ type AnnRequest struct {
topK int
offset int
templateParams map[string]any
functionRerankers []*entity.Function
}
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *AnnRequest {
@ -144,6 +146,13 @@ func (r *AnnRequest) searchRequest() (*milvuspb.SearchRequest, error) {
request.ExprTemplateValues[key] = tmplVal
}
if len(r.functionRerankers) > 0 {
request.FunctionScore = &schemapb.FunctionScore{}
for _, fr := range r.functionRerankers {
request.FunctionScore.Functions = append(request.FunctionScore.Functions, fr.ProtoMessage())
}
}
return request, nil
}
@ -277,6 +286,11 @@ func (r *AnnRequest) WithIgnoreGrowing(ignoreGrowing bool) *AnnRequest {
return r
}
func (r *AnnRequest) WithFunctionReranker(fr *entity.Function) *AnnRequest {
r.functionRerankers = append(r.functionRerankers, fr)
return r
}
func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) {
request, err := opt.annRequest.searchRequest()
if err != nil {
@ -358,6 +372,11 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption {
return opt
}
func (opt *searchOption) WithFunctionReranker(fr *entity.Function) *searchOption {
opt.annRequest.WithFunctionReranker(fr)
return opt
}
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
return &searchOption{
annRequest: NewAnnRequest("", limit, vectors...),
@ -430,9 +449,10 @@ type hybridSearchOption struct {
useDefaultConsistency bool
consistencyLevel entity.ConsistencyLevel
limit int
offset int
reranker Reranker
limit int
offset int
reranker Reranker
functionRerankers []*entity.Function
}
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
@ -461,6 +481,11 @@ func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOpti
return opt
}
func (opt *hybridSearchOption) WithFunctionRerankers(functionReranker *entity.Function) *hybridSearchOption {
opt.functionRerankers = append(opt.functionRerankers, functionReranker)
return opt
}
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
opt.offset = offset
return opt
@ -485,7 +510,7 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
}
return &milvuspb.HybridSearchRequest{
r := &milvuspb.HybridSearchRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
Requests: requests,
@ -493,7 +518,16 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
OutputFields: opt.outputFields,
RankParams: params,
}, nil
}
if len(opt.functionRerankers) > 0 {
r.FunctionScore = &schemapb.FunctionScore{}
for _, fr := range opt.functionRerankers {
r.FunctionScore.Functions = append(r.FunctionScore.Functions, fr.ProtoMessage())
}
}
return r, nil
}
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*AnnRequest) *hybridSearchOption {