mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
https://github.com/milvus-io/milvus/issues/35856 #41312 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
167 lines
4.8 KiB
Go
167 lines
4.8 KiB
Go
/*
|
|
* # Licensed to the LF AI & Data foundation under one
|
|
* # or more contributor license agreements. See the NOTICE file
|
|
* # distributed with this work for additional information
|
|
* # regarding copyright ownership. The ASF licenses this file
|
|
* # to you under the Apache License, Version 2.0 (the
|
|
* # "License"); you may not use this file except in compliance
|
|
* # with the License. You may obtain a copy of the License at
|
|
* #
|
|
* # http://www.apache.org/licenses/LICENSE-2.0
|
|
* #
|
|
* # Unless required by applicable law or agreed to in writing, software
|
|
* # distributed under the License is distributed on an "AS IS" BASIS,
|
|
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* # See the License for the specific language governing permissions and
|
|
* # limitations under the License.
|
|
*/
|
|
|
|
package rerank
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/samber/lo"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
const (
|
|
decayFunctionName string = "decay"
|
|
)
|
|
|
|
type SearchParams struct {
|
|
nq int64
|
|
limit int64
|
|
offset int64
|
|
roundDecimal int64
|
|
|
|
// TODO: supports group search
|
|
groupByFieldId int64
|
|
groupSize int64
|
|
strictGroupSize bool
|
|
}
|
|
|
|
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool) *SearchParams {
|
|
return &SearchParams{
|
|
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize,
|
|
}
|
|
}
|
|
|
|
type Reranker interface {
|
|
Process(ctx context.Context, searchParams *SearchParams, searchData []*schemapb.SearchResultData) (*schemapb.SearchResultData, error)
|
|
IsSupportGroup() bool
|
|
GetInputFieldNames() []string
|
|
GetInputFieldIDs() []int64
|
|
GetRankName() string
|
|
}
|
|
|
|
func getRerankName(funcSchema *schemapb.FunctionSchema) string {
|
|
for _, param := range funcSchema.Params {
|
|
switch strings.ToLower(param.Key) {
|
|
case reranker:
|
|
return strings.ToLower(param.Value)
|
|
default:
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Currently only supports single rerank
|
|
type FunctionScore struct {
|
|
reranker Reranker
|
|
}
|
|
|
|
func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
|
if funcSchema.GetType() != schemapb.FunctionType_Rerank {
|
|
return nil, fmt.Errorf("%s is not rerank function.", funcSchema.GetType().String())
|
|
}
|
|
if len(funcSchema.GetOutputFieldNames()) != 0 {
|
|
return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames()))
|
|
}
|
|
|
|
rerankerName := getRerankName(funcSchema)
|
|
var rerankFunc Reranker
|
|
var newRerankErr error
|
|
switch rerankerName {
|
|
case decayFunctionName:
|
|
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
|
|
default:
|
|
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
|
|
}
|
|
|
|
if newRerankErr != nil {
|
|
return nil, newRerankErr
|
|
}
|
|
return rerankFunc, nil
|
|
}
|
|
|
|
func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) {
|
|
if len(funcScoreSchema.Functions) > 1 || len(funcScoreSchema.Functions) == 0 {
|
|
return nil, fmt.Errorf("Currently only supports one rerank, but got %d", len(funcScoreSchema.Functions))
|
|
}
|
|
funcScore := &FunctionScore{}
|
|
var err error
|
|
if funcScore.reranker, err = createFunction(collSchema, funcScoreSchema.Functions[0]); err != nil {
|
|
return nil, err
|
|
}
|
|
return funcScore, nil
|
|
}
|
|
|
|
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
|
if len(multipleMilvusResults) == 0 {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: searchParams.nq,
|
|
TopK: searchParams.limit,
|
|
FieldsData: make([]*schemapb.FieldData, 0),
|
|
Scores: []float32{},
|
|
Ids: &schemapb.IDs{},
|
|
Topks: []int64{},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
allSearchResultData := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.SearchResultData, bool) {
|
|
return m.Results, true
|
|
})
|
|
|
|
// rankResult only has scores
|
|
rankResult, err := fScore.reranker.Process(ctx, searchParams, allSearchResultData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: rankResult,
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (fScore *FunctionScore) GetAllInputFieldNames() []string {
|
|
if fScore == nil {
|
|
return []string{}
|
|
}
|
|
return fScore.reranker.GetInputFieldNames()
|
|
}
|
|
|
|
func (fScore *FunctionScore) GetAllInputFieldIDs() []int64 {
|
|
if fScore == nil {
|
|
return []int64{}
|
|
}
|
|
return fScore.reranker.GetInputFieldIDs()
|
|
}
|
|
|
|
func (fScore *FunctionScore) IsSupportGroup() bool {
|
|
if fScore == nil {
|
|
return true
|
|
}
|
|
return fScore.reranker.IsSupportGroup()
|
|
}
|