mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: [GoSDK] Support function reranker (#43845)
Related to #35856 Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
3e9e830074
commit
1b87e864ca
@ -52,7 +52,15 @@ func (s *SearchOptionSuite) TestBasic() {
|
||||
topK := rand.Intn(100) + 1
|
||||
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
|
||||
|
||||
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true)
|
||||
rerankerFunction := entity.NewFunction().WithName("time_decay").WithInputFields("timestamp").WithType(entity.FunctionTypeRerank).
|
||||
WithParam("reranker", "decay").
|
||||
WithParam("function", "gauss").
|
||||
WithParam("origin", 1754995249).
|
||||
WithParam("scale", 7*24*60*60).
|
||||
WithParam("offset", 24*60*60).
|
||||
WithParam("decay", 0.5)
|
||||
|
||||
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true).WithFunctionReranker(rerankerFunction)
|
||||
req, err := opt.Request()
|
||||
s.Require().NoError(err)
|
||||
|
||||
@ -73,6 +81,9 @@ func (s *SearchOptionSuite) TestBasic() {
|
||||
s.Require().True(ok)
|
||||
s.Equal("true", spStrictGroupSize)
|
||||
|
||||
functionScore := req.GetFunctionScore()
|
||||
s.Len(functionScore.GetFunctions(), 1)
|
||||
|
||||
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
|
||||
_, err = opt.Request()
|
||||
s.Error(err)
|
||||
|
||||
@ -79,6 +79,8 @@ type AnnRequest struct {
|
||||
topK int
|
||||
offset int
|
||||
templateParams map[string]any
|
||||
|
||||
functionRerankers []*entity.Function
|
||||
}
|
||||
|
||||
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *AnnRequest {
|
||||
@ -144,6 +146,13 @@ func (r *AnnRequest) searchRequest() (*milvuspb.SearchRequest, error) {
|
||||
request.ExprTemplateValues[key] = tmplVal
|
||||
}
|
||||
|
||||
if len(r.functionRerankers) > 0 {
|
||||
request.FunctionScore = &schemapb.FunctionScore{}
|
||||
for _, fr := range r.functionRerankers {
|
||||
request.FunctionScore.Functions = append(request.FunctionScore.Functions, fr.ProtoMessage())
|
||||
}
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@ -277,6 +286,11 @@ func (r *AnnRequest) WithIgnoreGrowing(ignoreGrowing bool) *AnnRequest {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *AnnRequest) WithFunctionReranker(fr *entity.Function) *AnnRequest {
|
||||
r.functionRerankers = append(r.functionRerankers, fr)
|
||||
return r
|
||||
}
|
||||
|
||||
func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) {
|
||||
request, err := opt.annRequest.searchRequest()
|
||||
if err != nil {
|
||||
@ -358,6 +372,11 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption {
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithFunctionReranker(fr *entity.Function) *searchOption {
|
||||
opt.annRequest.WithFunctionReranker(fr)
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
||||
return &searchOption{
|
||||
annRequest: NewAnnRequest("", limit, vectors...),
|
||||
@ -430,9 +449,10 @@ type hybridSearchOption struct {
|
||||
useDefaultConsistency bool
|
||||
consistencyLevel entity.ConsistencyLevel
|
||||
|
||||
limit int
|
||||
offset int
|
||||
reranker Reranker
|
||||
limit int
|
||||
offset int
|
||||
reranker Reranker
|
||||
functionRerankers []*entity.Function
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
|
||||
@ -461,6 +481,11 @@ func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOpti
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) WithFunctionRerankers(functionReranker *entity.Function) *hybridSearchOption {
|
||||
opt.functionRerankers = append(opt.functionRerankers, functionReranker)
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
|
||||
opt.offset = offset
|
||||
return opt
|
||||
@ -485,7 +510,7 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
|
||||
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
|
||||
}
|
||||
|
||||
return &milvuspb.HybridSearchRequest{
|
||||
r := &milvuspb.HybridSearchRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
Requests: requests,
|
||||
@ -493,7 +518,16 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
||||
OutputFields: opt.outputFields,
|
||||
RankParams: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(opt.functionRerankers) > 0 {
|
||||
r.FunctionScore = &schemapb.FunctionScore{}
|
||||
for _, fr := range opt.functionRerankers {
|
||||
r.FunctionScore.Functions = append(r.FunctionScore.Functions, fr.ProtoMessage())
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*AnnRequest) *hybridSearchOption {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user