feat : Support decay rerank (#41223)

https://github.com/milvus-io/milvus/issues/35856
#41312

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-04-23 20:48:39 +08:00 committed by GitHub
parent f52c2909c4
commit f23df95a77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2386 additions and 190 deletions

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.17.9
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250325034212-6e98baa34971
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250407030015-dcf7688ad54a
github.com/minio/minio-go/v7 v7.0.73
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/prometheus/client_golang v1.14.0

4
go.sum
View File

@ -740,6 +740,10 @@ github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZz
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250325034212-6e98baa34971 h1:CKKrOtri+dbTUkMJehDuSM489OIqJab1t0pUq+PV73E=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250325034212-6e98baa34971/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250407030015-dcf7688ad54a h1:W+9nVXKcI9FdiyrFbrs9BIFUqRW0pLY+Fn0fsmmuLyw=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250407030015-dcf7688ad54a/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.8 h1:/oUdiYtwVlqiEMSzME7vDvir49Lt23nMpaZC9u22bIo=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.8/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE=
github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=

View File

@ -3168,6 +3168,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
lb: node.lbPolicy,
enableMaterializedView: node.enableMaterializedView,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
requeryFunc: requeryImpl,
}
log := log.Ctx(ctx).With( // TODO: it might cause some cpu consumption
@ -3406,6 +3407,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
node: node,
lb: node.lbPolicy,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
requeryFunc: requeryImpl,
}
log := log.Ctx(ctx).With(

View File

@ -567,6 +567,7 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se
UseDefaultConsistency: req.GetUseDefaultConsistency(),
SearchByPrimaryKeys: false,
SubReqs: nil,
FunctionScore: req.FunctionScore,
}
for _, sub := range req.GetRequests() {

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
@ -50,6 +51,8 @@ const (
rangeFilterKey = "range_filter"
)
// type requery func(span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error)
type searchTask struct {
baseTask
Condition
@ -62,7 +65,7 @@ type searchTask struct {
tr *timerecord.TimeRecorder
collectionName string
schema *schemaInfo
requery bool
needRequery bool
partitionKeyMode bool
enableMaterializedView bool
mustUsePartitionKey bool
@ -85,14 +88,21 @@ type searchTask struct {
queryInfos []*planpb.QueryInfo
relatedDataSize int64
// Will be deprecated, use functionScore after milvus 2.6
reScorers []reScorer
rankParams *rankParams
groupScorer func(group *Group) error
// New reranker functions
functionScore *rerank.FunctionScore
rankParams *rankParams
isIterator bool
// we always remove pk field from output fields, as search result already contains pk field.
// if the user explicitly set pk field in output fields, we add it back to the result.
userRequestedPkFieldExplicitly bool
// To facilitate writing unit tests
requeryFunc func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error)
}
func (t *searchTask) CanSkipAllocTimestamp() bool {
@ -126,6 +136,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool {
func (t *searchTask) PreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PreExecute")
defer sp.End()
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
t.Base.MsgType = commonpb.MsgType_Search
t.Base.SourceID = paramtable.GetNodeID()
@ -181,16 +192,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
}
}
if t.SearchRequest.GetIsAdvanced() {
t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema)
if err != nil {
log.Info("parseRankParams failed", zap.Error(err))
return err
}
} else {
t.rankParams = nil
}
// Manually update nq if not set.
nq, err := t.checkNq(ctx)
if err != nil {
log.Info("failed to check nq", zap.Error(err))
@ -211,15 +213,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
// Currently, we get vectors by requery. Once we support getting vectors from search,
// searches with small result size could no longer need requery.
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
if t.SearchRequest.GetIsAdvanced() {
t.requery = len(t.translatedOutputFields) > 0
err = t.initAdvancedSearchRequest(ctx)
} else {
t.requery = len(vectorOutputFields) > 0
err = t.initSearchRequest(ctx)
}
if err != nil {
@ -363,8 +359,43 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
defer sp.End()
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
var err error
// TODO: Use function score uniformly to implement related logic
if t.request.FunctionScore != nil {
if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil {
log.Warn("Failed to create function score", zap.Error(err))
return err
}
} else {
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
return err
}
// set up groupScorer for hybridsearch+groupBy
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
if err != nil {
groupScorerStr = MaxScorer
}
groupScorer, err := GetGroupScorer(groupScorerStr)
if err != nil {
return err
}
t.groupScorer = groupScorer
}
t.needRequery = len(t.request.OutputFields) > 0 || len(t.functionScore.GetAllInputFieldNames()) > 0
if t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema); err != nil {
log.Error("parseRankParams failed", zap.Error(err))
return err
}
if !t.functionScore.IsSupportGroup() && t.rankParams.GetGroupByFieldId() >= 0 {
return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search")
}
// fetch search_growing from search param
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
queryFieldIDs := []int64{}
@ -418,13 +449,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
internalSubReq.PartitionIDs = t.SearchRequest.GetPartitionIDs()
}
if t.requery {
plan.OutputFieldIds = nil
plan.DynamicFields = nil
} else {
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
plan.DynamicFields = t.userDynamicFields
}
plan.OutputFieldIds = nil
plan.DynamicFields = nil
internalSubReq.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
@ -440,7 +466,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
var err error
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
defer sp.End()
@ -463,38 +488,244 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
}
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
return nil
}
func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
// Collecting the results of a subsearch
// [[shard1, shard2, ...],[shard1, shard2, ...]]
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for _, searchResult := range toReduceResults {
// if get a non-advanced result, skip all
if !searchResult.GetIsAdvanced() {
continue
}
for _, subResult := range searchResult.GetSubResults() {
// swallow copy
internalResults := &internalpb.SearchResults{
MetricType: subResult.GetMetricType(),
NumQueries: subResult.GetNumQueries(),
TopK: subResult.GetTopK(),
SlicedBlob: subResult.GetSlicedBlob(),
SlicedNumCount: subResult.GetSlicedNumCount(),
SlicedOffset: subResult.GetSlicedOffset(),
IsAdvanced: false,
}
reqIndex := subResult.GetReqIndex()
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
}
}
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index]
// Since the metrictype in the request may be empty, it can only be obtained from the result
subMetricType := getMetricType(internalResults)
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
if err != nil {
return err
}
if t.functionScore == nil {
t.reScorers[index].setMetricType(subMetricType)
t.reScorers[index].reScore(result)
}
multipleMilvusResults[index] = result
}
if t.functionScore == nil {
if err := t.rank(ctx, span, multipleMilvusResults); err != nil {
return err
}
} else {
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults); err != nil {
return err
}
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
})
t.fillResult()
return nil
}
func (t *searchTask) fillResult() {
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
resultSizeInsufficient := false
for _, topk := range t.result.Results.Topks {
if topk < limit {
resultSizeInsufficient = true
break
}
}
t.resultSizeInsufficient = resultSizeInsufficient
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()
}
// TODO: Old version rerank: rrf/weighted, subsequent unified rerank implementation
func (t *searchTask) rank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error {
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
if t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
t.rankParams,
primaryFieldSchema.GetDataType(),
multipleMilvusResults,
t.SearchRequest.GetGroupByFieldId(),
t.SearchRequest.GetGroupSize(),
t.groupScorer); err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
// set up groupScorer for hybridsearch+groupBy
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
if err != nil {
groupScorerStr = MaxScorer
if t.needRequery {
if t.requeryFunc == nil {
t.requeryFunc = requeryImpl
}
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
fields, err := t.reorganizeRequeryResults(ctx, queryResult, []*schemapb.IDs{t.result.Results.Ids})
if err != nil {
return err
}
t.result.Results.FieldsData = fields[0]
}
groupScorer, err := GetGroupScorer(groupScorerStr)
if err != nil {
return err
}
t.groupScorer = groupScorer
return nil
}
func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
uniqueIDs := &schemapb.IDs{}
count := 0
switch idsList[0].GetIdField().(type) {
case *schemapb.IDs_IntId:
idsSet := typeutil.NewSet[int64]()
for _, ids := range idsList {
if data := ids.GetIntId().GetData(); data != nil {
idsSet.Insert(data...)
}
}
count = idsSet.Len()
uniqueIDs.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: idsSet.Collect(),
},
}
case *schemapb.IDs_StrId:
idsSet := typeutil.NewSet[string]()
for _, ids := range idsList {
if data := ids.GetStrId().GetData(); data != nil {
idsSet.Insert(data...)
}
}
count = idsSet.Len()
uniqueIDs.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: idsSet.Collect(),
},
}
}
return uniqueIDs, count
}
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error {
var err error
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
// At this time, outputFields and rerank input_fields will be recalled.
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
if t.needRequery {
idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) {
return m.Results.Ids, true
})
allIDs, count := mergeIDs(idsList)
if count == 0 {
t.result = &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: t.Nq,
TopK: t.rankParams.limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
return nil
}
allNames := typeutil.NewSet[string](t.translatedOutputFields...)
allNames.Insert(t.functionScore.GetAllInputFieldNames()...)
queryResult, err := t.requeryFunc(t, span, allIDs, allNames.Collect())
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
fields, err := t.reorganizeRequeryResults(ctx, queryResult, idsList)
if err != nil {
return err
}
for i := 0; i < len(multipleMilvusResults); i++ {
multipleMilvusResults[i].Results.FieldsData = fields[i]
}
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize,
)
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
return err
}
if fields, err := t.reorganizeRequeryResults(ctx, queryResult, []*schemapb.IDs{t.result.Results.Ids}); err != nil {
return err
} else {
t.result.Results.FieldsData = fields[0]
}
} else {
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize,
)
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
return err
}
}
return nil
}
func (t *searchTask) initSearchRequest(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), t.request.GetExprTemplateValues())
if err != nil {
return err
}
if t.request.FunctionScore != nil {
// TODO: When rerank is configured, range search is also supported
if isIterator {
return merr.WrapErrParameterInvalidMsg("Range search do not support rerank")
}
if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil {
log.Warn("Failed to create function score", zap.Error(err))
return err
}
// TODO: When rerank is configured, grouping search is also supported
if !t.functionScore.IsSupportGroup() && queryInfo.GetGroupByFieldId() > 0 {
return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search")
}
}
t.isIterator = isIterator
t.SearchRequest.Offset = offset
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
@ -514,10 +745,16 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
}
}
if t.requery {
plan.OutputFieldIds = nil
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
t.needRequery = len(vectorOutputFields) > 0
if t.needRequery {
plan.OutputFieldIds = t.functionScore.GetAllInputFieldIDs()
} else {
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
allFieldIDs := typeutil.NewSet[int64](t.SearchRequest.OutputFieldsId...)
allFieldIDs.Insert(t.functionScore.GetAllInputFieldIDs()...)
plan.OutputFieldIds = allFieldIDs.Collect()
plan.DynamicFields = t.userDynamicFields
}
@ -556,6 +793,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
}
sp.AddEvent("Call-function-udf")
}
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
@ -563,6 +801,45 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
return nil
}
func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
metricType := getMetricType(toReduceResults)
result, err := t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
if err != nil {
return err
}
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize)
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err
}
} else {
t.result = result
}
t.fillResult()
if t.needRequery {
if t.requeryFunc == nil {
t.requeryFunc = requeryImpl
}
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
fields, err := t.reorganizeRequeryResults(ctx, queryResult, []*schemapb.IDs{t.result.Results.Ids})
if err != nil {
return err
}
t.result.Results.FieldsData = fields[0]
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
})
return nil
}
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
if err != nil || len(annsFieldName) == 0 {
@ -753,90 +1030,22 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
}
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
metricType := getMetricType(toReduceResults)
// reduce
if t.SearchRequest.GetIsAdvanced() {
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for _, searchResult := range toReduceResults {
// if get a non-advanced result, skip all
if !searchResult.GetIsAdvanced() {
continue
}
for _, subResult := range searchResult.GetSubResults() {
// swallow copy
internalResults := &internalpb.SearchResults{
MetricType: subResult.GetMetricType(),
NumQueries: subResult.GetNumQueries(),
TopK: subResult.GetTopK(),
SlicedBlob: subResult.GetSlicedBlob(),
SlicedNumCount: subResult.GetSlicedNumCount(),
SlicedOffset: subResult.GetSlicedOffset(),
IsAdvanced: false,
}
reqIndex := subResult.GetReqIndex()
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
}
}
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index]
subMetricType := getMetricType(internalResults)
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
if err != nil {
return err
}
t.reScorers[index].setMetricType(subMetricType)
t.reScorers[index].reScore(result)
multipleMilvusResults[index] = result
}
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
t.rankParams,
primaryFieldSchema.GetDataType(),
multipleMilvusResults,
t.SearchRequest.GetGroupByFieldId(),
t.SearchRequest.GetGroupSize(),
t.groupScorer)
if err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
} else {
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
if err != nil {
return err
}
}
// reduce done, get final result
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
resultSizeInsufficient := false
for _, topk := range t.result.Results.Topks {
if topk < limit {
resultSizeInsufficient = true
break
}
}
t.resultSizeInsufficient = resultSizeInsufficient
t.isTopkReduce = isTopkReduce
t.isRecallEvaluation = isRecallEvaluation
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()
if t.requery {
err = t.Requery(sp)
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
if t.SearchRequest.GetIsAdvanced() {
err = t.advancedPostProcess(ctx, sp, toReduceResults)
} else {
err = t.searchPostProcess(ctx, sp, toReduceResults)
}
if err != nil {
return err
}
t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName()
primaryFieldSchema, _ := t.schema.GetPkField()
if t.userRequestedPkFieldExplicitly {
t.result.Results.OutputFields = append(t.result.Results.OutputFields, primaryFieldSchema.GetName())
var scalars *schemapb.ScalarField
@ -869,7 +1078,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil {
t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{
Token: iterInfo.GetToken(),
LastBound: getLastBound(t.result, iterInfo.LastBound, metricType),
LastBound: getLastBound(t.result, iterInfo.LastBound, getMetricType(toReduceResults)),
}
}
}
@ -950,7 +1159,7 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
//return int64(sizePerRecord) * nq * topK, nil
}
func (t *searchTask) Requery(span trace.Span) error {
func requeryImpl(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
@ -961,16 +1170,16 @@ func (t *searchTask) Requery(span trace.Span) error {
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
Expr: "",
OutputFields: t.translatedOutputFields,
OutputFields: outputFields,
PartitionNames: t.request.GetPartitionNames(),
UseDefaultConsistency: false,
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
}
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
if err != nil {
return err
return nil, err
}
ids := t.result.GetResults().GetIds()
plan := planparserv2.CreateRequeryPlan(pkField, ids)
channelsMvcc := make(map[string]Timestamp)
for k, v := range t.queryChannelsTs {
@ -998,11 +1207,49 @@ func (t *searchTask) Requery(span trace.Span) error {
}
queryResult, err := t.node.(*Proxy).query(t.ctx, qt, span)
if err != nil {
return err
return nil, err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return merr.Error(queryResult.GetStatus())
return nil, merr.Error(queryResult.GetStatus())
}
return queryResult, err
}
func (t *searchTask) reorganizeRequeryResults(ctx context.Context, queryResult *milvuspb.QueryResults, idsList []*schemapb.IDs) ([][]*schemapb.FieldData, error) {
_, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults")
defer sp.End()
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
if err != nil {
return nil, err
}
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
if err != nil {
return nil, err
}
offsets := make(map[any]int)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := typeutil.GetData(pkFieldData, i)
offsets[pk] = i
}
allFieldData := make([][]*schemapb.FieldData, len(idsList))
for idx, ids := range idsList {
if ids == nil {
allFieldData[idx] = []*schemapb.FieldData{}
continue
}
if fieldData, err := t.pickFieldData(ids, offsets, queryResult); err != nil {
return nil, err
} else {
allFieldData[idx] = fieldData
}
}
return allFieldData, nil
}
// pick field data from query results
func (t *searchTask) pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, queryResult *milvuspb.QueryResults) ([]*schemapb.FieldData, error) {
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
@ -1018,44 +1265,26 @@ func (t *searchTask) Requery(span trace.Span) error {
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
_, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults")
defer sp.End()
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
if err != nil {
return err
}
offsets := make(map[any]int)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := typeutil.GetData(pkFieldData, i)
offsets[pk] = i
}
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
fieldsData := make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := offsets[id]; !ok {
return merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %v, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID()))
if _, ok := pkOffset[id]; !ok {
return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(pkOffset), t.GetCollectionID()))
}
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
typeutil.AppendFieldData(fieldsData, queryResult.GetFieldsData(), int64(pkOffset[id]))
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
})
return nil
return fieldsData, nil
}
func (t *searchTask) fillInFieldInfo() {
if len(t.translatedOutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
for i, name := range t.translatedOutputFields {
for _, field := range t.schema.Fields {
if t.result.Results.FieldsData[i] != nil && field.Name == name {
t.result.Results.FieldsData[i].FieldName = field.Name
t.result.Results.FieldsData[i].FieldId = field.FieldID
t.result.Results.FieldsData[i].Type = field.DataType
t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic
}
for _, retField := range t.result.Results.FieldsData {
for _, schemaField := range t.schema.Fields {
if retField != nil && retField.FieldId == schemaField.FieldID {
retField.FieldName = schemaField.Name
retField.Type = schemaField.DataType
retField.IsDynamic = schemaField.IsDynamic
}
}
}

View File

@ -19,6 +19,7 @@ import (
"context"
"fmt"
"math"
"slices"
"strconv"
"strings"
"testing"
@ -31,6 +32,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
@ -50,6 +52,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -218,6 +221,375 @@ func TestSearchTask_PostExecute(t *testing.T) {
}
})
})
getSearchTaskWithRerank := func(t *testing.T, collName string, funcInput string) *searchTask {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{funcInput},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
task := &searchTask{
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{
IsTopkReduce: true,
},
request: &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
SearchParams: getBaseSearchParams(),
FunctionScore: &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
},
},
mixCoord: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("Test empty result with rerank", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collName := "test_collection_empty_result_with_rerank" + funcutil.GenRandomStr()
createCollWithFields(t, collName, qc)
qt := getSearchTaskWithRerank(t, collName, testFloatField)
err = qt.PreExecute(ctx)
assert.NoError(t, err)
assert.NotNil(t, qt.resultBuf)
qt.resultBuf.Insert(&internalpb.SearchResults{})
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, qt.resultSizeInsufficient, true)
assert.Equal(t, qt.isTopkReduce, false)
})
t.Run("Test search rerank", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collName := "test_collection_empty_result_with_rerank" + funcutil.GenRandomStr()
_, fieldNameId := createCollWithFields(t, collName, qc)
qt := getSearchTaskWithRerank(t, collName, testFloatField)
err = qt.PreExecute(ctx)
assert.NoError(t, err)
assert.NotNil(t, qt.resultBuf)
qt.resultBuf.Insert(genTestSearchResultData(1, 10, schemapb.DataType_Float, testFloatField, fieldNameId[testFloatField], false))
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, []int64{10}, qt.result.Results.Topks)
assert.Equal(t, int64(10), qt.result.Results.TopK)
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, qt.result.Results.Ids.GetIntId().Data)
})
getHybridSearchTaskWithRerank := func(t *testing.T, collName string, funcInput string, data [][]string) *searchTask {
subReqs := []*milvuspb.SubSearchRequest{}
for _, item := range data {
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_VarChar,
Values: lo.Map(item, func(str string, _ int) []byte { return []byte(str) }),
}
holder := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
}
holderByte, _ := proto.Marshal(holder)
subReq := &milvuspb.SubSearchRequest{
PlaceholderGroup: holderByte,
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: testFloatVecField},
{Key: TopKKey, Value: "10"},
},
Nq: int64(len(item)),
}
subReqs = append(subReqs, subReq)
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{funcInput},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
task := &searchTask{
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
},
request: &milvuspb.SearchRequest{
CollectionName: collName,
SubReqs: subReqs,
SearchParams: []*commonpb.KeyValuePair{
{Key: LimitKey, Value: "10"},
},
FunctionScore: &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
},
OutputFields: []string{testInt32Field},
},
mixCoord: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("Test hybridsearch all empty result with rerank", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collName := "test_collection_empty_result_with_rerank" + funcutil.GenRandomStr()
createCollWithFields(t, collName, qc)
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence"}, {"sentence"}})
err = qt.PreExecute(ctx)
assert.NoError(t, err)
assert.NotNil(t, qt.resultBuf)
qt.resultBuf.Insert(&internalpb.SearchResults{})
qt.resultBuf.Insert(&internalpb.SearchResults{})
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
})
t.Run("Test hybridsearch search rerank with empty", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collName := "test_hybridsearch_rerank_with_empty" + funcutil.GenRandomStr()
_, fieldNameId := createCollWithFields(t, collName, qc)
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence", "sentence"}, {"sentence", "sentence"}})
err = qt.PreExecute(ctx)
assert.Equal(t, qt.Nq, int64(2))
assert.NoError(t, err)
assert.NotNil(t, qt.resultBuf)
// All data are from the same subsearch
qt.resultBuf.Insert(genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true))
qt.resultBuf.Insert(genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true))
// rerank inputs
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
f1.FieldId = fieldNameId[testFloatField]
// search output field
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int32, testInt32Field, 20)
f2.FieldId = fieldNameId[testInt32Field]
// pk
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
f3.FieldId = fieldNameId[testInt64Field]
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3},
PrimaryFieldName: testInt64Field,
}, nil
}
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
assert.Equal(t, int64(10), qt.result.Results.TopK)
assert.Equal(t, int64(2), qt.result.Results.NumQueries)
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, qt.result.Results.Ids.GetIntId().Data)
assert.Equal(t, testInt32Field, qt.result.Results.FieldsData[0].FieldName)
})
t.Run("Test hybridsearch search rerank ", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collName := "test_hybridsearch_result_with_rerank" + funcutil.GenRandomStr()
_, fieldNameId := createCollWithFields(t, collName, qc)
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence", "sentence"}, {"sentence", "sentence"}})
err = qt.PreExecute(ctx)
assert.Equal(t, qt.Nq, int64(2))
assert.NoError(t, err)
assert.NotNil(t, qt.resultBuf)
data1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
data1.SubResults[0].ReqIndex = 0
data2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
data1.SubResults[0].ReqIndex = 2
qt.resultBuf.Insert(data2)
// rerank inputs
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
f1.FieldId = fieldNameId[testFloatField]
// search output field
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int32, testInt32Field, 20)
f2.FieldId = fieldNameId[testInt32Field]
// pk
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
f3.FieldId = fieldNameId[testInt64Field]
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3},
PrimaryFieldName: testInt64Field,
}, nil
}
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
assert.Equal(t, int64(10), qt.result.Results.TopK)
assert.Equal(t, int64(2), qt.result.Results.NumQueries)
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, qt.result.Results.Ids.GetIntId().Data)
assert.Equal(t, testInt32Field, qt.result.Results.FieldsData[0].FieldName)
})
// rrf/weigted rank
t.Run("Test rank function", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
collName := "test_rank_function" + funcutil.GenRandomStr()
_, fieldNameId := createCollWithFields(t, collName, qc)
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence", "sentence"}, {"sentence", "sentence"}})
qt.request.FunctionScore = nil
qt.request.SearchParams = []*commonpb.KeyValuePair{{Key: "limit", Value: "10"}}
qt.request.OutputFields = []string{"*"}
err = qt.PreExecute(ctx)
assert.NoError(t, err)
data1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
data1.SubResults[0].ReqIndex = 0
data2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
data1.SubResults[0].ReqIndex = 2
qt.resultBuf.Insert(data2)
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int32, testInt32Field, 20)
f1.FieldId = fieldNameId[testInt32Field]
f2 := testutils.GenerateVectorFieldData(schemapb.DataType_FloatVector, testFloatVecField, 20, testVecDim)
f2.FieldId = fieldNameId[testFloatVecField]
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
f3.FieldId = fieldNameId[testInt64Field]
f4 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
f4.FieldId = fieldNameId[testFloatField]
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{f1, f2, f3, f4},
PrimaryFieldName: testInt64Field,
}, nil
}
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
assert.Equal(t, int64(10), qt.result.Results.TopK)
assert.Equal(t, int64(2), qt.result.Results.NumQueries)
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, qt.result.Results.Ids.GetIntId().Data)
for _, field := range qt.result.Results.FieldsData {
switch field.FieldName {
case testInt32Field:
assert.True(t, len(field.GetScalars().GetIntData().Data) != 0)
case testBoolField:
assert.True(t, len(field.GetScalars().GetBoolData().Data) != 0)
case testFloatField:
assert.True(t, len(field.GetScalars().GetFloatData().Data) != 0)
case testFloatVecField:
assert.True(t, len(field.GetVectors().GetFloatVector().Data) != 0)
case testInt64Field:
assert.True(t, len(field.GetScalars().GetLongData().Data) != 0)
}
}
})
t.Run("Test mergeIDs function", func(t *testing.T) {
{
ids1 := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 5},
},
},
}
ids2 := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 4, 5, 6},
},
},
}
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
assert.Equal(t, count, 6)
sortedIds := allIDs.GetIntId().GetData()
slices.Sort(sortedIds)
assert.Equal(t, sortedIds, []int64{1, 2, 3, 4, 5, 6})
}
{
ids1 := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "e"},
},
},
}
ids2 := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "c", "d"},
},
},
}
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
assert.Equal(t, count, 5)
sortedIds := allIDs.GetStrId().GetData()
slices.Sort(sortedIds)
assert.Equal(t, sortedIds, []string{"a", "b", "c", "d", "e"})
}
})
}
func createCollWithFields(t *testing.T, collName string, rc types.MixCoordClient) (*schemapb.CollectionSchema, map[string]int64) {
fieldName2Types := map[string]schemapb.DataType{
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testFloatVecField: schemapb.DataType_FloatVector,
testInt32Field: schemapb.DataType_Int32,
testBoolField: schemapb.DataType_Bool,
}
schema := constructCollectionSchemaByDataType(collName, fieldName2Types, testInt64Field, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
ctx := context.TODO()
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: collName,
Schema: marshaledSchema,
ShardsNum: 1,
},
ctx: ctx,
mixCoord: rc,
}
require.NoError(t, createColT.OnEnqueue())
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
_, err = globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collName)
assert.NoError(t, err)
fieldNameId := make(map[string]int64)
for _, field := range schema.Fields {
fieldNameId[field.Name] = field.FieldID
}
return schema, fieldNameId
}
func createColl(t *testing.T, name string, rc types.MixCoordClient) *schemapb.CollectionSchema {
@ -349,6 +721,39 @@ func TestSearchTask_PreExecute(t *testing.T) {
return task
}
getSearchTaskWithRerank := func(t *testing.T, collName string, funcInput string) *searchTask {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{funcInput},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
task := &searchTask{
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
SubReqs: []*milvuspb.SubSearchRequest{},
FunctionScore: &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
},
},
mixCoord: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("bad nq 0", func(t *testing.T) {
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
createColl(t, collName, qc)
@ -508,6 +913,64 @@ func TestSearchTask_PreExecute(t *testing.T) {
assert.NoError(t, st.PreExecute(ctx))
assert.Equal(t, collInfo.updateTimestamp, st.SearchRequest.GuaranteeTimestamp)
})
t.Run("search with rerank", func(t *testing.T) {
collName := "search_with_rerank" + funcutil.GenRandomStr()
createCollWithFields(t, collName, qc)
st := getSearchTaskWithRerank(t, collName, testFloatField)
st.request.SearchParams = getValidSearchParams()
st.request.DslType = commonpb.DslType_BoolExprV1
_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.NoError(t, st.PreExecute(ctx))
assert.NotNil(t, st.functionScore)
assert.Equal(t, false, st.SearchRequest.GetIsAdvanced())
})
t.Run("advance search with rerank", func(t *testing.T) {
collName := "search_with_rerank" + funcutil.GenRandomStr()
createCollWithFields(t, collName, qc)
st := getSearchTaskWithRerank(t, collName, testFloatField)
st.request.SearchParams = getValidSearchParams()
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: LimitKey,
Value: "10",
})
st.request.DslType = commonpb.DslType_BoolExprV1
st.request.SubReqs = append(st.request.SubReqs, &milvuspb.SubSearchRequest{Nq: 1})
st.request.SubReqs = append(st.request.SubReqs, &milvuspb.SubSearchRequest{Nq: 1})
_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.NoError(t, st.PreExecute(ctx))
assert.NotNil(t, st.functionScore)
assert.Equal(t, true, st.SearchRequest.GetIsAdvanced())
})
t.Run("search with rerank grouping", func(t *testing.T) {
collName := "search_with_rerank" + funcutil.GenRandomStr()
createCollWithFields(t, collName, qc)
st := getSearchTaskWithRerank(t, collName, testFloatField)
st.request.SearchParams = getValidSearchParams()
st.request.DslType = commonpb.DslType_BoolExprV1
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: testInt64Field,
})
_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.ErrorContains(t, st.PreExecute(ctx), "Current rerank does not support grouping search")
})
}
func TestSearchTask_WithFunctions(t *testing.T) {
@ -538,6 +1001,9 @@ func TestSearchTask_WithFunctions(t *testing.T) {
{Key: "dim", Value: "4"},
},
},
{
FieldID: 104, Name: "ts", DataType: schemapb.DataType_Int64,
},
},
Functions: []*schemapb.FunctionSchema{
{
@ -584,7 +1050,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
err = InitMetaCache(ctx, qc, mgr)
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string, data []string) *searchTask {
getSearchTask := func(t *testing.T, collName string, data []string, withRerank bool) *searchTask {
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_VarChar,
@ -594,6 +1060,21 @@ func TestSearchTask_WithFunctions(t *testing.T) {
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
}
holderByte, _ := proto.Marshal(holder)
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
Params: []*commonpb.KeyValuePair{
{Key: "reranker", Value: "decay"},
{Key: "origin", Value: "4"},
{Key: "scale", Value: "4"},
{Key: "offset", Value: "4"},
{Key: "decay", Value: "0.5"},
{Key: "function", Value: "gauss"},
},
}
task := &searchTask{
ctx: ctx,
collectionName: collectionName,
@ -615,6 +1096,11 @@ func TestSearchTask_WithFunctions(t *testing.T) {
mixCoord: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
if withRerank {
task.request.FunctionScore = &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
}
}
require.NoError(t, task.OnEnqueue())
return task
}
@ -631,7 +1117,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
globalMetaCache = cache
{
task := getSearchTask(t, collectionName, []string{"sentence"})
task := getSearchTask(t, collectionName, []string{"sentence"}, false)
err = task.PreExecute(ctx)
assert.NoError(t, err)
pb := &commonpb.PlaceholderGroup{}
@ -642,7 +1128,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
}
{
task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"})
task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"}, false)
err = task.PreExecute(ctx)
assert.NoError(t, err)
pb := &commonpb.PlaceholderGroup{}
@ -654,7 +1140,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
// process failed
{
task := getSearchTask(t, collectionName, []string{"sentence"})
task := getSearchTask(t, collectionName, []string{"sentence"}, false)
task.request.Nq = 10000
err = task.PreExecute(ctx)
assert.Error(t, err)
@ -3167,11 +3653,11 @@ func TestSearchTask_Requery(t *testing.T) {
tr: timerecord.NewTimeRecorder("search"),
node: node,
translatedOutputFields: outputFields,
requeryFunc: requeryImpl,
}
err := qt.Requery(nil)
queryResult, err := qt.requeryFunc(qt, nil, qt.result.Results.Ids, outputFields)
assert.NoError(t, err)
assert.Len(t, qt.result.Results.FieldsData, 2)
assert.Len(t, queryResult.FieldsData, 2)
for _, field := range qt.result.Results.FieldsData {
fieldName := field.GetFieldName()
assert.Contains(t, []string{pkField, vecField}, fieldName)
@ -3192,13 +3678,14 @@ func TestSearchTask_Requery(t *testing.T) {
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
requeryFunc: requeryImpl,
}
err := qt.Requery(nil)
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
t.Logf("err = %s", err)
assert.Error(t, err)
})
@ -3227,12 +3714,13 @@ func TestSearchTask_Requery(t *testing.T) {
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
requeryFunc: requeryImpl,
}
err := qt.Requery(nil)
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
t.Logf("err = %s", err)
assert.Error(t, err)
})
@ -3274,11 +3762,11 @@ func TestSearchTask_Requery(t *testing.T) {
Ids: resultIDs,
},
},
requery: true,
schema: schema,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
tr: timerecord.NewTimeRecorder("search"),
node: node,
needRequery: true,
schema: schema,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
scores := make([]float32, rows)
for i := range scores {
@ -3712,3 +4200,64 @@ func (s *MaterializedViewTestSuite) TestMvEnabledPartitionKeyOnVarCharWithIsolat
func TestMaterializedView(t *testing.T) {
suite.Run(t, new(MaterializedViewTestSuite))
}
func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64, IsAdvanced bool) *internalpb.SearchResults {
result := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
MetricType: "COSINE",
NumQueries: nq,
TopK: topk,
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
IsAdvanced: IsAdvanced,
}
tops := make([]int64, nq)
for i := 0; i < int(nq); i++ {
tops[i] = topk
}
resultData := &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: testutils.GenerateInt64Array(int(nq * topk)),
},
},
},
Topks: tops,
FieldsData: []*schemapb.FieldData{
testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk)),
},
}
resultData.FieldsData[0].FieldId = fieldId
sliceBlob, _ := proto.Marshal(resultData)
if !IsAdvanced {
result.SlicedBlob = sliceBlob
} else {
result.SubResults = []*internalpb.SubSearchResults{
{
SlicedBlob: sliceBlob,
SlicedNumCount: 1,
SlicedOffset: 0,
},
}
}
return result
}

View File

@ -164,7 +164,7 @@ func parseAKAndURL(params []*commonpb.KeyValuePair, confParams map[string]string
// from env, url doesn't support configuration in in env
if apiKey == "" {
url = os.Getenv(apiKeyEnv)
apiKey = os.Getenv(apiKeyEnv)
}
return apiKey, url
}

View File

@ -0,0 +1,316 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"fmt"
"math"
"sort"
"strconv"
"strings"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
originKey string = "origin"
scaleKey string = "scale"
offsetKey string = "offset"
decayKey string = "decay"
functionKey string = "function"
)
const (
gaussFunction string = "gauss"
linerFunction string = "liner"
expFunction string = "exp"
)
type DecayFunction[T int64 | string, R int32 | int64 | float32 | float64] struct {
RerankBase
functionName string
origin float64
scale float64
offset float64
decay float64
reScorer decayReScorer
}
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
pkType := schemapb.DataType_None
for _, field := range collSchema.Fields {
if field.IsPrimaryKey {
pkType = field.DataType
}
}
if pkType == schemapb.DataType_None {
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
}
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
if err != nil {
return nil, err
}
if len(base.GetInputFieldNames()) != 1 {
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
}
var inputType schemapb.DataType
for _, field := range collSchema.Fields {
if field.Name == base.GetInputFieldNames()[0] {
inputType = field.DataType
}
}
if pkType == schemapb.DataType_Int64 {
switch inputType {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return newFunction[int64, int32](base, funcSchema)
case schemapb.DataType_Int64:
return newFunction[int64, int64](base, funcSchema)
case schemapb.DataType_Float:
return newFunction[int64, float32](base, funcSchema)
case schemapb.DataType_Double:
return newFunction[int64, float64](base, funcSchema)
default:
return nil, fmt.Errorf("Decay rerank: unsupported input field type:%s, only support numberic field", inputType.String())
}
} else {
switch inputType {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return newFunction[string, int32](base, funcSchema)
case schemapb.DataType_Int64:
return newFunction[string, int64](base, funcSchema)
case schemapb.DataType_Float:
return newFunction[string, float32](base, funcSchema)
case schemapb.DataType_Double:
return newFunction[string, float64](base, funcSchema)
default:
return nil, fmt.Errorf("Decay rerank: unsupported input field type:%s, only support numberic field", inputType.String())
}
}
}
// T: PK Type, R: field type
func newFunction[T int64 | string, R int32 | int64 | float32 | float64](base *RerankBase, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
var err error
decayFunc := &DecayFunction[T, R]{RerankBase: *base, offset: 0, decay: 0.5}
orginInit := false
scaleInit := false
for _, param := range funcSchema.Params {
switch strings.ToLower(param.Key) {
case functionKey:
decayFunc.functionName = param.Value
case originKey:
if decayFunc.origin, err = strconv.ParseFloat(param.Value, 64); err != nil {
return nil, fmt.Errorf("Param origin:%s is not a number", param.Value)
}
orginInit = true
case scaleKey:
if decayFunc.scale, err = strconv.ParseFloat(param.Value, 64); err != nil {
return nil, fmt.Errorf("Param scale:%s is not a number", param.Value)
}
scaleInit = true
case offsetKey:
if decayFunc.offset, err = strconv.ParseFloat(param.Value, 64); err != nil {
return nil, fmt.Errorf("Param offset:%s is not a number", param.Value)
}
case decayKey:
if decayFunc.decay, err = strconv.ParseFloat(param.Value, 64); err != nil {
return nil, fmt.Errorf("Param decay:%s is not a number", param.Value)
}
default:
}
}
if !orginInit {
return nil, fmt.Errorf("Decay function lost param: origin")
}
if !scaleInit {
return nil, fmt.Errorf("Decay function lost param: scale")
}
if decayFunc.scale <= 0 {
return nil, fmt.Errorf("Decay function param: scale must > 0, but got %f", decayFunc.scale)
}
if decayFunc.offset < 0 {
return nil, fmt.Errorf("Decay function param: offset must => 0, but got %f", decayFunc.offset)
}
if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1 0, but got %f", decayFunc.offset)
}
switch decayFunc.functionName {
case gaussFunction:
decayFunc.reScorer = gaussianDecay
case expFunction:
decayFunc.reScorer = expDecay
case linerFunction:
decayFunc.reScorer = linearDecay
default:
return nil, fmt.Errorf("Invaild decay function: %s, only support [%s,%s,%s]", decayFunctionName, gaussFunction, linerFunction, expFunction)
}
return decayFunc, nil
}
func (decay *DecayFunction[T, R]) reScore(multipSearchResultData []*schemapb.SearchResultData) (*idSocres[T], error) {
newScores := newIdScores[T]()
for _, data := range multipSearchResultData {
var inputField *schemapb.FieldData
for _, field := range data.FieldsData {
if field.FieldId == decay.GetInputFieldIDs()[0] {
inputField = field
}
}
if inputField == nil {
return nil, fmt.Errorf("Rerank decay function can not find input field, name: %s", decay.GetInputFieldNames()[0])
}
var inputValues *numberField[R]
if tmp, err := getNumberic(inputField); err != nil {
return nil, err
} else {
inputValues = tmp.(*numberField[R])
}
ids := newMilvusIDs(data.Ids, decay.pkType).(milvusIDs[T])
for idx, id := range ids.data {
if !newScores.exist(id) {
if v, err := inputValues.GetFloat64(idx); err != nil {
return nil, err
} else {
newScores.set(id, float32(decay.reScorer(decay.origin, decay.scale, decay.decay, decay.offset, v)))
}
}
}
}
return newScores, nil
}
func (decay *DecayFunction[T, R]) orgnizeNqScores(searchParams *SearchParams, multipSearchResultData []*schemapb.SearchResultData, idScoreData *idSocres[T]) []map[T]float32 {
nqScores := make([]map[T]float32, searchParams.nq)
for i := int64(0); i < searchParams.nq; i++ {
nqScores[i] = make(map[T]float32)
}
for _, data := range multipSearchResultData {
start := int64(0)
for nqIdx := int64(0); nqIdx < searchParams.nq; nqIdx++ {
realTopk := data.Topks[nqIdx]
for j := start; j < start+realTopk; j++ {
id := typeutil.GetPK(data.GetIds(), j).(T)
if _, exists := nqScores[nqIdx][id]; !exists {
nqScores[nqIdx][id] = idScoreData.get(id)
}
}
start += realTopk
}
}
return nqScores
}
func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, multipSearchResultData []*schemapb.SearchResultData) (*schemapb.SearchResultData, error) {
ret := &schemapb.SearchResultData{
NumQueries: searchParams.nq,
TopK: searchParams.limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
}
multipSearchResultData = lo.Filter(multipSearchResultData, func(searchResult *schemapb.SearchResultData, i int) bool {
return len(searchResult.FieldsData) != 0
})
if len(multipSearchResultData) == 0 {
return ret, nil
}
idScoreData, err := decay.reScore(multipSearchResultData)
if err != nil {
return nil, err
}
nqScores := decay.orgnizeNqScores(searchParams, multipSearchResultData, idScoreData)
topk := searchParams.limit + searchParams.offset
for i := int64(0); i < searchParams.nq; i++ {
idScoreMap := nqScores[i]
ids := make([]T, 0)
for id := range idScoreMap {
ids = append(ids, id)
}
big := func(i, j int) bool {
if idScoreMap[ids[i]] == idScoreMap[ids[j]] {
return ids[i] < ids[j]
}
return idScoreMap[ids[i]] > idScoreMap[ids[j]]
}
sort.Slice(ids, big)
if int64(len(ids)) > topk {
ids = ids[:topk]
}
// set real topk
ret.Topks = append(ret.Topks, max(0, int64(len(ids))-searchParams.offset))
// append id and score
for index := searchParams.offset; index < int64(len(ids)); index++ {
typeutil.AppendPKs(ret.Ids, ids[index])
score := idScoreMap[ids[index]]
if searchParams.roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(searchParams.roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Scores = append(ret.Scores, score)
}
}
return ret, nil
}
type decayReScorer func(float64, float64, float64, float64, float64) float64
func gaussianDecay(origin, scale, decay, offset, distance float64) float64 {
adjustedDist := math.Max(0, math.Abs(distance-origin)-offset)
sigmaSquare := 0.5 * math.Pow(scale, 2.0) / math.Log(decay)
exponent := math.Pow(adjustedDist, 2.0) / sigmaSquare
return math.Exp(exponent)
}
func expDecay(origin, scale, decay, offset, distance float64) float64 {
adjustedDist := math.Max(0, math.Abs(distance-origin)-offset)
lambda := math.Log(decay) / scale
return math.Exp(lambda * adjustedDist)
}
func linearDecay(origin, scale, decay, offset, distance float64) float64 {
adjustedDist := math.Max(0, math.Abs(distance-origin)-offset)
slope := (1 - decay) / scale
return math.Max(decay, 1-slope*adjustedDist)
}

View File

@ -0,0 +1,434 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
)
func TestDecayFunction(t *testing.T) {
suite.Run(t, new(DecayFunctionSuite))
}
type DecayFunctionSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *DecayFunctionSuite) TestNewDecayErrors() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
OutputFieldNames: []string{"text"},
Params: []*commonpb.KeyValuePair{
{Key: originKey, Value: "4"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "4"},
{Key: decayKey, Value: "0.5"},
{Key: functionKey, Value: "gauss"},
},
}
{
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Rerank function output field names should be empty")
}
{
functionSchema.OutputFieldNames = []string{}
functionSchema.InputFieldNames = []string{""}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Rerank input field name cannot be empty string")
}
{
functionSchema.InputFieldNames = []string{"ts", "ts"}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Each function input field should be used exactly once in the same function")
}
{
functionSchema.InputFieldNames = []string{"ts", "pk"}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
}
{
functionSchema.InputFieldNames = []string{"notExists"}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Function input field not found:")
}
{
functionSchema.InputFieldNames = []string{"vector"}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay rerank: unsupported input field type")
}
{
functionSchema.InputFieldNames = []string{"ts"}
_, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
}
{
for i := 0; i < 4; i++ {
functionSchema.Params[i].Value = "NotNum"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "is not a number")
functionSchema.Params[i].Value = "0.9"
}
}
{
fs := []string{gaussFunction, linerFunction, expFunction}
for i := 0; i < 3; i++ {
functionSchema.Params[4].Value = fs[i]
_, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
}
functionSchema.Params[4].Value = "NotExist"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Invaild decay function:")
functionSchema.Params[4].Value = "exp"
}
{
newSchema := proto.Clone(schema).(*schemapb.CollectionSchema)
newSchema.Fields[0].IsPrimaryKey = false
_, err := newDecayFunction(newSchema, functionSchema)
s.ErrorContains(err, " can not found pk field")
}
}
func (s *DecayFunctionSuite) TestAllTypesInput() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
Params: []*commonpb.KeyValuePair{
{Key: functionKey, Value: "gauss"},
{Key: originKey, Value: "4"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "4"},
{Key: decayKey, Value: "0.5"},
},
}
inputTypes := []schemapb.DataType{schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8, schemapb.DataType_Float, schemapb.DataType_Double, schemapb.DataType_Bool}
for i, inputType := range inputTypes {
schema.Fields[3].DataType = inputType
_, err := newDecayFunction(schema, functionSchema)
if i < len(inputTypes)-1 {
s.NoError(err)
} else {
s.ErrorContains(err, "Decay rerank: unsupported input field type")
}
}
schema.Fields[0].DataType = schemapb.DataType_String
for i, inputType := range inputTypes {
schema.Fields[3].DataType = inputType
_, err := newDecayFunction(schema, functionSchema)
if i < len(inputTypes)-1 {
s.NoError(err)
} else {
s.ErrorContains(err, "Decay rerank: unsupported input field type")
}
}
schema.Fields[3].DataType = schemapb.DataType_Double
{
functionSchema.Params[1].Key = "N"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function lost param: origin")
functionSchema.Params[1].Key = originKey
}
{
functionSchema.Params[2].Key = "N"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function lost param: scale")
functionSchema.Params[2].Key = scaleKey
}
{
functionSchema.Params[2].Value = "-1"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function param: scale must > 0,")
functionSchema.Params[2].Value = "0.5"
}
{
functionSchema.Params[3].Value = "-1"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function param: offset must => 0")
functionSchema.Params[3].Value = "0.5"
}
{
functionSchema.Params[4].Value = "10"
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function param: decay must 0 < decay < 1 0")
functionSchema.Params[2].Value = "0.5"
}
}
func (s *DecayFunctionSuite) TestRerankProcess() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
Params: []*commonpb.KeyValuePair{
{Key: functionKey, Value: "gauss"},
{Key: originKey, Value: "0"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "2"},
},
}
// empty
{
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{})
s.NoError(err)
s.Equal(int64(3), ret.TopK)
s.Equal([]int64{}, ret.Topks)
}
// no input field exist
{
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
s.NoError(err)
_, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data})
s.ErrorContains(err, "Rerank decay function can not find input field, name")
}
// singleSearchResultData
// nq = 1
{
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data})
s.NoError(err)
s.Equal([]int64{3}, ret.Topks)
s.Equal(int64(3), ret.TopK)
s.Equal([]int64{2, 3, 4}, ret.Ids.GetIntId().Data)
}
// nq = 3
{
nq := int64(3)
f, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data})
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.Topks)
s.Equal(int64(3), ret.TopK)
s.Equal([]int64{2, 3, 4, 12, 13, 14, 22, 23, 24}, ret.Ids.GetIntId().Data)
}
// multipSearchResultData
// nq = 1
functionSchema2 := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
Params: []*commonpb.KeyValuePair{
{Key: functionKey, Value: "gauss"},
{Key: originKey, Value: "5"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "2"},
},
}
{
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema2)
s.NoError(err)
// ts/id data: 0 - 9
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
// ts/id data: 0 - 3
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data1, data2})
s.NoError(err)
s.Equal([]int64{3}, ret.Topks)
s.Equal(int64(3), ret.TopK)
s.Equal([]int64{5, 6, 7}, ret.Ids.GetIntId().Data)
}
// nq = 3
{
nq := int64(3)
f, err := newDecayFunction(schema, functionSchema2)
s.NoError(err)
// nq1 ts/id data: 0 - 9
// nq2 ts/id data: 10 - 19
// nq3 ts/id data: 20 - 29
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
// nq1 ts/id data: 0 - 3
// nq2 ts/id data: 4 - 7
// nq3 ts/id data: 8 - 11
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false}, []*schemapb.SearchResultData{data1, data2})
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.Topks)
s.Equal(int64(3), ret.TopK)
s.Equal([]int64{5, 6, 7, 6, 7, 10, 10, 11, 20}, ret.Ids.GetIntId().Data)
}
}
func (s *DecayFunctionSuite) TestDecay() {
s.Equal(gaussianDecay(0, 1, 0.5, 5, 4), 1.0)
s.Equal(gaussianDecay(0, 1, 0.5, 5, 5), 1.0)
s.Less(gaussianDecay(0, 1, 0.5, 5, 6), 1.0)
s.Equal(expDecay(0, 1, 0.5, 5, 4), 1.0)
s.Equal(expDecay(0, 1, 0.5, 5, 5), 1.0)
s.Less(expDecay(0, 1, 0.5, 5, 6), 1.0)
s.Equal(linearDecay(0, 1, 0.5, 5, 4), 1.0)
s.Equal(linearDecay(0, 1, 0.5, 5, 5), 1.0)
s.Less(linearDecay(0, 1, 0.5, 5, 6), 1.0)
}
func (s *DecayFunctionSuite) TestUtil() {
inputTypes := []schemapb.DataType{schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8, schemapb.DataType_Float, schemapb.DataType_Double}
for _, tp := range inputTypes {
field := genSearchResultData(2, 10, tp, "test", 100)
num, err := getNumberic(field.FieldsData[0])
s.NoError(err)
switch tp {
case schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8:
s.True(len(num.(*numberField[int32]).data) == 20)
case schemapb.DataType_Int64:
s.True(len(num.(*numberField[int64]).data) == 20)
case schemapb.DataType_Float:
s.True(len(num.(*numberField[float32]).data) == 20)
case schemapb.DataType_Double:
s.True(len(num.(*numberField[float64]).data) == 20)
}
}
field := genSearchResultData(2, 10, schemapb.DataType_Bool, "test", 100)
_, err := getNumberic(field.FieldsData[0])
s.ErrorContains(err, "only support numberic field")
{
ids := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: testutils.GenerateInt64Array(10),
},
},
}
mid := newMilvusIDs(ids, schemapb.DataType_Int64).(milvusIDs[int64])
s.True(len(mid.data) == 10)
}
{
ids := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: testutils.GenerateStringArray(10),
},
},
}
mid := newMilvusIDs(ids, schemapb.DataType_String).(milvusIDs[string])
s.True(len(mid.data) == 10)
}
}
func genSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData {
tops := make([]int64, nq)
for i := 0; i < int(nq); i++ {
tops[i] = topk
}
data := &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: testutils.GenerateInt64Array(int(nq * topk)),
},
},
},
Topks: tops,
FieldsData: []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))},
}
data.FieldsData[0].FieldId = fieldId
return data
}

View File

@ -0,0 +1,166 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"fmt"
"strings"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
const (
decayFunctionName string = "decay"
)
type SearchParams struct {
nq int64
limit int64
offset int64
roundDecimal int64
// TODO: supports group search
groupByFieldId int64
groupSize int64
strictGroupSize bool
}
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool) *SearchParams {
return &SearchParams{
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize,
}
}
type Reranker interface {
Process(ctx context.Context, searchParams *SearchParams, searchData []*schemapb.SearchResultData) (*schemapb.SearchResultData, error)
IsSupportGroup() bool
GetInputFieldNames() []string
GetInputFieldIDs() []int64
GetRankName() string
}
func getRerankName(funcSchema *schemapb.FunctionSchema) string {
for _, param := range funcSchema.Params {
switch strings.ToLower(param.Key) {
case reranker:
return strings.ToLower(param.Value)
default:
}
}
return ""
}
// Currently only supports single rerank
type FunctionScore struct {
reranker Reranker
}
func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
if funcSchema.GetType() != schemapb.FunctionType_Rerank {
return nil, fmt.Errorf("%s is not rerank function.", funcSchema.GetType().String())
}
if len(funcSchema.GetOutputFieldNames()) != 0 {
return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames()))
}
rerankerName := getRerankName(funcSchema)
var rerankFunc Reranker
var newRerankErr error
switch rerankerName {
case decayFunctionName:
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
default:
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
}
if newRerankErr != nil {
return nil, newRerankErr
}
return rerankFunc, nil
}
func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) {
if len(funcScoreSchema.Functions) > 1 || len(funcScoreSchema.Functions) == 0 {
return nil, fmt.Errorf("Currently only supports one rerank, but got %d", len(funcScoreSchema.Functions))
}
funcScore := &FunctionScore{}
var err error
if funcScore.reranker, err = createFunction(collSchema, funcScoreSchema.Functions[0]); err != nil {
return nil, err
}
return funcScore, nil
}
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
if len(multipleMilvusResults) == 0 {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: searchParams.nq,
TopK: searchParams.limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}, nil
}
allSearchResultData := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.SearchResultData, bool) {
return m.Results, true
})
// rankResult only has scores
rankResult, err := fScore.reranker.Process(ctx, searchParams, allSearchResultData)
if err != nil {
return nil, err
}
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: rankResult,
}
return ret, nil
}
func (fScore *FunctionScore) GetAllInputFieldNames() []string {
if fScore == nil {
return []string{}
}
return fScore.reranker.GetInputFieldNames()
}
func (fScore *FunctionScore) GetAllInputFieldIDs() []int64 {
if fScore == nil {
return []int64{}
}
return fScore.reranker.GetInputFieldIDs()
}
func (fScore *FunctionScore) IsSupportGroup() bool {
if fScore == nil {
return true
}
return fScore.reranker.IsSupportGroup()
}

View File

@ -0,0 +1,302 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func TestFunctionScore(t *testing.T) {
suite.Run(t, new(FunctionScoreSuite))
}
type FunctionScoreSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *FunctionScoreSuite) TestNewFunctionScore() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
Params: []*commonpb.KeyValuePair{
{Key: reranker, Value: decayFunctionName},
{Key: originKey, Value: "4"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "4"},
{Key: decayKey, Value: "0.5"},
{Key: functionKey, Value: "gauss"},
},
}
funcScores := &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
}
f, err := NewFunctionScore(schema, funcScores)
s.NoError(err)
s.Equal([]string{"ts"}, f.GetAllInputFieldNames())
s.Equal([]int64{102}, f.GetAllInputFieldIDs())
s.Equal(false, f.IsSupportGroup())
s.Equal("decay", f.reranker.GetRankName())
{
schema.Fields[3].Nullable = true
_, err := NewFunctionScore(schema, funcScores)
s.ErrorContains(err, "Function input field cannot be nullable")
schema.Fields[3].Nullable = false
}
{
funcScores.Functions[0].Params[0].Value = "NotExist"
_, err := NewFunctionScore(schema, funcScores)
s.ErrorContains(err, "Unsupported rerank function")
funcScores.Functions[0].Params[0].Value = decayFunctionName
}
{
funcScores.Functions = append(funcScores.Functions, functionSchema)
_, err := NewFunctionScore(schema, funcScores)
s.ErrorContains(err, "Currently only supports one rerank, but got")
funcScores.Functions = funcScores.Functions[:1]
}
{
funcScores.Functions[0].Type = schemapb.FunctionType_BM25
_, err := NewFunctionScore(schema, funcScores)
s.ErrorContains(err, "is not rerank function")
funcScores.Functions[0].Type = schemapb.FunctionType_Rerank
}
{
funcScores.Functions[0].OutputFieldNames = []string{"text"}
_, err := NewFunctionScore(schema, funcScores)
s.ErrorContains(err, "Rerank function should not have output field")
funcScores.Functions[0].OutputFieldNames = []string{""}
}
}
func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
Params: []*commonpb.KeyValuePair{
{Key: reranker, Value: decayFunctionName},
{Key: originKey, Value: "4"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "4"},
{Key: decayKey, Value: "0.5"},
{Key: functionKey, Value: "gauss"},
},
}
funcScores := &schemapb.FunctionScore{
Functions: []*schemapb.FunctionSchema{functionSchema},
}
f, err := NewFunctionScore(schema, funcScores)
s.NoError(err)
// empty inputs
{
nq := int64(1)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal(0, len(ret.Results.FieldsData))
}
// single input
// nq = 1
{
nq := int64(1)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3}, ret.Results.Topks)
}
// nq=1, input is empty
{
nq := int64(1)
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0}, ret.Results.Topks)
}
// nq=3
{
nq := int64(3)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
}
// nq=3, all input is empty
{
nq := int64(3)
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
}
// multi inputs
// nq = 1
{
nq := int64(1)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3}, ret.Results.Topks)
}
// nq=1, all input is empty
{
nq := int64(1)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0}, ret.Results.Topks)
}
// nq=1, has empty input
{
nq := int64(1)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3}, ret.Results.Topks)
}
// nq = 3
{
nq := int64(3)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
}
// nq=3, all input is empty
{
nq := int64(3)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
}
// nq=3, has empty input
{
nq := int64(3)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
}
}

View File

@ -0,0 +1,107 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"fmt"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
const (
reranker string = "reranker"
)
// topk and group related parameters, reranker can choose to process or ignore
type searchParams struct {
limit int64
groupByFieldId int64
groupSize int64
strictGroupSize bool
}
type RerankBase struct {
coll *schemapb.CollectionSchema
funcSchema *schemapb.FunctionSchema
rerankerName string
isSupportGroup bool
pkType schemapb.DataType
inputFields []*schemapb.FieldSchema
inputFieldIDs []int64
// TODO: The parameter is passed to the reranker, and the reranker decides whether to implement the parameter
searchParams *searchParams
}
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool, pkType schemapb.DataType) (*RerankBase, error) {
base := RerankBase{
coll: coll,
funcSchema: funcSchema,
rerankerName: rerankerName,
isSupportGroup: isSupportGroup,
pkType: pkType,
}
nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) {
return field.GetName(), field
})
if len(funcSchema.GetOutputFieldNames()) != 0 {
return nil, fmt.Errorf("Rerank function output field names should be empty")
}
for _, name := range funcSchema.GetInputFieldNames() {
if name == "" {
return nil, fmt.Errorf("Rerank input field name cannot be empty string")
}
if lo.Count(funcSchema.GetInputFieldNames(), name) > 1 {
return nil, fmt.Errorf("Each function input field should be used exactly once in the same function, input field: %s", name)
}
inputField, ok := nameMap[name]
if !ok {
return nil, fmt.Errorf("Function input field not found: %s", name)
}
if inputField.GetNullable() {
return nil, fmt.Errorf("Function input field cannot be nullable: field %s", inputField.GetName())
}
base.inputFields = append(base.inputFields, inputField)
base.inputFieldIDs = append(base.inputFieldIDs, inputField.FieldID)
}
return &base, nil
}
func (base *RerankBase) GetInputFieldNames() []string {
return base.funcSchema.InputFieldNames
}
func (base *RerankBase) GetInputFieldIDs() []int64 {
return base.inputFieldIDs
}
func (base *RerankBase) IsSupportGroup() bool {
return base.isSupportGroup
}
func (base *RerankBase) GetRankName() string {
return base.rerankerName
}

View File

@ -0,0 +1,86 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type milvusIDs[T int64 | string] struct {
data []T
}
func newMilvusIDs(ids *schemapb.IDs, pkType schemapb.DataType) any {
switch pkType {
case schemapb.DataType_Int64:
return milvusIDs[int64]{ids.GetIntId().Data}
case schemapb.DataType_String:
return milvusIDs[string]{ids.GetStrId().Data}
}
return nil
}
type idSocres[T int64 | string] struct {
idScoreMap map[T]float32
}
func newIdScores[T int64 | string]() *idSocres[T] {
return &idSocres[T]{idScoreMap: make(map[T]float32)}
}
func (ids *idSocres[T]) exist(id T) bool {
_, exists := ids.idScoreMap[id]
return exists
}
func (ids *idSocres[T]) set(id T, scores float32) {
ids.idScoreMap[id] = scores
}
func (ids *idSocres[T]) get(id T) float32 {
return ids.idScoreMap[id]
}
type numberField[T int32 | int64 | float32 | float64] struct {
data []T
}
func (n *numberField[T]) GetFloat64(idx int) (float64, error) {
if len(n.data) <= idx {
return 0.0, fmt.Errorf("Get field err, idx:%d is larger than data size:%d", idx, len(n.data))
}
return float64(n.data[idx]), nil
}
func getNumberic(inputField *schemapb.FieldData) (any, error) {
switch inputField.Type {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return &numberField[int32]{inputField.GetScalars().GetIntData().Data}, nil
case schemapb.DataType_Int64:
return &numberField[int64]{inputField.GetScalars().GetLongData().Data}, nil
case schemapb.DataType_Float:
return &numberField[float32]{inputField.GetScalars().GetFloatData().Data}, nil
case schemapb.DataType_Double:
return &numberField[float64]{inputField.GetScalars().GetDoubleData().Data}, nil
default:
return nil, fmt.Errorf("Unsupported field type:%s, only support numberic field", inputField.Type.String())
}
}

View File

@ -64,7 +64,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 100,
ct.err_msg: f"collection not found[database=default][collection={name}]"}
self.hybrid_search(client, name, [sub_search1], ranker, limit=default_limit,
self.hybrid_search(client, name, [sub_search1, sub_search1], ranker, limit=default_limit,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -88,7 +88,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 100,
ct.err_msg: f"collection not found[database=default][collection={name}]"}
self.hybrid_search(client, name, [sub_search1], ranker, limit=default_limit,
self.hybrid_search(client, name, [sub_search1, sub_search1], ranker, limit=default_limit,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -133,7 +133,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 100,
ct.err_msg: f"collection not found[database=default][collection=1]"}
self.hybrid_search(client, collection_name, [sub_search1], invalid_ranker, limit=default_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], invalid_ranker, limit=default_limit,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -156,7 +156,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 1,
ct.err_msg: f"`limit` value {invalid_limit} is illegal"}
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=invalid_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=invalid_limit,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -179,7 +179,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 65535,
ct.err_msg: "invalid max query result window, (offset+limit) should be in range [1, 16384], but got 16385"}
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=invalid_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=invalid_limit,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -202,7 +202,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 1,
ct.err_msg: f"`output_fields` value {invalid_output_fields} is illegal"}
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
output_fields=invalid_output_fields, check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -226,7 +226,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 1,
ct.err_msg: f"`partition_name_array` value {invalid_partition_names} is illegal"}
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
partition_names=invalid_partition_names, check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -249,7 +249,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
ranker = WeightedRanker(0.2, 0.8)
error = {ct.err_code: 65535,
ct.err_msg: f"partition name {invalid_partition_names} not found"}
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
partition_names=[invalid_partition_names], check_task=CheckTasks.err_res,
check_items=error)
self.drop_collection(client, collection_name)
@ -274,7 +274,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
error = {ct.err_code: 1100,
ct.err_msg: f"failed to create query plan: failed to get field schema by name: "
f"fieldName({not_exist_vector_field}) not found: invalid parameter"}
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@ -485,4 +485,4 @@ class TestMilvusClientHybridSearchValid(TestMilvusClientV2Base):
"nq": len(vectors_to_search),
"ids": insert_ids,
"limit": default_limit})
self.drop_collection(client, collection_name)
self.drop_collection(client, collection_name)