mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
feat : Support decay rerank (#41223)
https://github.com/milvus-io/milvus/issues/35856 #41312 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
f52c2909c4
commit
f23df95a77
2
go.mod
2
go.mod
@ -21,7 +21,7 @@ require (
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
github.com/klauspost/compress v1.17.9
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250325034212-6e98baa34971
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250407030015-dcf7688ad54a
|
||||
github.com/minio/minio-go/v7 v7.0.73
|
||||
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
|
||||
github.com/prometheus/client_golang v1.14.0
|
||||
|
||||
4
go.sum
4
go.sum
@ -740,6 +740,10 @@ github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZz
|
||||
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250325034212-6e98baa34971 h1:CKKrOtri+dbTUkMJehDuSM489OIqJab1t0pUq+PV73E=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250325034212-6e98baa34971/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250407030015-dcf7688ad54a h1:W+9nVXKcI9FdiyrFbrs9BIFUqRW0pLY+Fn0fsmmuLyw=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250407030015-dcf7688ad54a/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.8 h1:/oUdiYtwVlqiEMSzME7vDvir49Lt23nMpaZC9u22bIo=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.8/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE=
|
||||
github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
|
||||
|
||||
@ -3168,6 +3168,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
lb: node.lbPolicy,
|
||||
enableMaterializedView: node.enableMaterializedView,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With( // TODO: it might cause some cpu consumption
|
||||
@ -3406,6 +3407,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
|
||||
@ -567,6 +567,7 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se
|
||||
UseDefaultConsistency: req.GetUseDefaultConsistency(),
|
||||
SearchByPrimaryKeys: false,
|
||||
SubReqs: nil,
|
||||
FunctionScore: req.FunctionScore,
|
||||
}
|
||||
|
||||
for _, sub := range req.GetRequests() {
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
@ -50,6 +51,8 @@ const (
|
||||
rangeFilterKey = "range_filter"
|
||||
)
|
||||
|
||||
// type requery func(span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error)
|
||||
|
||||
type searchTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
@ -62,7 +65,7 @@ type searchTask struct {
|
||||
tr *timerecord.TimeRecorder
|
||||
collectionName string
|
||||
schema *schemaInfo
|
||||
requery bool
|
||||
needRequery bool
|
||||
partitionKeyMode bool
|
||||
enableMaterializedView bool
|
||||
mustUsePartitionKey bool
|
||||
@ -85,14 +88,21 @@ type searchTask struct {
|
||||
queryInfos []*planpb.QueryInfo
|
||||
relatedDataSize int64
|
||||
|
||||
// Will be deprecated, use functionScore after milvus 2.6
|
||||
reScorers []reScorer
|
||||
rankParams *rankParams
|
||||
groupScorer func(group *Group) error
|
||||
|
||||
// New reranker functions
|
||||
functionScore *rerank.FunctionScore
|
||||
rankParams *rankParams
|
||||
|
||||
isIterator bool
|
||||
// we always remove pk field from output fields, as search result already contains pk field.
|
||||
// if the user explicitly set pk field in output fields, we add it back to the result.
|
||||
userRequestedPkFieldExplicitly bool
|
||||
|
||||
// To facilitate writing unit tests
|
||||
requeryFunc func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error)
|
||||
}
|
||||
|
||||
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
@ -126,6 +136,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PreExecute")
|
||||
defer sp.End()
|
||||
|
||||
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
|
||||
t.Base.MsgType = commonpb.MsgType_Search
|
||||
t.Base.SourceID = paramtable.GetNodeID()
|
||||
@ -181,16 +192,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
|
||||
}
|
||||
}
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Info("parseRankParams failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.rankParams = nil
|
||||
}
|
||||
// Manually update nq if not set.
|
||||
|
||||
nq, err := t.checkNq(ctx)
|
||||
if err != nil {
|
||||
log.Info("failed to check nq", zap.Error(err))
|
||||
@ -211,15 +213,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
|
||||
// Currently, we get vectors by requery. Once we support getting vectors from search,
|
||||
// searches with small result size could no longer need requery.
|
||||
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||
})
|
||||
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
t.requery = len(t.translatedOutputFields) > 0
|
||||
err = t.initAdvancedSearchRequest(ctx)
|
||||
} else {
|
||||
t.requery = len(vectorOutputFields) > 0
|
||||
err = t.initSearchRequest(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
@ -363,8 +359,43 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
defer sp.End()
|
||||
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
var err error
|
||||
// TODO: Use function score uniformly to implement related logic
|
||||
if t.request.FunctionScore != nil {
|
||||
if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil {
|
||||
log.Warn("Failed to create function score", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// set up groupScorer for hybridsearch+groupBy
|
||||
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
groupScorerStr = MaxScorer
|
||||
}
|
||||
groupScorer, err := GetGroupScorer(groupScorerStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.groupScorer = groupScorer
|
||||
}
|
||||
|
||||
t.needRequery = len(t.request.OutputFields) > 0 || len(t.functionScore.GetAllInputFieldNames()) > 0
|
||||
|
||||
if t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema); err != nil {
|
||||
log.Error("parseRankParams failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if !t.functionScore.IsSupportGroup() && t.rankParams.GetGroupByFieldId() >= 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search")
|
||||
}
|
||||
|
||||
// fetch search_growing from search param
|
||||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||
queryFieldIDs := []int64{}
|
||||
@ -418,13 +449,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
internalSubReq.PartitionIDs = t.SearchRequest.GetPartitionIDs()
|
||||
}
|
||||
|
||||
if t.requery {
|
||||
plan.OutputFieldIds = nil
|
||||
plan.DynamicFields = nil
|
||||
} else {
|
||||
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
|
||||
plan.DynamicFields = t.userDynamicFields
|
||||
}
|
||||
plan.OutputFieldIds = nil
|
||||
plan.DynamicFields = nil
|
||||
|
||||
internalSubReq.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
@ -440,7 +466,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
|
||||
var err error
|
||||
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
|
||||
defer sp.End()
|
||||
@ -463,38 +488,244 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
|
||||
}
|
||||
|
||||
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
|
||||
// Collecting the results of a subsearch
|
||||
// [[shard1, shard2, ...],[shard1, shard2, ...]]
|
||||
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for _, searchResult := range toReduceResults {
|
||||
// if get a non-advanced result, skip all
|
||||
if !searchResult.GetIsAdvanced() {
|
||||
continue
|
||||
}
|
||||
for _, subResult := range searchResult.GetSubResults() {
|
||||
// swallow copy
|
||||
internalResults := &internalpb.SearchResults{
|
||||
MetricType: subResult.GetMetricType(),
|
||||
NumQueries: subResult.GetNumQueries(),
|
||||
TopK: subResult.GetTopK(),
|
||||
SlicedBlob: subResult.GetSlicedBlob(),
|
||||
SlicedNumCount: subResult.GetSlicedNumCount(),
|
||||
SlicedOffset: subResult.GetSlicedOffset(),
|
||||
IsAdvanced: false,
|
||||
}
|
||||
reqIndex := subResult.GetReqIndex()
|
||||
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
||||
}
|
||||
}
|
||||
|
||||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for index, internalResults := range multipleInternalResults {
|
||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||
// Since the metrictype in the request may be empty, it can only be obtained from the result
|
||||
subMetricType := getMetricType(internalResults)
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.functionScore == nil {
|
||||
t.reScorers[index].setMetricType(subMetricType)
|
||||
t.reScorers[index].reScore(result)
|
||||
}
|
||||
multipleMilvusResults[index] = result
|
||||
}
|
||||
|
||||
if t.functionScore == nil {
|
||||
if err := t.rank(ctx, span, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
|
||||
})
|
||||
t.fillResult()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillResult() {
|
||||
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
|
||||
resultSizeInsufficient := false
|
||||
for _, topk := range t.result.Results.Topks {
|
||||
if topk < limit {
|
||||
resultSizeInsufficient = true
|
||||
break
|
||||
}
|
||||
}
|
||||
t.resultSizeInsufficient = resultSizeInsufficient
|
||||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
}
|
||||
|
||||
// TODO: Old version rerank: rrf/weighted, subsequent unified rerank implementation
|
||||
func (t *searchTask) rank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error {
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
multipleMilvusResults,
|
||||
t.SearchRequest.GetGroupByFieldId(),
|
||||
t.SearchRequest.GetGroupSize(),
|
||||
t.groupScorer); err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// set up groupScorer for hybridsearch+groupBy
|
||||
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
groupScorerStr = MaxScorer
|
||||
if t.needRequery {
|
||||
if t.requeryFunc == nil {
|
||||
t.requeryFunc = requeryImpl
|
||||
}
|
||||
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
fields, err := t.reorganizeRequeryResults(ctx, queryResult, []*schemapb.IDs{t.result.Results.Ids})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
groupScorer, err := GetGroupScorer(groupScorerStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.groupScorer = groupScorer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
|
||||
uniqueIDs := &schemapb.IDs{}
|
||||
count := 0
|
||||
switch idsList[0].GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
idsSet := typeutil.NewSet[int64]()
|
||||
for _, ids := range idsList {
|
||||
if data := ids.GetIntId().GetData(); data != nil {
|
||||
idsSet.Insert(data...)
|
||||
}
|
||||
}
|
||||
count = idsSet.Len()
|
||||
uniqueIDs.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: idsSet.Collect(),
|
||||
},
|
||||
}
|
||||
case *schemapb.IDs_StrId:
|
||||
idsSet := typeutil.NewSet[string]()
|
||||
for _, ids := range idsList {
|
||||
if data := ids.GetStrId().GetData(); data != nil {
|
||||
idsSet.Insert(data...)
|
||||
}
|
||||
}
|
||||
count = idsSet.Len()
|
||||
uniqueIDs.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: idsSet.Collect(),
|
||||
},
|
||||
}
|
||||
}
|
||||
return uniqueIDs, count
|
||||
}
|
||||
|
||||
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error {
|
||||
var err error
|
||||
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
|
||||
// At this time, outputFields and rerank input_fields will be recalled.
|
||||
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
|
||||
if t.needRequery {
|
||||
idsList := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.IDs, bool) {
|
||||
return m.Results.Ids, true
|
||||
})
|
||||
allIDs, count := mergeIDs(idsList)
|
||||
if count == 0 {
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: t.Nq,
|
||||
TopK: t.rankParams.limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
allNames := typeutil.NewSet[string](t.translatedOutputFields...)
|
||||
allNames.Insert(t.functionScore.GetAllInputFieldNames()...)
|
||||
queryResult, err := t.requeryFunc(t, span, allIDs, allNames.Collect())
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
fields, err := t.reorganizeRequeryResults(ctx, queryResult, idsList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < len(multipleMilvusResults); i++ {
|
||||
multipleMilvusResults[i].Results.FieldsData = fields[i]
|
||||
}
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize,
|
||||
)
|
||||
|
||||
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
if fields, err := t.reorganizeRequeryResults(ctx, queryResult, []*schemapb.IDs{t.result.Results.Ids}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
} else {
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize,
|
||||
)
|
||||
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
|
||||
plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), t.request.GetExprTemplateValues())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.request.FunctionScore != nil {
|
||||
// TODO: When rerank is configured, range search is also supported
|
||||
if isIterator {
|
||||
return merr.WrapErrParameterInvalidMsg("Range search do not support rerank")
|
||||
}
|
||||
|
||||
if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore); err != nil {
|
||||
log.Warn("Failed to create function score", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: When rerank is configured, grouping search is also supported
|
||||
if !t.functionScore.IsSupportGroup() && queryInfo.GetGroupByFieldId() > 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search")
|
||||
}
|
||||
}
|
||||
|
||||
t.isIterator = isIterator
|
||||
t.SearchRequest.Offset = offset
|
||||
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
||||
@ -514,10 +745,16 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if t.requery {
|
||||
plan.OutputFieldIds = nil
|
||||
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||
})
|
||||
t.needRequery = len(vectorOutputFields) > 0
|
||||
if t.needRequery {
|
||||
plan.OutputFieldIds = t.functionScore.GetAllInputFieldIDs()
|
||||
} else {
|
||||
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
|
||||
allFieldIDs := typeutil.NewSet[int64](t.SearchRequest.OutputFieldsId...)
|
||||
allFieldIDs.Insert(t.functionScore.GetAllInputFieldIDs()...)
|
||||
plan.OutputFieldIds = allFieldIDs.Collect()
|
||||
plan.DynamicFields = t.userDynamicFields
|
||||
}
|
||||
|
||||
@ -556,6 +793,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
}
|
||||
sp.AddEvent("Call-function-udf")
|
||||
}
|
||||
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
@ -563,6 +801,45 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toReduceResults []*internalpb.SearchResults) error {
|
||||
metricType := getMetricType(toReduceResults)
|
||||
result, err := t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
|
||||
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
|
||||
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize)
|
||||
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.result = result
|
||||
}
|
||||
|
||||
t.fillResult()
|
||||
if t.needRequery {
|
||||
if t.requeryFunc == nil {
|
||||
t.requeryFunc = requeryImpl
|
||||
}
|
||||
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
fields, err := t.reorganizeRequeryResults(ctx, queryResult, []*schemapb.IDs{t.result.Results.Ids})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
|
||||
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
|
||||
if err != nil || len(annsFieldName) == 0 {
|
||||
@ -753,90 +1030,22 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
metricType := getMetricType(toReduceResults)
|
||||
// reduce
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for _, searchResult := range toReduceResults {
|
||||
// if get a non-advanced result, skip all
|
||||
if !searchResult.GetIsAdvanced() {
|
||||
continue
|
||||
}
|
||||
for _, subResult := range searchResult.GetSubResults() {
|
||||
// swallow copy
|
||||
internalResults := &internalpb.SearchResults{
|
||||
MetricType: subResult.GetMetricType(),
|
||||
NumQueries: subResult.GetNumQueries(),
|
||||
TopK: subResult.GetTopK(),
|
||||
SlicedBlob: subResult.GetSlicedBlob(),
|
||||
SlicedNumCount: subResult.GetSlicedNumCount(),
|
||||
SlicedOffset: subResult.GetSlicedOffset(),
|
||||
IsAdvanced: false,
|
||||
}
|
||||
reqIndex := subResult.GetReqIndex()
|
||||
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
||||
}
|
||||
}
|
||||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for index, internalResults := range multipleInternalResults {
|
||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||
subMetricType := getMetricType(internalResults)
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.reScorers[index].setMetricType(subMetricType)
|
||||
t.reScorers[index].reScore(result)
|
||||
multipleMilvusResults[index] = result
|
||||
}
|
||||
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
multipleMilvusResults,
|
||||
t.SearchRequest.GetGroupByFieldId(),
|
||||
t.SearchRequest.GetGroupSize(),
|
||||
t.groupScorer)
|
||||
if err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// reduce done, get final result
|
||||
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
|
||||
resultSizeInsufficient := false
|
||||
for _, topk := range t.result.Results.Topks {
|
||||
if topk < limit {
|
||||
resultSizeInsufficient = true
|
||||
break
|
||||
}
|
||||
}
|
||||
t.resultSizeInsufficient = resultSizeInsufficient
|
||||
t.isTopkReduce = isTopkReduce
|
||||
t.isRecallEvaluation = isRecallEvaluation
|
||||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
|
||||
if t.requery {
|
||||
err = t.Requery(sp)
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
err = t.advancedPostProcess(ctx, sp, toReduceResults)
|
||||
} else {
|
||||
err = t.searchPostProcess(ctx, sp, toReduceResults)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
|
||||
primaryFieldSchema, _ := t.schema.GetPkField()
|
||||
if t.userRequestedPkFieldExplicitly {
|
||||
t.result.Results.OutputFields = append(t.result.Results.OutputFields, primaryFieldSchema.GetName())
|
||||
var scalars *schemapb.ScalarField
|
||||
@ -869,7 +1078,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil {
|
||||
t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{
|
||||
Token: iterInfo.GetToken(),
|
||||
LastBound: getLastBound(t.result, iterInfo.LastBound, metricType),
|
||||
LastBound: getLastBound(t.result, iterInfo.LastBound, getMetricType(toReduceResults)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -950,7 +1159,7 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
|
||||
//return int64(sizePerRecord) * nq * topK, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) Requery(span trace.Span) error {
|
||||
func requeryImpl(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
@ -961,16 +1170,16 @@ func (t *searchTask) Requery(span trace.Span) error {
|
||||
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
|
||||
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
|
||||
Expr: "",
|
||||
OutputFields: t.translatedOutputFields,
|
||||
OutputFields: outputFields,
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
UseDefaultConsistency: false,
|
||||
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
|
||||
}
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
ids := t.result.GetResults().GetIds()
|
||||
|
||||
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
||||
channelsMvcc := make(map[string]Timestamp)
|
||||
for k, v := range t.queryChannelsTs {
|
||||
@ -998,11 +1207,49 @@ func (t *searchTask) Requery(span trace.Span) error {
|
||||
}
|
||||
queryResult, err := t.node.(*Proxy).query(t.ctx, qt, span)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return merr.Error(queryResult.GetStatus())
|
||||
return nil, merr.Error(queryResult.GetStatus())
|
||||
}
|
||||
return queryResult, err
|
||||
}
|
||||
|
||||
func (t *searchTask) reorganizeRequeryResults(ctx context.Context, queryResult *milvuspb.QueryResults, idsList []*schemapb.IDs) ([][]*schemapb.FieldData, error) {
|
||||
_, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults")
|
||||
defer sp.End()
|
||||
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||
pk := typeutil.GetData(pkFieldData, i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
allFieldData := make([][]*schemapb.FieldData, len(idsList))
|
||||
for idx, ids := range idsList {
|
||||
if ids == nil {
|
||||
allFieldData[idx] = []*schemapb.FieldData{}
|
||||
continue
|
||||
}
|
||||
if fieldData, err := t.pickFieldData(ids, offsets, queryResult); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
allFieldData[idx] = fieldData
|
||||
}
|
||||
}
|
||||
return allFieldData, nil
|
||||
}
|
||||
|
||||
// pick field data from query results
|
||||
func (t *searchTask) pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, queryResult *milvuspb.QueryResults) ([]*schemapb.FieldData, error) {
|
||||
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||
// ===========================================
|
||||
@ -1018,44 +1265,26 @@ func (t *searchTask) Requery(span trace.Span) error {
|
||||
// 3 2 5 4 1 (result ids)
|
||||
// v3 v2 v5 v4 v1 (result vectors)
|
||||
// ===========================================
|
||||
_, sp := otel.Tracer(typeutil.ProxyRole).Start(t.ctx, "reorganizeRequeryResults")
|
||||
defer sp.End()
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||
pk := typeutil.GetData(pkFieldData, i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||
fieldsData := make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||
id := typeutil.GetPK(ids, int64(i))
|
||||
if _, ok := offsets[id]; !ok {
|
||||
return merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %v, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID()))
|
||||
if _, ok := pkOffset[id]; !ok {
|
||||
return nil, merr.WrapErrInconsistentRequery(fmt.Sprintf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(pkOffset), t.GetCollectionID()))
|
||||
}
|
||||
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
|
||||
typeutil.AppendFieldData(fieldsData, queryResult.GetFieldsData(), int64(pkOffset[id]))
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
|
||||
})
|
||||
return nil
|
||||
return fieldsData, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInFieldInfo() {
|
||||
if len(t.translatedOutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
||||
for i, name := range t.translatedOutputFields {
|
||||
for _, field := range t.schema.Fields {
|
||||
if t.result.Results.FieldsData[i] != nil && field.Name == name {
|
||||
t.result.Results.FieldsData[i].FieldName = field.Name
|
||||
t.result.Results.FieldsData[i].FieldId = field.FieldID
|
||||
t.result.Results.FieldsData[i].Type = field.DataType
|
||||
t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic
|
||||
}
|
||||
for _, retField := range t.result.Results.FieldsData {
|
||||
for _, schemaField := range t.schema.Fields {
|
||||
if retField != nil && retField.FieldId == schemaField.FieldID {
|
||||
retField.FieldName = schemaField.Name
|
||||
retField.Type = schemaField.DataType
|
||||
retField.IsDynamic = schemaField.IsDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -31,6 +32,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
@ -50,6 +52,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
@ -218,6 +221,375 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
getSearchTaskWithRerank := func(t *testing.T, collName string, funcInput string) *searchTask {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{funcInput},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
IsTopkReduce: true,
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
Nq: 1,
|
||||
SearchParams: getBaseSearchParams(),
|
||||
FunctionScore: &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
},
|
||||
},
|
||||
mixCoord: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
t.Run("Test empty result with rerank", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
collName := "test_collection_empty_result_with_rerank" + funcutil.GenRandomStr()
|
||||
createCollWithFields(t, collName, qc)
|
||||
qt := getSearchTaskWithRerank(t, collName, testFloatField)
|
||||
err = qt.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, qt.resultBuf)
|
||||
qt.resultBuf.Insert(&internalpb.SearchResults{})
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, qt.resultSizeInsufficient, true)
|
||||
assert.Equal(t, qt.isTopkReduce, false)
|
||||
})
|
||||
|
||||
t.Run("Test search rerank", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
collName := "test_collection_empty_result_with_rerank" + funcutil.GenRandomStr()
|
||||
_, fieldNameId := createCollWithFields(t, collName, qc)
|
||||
qt := getSearchTaskWithRerank(t, collName, testFloatField)
|
||||
err = qt.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, qt.resultBuf)
|
||||
qt.resultBuf.Insert(genTestSearchResultData(1, 10, schemapb.DataType_Float, testFloatField, fieldNameId[testFloatField], false))
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int64{10}, qt.result.Results.Topks)
|
||||
assert.Equal(t, int64(10), qt.result.Results.TopK)
|
||||
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, qt.result.Results.Ids.GetIntId().Data)
|
||||
})
|
||||
|
||||
getHybridSearchTaskWithRerank := func(t *testing.T, collName string, funcInput string, data [][]string) *searchTask {
|
||||
subReqs := []*milvuspb.SubSearchRequest{}
|
||||
for _, item := range data {
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_VarChar,
|
||||
Values: lo.Map(item, func(str string, _ int) []byte { return []byte(str) }),
|
||||
}
|
||||
holder := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
|
||||
}
|
||||
holderByte, _ := proto.Marshal(holder)
|
||||
subReq := &milvuspb.SubSearchRequest{
|
||||
PlaceholderGroup: holderByte,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: testFloatVecField},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
Nq: int64(len(item)),
|
||||
}
|
||||
subReqs = append(subReqs, subReq)
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{funcInput},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
SubReqs: subReqs,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: LimitKey, Value: "10"},
|
||||
},
|
||||
FunctionScore: &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
},
|
||||
OutputFields: []string{testInt32Field},
|
||||
},
|
||||
mixCoord: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
t.Run("Test hybridsearch all empty result with rerank", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
collName := "test_collection_empty_result_with_rerank" + funcutil.GenRandomStr()
|
||||
createCollWithFields(t, collName, qc)
|
||||
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence"}, {"sentence"}})
|
||||
err = qt.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, qt.resultBuf)
|
||||
qt.resultBuf.Insert(&internalpb.SearchResults{})
|
||||
qt.resultBuf.Insert(&internalpb.SearchResults{})
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test hybridsearch search rerank with empty", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
collName := "test_hybridsearch_rerank_with_empty" + funcutil.GenRandomStr()
|
||||
_, fieldNameId := createCollWithFields(t, collName, qc)
|
||||
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence", "sentence"}, {"sentence", "sentence"}})
|
||||
err = qt.PreExecute(ctx)
|
||||
assert.Equal(t, qt.Nq, int64(2))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, qt.resultBuf)
|
||||
// All data are from the same subsearch
|
||||
qt.resultBuf.Insert(genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true))
|
||||
qt.resultBuf.Insert(genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true))
|
||||
|
||||
// rerank inputs
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
|
||||
f1.FieldId = fieldNameId[testFloatField]
|
||||
// search output field
|
||||
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int32, testInt32Field, 20)
|
||||
f2.FieldId = fieldNameId[testInt32Field]
|
||||
// pk
|
||||
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
|
||||
f3.FieldId = fieldNameId[testInt64Field]
|
||||
|
||||
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3},
|
||||
PrimaryFieldName: testInt64Field,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
|
||||
assert.Equal(t, int64(10), qt.result.Results.TopK)
|
||||
assert.Equal(t, int64(2), qt.result.Results.NumQueries)
|
||||
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, qt.result.Results.Ids.GetIntId().Data)
|
||||
assert.Equal(t, testInt32Field, qt.result.Results.FieldsData[0].FieldName)
|
||||
})
|
||||
|
||||
t.Run("Test hybridsearch search rerank ", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
collName := "test_hybridsearch_result_with_rerank" + funcutil.GenRandomStr()
|
||||
_, fieldNameId := createCollWithFields(t, collName, qc)
|
||||
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence", "sentence"}, {"sentence", "sentence"}})
|
||||
err = qt.PreExecute(ctx)
|
||||
assert.Equal(t, qt.Nq, int64(2))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, qt.resultBuf)
|
||||
data1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
|
||||
data1.SubResults[0].ReqIndex = 0
|
||||
data2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
|
||||
data1.SubResults[0].ReqIndex = 2
|
||||
qt.resultBuf.Insert(data2)
|
||||
|
||||
// rerank inputs
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
|
||||
f1.FieldId = fieldNameId[testFloatField]
|
||||
// search output field
|
||||
f2 := testutils.GenerateScalarFieldData(schemapb.DataType_Int32, testInt32Field, 20)
|
||||
f2.FieldId = fieldNameId[testInt32Field]
|
||||
// pk
|
||||
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
|
||||
f3.FieldId = fieldNameId[testInt64Field]
|
||||
|
||||
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3},
|
||||
PrimaryFieldName: testInt64Field,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
|
||||
assert.Equal(t, int64(10), qt.result.Results.TopK)
|
||||
assert.Equal(t, int64(2), qt.result.Results.NumQueries)
|
||||
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, qt.result.Results.Ids.GetIntId().Data)
|
||||
assert.Equal(t, testInt32Field, qt.result.Results.FieldsData[0].FieldName)
|
||||
})
|
||||
|
||||
// rrf/weigted rank
|
||||
t.Run("Test rank function", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
collName := "test_rank_function" + funcutil.GenRandomStr()
|
||||
_, fieldNameId := createCollWithFields(t, collName, qc)
|
||||
qt := getHybridSearchTaskWithRerank(t, collName, testFloatField, [][]string{{"sentence", "sentence"}, {"sentence", "sentence"}})
|
||||
qt.request.FunctionScore = nil
|
||||
qt.request.SearchParams = []*commonpb.KeyValuePair{{Key: "limit", Value: "10"}}
|
||||
qt.request.OutputFields = []string{"*"}
|
||||
err = qt.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
data1 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
|
||||
data1.SubResults[0].ReqIndex = 0
|
||||
data2 := genTestSearchResultData(2, 10, schemapb.DataType_Int64, testInt64Field, fieldNameId[testInt64Field], true)
|
||||
data1.SubResults[0].ReqIndex = 2
|
||||
qt.resultBuf.Insert(data2)
|
||||
|
||||
f1 := testutils.GenerateScalarFieldData(schemapb.DataType_Int32, testInt32Field, 20)
|
||||
f1.FieldId = fieldNameId[testInt32Field]
|
||||
f2 := testutils.GenerateVectorFieldData(schemapb.DataType_FloatVector, testFloatVecField, 20, testVecDim)
|
||||
f2.FieldId = fieldNameId[testFloatVecField]
|
||||
f3 := testutils.GenerateScalarFieldData(schemapb.DataType_Int64, testInt64Field, 20)
|
||||
f3.FieldId = fieldNameId[testInt64Field]
|
||||
f4 := testutils.GenerateScalarFieldData(schemapb.DataType_Float, testFloatField, 20)
|
||||
f4.FieldId = fieldNameId[testFloatField]
|
||||
qt.requeryFunc = func(t *searchTask, span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{f1, f2, f3, f4},
|
||||
PrimaryFieldName: testInt64Field,
|
||||
}, nil
|
||||
}
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int64{10, 10}, qt.result.Results.Topks)
|
||||
assert.Equal(t, int64(10), qt.result.Results.TopK)
|
||||
assert.Equal(t, int64(2), qt.result.Results.NumQueries)
|
||||
assert.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, qt.result.Results.Ids.GetIntId().Data)
|
||||
for _, field := range qt.result.Results.FieldsData {
|
||||
switch field.FieldName {
|
||||
case testInt32Field:
|
||||
assert.True(t, len(field.GetScalars().GetIntData().Data) != 0)
|
||||
case testBoolField:
|
||||
assert.True(t, len(field.GetScalars().GetBoolData().Data) != 0)
|
||||
case testFloatField:
|
||||
assert.True(t, len(field.GetScalars().GetFloatData().Data) != 0)
|
||||
case testFloatVecField:
|
||||
assert.True(t, len(field.GetVectors().GetFloatVector().Data) != 0)
|
||||
case testInt64Field:
|
||||
assert.True(t, len(field.GetScalars().GetLongData().Data) != 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test mergeIDs function", func(t *testing.T) {
|
||||
{
|
||||
ids1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 4, 5, 6},
|
||||
},
|
||||
},
|
||||
}
|
||||
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
|
||||
assert.Equal(t, count, 6)
|
||||
sortedIds := allIDs.GetIntId().GetData()
|
||||
slices.Sort(sortedIds)
|
||||
assert.Equal(t, sortedIds, []int64{1, 2, 3, 4, 5, 6})
|
||||
}
|
||||
{
|
||||
ids1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "e"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c", "d"},
|
||||
},
|
||||
},
|
||||
}
|
||||
allIDs, count := mergeIDs([]*schemapb.IDs{ids1, ids2})
|
||||
assert.Equal(t, count, 5)
|
||||
sortedIds := allIDs.GetStrId().GetData()
|
||||
slices.Sort(sortedIds)
|
||||
assert.Equal(t, sortedIds, []string{"a", "b", "c", "d", "e"})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createCollWithFields(t *testing.T, collName string, rc types.MixCoordClient) (*schemapb.CollectionSchema, map[string]int64) {
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
testInt64Field: schemapb.DataType_Int64,
|
||||
testFloatField: schemapb.DataType_Float,
|
||||
testFloatVecField: schemapb.DataType_FloatVector,
|
||||
testInt32Field: schemapb.DataType_Int32,
|
||||
testBoolField: schemapb.DataType_Bool,
|
||||
}
|
||||
schema := constructCollectionSchemaByDataType(collName, fieldName2Types, testInt64Field, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
ctx := context.TODO()
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: collName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: 1,
|
||||
},
|
||||
ctx: ctx,
|
||||
mixCoord: rc,
|
||||
}
|
||||
|
||||
require.NoError(t, createColT.OnEnqueue())
|
||||
require.NoError(t, createColT.PreExecute(ctx))
|
||||
require.NoError(t, createColT.Execute(ctx))
|
||||
require.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
_, err = globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fieldNameId := make(map[string]int64)
|
||||
for _, field := range schema.Fields {
|
||||
fieldNameId[field.Name] = field.FieldID
|
||||
}
|
||||
return schema, fieldNameId
|
||||
}
|
||||
|
||||
func createColl(t *testing.T, name string, rc types.MixCoordClient) *schemapb.CollectionSchema {
|
||||
@ -349,6 +721,39 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||
return task
|
||||
}
|
||||
|
||||
getSearchTaskWithRerank := func(t *testing.T, collName string, funcInput string) *searchTask {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{funcInput},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collName,
|
||||
SearchRequest: &internalpb.SearchRequest{},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
Nq: 1,
|
||||
SubReqs: []*milvuspb.SubSearchRequest{},
|
||||
FunctionScore: &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
},
|
||||
},
|
||||
mixCoord: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
t.Run("bad nq 0", func(t *testing.T) {
|
||||
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
|
||||
createColl(t, collName, qc)
|
||||
@ -508,6 +913,64 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||
assert.NoError(t, st.PreExecute(ctx))
|
||||
assert.Equal(t, collInfo.updateTimestamp, st.SearchRequest.GuaranteeTimestamp)
|
||||
})
|
||||
t.Run("search with rerank", func(t *testing.T) {
|
||||
collName := "search_with_rerank" + funcutil.GenRandomStr()
|
||||
createCollWithFields(t, collName, qc)
|
||||
st := getSearchTaskWithRerank(t, collName, testFloatField)
|
||||
st.request.SearchParams = getValidSearchParams()
|
||||
st.request.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
_, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
enqueueTs := uint64(100000)
|
||||
st.SetTs(enqueueTs)
|
||||
assert.NoError(t, st.PreExecute(ctx))
|
||||
assert.NotNil(t, st.functionScore)
|
||||
assert.Equal(t, false, st.SearchRequest.GetIsAdvanced())
|
||||
})
|
||||
|
||||
t.Run("advance search with rerank", func(t *testing.T) {
|
||||
collName := "search_with_rerank" + funcutil.GenRandomStr()
|
||||
createCollWithFields(t, collName, qc)
|
||||
st := getSearchTaskWithRerank(t, collName, testFloatField)
|
||||
st.request.SearchParams = getValidSearchParams()
|
||||
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
|
||||
Key: LimitKey,
|
||||
Value: "10",
|
||||
})
|
||||
st.request.DslType = commonpb.DslType_BoolExprV1
|
||||
st.request.SubReqs = append(st.request.SubReqs, &milvuspb.SubSearchRequest{Nq: 1})
|
||||
st.request.SubReqs = append(st.request.SubReqs, &milvuspb.SubSearchRequest{Nq: 1})
|
||||
_, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
enqueueTs := uint64(100000)
|
||||
st.SetTs(enqueueTs)
|
||||
assert.NoError(t, st.PreExecute(ctx))
|
||||
assert.NotNil(t, st.functionScore)
|
||||
assert.Equal(t, true, st.SearchRequest.GetIsAdvanced())
|
||||
})
|
||||
|
||||
t.Run("search with rerank grouping", func(t *testing.T) {
|
||||
collName := "search_with_rerank" + funcutil.GenRandomStr()
|
||||
createCollWithFields(t, collName, qc)
|
||||
st := getSearchTaskWithRerank(t, collName, testFloatField)
|
||||
st.request.SearchParams = getValidSearchParams()
|
||||
st.request.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: testInt64Field,
|
||||
})
|
||||
|
||||
_, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
enqueueTs := uint64(100000)
|
||||
st.SetTs(enqueueTs)
|
||||
assert.ErrorContains(t, st.PreExecute(ctx), "Current rerank does not support grouping search")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
@ -538,6 +1001,9 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 104, Name: "ts", DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
@ -584,7 +1050,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSearchTask := func(t *testing.T, collName string, data []string) *searchTask {
|
||||
getSearchTask := func(t *testing.T, collName string, data []string, withRerank bool) *searchTask {
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_VarChar,
|
||||
@ -594,6 +1060,21 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
|
||||
}
|
||||
holderByte, _ := proto.Marshal(holder)
|
||||
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: "decay"},
|
||||
{Key: "origin", Value: "4"},
|
||||
{Key: "scale", Value: "4"},
|
||||
{Key: "offset", Value: "4"},
|
||||
{Key: "decay", Value: "0.5"},
|
||||
{Key: "function", Value: "gauss"},
|
||||
},
|
||||
}
|
||||
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collectionName,
|
||||
@ -615,6 +1096,11 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
mixCoord: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
if withRerank {
|
||||
task.request.FunctionScore = &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
}
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
@ -631,7 +1117,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
globalMetaCache = cache
|
||||
|
||||
{
|
||||
task := getSearchTask(t, collectionName, []string{"sentence"})
|
||||
task := getSearchTask(t, collectionName, []string{"sentence"}, false)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
@ -642,7 +1128,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
{
|
||||
task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"})
|
||||
task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"}, false)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
@ -654,7 +1140,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
|
||||
// process failed
|
||||
{
|
||||
task := getSearchTask(t, collectionName, []string{"sentence"})
|
||||
task := getSearchTask(t, collectionName, []string{"sentence"}, false)
|
||||
task.request.Nq = 10000
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
@ -3167,11 +3653,11 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
translatedOutputFields: outputFields,
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
err := qt.Requery(nil)
|
||||
queryResult, err := qt.requeryFunc(qt, nil, qt.result.Results.Ids, outputFields)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, qt.result.Results.FieldsData, 2)
|
||||
assert.Len(t, queryResult.FieldsData, 2)
|
||||
for _, field := range qt.result.Results.FieldsData {
|
||||
fieldName := field.GetFieldName()
|
||||
assert.Contains(t, []string{pkField, vecField}, fieldName)
|
||||
@ -3192,13 +3678,14 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
err := qt.Requery(nil)
|
||||
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
@ -3227,12 +3714,13 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
requeryFunc: requeryImpl,
|
||||
}
|
||||
|
||||
err := qt.Requery(nil)
|
||||
_, err := qt.requeryFunc(qt, nil, &schemapb.IDs{}, []string{})
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
@ -3274,11 +3762,11 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
requery: true,
|
||||
schema: schema,
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
needRequery: true,
|
||||
schema: schema,
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
scores := make([]float32, rows)
|
||||
for i := range scores {
|
||||
@ -3712,3 +4200,64 @@ func (s *MaterializedViewTestSuite) TestMvEnabledPartitionKeyOnVarCharWithIsolat
|
||||
func TestMaterializedView(t *testing.T) {
|
||||
suite.Run(t, new(MaterializedViewTestSuite))
|
||||
}
|
||||
|
||||
func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64, IsAdvanced bool) *internalpb.SearchResults {
|
||||
result := &internalpb.SearchResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_SearchResult,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
MetricType: "COSINE",
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
SealedSegmentIDsSearched: nil,
|
||||
ChannelIDsSearched: nil,
|
||||
GlobalSealedSegmentIDs: nil,
|
||||
SlicedBlob: nil,
|
||||
SlicedNumCount: 1,
|
||||
SlicedOffset: 0,
|
||||
IsAdvanced: IsAdvanced,
|
||||
}
|
||||
|
||||
tops := make([]int64, nq)
|
||||
for i := 0; i < int(nq); i++ {
|
||||
tops[i] = topk
|
||||
}
|
||||
|
||||
resultData := &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: testutils.GenerateInt64Array(int(nq * topk)),
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: tops,
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk)),
|
||||
},
|
||||
}
|
||||
resultData.FieldsData[0].FieldId = fieldId
|
||||
sliceBlob, _ := proto.Marshal(resultData)
|
||||
if !IsAdvanced {
|
||||
result.SlicedBlob = sliceBlob
|
||||
} else {
|
||||
result.SubResults = []*internalpb.SubSearchResults{
|
||||
{
|
||||
SlicedBlob: sliceBlob,
|
||||
SlicedNumCount: 1,
|
||||
SlicedOffset: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@ -164,7 +164,7 @@ func parseAKAndURL(params []*commonpb.KeyValuePair, confParams map[string]string
|
||||
|
||||
// from env, url doesn't support configuration in in env
|
||||
if apiKey == "" {
|
||||
url = os.Getenv(apiKeyEnv)
|
||||
apiKey = os.Getenv(apiKeyEnv)
|
||||
}
|
||||
return apiKey, url
|
||||
}
|
||||
|
||||
316
internal/util/function/rerank/decay_function.go
Normal file
316
internal/util/function/rerank/decay_function.go
Normal file
@ -0,0 +1,316 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
originKey string = "origin"
|
||||
scaleKey string = "scale"
|
||||
offsetKey string = "offset"
|
||||
decayKey string = "decay"
|
||||
functionKey string = "function"
|
||||
)
|
||||
|
||||
const (
|
||||
gaussFunction string = "gauss"
|
||||
linerFunction string = "liner"
|
||||
expFunction string = "exp"
|
||||
)
|
||||
|
||||
type DecayFunction[T int64 | string, R int32 | int64 | float32 | float64] struct {
|
||||
RerankBase
|
||||
|
||||
functionName string
|
||||
origin float64
|
||||
scale float64
|
||||
offset float64
|
||||
decay float64
|
||||
reScorer decayReScorer
|
||||
}
|
||||
|
||||
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
pkType := schemapb.DataType_None
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
pkType = field.DataType
|
||||
}
|
||||
}
|
||||
|
||||
if pkType == schemapb.DataType_None {
|
||||
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
|
||||
}
|
||||
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(base.GetInputFieldNames()) != 1 {
|
||||
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
|
||||
}
|
||||
|
||||
var inputType schemapb.DataType
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.Name == base.GetInputFieldNames()[0] {
|
||||
inputType = field.DataType
|
||||
}
|
||||
}
|
||||
|
||||
if pkType == schemapb.DataType_Int64 {
|
||||
switch inputType {
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return newFunction[int64, int32](base, funcSchema)
|
||||
case schemapb.DataType_Int64:
|
||||
return newFunction[int64, int64](base, funcSchema)
|
||||
case schemapb.DataType_Float:
|
||||
return newFunction[int64, float32](base, funcSchema)
|
||||
case schemapb.DataType_Double:
|
||||
return newFunction[int64, float64](base, funcSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Decay rerank: unsupported input field type:%s, only support numberic field", inputType.String())
|
||||
}
|
||||
} else {
|
||||
switch inputType {
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return newFunction[string, int32](base, funcSchema)
|
||||
case schemapb.DataType_Int64:
|
||||
return newFunction[string, int64](base, funcSchema)
|
||||
case schemapb.DataType_Float:
|
||||
return newFunction[string, float32](base, funcSchema)
|
||||
case schemapb.DataType_Double:
|
||||
return newFunction[string, float64](base, funcSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Decay rerank: unsupported input field type:%s, only support numberic field", inputType.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// T: PK Type, R: field type
|
||||
func newFunction[T int64 | string, R int32 | int64 | float32 | float64](base *RerankBase, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
var err error
|
||||
decayFunc := &DecayFunction[T, R]{RerankBase: *base, offset: 0, decay: 0.5}
|
||||
orginInit := false
|
||||
scaleInit := false
|
||||
for _, param := range funcSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case functionKey:
|
||||
decayFunc.functionName = param.Value
|
||||
case originKey:
|
||||
if decayFunc.origin, err = strconv.ParseFloat(param.Value, 64); err != nil {
|
||||
return nil, fmt.Errorf("Param origin:%s is not a number", param.Value)
|
||||
}
|
||||
orginInit = true
|
||||
case scaleKey:
|
||||
if decayFunc.scale, err = strconv.ParseFloat(param.Value, 64); err != nil {
|
||||
return nil, fmt.Errorf("Param scale:%s is not a number", param.Value)
|
||||
}
|
||||
scaleInit = true
|
||||
case offsetKey:
|
||||
if decayFunc.offset, err = strconv.ParseFloat(param.Value, 64); err != nil {
|
||||
return nil, fmt.Errorf("Param offset:%s is not a number", param.Value)
|
||||
}
|
||||
case decayKey:
|
||||
if decayFunc.decay, err = strconv.ParseFloat(param.Value, 64); err != nil {
|
||||
return nil, fmt.Errorf("Param decay:%s is not a number", param.Value)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if !orginInit {
|
||||
return nil, fmt.Errorf("Decay function lost param: origin")
|
||||
}
|
||||
|
||||
if !scaleInit {
|
||||
return nil, fmt.Errorf("Decay function lost param: scale")
|
||||
}
|
||||
|
||||
if decayFunc.scale <= 0 {
|
||||
return nil, fmt.Errorf("Decay function param: scale must > 0, but got %f", decayFunc.scale)
|
||||
}
|
||||
|
||||
if decayFunc.offset < 0 {
|
||||
return nil, fmt.Errorf("Decay function param: offset must => 0, but got %f", decayFunc.offset)
|
||||
}
|
||||
|
||||
if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
|
||||
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1 0, but got %f", decayFunc.offset)
|
||||
}
|
||||
|
||||
switch decayFunc.functionName {
|
||||
case gaussFunction:
|
||||
decayFunc.reScorer = gaussianDecay
|
||||
case expFunction:
|
||||
decayFunc.reScorer = expDecay
|
||||
case linerFunction:
|
||||
decayFunc.reScorer = linearDecay
|
||||
default:
|
||||
return nil, fmt.Errorf("Invaild decay function: %s, only support [%s,%s,%s]", decayFunctionName, gaussFunction, linerFunction, expFunction)
|
||||
}
|
||||
|
||||
return decayFunc, nil
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) reScore(multipSearchResultData []*schemapb.SearchResultData) (*idSocres[T], error) {
|
||||
newScores := newIdScores[T]()
|
||||
for _, data := range multipSearchResultData {
|
||||
var inputField *schemapb.FieldData
|
||||
for _, field := range data.FieldsData {
|
||||
if field.FieldId == decay.GetInputFieldIDs()[0] {
|
||||
inputField = field
|
||||
}
|
||||
}
|
||||
if inputField == nil {
|
||||
return nil, fmt.Errorf("Rerank decay function can not find input field, name: %s", decay.GetInputFieldNames()[0])
|
||||
}
|
||||
var inputValues *numberField[R]
|
||||
if tmp, err := getNumberic(inputField); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
inputValues = tmp.(*numberField[R])
|
||||
}
|
||||
|
||||
ids := newMilvusIDs(data.Ids, decay.pkType).(milvusIDs[T])
|
||||
for idx, id := range ids.data {
|
||||
if !newScores.exist(id) {
|
||||
if v, err := inputValues.GetFloat64(idx); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
newScores.set(id, float32(decay.reScorer(decay.origin, decay.scale, decay.decay, decay.offset, v)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return newScores, nil
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) orgnizeNqScores(searchParams *SearchParams, multipSearchResultData []*schemapb.SearchResultData, idScoreData *idSocres[T]) []map[T]float32 {
|
||||
nqScores := make([]map[T]float32, searchParams.nq)
|
||||
for i := int64(0); i < searchParams.nq; i++ {
|
||||
nqScores[i] = make(map[T]float32)
|
||||
}
|
||||
|
||||
for _, data := range multipSearchResultData {
|
||||
start := int64(0)
|
||||
for nqIdx := int64(0); nqIdx < searchParams.nq; nqIdx++ {
|
||||
realTopk := data.Topks[nqIdx]
|
||||
for j := start; j < start+realTopk; j++ {
|
||||
id := typeutil.GetPK(data.GetIds(), j).(T)
|
||||
if _, exists := nqScores[nqIdx][id]; !exists {
|
||||
nqScores[nqIdx][id] = idScoreData.get(id)
|
||||
}
|
||||
}
|
||||
start += realTopk
|
||||
}
|
||||
}
|
||||
return nqScores
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, multipSearchResultData []*schemapb.SearchResultData) (*schemapb.SearchResultData, error) {
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: searchParams.nq,
|
||||
TopK: searchParams.limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
}
|
||||
multipSearchResultData = lo.Filter(multipSearchResultData, func(searchResult *schemapb.SearchResultData, i int) bool {
|
||||
return len(searchResult.FieldsData) != 0
|
||||
})
|
||||
|
||||
if len(multipSearchResultData) == 0 {
|
||||
return ret, nil
|
||||
}
|
||||
idScoreData, err := decay.reScore(multipSearchResultData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nqScores := decay.orgnizeNqScores(searchParams, multipSearchResultData, idScoreData)
|
||||
topk := searchParams.limit + searchParams.offset
|
||||
for i := int64(0); i < searchParams.nq; i++ {
|
||||
idScoreMap := nqScores[i]
|
||||
ids := make([]T, 0)
|
||||
for id := range idScoreMap {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
big := func(i, j int) bool {
|
||||
if idScoreMap[ids[i]] == idScoreMap[ids[j]] {
|
||||
return ids[i] < ids[j]
|
||||
}
|
||||
return idScoreMap[ids[i]] > idScoreMap[ids[j]]
|
||||
}
|
||||
sort.Slice(ids, big)
|
||||
|
||||
if int64(len(ids)) > topk {
|
||||
ids = ids[:topk]
|
||||
}
|
||||
|
||||
// set real topk
|
||||
ret.Topks = append(ret.Topks, max(0, int64(len(ids))-searchParams.offset))
|
||||
// append id and score
|
||||
for index := searchParams.offset; index < int64(len(ids)); index++ {
|
||||
typeutil.AppendPKs(ret.Ids, ids[index])
|
||||
score := idScoreMap[ids[index]]
|
||||
if searchParams.roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(searchParams.roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type decayReScorer func(float64, float64, float64, float64, float64) float64
|
||||
|
||||
func gaussianDecay(origin, scale, decay, offset, distance float64) float64 {
|
||||
adjustedDist := math.Max(0, math.Abs(distance-origin)-offset)
|
||||
sigmaSquare := 0.5 * math.Pow(scale, 2.0) / math.Log(decay)
|
||||
exponent := math.Pow(adjustedDist, 2.0) / sigmaSquare
|
||||
return math.Exp(exponent)
|
||||
}
|
||||
|
||||
func expDecay(origin, scale, decay, offset, distance float64) float64 {
|
||||
adjustedDist := math.Max(0, math.Abs(distance-origin)-offset)
|
||||
lambda := math.Log(decay) / scale
|
||||
return math.Exp(lambda * adjustedDist)
|
||||
}
|
||||
|
||||
func linearDecay(origin, scale, decay, offset, distance float64) float64 {
|
||||
adjustedDist := math.Max(0, math.Abs(distance-origin)-offset)
|
||||
slope := (1 - decay) / scale
|
||||
return math.Max(decay, 1-slope*adjustedDist)
|
||||
}
|
||||
434
internal/util/function/rerank/decay_function_test.go
Normal file
434
internal/util/function/rerank/decay_function_test.go
Normal file
@ -0,0 +1,434 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
)
|
||||
|
||||
func TestDecayFunction(t *testing.T) {
|
||||
suite.Run(t, new(DecayFunctionSuite))
|
||||
}
|
||||
|
||||
type DecayFunctionSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *DecayFunctionSuite) TestNewDecayErrors() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
OutputFieldNames: []string{"text"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
{Key: decayKey, Value: "0.5"},
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
},
|
||||
}
|
||||
|
||||
{
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Rerank function output field names should be empty")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.OutputFieldNames = []string{}
|
||||
functionSchema.InputFieldNames = []string{""}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Rerank input field name cannot be empty string")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts", "ts"}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Each function input field should be used exactly once in the same function")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts", "pk"}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"notExists"}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Function input field not found:")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"vector"}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay rerank: unsupported input field type")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts"}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
{
|
||||
for i := 0; i < 4; i++ {
|
||||
functionSchema.Params[i].Value = "NotNum"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "is not a number")
|
||||
functionSchema.Params[i].Value = "0.9"
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
fs := []string{gaussFunction, linerFunction, expFunction}
|
||||
for i := 0; i < 3; i++ {
|
||||
functionSchema.Params[4].Value = fs[i]
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
functionSchema.Params[4].Value = "NotExist"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Invaild decay function:")
|
||||
functionSchema.Params[4].Value = "exp"
|
||||
}
|
||||
|
||||
{
|
||||
newSchema := proto.Clone(schema).(*schemapb.CollectionSchema)
|
||||
newSchema.Fields[0].IsPrimaryKey = false
|
||||
_, err := newDecayFunction(newSchema, functionSchema)
|
||||
s.ErrorContains(err, " can not found pk field")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DecayFunctionSuite) TestAllTypesInput() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
{Key: decayKey, Value: "0.5"},
|
||||
},
|
||||
}
|
||||
inputTypes := []schemapb.DataType{schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8, schemapb.DataType_Float, schemapb.DataType_Double, schemapb.DataType_Bool}
|
||||
for i, inputType := range inputTypes {
|
||||
schema.Fields[3].DataType = inputType
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
if i < len(inputTypes)-1 {
|
||||
s.NoError(err)
|
||||
} else {
|
||||
s.ErrorContains(err, "Decay rerank: unsupported input field type")
|
||||
}
|
||||
}
|
||||
|
||||
schema.Fields[0].DataType = schemapb.DataType_String
|
||||
for i, inputType := range inputTypes {
|
||||
schema.Fields[3].DataType = inputType
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
if i < len(inputTypes)-1 {
|
||||
s.NoError(err)
|
||||
} else {
|
||||
s.ErrorContains(err, "Decay rerank: unsupported input field type")
|
||||
}
|
||||
}
|
||||
|
||||
schema.Fields[3].DataType = schemapb.DataType_Double
|
||||
|
||||
{
|
||||
functionSchema.Params[1].Key = "N"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function lost param: origin")
|
||||
functionSchema.Params[1].Key = originKey
|
||||
}
|
||||
{
|
||||
functionSchema.Params[2].Key = "N"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function lost param: scale")
|
||||
functionSchema.Params[2].Key = scaleKey
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.Params[2].Value = "-1"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function param: scale must > 0,")
|
||||
functionSchema.Params[2].Value = "0.5"
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.Params[3].Value = "-1"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function param: offset must => 0")
|
||||
functionSchema.Params[3].Value = "0.5"
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.Params[4].Value = "10"
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function param: decay must 0 < decay < 1 0")
|
||||
functionSchema.Params[2].Value = "0.5"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
{Key: originKey, Value: "0"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "2"},
|
||||
},
|
||||
}
|
||||
|
||||
// empty
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.TopK)
|
||||
s.Equal([]int64{}, ret.Topks)
|
||||
}
|
||||
|
||||
// no input field exist
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
||||
s.NoError(err)
|
||||
_, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data})
|
||||
s.ErrorContains(err, "Rerank decay function can not find input field, name")
|
||||
}
|
||||
|
||||
// singleSearchResultData
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data})
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.Topks)
|
||||
s.Equal(int64(3), ret.TopK)
|
||||
s.Equal([]int64{2, 3, 4}, ret.Ids.GetIntId().Data)
|
||||
}
|
||||
// nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data})
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Topks)
|
||||
s.Equal(int64(3), ret.TopK)
|
||||
s.Equal([]int64{2, 3, 4, 12, 13, 14, 22, 23, 24}, ret.Ids.GetIntId().Data)
|
||||
}
|
||||
|
||||
// multipSearchResultData
|
||||
// nq = 1
|
||||
functionSchema2 := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
{Key: originKey, Value: "5"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "2"},
|
||||
},
|
||||
}
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema2)
|
||||
s.NoError(err)
|
||||
// ts/id data: 0 - 9
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
// ts/id data: 0 - 3
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*schemapb.SearchResultData{data1, data2})
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.Topks)
|
||||
s.Equal(int64(3), ret.TopK)
|
||||
s.Equal([]int64{5, 6, 7}, ret.Ids.GetIntId().Data)
|
||||
}
|
||||
// nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newDecayFunction(schema, functionSchema2)
|
||||
s.NoError(err)
|
||||
// nq1 ts/id data: 0 - 9
|
||||
// nq2 ts/id data: 10 - 19
|
||||
// nq3 ts/id data: 20 - 29
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
// nq1 ts/id data: 0 - 3
|
||||
// nq2 ts/id data: 4 - 7
|
||||
// nq3 ts/id data: 8 - 11
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false}, []*schemapb.SearchResultData{data1, data2})
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Topks)
|
||||
s.Equal(int64(3), ret.TopK)
|
||||
s.Equal([]int64{5, 6, 7, 6, 7, 10, 10, 11, 20}, ret.Ids.GetIntId().Data)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DecayFunctionSuite) TestDecay() {
|
||||
s.Equal(gaussianDecay(0, 1, 0.5, 5, 4), 1.0)
|
||||
s.Equal(gaussianDecay(0, 1, 0.5, 5, 5), 1.0)
|
||||
s.Less(gaussianDecay(0, 1, 0.5, 5, 6), 1.0)
|
||||
|
||||
s.Equal(expDecay(0, 1, 0.5, 5, 4), 1.0)
|
||||
s.Equal(expDecay(0, 1, 0.5, 5, 5), 1.0)
|
||||
s.Less(expDecay(0, 1, 0.5, 5, 6), 1.0)
|
||||
|
||||
s.Equal(linearDecay(0, 1, 0.5, 5, 4), 1.0)
|
||||
s.Equal(linearDecay(0, 1, 0.5, 5, 5), 1.0)
|
||||
s.Less(linearDecay(0, 1, 0.5, 5, 6), 1.0)
|
||||
}
|
||||
|
||||
func (s *DecayFunctionSuite) TestUtil() {
|
||||
inputTypes := []schemapb.DataType{schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8, schemapb.DataType_Float, schemapb.DataType_Double}
|
||||
for _, tp := range inputTypes {
|
||||
field := genSearchResultData(2, 10, tp, "test", 100)
|
||||
num, err := getNumberic(field.FieldsData[0])
|
||||
s.NoError(err)
|
||||
switch tp {
|
||||
case schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8:
|
||||
s.True(len(num.(*numberField[int32]).data) == 20)
|
||||
case schemapb.DataType_Int64:
|
||||
s.True(len(num.(*numberField[int64]).data) == 20)
|
||||
case schemapb.DataType_Float:
|
||||
s.True(len(num.(*numberField[float32]).data) == 20)
|
||||
case schemapb.DataType_Double:
|
||||
s.True(len(num.(*numberField[float64]).data) == 20)
|
||||
}
|
||||
}
|
||||
|
||||
field := genSearchResultData(2, 10, schemapb.DataType_Bool, "test", 100)
|
||||
_, err := getNumberic(field.FieldsData[0])
|
||||
s.ErrorContains(err, "only support numberic field")
|
||||
|
||||
{
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: testutils.GenerateInt64Array(10),
|
||||
},
|
||||
},
|
||||
}
|
||||
mid := newMilvusIDs(ids, schemapb.DataType_Int64).(milvusIDs[int64])
|
||||
s.True(len(mid.data) == 10)
|
||||
}
|
||||
{
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: testutils.GenerateStringArray(10),
|
||||
},
|
||||
},
|
||||
}
|
||||
mid := newMilvusIDs(ids, schemapb.DataType_String).(milvusIDs[string])
|
||||
s.True(len(mid.data) == 10)
|
||||
}
|
||||
}
|
||||
|
||||
func genSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData {
|
||||
tops := make([]int64, nq)
|
||||
for i := 0; i < int(nq); i++ {
|
||||
tops[i] = topk
|
||||
}
|
||||
data := &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: testutils.GenerateInt64Array(int(nq * topk)),
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: tops,
|
||||
FieldsData: []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))},
|
||||
}
|
||||
data.FieldsData[0].FieldId = fieldId
|
||||
return data
|
||||
}
|
||||
166
internal/util/function/rerank/function_score.go
Normal file
166
internal/util/function/rerank/function_score.go
Normal file
@ -0,0 +1,166 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
const (
|
||||
decayFunctionName string = "decay"
|
||||
)
|
||||
|
||||
type SearchParams struct {
|
||||
nq int64
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
|
||||
// TODO: supports group search
|
||||
groupByFieldId int64
|
||||
groupSize int64
|
||||
strictGroupSize bool
|
||||
}
|
||||
|
||||
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool) *SearchParams {
|
||||
return &SearchParams{
|
||||
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize,
|
||||
}
|
||||
}
|
||||
|
||||
type Reranker interface {
|
||||
Process(ctx context.Context, searchParams *SearchParams, searchData []*schemapb.SearchResultData) (*schemapb.SearchResultData, error)
|
||||
IsSupportGroup() bool
|
||||
GetInputFieldNames() []string
|
||||
GetInputFieldIDs() []int64
|
||||
GetRankName() string
|
||||
}
|
||||
|
||||
func getRerankName(funcSchema *schemapb.FunctionSchema) string {
|
||||
for _, param := range funcSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case reranker:
|
||||
return strings.ToLower(param.Value)
|
||||
default:
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Currently only supports single rerank
|
||||
type FunctionScore struct {
|
||||
reranker Reranker
|
||||
}
|
||||
|
||||
func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
if funcSchema.GetType() != schemapb.FunctionType_Rerank {
|
||||
return nil, fmt.Errorf("%s is not rerank function.", funcSchema.GetType().String())
|
||||
}
|
||||
if len(funcSchema.GetOutputFieldNames()) != 0 {
|
||||
return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames()))
|
||||
}
|
||||
|
||||
rerankerName := getRerankName(funcSchema)
|
||||
var rerankFunc Reranker
|
||||
var newRerankErr error
|
||||
switch rerankerName {
|
||||
case decayFunctionName:
|
||||
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
|
||||
}
|
||||
|
||||
if newRerankErr != nil {
|
||||
return nil, newRerankErr
|
||||
}
|
||||
return rerankFunc, nil
|
||||
}
|
||||
|
||||
func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) {
|
||||
if len(funcScoreSchema.Functions) > 1 || len(funcScoreSchema.Functions) == 0 {
|
||||
return nil, fmt.Errorf("Currently only supports one rerank, but got %d", len(funcScoreSchema.Functions))
|
||||
}
|
||||
funcScore := &FunctionScore{}
|
||||
var err error
|
||||
if funcScore.reranker, err = createFunction(collSchema, funcScoreSchema.Functions[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return funcScore, nil
|
||||
}
|
||||
|
||||
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
||||
if len(multipleMilvusResults) == 0 {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: searchParams.nq,
|
||||
TopK: searchParams.limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
allSearchResultData := lo.FilterMap(multipleMilvusResults, func(m *milvuspb.SearchResults, _ int) (*schemapb.SearchResultData, bool) {
|
||||
return m.Results, true
|
||||
})
|
||||
|
||||
// rankResult only has scores
|
||||
rankResult, err := fScore.reranker.Process(ctx, searchParams, allSearchResultData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: rankResult,
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (fScore *FunctionScore) GetAllInputFieldNames() []string {
|
||||
if fScore == nil {
|
||||
return []string{}
|
||||
}
|
||||
return fScore.reranker.GetInputFieldNames()
|
||||
}
|
||||
|
||||
func (fScore *FunctionScore) GetAllInputFieldIDs() []int64 {
|
||||
if fScore == nil {
|
||||
return []int64{}
|
||||
}
|
||||
return fScore.reranker.GetInputFieldIDs()
|
||||
}
|
||||
|
||||
func (fScore *FunctionScore) IsSupportGroup() bool {
|
||||
if fScore == nil {
|
||||
return true
|
||||
}
|
||||
return fScore.reranker.IsSupportGroup()
|
||||
}
|
||||
302
internal/util/function/rerank/function_score_test.go
Normal file
302
internal/util/function/rerank/function_score_test.go
Normal file
@ -0,0 +1,302 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestFunctionScore(t *testing.T) {
|
||||
suite.Run(t, new(FunctionScoreSuite))
|
||||
}
|
||||
|
||||
type FunctionScoreSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *FunctionScoreSuite) TestNewFunctionScore() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: reranker, Value: decayFunctionName},
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
{Key: decayKey, Value: "0.5"},
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
},
|
||||
}
|
||||
funcScores := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
}
|
||||
|
||||
f, err := NewFunctionScore(schema, funcScores)
|
||||
s.NoError(err)
|
||||
s.Equal([]string{"ts"}, f.GetAllInputFieldNames())
|
||||
s.Equal([]int64{102}, f.GetAllInputFieldIDs())
|
||||
s.Equal(false, f.IsSupportGroup())
|
||||
s.Equal("decay", f.reranker.GetRankName())
|
||||
|
||||
{
|
||||
schema.Fields[3].Nullable = true
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "Function input field cannot be nullable")
|
||||
schema.Fields[3].Nullable = false
|
||||
}
|
||||
|
||||
{
|
||||
funcScores.Functions[0].Params[0].Value = "NotExist"
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "Unsupported rerank function")
|
||||
funcScores.Functions[0].Params[0].Value = decayFunctionName
|
||||
}
|
||||
|
||||
{
|
||||
funcScores.Functions = append(funcScores.Functions, functionSchema)
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "Currently only supports one rerank, but got")
|
||||
funcScores.Functions = funcScores.Functions[:1]
|
||||
}
|
||||
|
||||
{
|
||||
funcScores.Functions[0].Type = schemapb.FunctionType_BM25
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "is not rerank function")
|
||||
funcScores.Functions[0].Type = schemapb.FunctionType_Rerank
|
||||
}
|
||||
|
||||
{
|
||||
funcScores.Functions[0].OutputFieldNames = []string{"text"}
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "Rerank function should not have output field")
|
||||
funcScores.Functions[0].OutputFieldNames = []string{""}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: reranker, Value: decayFunctionName},
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
{Key: decayKey, Value: "0.5"},
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
},
|
||||
}
|
||||
funcScores := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
}
|
||||
|
||||
f, err := NewFunctionScore(schema, funcScores)
|
||||
s.NoError(err)
|
||||
|
||||
// empty inputs
|
||||
{
|
||||
nq := int64(1)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal(0, len(ret.Results.FieldsData))
|
||||
}
|
||||
|
||||
// single input
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3}, ret.Results.Topks)
|
||||
}
|
||||
// nq=1, input is empty
|
||||
{
|
||||
nq := int64(1)
|
||||
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0}, ret.Results.Topks)
|
||||
}
|
||||
// nq=3
|
||||
{
|
||||
nq := int64(3)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
|
||||
}
|
||||
// nq=3, all input is empty
|
||||
{
|
||||
nq := int64(3)
|
||||
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
|
||||
}
|
||||
|
||||
// multi inputs
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3}, ret.Results.Topks)
|
||||
}
|
||||
// nq=1, all input is empty
|
||||
{
|
||||
nq := int64(1)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0}, ret.Results.Topks)
|
||||
}
|
||||
// nq=1, has empty input
|
||||
{
|
||||
nq := int64(1)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3}, ret.Results.Topks)
|
||||
}
|
||||
// nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
|
||||
}
|
||||
// nq=3, all input is empty
|
||||
{
|
||||
nq := int64(3)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
|
||||
}
|
||||
// nq=3, has empty input
|
||||
{
|
||||
nq := int64(3)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
|
||||
}
|
||||
}
|
||||
107
internal/util/function/rerank/rerank_base.go
Normal file
107
internal/util/function/rerank/rerank_base.go
Normal file
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
const (
|
||||
reranker string = "reranker"
|
||||
)
|
||||
|
||||
// topk and group related parameters, reranker can choose to process or ignore
|
||||
type searchParams struct {
|
||||
limit int64
|
||||
|
||||
groupByFieldId int64
|
||||
groupSize int64
|
||||
strictGroupSize bool
|
||||
}
|
||||
|
||||
type RerankBase struct {
|
||||
coll *schemapb.CollectionSchema
|
||||
funcSchema *schemapb.FunctionSchema
|
||||
rerankerName string
|
||||
isSupportGroup bool
|
||||
|
||||
pkType schemapb.DataType
|
||||
inputFields []*schemapb.FieldSchema
|
||||
inputFieldIDs []int64
|
||||
|
||||
// TODO: The parameter is passed to the reranker, and the reranker decides whether to implement the parameter
|
||||
searchParams *searchParams
|
||||
}
|
||||
|
||||
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool, pkType schemapb.DataType) (*RerankBase, error) {
|
||||
base := RerankBase{
|
||||
coll: coll,
|
||||
funcSchema: funcSchema,
|
||||
rerankerName: rerankerName,
|
||||
isSupportGroup: isSupportGroup,
|
||||
pkType: pkType,
|
||||
}
|
||||
|
||||
nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) {
|
||||
return field.GetName(), field
|
||||
})
|
||||
|
||||
if len(funcSchema.GetOutputFieldNames()) != 0 {
|
||||
return nil, fmt.Errorf("Rerank function output field names should be empty")
|
||||
}
|
||||
|
||||
for _, name := range funcSchema.GetInputFieldNames() {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("Rerank input field name cannot be empty string")
|
||||
}
|
||||
if lo.Count(funcSchema.GetInputFieldNames(), name) > 1 {
|
||||
return nil, fmt.Errorf("Each function input field should be used exactly once in the same function, input field: %s", name)
|
||||
}
|
||||
inputField, ok := nameMap[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Function input field not found: %s", name)
|
||||
}
|
||||
if inputField.GetNullable() {
|
||||
return nil, fmt.Errorf("Function input field cannot be nullable: field %s", inputField.GetName())
|
||||
}
|
||||
base.inputFields = append(base.inputFields, inputField)
|
||||
base.inputFieldIDs = append(base.inputFieldIDs, inputField.FieldID)
|
||||
}
|
||||
return &base, nil
|
||||
}
|
||||
|
||||
func (base *RerankBase) GetInputFieldNames() []string {
|
||||
return base.funcSchema.InputFieldNames
|
||||
}
|
||||
|
||||
func (base *RerankBase) GetInputFieldIDs() []int64 {
|
||||
return base.inputFieldIDs
|
||||
}
|
||||
|
||||
func (base *RerankBase) IsSupportGroup() bool {
|
||||
return base.isSupportGroup
|
||||
}
|
||||
|
||||
func (base *RerankBase) GetRankName() string {
|
||||
return base.rerankerName
|
||||
}
|
||||
86
internal/util/function/rerank/util.go
Normal file
86
internal/util/function/rerank/util.go
Normal file
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type milvusIDs[T int64 | string] struct {
|
||||
data []T
|
||||
}
|
||||
|
||||
func newMilvusIDs(ids *schemapb.IDs, pkType schemapb.DataType) any {
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
return milvusIDs[int64]{ids.GetIntId().Data}
|
||||
case schemapb.DataType_String:
|
||||
return milvusIDs[string]{ids.GetStrId().Data}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type idSocres[T int64 | string] struct {
|
||||
idScoreMap map[T]float32
|
||||
}
|
||||
|
||||
func newIdScores[T int64 | string]() *idSocres[T] {
|
||||
return &idSocres[T]{idScoreMap: make(map[T]float32)}
|
||||
}
|
||||
|
||||
func (ids *idSocres[T]) exist(id T) bool {
|
||||
_, exists := ids.idScoreMap[id]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (ids *idSocres[T]) set(id T, scores float32) {
|
||||
ids.idScoreMap[id] = scores
|
||||
}
|
||||
|
||||
func (ids *idSocres[T]) get(id T) float32 {
|
||||
return ids.idScoreMap[id]
|
||||
}
|
||||
|
||||
type numberField[T int32 | int64 | float32 | float64] struct {
|
||||
data []T
|
||||
}
|
||||
|
||||
func (n *numberField[T]) GetFloat64(idx int) (float64, error) {
|
||||
if len(n.data) <= idx {
|
||||
return 0.0, fmt.Errorf("Get field err, idx:%d is larger than data size:%d", idx, len(n.data))
|
||||
}
|
||||
return float64(n.data[idx]), nil
|
||||
}
|
||||
|
||||
func getNumberic(inputField *schemapb.FieldData) (any, error) {
|
||||
switch inputField.Type {
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return &numberField[int32]{inputField.GetScalars().GetIntData().Data}, nil
|
||||
case schemapb.DataType_Int64:
|
||||
return &numberField[int64]{inputField.GetScalars().GetLongData().Data}, nil
|
||||
case schemapb.DataType_Float:
|
||||
return &numberField[float32]{inputField.GetScalars().GetFloatData().Data}, nil
|
||||
case schemapb.DataType_Double:
|
||||
return &numberField[float64]{inputField.GetScalars().GetDoubleData().Data}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported field type:%s, only support numberic field", inputField.Type.String())
|
||||
}
|
||||
}
|
||||
@ -64,7 +64,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 100,
|
||||
ct.err_msg: f"collection not found[database=default][collection={name}]"}
|
||||
self.hybrid_search(client, name, [sub_search1], ranker, limit=default_limit,
|
||||
self.hybrid_search(client, name, [sub_search1, sub_search1], ranker, limit=default_limit,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -88,7 +88,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 100,
|
||||
ct.err_msg: f"collection not found[database=default][collection={name}]"}
|
||||
self.hybrid_search(client, name, [sub_search1], ranker, limit=default_limit,
|
||||
self.hybrid_search(client, name, [sub_search1, sub_search1], ranker, limit=default_limit,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -133,7 +133,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 100,
|
||||
ct.err_msg: f"collection not found[database=default][collection=1]"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], invalid_ranker, limit=default_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], invalid_ranker, limit=default_limit,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -156,7 +156,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 1,
|
||||
ct.err_msg: f"`limit` value {invalid_limit} is illegal"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=invalid_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=invalid_limit,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -179,7 +179,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 65535,
|
||||
ct.err_msg: "invalid max query result window, (offset+limit) should be in range [1, 16384], but got 16385"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=invalid_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=invalid_limit,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -202,7 +202,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 1,
|
||||
ct.err_msg: f"`output_fields` value {invalid_output_fields} is illegal"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
|
||||
output_fields=invalid_output_fields, check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -226,7 +226,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 1,
|
||||
ct.err_msg: f"`partition_name_array` value {invalid_partition_names} is illegal"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
|
||||
partition_names=invalid_partition_names, check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -249,7 +249,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
ranker = WeightedRanker(0.2, 0.8)
|
||||
error = {ct.err_code: 65535,
|
||||
ct.err_msg: f"partition name {invalid_partition_names} not found"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
|
||||
partition_names=[invalid_partition_names], check_task=CheckTasks.err_res,
|
||||
check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
@ -274,7 +274,7 @@ class TestMilvusClientHybridSearchInvalid(TestMilvusClientV2Base):
|
||||
error = {ct.err_code: 1100,
|
||||
ct.err_msg: f"failed to create query plan: failed to get field schema by name: "
|
||||
f"fieldName({not_exist_vector_field}) not found: invalid parameter"}
|
||||
self.hybrid_search(client, collection_name, [sub_search1], ranker, limit=default_limit,
|
||||
self.hybrid_search(client, collection_name, [sub_search1, sub_search1], ranker, limit=default_limit,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@ -485,4 +485,4 @@ class TestMilvusClientHybridSearchValid(TestMilvusClientV2Base):
|
||||
"nq": len(vectors_to_search),
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})
|
||||
self.drop_collection(client, collection_name)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user