mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
relate: https://github.com/milvus-io/milvus/issues/43867 Support boost function score, multiply by the weight if match filter. Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
312 lines
9.1 KiB
Go
312 lines
9.1 KiB
Go
/*
|
|
* # Licensed to the LF AI & Data foundation under one
|
|
* # or more contributor license agreements. See the NOTICE file
|
|
* # distributed with this work for additional information
|
|
* # regarding copyright ownership. The ASF licenses this file
|
|
* # to you under the Apache License, Version 2.0 (the
|
|
* # "License"); you may not use this file except in compliance
|
|
* # with the License. You may obtain a copy of the License at
|
|
* #
|
|
* # http://www.apache.org/licenses/LICENSE-2.0
|
|
* #
|
|
* # Unless required by applicable law or agreed to in writing, software
|
|
* # distributed under the License is distributed on an "AS IS" BASIS,
|
|
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* # See the License for the specific language governing permissions and
|
|
* # limitations under the License.
|
|
*/
|
|
|
|
package rerank
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/samber/lo"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
const (
|
|
DecayFunctionName string = "decay"
|
|
ModelFunctionName string = "model"
|
|
RRFName string = "rrf"
|
|
WeightedName string = "weighted"
|
|
)
|
|
|
|
const (
|
|
maxScorer string = "max"
|
|
sumScorer string = "sum"
|
|
avgScorer string = "avg"
|
|
)
|
|
|
|
// legacy rrf/weighted rerank configs
|
|
|
|
const (
|
|
legacyRankTypeKey = "strategy"
|
|
legacyRankParamsKey = "params"
|
|
)
|
|
|
|
type rankType int
|
|
|
|
const (
|
|
invalidRankType rankType = iota // invalidRankType = 0
|
|
rrfRankType // rrfRankType = 1
|
|
weightedRankType // weightedRankType = 2
|
|
)
|
|
|
|
// segment scorer configs
|
|
const (
|
|
BoostName = "boost"
|
|
FilterKey = "filter"
|
|
WeightKey = "weight"
|
|
)
|
|
|
|
var rankTypeMap = map[string]rankType{
|
|
"invalid": invalidRankType,
|
|
"rrf": rrfRankType,
|
|
"weighted": weightedRankType,
|
|
}
|
|
|
|
type SearchParams struct {
|
|
nq int64
|
|
limit int64
|
|
offset int64
|
|
roundDecimal int64
|
|
|
|
groupByFieldId int64
|
|
groupSize int64
|
|
strictGroupSize bool
|
|
groupScore string
|
|
|
|
searchMetrics []string
|
|
}
|
|
|
|
func (s *SearchParams) isGrouping() bool {
|
|
return s.groupByFieldId > 0
|
|
}
|
|
|
|
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, groupScore string, searchMetrics []string) *SearchParams {
|
|
if groupScore == "" {
|
|
groupScore = maxScorer
|
|
}
|
|
return &SearchParams{
|
|
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, groupScore, searchMetrics,
|
|
}
|
|
}
|
|
|
|
type Reranker interface {
|
|
Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error)
|
|
IsSupportGroup() bool
|
|
GetInputFieldNames() []string
|
|
GetInputFieldIDs() []int64
|
|
GetRankName() string
|
|
}
|
|
|
|
func GetRerankName(funcSchema *schemapb.FunctionSchema) string {
|
|
for _, param := range funcSchema.Params {
|
|
switch strings.ToLower(param.Key) {
|
|
case reranker:
|
|
return strings.ToLower(param.Value)
|
|
default:
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Currently only supports single rerank
|
|
type FunctionScore struct {
|
|
reranker Reranker
|
|
}
|
|
|
|
func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
|
if funcSchema.GetType() != schemapb.FunctionType_Rerank {
|
|
return nil, fmt.Errorf("%s is not rerank function.", funcSchema.GetType().String())
|
|
}
|
|
if len(funcSchema.GetOutputFieldNames()) != 0 {
|
|
return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames()))
|
|
}
|
|
|
|
rerankerName := GetRerankName(funcSchema)
|
|
var rerankFunc Reranker
|
|
var newRerankErr error
|
|
switch rerankerName {
|
|
case DecayFunctionName:
|
|
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
|
|
case ModelFunctionName:
|
|
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
|
|
case RRFName:
|
|
rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema)
|
|
case WeightedName:
|
|
rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema)
|
|
case BoostName:
|
|
return nil, nil
|
|
default:
|
|
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s,%s,%s]", rerankerName, DecayFunctionName, ModelFunctionName, RRFName, WeightedName)
|
|
}
|
|
|
|
if newRerankErr != nil {
|
|
return nil, newRerankErr
|
|
}
|
|
return rerankFunc, nil
|
|
}
|
|
|
|
func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) {
|
|
funcScore := &FunctionScore{}
|
|
|
|
for _, function := range funcScoreSchema.Functions {
|
|
reranker, err := createFunction(collSchema, function)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if reranker != nil {
|
|
if funcScore.reranker == nil {
|
|
funcScore.reranker = reranker
|
|
} else {
|
|
// now only support only use one proxy rerank
|
|
return nil, fmt.Errorf("Currently only supports one rerank")
|
|
}
|
|
}
|
|
}
|
|
|
|
if funcScore.reranker != nil {
|
|
return funcScore, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) {
|
|
var params map[string]interface{}
|
|
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankTypeKey, rankParams)
|
|
if err != nil {
|
|
rankTypeStr = "rrf"
|
|
params = make(map[string]interface{}, 0)
|
|
} else {
|
|
if _, ok := rankTypeMap[rankTypeStr]; !ok {
|
|
return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr)
|
|
}
|
|
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankParamsKey, rankParams)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("params" + " not found in rank_params")
|
|
}
|
|
err = json.Unmarshal([]byte(paramStr), ¶ms)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Parse rerank params failed, err: %s", err)
|
|
}
|
|
}
|
|
fSchema := schemapb.FunctionSchema{
|
|
Type: schemapb.FunctionType_Rerank,
|
|
InputFieldNames: []string{},
|
|
OutputFieldNames: []string{},
|
|
Params: []*commonpb.KeyValuePair{},
|
|
}
|
|
switch rankTypeMap[rankTypeStr] {
|
|
case rrfRankType:
|
|
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: RRFName})
|
|
if v, ok := params[RRFParamsKey]; ok {
|
|
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
|
|
k := reflect.ValueOf(v).Float()
|
|
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: RRFParamsKey, Value: strconv.FormatFloat(k, 'f', -1, 64)})
|
|
} else {
|
|
return nil, fmt.Errorf("The type of rank param k should be float")
|
|
}
|
|
}
|
|
case weightedRankType:
|
|
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: WeightedName})
|
|
if v, ok := params[WeightsParamsKey]; ok {
|
|
if d, err := json.Marshal(v); err != nil {
|
|
return nil, fmt.Errorf("The weights param should be an array")
|
|
} else {
|
|
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: string(d)})
|
|
}
|
|
}
|
|
if normScore, ok := params[NormScoreKey]; ok {
|
|
if ns, ok := normScore.(bool); ok {
|
|
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: NormScoreKey, Value: strconv.FormatBool(ns)})
|
|
} else {
|
|
return nil, fmt.Errorf("Weighted rerank err, norm_score should been bool type, but [norm_score:%s]'s type is %T", normScore, normScore)
|
|
}
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr)
|
|
}
|
|
funcScore := &FunctionScore{}
|
|
if funcScore.reranker, err = createFunction(collSchema, &fSchema); err != nil {
|
|
return nil, err
|
|
}
|
|
return funcScore, nil
|
|
}
|
|
|
|
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
|
if len(multipleMilvusResults) == 0 {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: searchParams.nq,
|
|
TopK: searchParams.limit,
|
|
FieldsData: make([]*schemapb.FieldData, 0),
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: make([]int64, searchParams.nq),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
allSearchResultData := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.SearchResultData, bool) {
|
|
return m.Results, true
|
|
})
|
|
|
|
// rankResult only has scores
|
|
inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs(), searchParams.isGrouping())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rankResult, err := fScore.reranker.Process(ctx, searchParams, inputs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: rankResult.searchResultData,
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (fScore *FunctionScore) GetAllInputFieldNames() []string {
|
|
if fScore == nil {
|
|
return []string{}
|
|
}
|
|
return fScore.reranker.GetInputFieldNames()
|
|
}
|
|
|
|
func (fScore *FunctionScore) GetAllInputFieldIDs() []int64 {
|
|
if fScore == nil {
|
|
return []int64{}
|
|
}
|
|
return fScore.reranker.GetInputFieldIDs()
|
|
}
|
|
|
|
func (fScore *FunctionScore) IsSupportGroup() bool {
|
|
if fScore == nil {
|
|
return true
|
|
}
|
|
return fScore.reranker.IsSupportGroup()
|
|
}
|
|
|
|
func (fScore *FunctionScore) RerankName() string {
|
|
if fScore == nil {
|
|
return ""
|
|
}
|
|
return fScore.reranker.GetRankName()
|
|
}
|