diff --git a/internal/proxy/task_hybrid_search.go b/internal/proxy/task_hybrid_search.go index 0589295876..4c8b4cc74f 100644 --- a/internal/proxy/task_hybrid_search.go +++ b/internal/proxy/task_hybrid_search.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -395,8 +396,10 @@ func (t *hybridSearchTask) PostExecute(ctx context.Context) error { return err } + metricType := "" t.queryChannelsTs = make(map[string]uint64) for _, r := range t.resultBuf.Collect() { + metricType = r.GetResults()[0].GetMetricType() for ch, ts := range r.GetChannelsMvcc() { t.queryChannelsTs[ch] = ts } @@ -416,6 +419,7 @@ func (t *hybridSearchTask) PostExecute(ctx context.Context) error { t.result, err = rankSearchResultData(ctx, 1, t.rankParams, primaryFieldSchema.GetDataType(), + metricType, t.multipleRecallResults.Collect()) if err != nil { log.Warn("rank search result failed", zap.Error(err)) @@ -468,6 +472,7 @@ func rankSearchResultData(ctx context.Context, nq int64, params *rankParams, pkType schemapb.DataType, + metricType string, searchResults []*milvuspb.SearchResults, ) (*milvuspb.SearchResults, error) { tr := timerecord.NewTimeRecorder("rankSearchResultData") @@ -483,7 +488,8 @@ func rankSearchResultData(ctx context.Context, zap.Int("len(searchResults)", len(searchResults)), zap.Int64("nq", nq), zap.Int64("offset", offset), - zap.Int64("limit", limit)) + zap.Int64("limit", limit), + zap.String("metric type", metricType)) ret := &milvuspb.SearchResults{ Status: merr.Success(), @@ -546,9 +552,18 @@ func rankSearchResultData(ctx context.Context, } // sort id by score - sort.Slice(keys, func(i, j int) bool { - return idSet[keys[i]] >= idSet[keys[j]] - }) + var less func(i, j int) bool + if metric.PositivelyRelated(metricType) { + less = func(i, j int) bool { + return idSet[keys[i]] > idSet[keys[j]] + } + } else { + less = func(i, j int) bool { + return idSet[keys[i]] < idSet[keys[j]] + } + } + + sort.Slice(keys, less) if int64(len(keys)) > topk { keys = keys[:topk]