mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: [GoSDK] Support function reranker (#43845)
Related to #35856 Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
3e9e830074
commit
1b87e864ca
@ -52,7 +52,15 @@ func (s *SearchOptionSuite) TestBasic() {
|
|||||||
topK := rand.Intn(100) + 1
|
topK := rand.Intn(100) + 1
|
||||||
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
|
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
|
||||||
|
|
||||||
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true)
|
rerankerFunction := entity.NewFunction().WithName("time_decay").WithInputFields("timestamp").WithType(entity.FunctionTypeRerank).
|
||||||
|
WithParam("reranker", "decay").
|
||||||
|
WithParam("function", "gauss").
|
||||||
|
WithParam("origin", 1754995249).
|
||||||
|
WithParam("scale", 7*24*60*60).
|
||||||
|
WithParam("offset", 24*60*60).
|
||||||
|
WithParam("decay", 0.5)
|
||||||
|
|
||||||
|
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true).WithFunctionReranker(rerankerFunction)
|
||||||
req, err := opt.Request()
|
req, err := opt.Request()
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
@ -73,6 +81,9 @@ func (s *SearchOptionSuite) TestBasic() {
|
|||||||
s.Require().True(ok)
|
s.Require().True(ok)
|
||||||
s.Equal("true", spStrictGroupSize)
|
s.Equal("true", spStrictGroupSize)
|
||||||
|
|
||||||
|
functionScore := req.GetFunctionScore()
|
||||||
|
s.Len(functionScore.GetFunctions(), 1)
|
||||||
|
|
||||||
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
|
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
|
||||||
_, err = opt.Request()
|
_, err = opt.Request()
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|||||||
@ -79,6 +79,8 @@ type AnnRequest struct {
|
|||||||
topK int
|
topK int
|
||||||
offset int
|
offset int
|
||||||
templateParams map[string]any
|
templateParams map[string]any
|
||||||
|
|
||||||
|
functionRerankers []*entity.Function
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *AnnRequest {
|
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *AnnRequest {
|
||||||
@ -144,6 +146,13 @@ func (r *AnnRequest) searchRequest() (*milvuspb.SearchRequest, error) {
|
|||||||
request.ExprTemplateValues[key] = tmplVal
|
request.ExprTemplateValues[key] = tmplVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.functionRerankers) > 0 {
|
||||||
|
request.FunctionScore = &schemapb.FunctionScore{}
|
||||||
|
for _, fr := range r.functionRerankers {
|
||||||
|
request.FunctionScore.Functions = append(request.FunctionScore.Functions, fr.ProtoMessage())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -277,6 +286,11 @@ func (r *AnnRequest) WithIgnoreGrowing(ignoreGrowing bool) *AnnRequest {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AnnRequest) WithFunctionReranker(fr *entity.Function) *AnnRequest {
|
||||||
|
r.functionRerankers = append(r.functionRerankers, fr)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) {
|
func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) {
|
||||||
request, err := opt.annRequest.searchRequest()
|
request, err := opt.annRequest.searchRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -358,6 +372,11 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption {
|
|||||||
return opt
|
return opt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (opt *searchOption) WithFunctionReranker(fr *entity.Function) *searchOption {
|
||||||
|
opt.annRequest.WithFunctionReranker(fr)
|
||||||
|
return opt
|
||||||
|
}
|
||||||
|
|
||||||
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
||||||
return &searchOption{
|
return &searchOption{
|
||||||
annRequest: NewAnnRequest("", limit, vectors...),
|
annRequest: NewAnnRequest("", limit, vectors...),
|
||||||
@ -433,6 +452,7 @@ type hybridSearchOption struct {
|
|||||||
limit int
|
limit int
|
||||||
offset int
|
offset int
|
||||||
reranker Reranker
|
reranker Reranker
|
||||||
|
functionRerankers []*entity.Function
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
|
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
|
||||||
@ -461,6 +481,11 @@ func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOpti
|
|||||||
return opt
|
return opt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (opt *hybridSearchOption) WithFunctionRerankers(functionReranker *entity.Function) *hybridSearchOption {
|
||||||
|
opt.functionRerankers = append(opt.functionRerankers, functionReranker)
|
||||||
|
return opt
|
||||||
|
}
|
||||||
|
|
||||||
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
|
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
|
||||||
opt.offset = offset
|
opt.offset = offset
|
||||||
return opt
|
return opt
|
||||||
@ -485,7 +510,7 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
|
|||||||
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
|
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &milvuspb.HybridSearchRequest{
|
r := &milvuspb.HybridSearchRequest{
|
||||||
CollectionName: opt.collectionName,
|
CollectionName: opt.collectionName,
|
||||||
PartitionNames: opt.partitionNames,
|
PartitionNames: opt.partitionNames,
|
||||||
Requests: requests,
|
Requests: requests,
|
||||||
@ -493,7 +518,16 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
|
|||||||
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
||||||
OutputFields: opt.outputFields,
|
OutputFields: opt.outputFields,
|
||||||
RankParams: params,
|
RankParams: params,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
if len(opt.functionRerankers) > 0 {
|
||||||
|
r.FunctionScore = &schemapb.FunctionScore{}
|
||||||
|
for _, fr := range opt.functionRerankers {
|
||||||
|
r.FunctionScore.Functions = append(r.FunctionScore.Functions, fr.ProtoMessage())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*AnnRequest) *hybridSearchOption {
|
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*AnnRequest) *hybridSearchOption {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user