milvus/internal/util/function/rerank/function_score.go
junjiejiangjjj f23df95a77
feat : Support decay rerank (#41223)
https://github.com/milvus-io/milvus/issues/35856
#41312

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
2025-04-23 20:48:39 +08:00

167 lines
4.8 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"fmt"
"strings"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
const (
decayFunctionName string = "decay"
)
type SearchParams struct {
nq int64
limit int64
offset int64
roundDecimal int64
// TODO: supports group search
groupByFieldId int64
groupSize int64
strictGroupSize bool
}
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool) *SearchParams {
return &SearchParams{
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize,
}
}
type Reranker interface {
Process(ctx context.Context, searchParams *SearchParams, searchData []*schemapb.SearchResultData) (*schemapb.SearchResultData, error)
IsSupportGroup() bool
GetInputFieldNames() []string
GetInputFieldIDs() []int64
GetRankName() string
}
func getRerankName(funcSchema *schemapb.FunctionSchema) string {
for _, param := range funcSchema.Params {
switch strings.ToLower(param.Key) {
case reranker:
return strings.ToLower(param.Value)
default:
}
}
return ""
}
// Currently only supports single rerank
type FunctionScore struct {
reranker Reranker
}
func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
if funcSchema.GetType() != schemapb.FunctionType_Rerank {
return nil, fmt.Errorf("%s is not rerank function.", funcSchema.GetType().String())
}
if len(funcSchema.GetOutputFieldNames()) != 0 {
return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames()))
}
rerankerName := getRerankName(funcSchema)
var rerankFunc Reranker
var newRerankErr error
switch rerankerName {
case decayFunctionName:
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
default:
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
}
if newRerankErr != nil {
return nil, newRerankErr
}
return rerankFunc, nil
}
func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) {
if len(funcScoreSchema.Functions) > 1 || len(funcScoreSchema.Functions) == 0 {
return nil, fmt.Errorf("Currently only supports one rerank, but got %d", len(funcScoreSchema.Functions))
}
funcScore := &FunctionScore{}
var err error
if funcScore.reranker, err = createFunction(collSchema, funcScoreSchema.Functions[0]); err != nil {
return nil, err
}
return funcScore, nil
}
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
if len(multipleMilvusResults) == 0 {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: searchParams.nq,
TopK: searchParams.limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}, nil
}
allSearchResultData := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.SearchResultData, bool) {
return m.Results, true
})
// rankResult only has scores
rankResult, err := fScore.reranker.Process(ctx, searchParams, allSearchResultData)
if err != nil {
return nil, err
}
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: rankResult,
}
return ret, nil
}
func (fScore *FunctionScore) GetAllInputFieldNames() []string {
if fScore == nil {
return []string{}
}
return fScore.reranker.GetInputFieldNames()
}
func (fScore *FunctionScore) GetAllInputFieldIDs() []int64 {
if fScore == nil {
return []int64{}
}
return fScore.reranker.GetInputFieldIDs()
}
func (fScore *FunctionScore) IsSupportGroup() bool {
if fScore == nil {
return true
}
return fScore.reranker.IsSupportGroup()
}