From f1a4526bac16d63868ccc5cea3d4fc6eea3b7e8d Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Tue, 10 Jun 2025 18:08:35 +0800 Subject: [PATCH] enhance: refactor rrf and weighted rerank (#42154) https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjie.jiang --- internal/proxy/reScorer.go | 246 -------------- internal/proxy/reScorer_test.go | 146 -------- internal/proxy/search_reduce_util.go | 312 ------------------ internal/proxy/search_reduce_util_test.go | 81 ----- internal/proxy/search_util.go | 8 + internal/proxy/task_search.go | 79 +---- internal/proxy/task_search_test.go | 2 +- internal/proxy/util.go | 4 - .../util/function/mock_embedding_service.go | 59 ++++ .../util/function/rerank/decay_function.go | 14 +- .../function/rerank/decay_function_test.go | 70 ++-- .../util/function/rerank/function_score.go | 115 ++++++- .../function/rerank/function_score_test.go | 174 ++++++++-- .../util/function/rerank/model_function.go | 9 +- .../function/rerank/model_function_test.go | 50 +-- internal/util/function/rerank/rrf_function.go | 101 ++++++ .../util/function/rerank/rrf_function_test.go | 250 ++++++++++++++ internal/util/function/rerank/util.go | 199 +++++++++-- .../util/function/rerank/weighted_function.go | 154 +++++++++ .../function/rerank/weighted_function_test.go | 298 +++++++++++++++++ .../partial_result_on_node_down_test.go | 1 + .../test_milvus_client_search.py | 6 +- 22 files changed, 1368 insertions(+), 1010 deletions(-) delete mode 100644 internal/proxy/reScorer.go delete mode 100644 internal/proxy/reScorer_test.go create mode 100644 internal/util/function/rerank/rrf_function.go create mode 100644 internal/util/function/rerank/rrf_function_test.go create mode 100644 internal/util/function/rerank/weighted_function.go create mode 100644 internal/util/function/rerank/weighted_function_test.go diff --git a/internal/proxy/reScorer.go b/internal/proxy/reScorer.go deleted file mode 100644 index 730415bfb0..0000000000 --- a/internal/proxy/reScorer.go +++ /dev/null @@ -1,246 +0,0 @@ -package proxy - -import ( - "context" - "fmt" - "math" - "reflect" - "strings" - - "github.com/cockroachdb/errors" - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/json" - "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/util/funcutil" - "github.com/milvus-io/milvus/pkg/v2/util/merr" - "github.com/milvus-io/milvus/pkg/v2/util/metric" -) - -type rankType int - -const ( - invalidRankType rankType = iota // invalidRankType = 0 - rrfRankType // rrfRankType = 1 - weightedRankType // weightedRankType = 2 - udfExprRankType // udfExprRankType = 3 -) - -var rankTypeMap = map[string]rankType{ - "invalid": invalidRankType, - "rrf": rrfRankType, - "weighted": weightedRankType, - "expr": udfExprRankType, -} - -type reScorer interface { - name() string - scorerType() rankType - reScore(input *milvuspb.SearchResults) - setMetricType(metricType string) - getMetricType() string -} - -type baseScorer struct { - scorerName string - metricType string -} - -func (bs *baseScorer) name() string { - return bs.scorerName -} - -func (bs *baseScorer) setMetricType(metricType string) { - bs.metricType = metricType -} - -func (bs *baseScorer) getMetricType() string { - return bs.metricType -} - -type rrfScorer struct { - baseScorer - k float32 -} - -func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) { - index := 0 - for _, topk := range input.Results.GetTopks() { - for i := int64(0); i < topk; i++ { - input.Results.Scores[index] = 1 / (rs.k + float32(i+1)) - index++ - } - } -} - -func (rs *rrfScorer) scorerType() rankType { - return rrfRankType -} - -type weightedScorer struct { - baseScorer - weight float32 - normScore bool -} - -type activateFunc func(float32) float32 - -func (ws *weightedScorer) getActivateFunc() activateFunc { - if !ws.normScore { - return func(distance float32) float32 { - return distance - } - } - mUpper := strings.ToUpper(ws.getMetricType()) - isCosine := mUpper == strings.ToUpper(metric.COSINE) - isIP := mUpper == strings.ToUpper(metric.IP) - isBM25 := mUpper == strings.ToUpper(metric.BM25) - if isCosine { - f := func(distance float32) float32 { - return (1 + distance) * 0.5 - } - return f - } - - if isIP { - f := func(distance float32) float32 { - return 0.5 + float32(math.Atan(float64(distance)))/math.Pi - } - return f - } - - if isBM25 { - f := func(distance float32) float32 { - return 2 * float32(math.Atan(float64(distance))) / math.Pi - } - return f - } - - f := func(distance float32) float32 { - return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi - } - return f -} - -func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) { - activateF := ws.getActivateFunc() - for i, distance := range input.Results.GetScores() { - input.Results.Scores[i] = ws.weight * activateF(distance) - } -} - -func (ws *weightedScorer) scorerType() rankType { - return weightedRankType -} - -func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) { - if reqCnt == 0 { - return []reScorer{}, nil - } - - log := log.Ctx(ctx) - res := make([]reScorer, reqCnt) - rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams) - if err != nil { - log.Info("rank strategy not specified, use rrf instead") - // if not set rank strategy, use rrf rank as default - for i := 0; i < reqCnt; i++ { - res[i] = &rrfScorer{ - baseScorer: baseScorer{ - scorerName: "rrf", - }, - k: float32(defaultRRFParamsValue), - } - } - return res, nil - } - - if _, ok := rankTypeMap[rankTypeStr]; !ok { - return nil, errors.Errorf("unsupported rank type %s", rankTypeStr) - } - - paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams) - if err != nil { - return nil, errors.New(RankParamsKey + " not found in rank_params") - } - - var params map[string]interface{} - err = json.Unmarshal([]byte(paramStr), ¶ms) - if err != nil { - return nil, err - } - - switch rankTypeMap[rankTypeStr] { - case rrfRankType: - _, ok := params[RRFParamsKey] - if !ok { - return nil, errors.New(RRFParamsKey + " not found in rank_params") - } - var k float64 - if reflect.ValueOf(params[RRFParamsKey]).CanFloat() { - k = reflect.ValueOf(params[RRFParamsKey]).Float() - } else { - return nil, errors.New("The type of rank param k should be float") - } - if k <= 0 || k >= maxRRFParamsValue { - return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue)) - } - log.Debug("rrf params", zap.Float64("k", k)) - for i := 0; i < reqCnt; i++ { - res[i] = &rrfScorer{ - baseScorer: baseScorer{ - scorerName: "rrf", - }, - k: float32(k), - } - } - case weightedRankType: - if _, ok := params[WeightsParamsKey]; !ok { - return nil, errors.New(WeightsParamsKey + " not found in rank_params") - } - // normalize scores by default - normScore := true - if _, ok := params[NormScoreKey]; ok { - normScore = params[NormScoreKey].(bool) - } - weights := make([]float32, 0) - switch reflect.TypeOf(params[WeightsParamsKey]).Kind() { - case reflect.Slice: - rs := reflect.ValueOf(params[WeightsParamsKey]) - for i := 0; i < rs.Len(); i++ { - v := rs.Index(i).Elem() - if v.CanFloat() { - weight := v.Float() - if weight < 0 || weight > 1 { - return nil, errors.New("rank param weight should be in range [0, 1]") - } - weights = append(weights, float32(weight)) - } else { - return nil, errors.New("The type of rank param weight should be float") - } - } - default: - return nil, errors.New("The weights param should be an array") - } - - log.Debug("weights params", zap.Any("weights", weights), zap.Bool("norm_score", normScore)) - if reqCnt != len(weights) { - return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests") - } - for i := 0; i < reqCnt; i++ { - res[i] = &weightedScorer{ - baseScorer: baseScorer{ - scorerName: "weighted", - }, - weight: weights[i], - normScore: normScore, - } - } - default: - return nil, errors.Errorf("unsupported rank type %s", rankTypeStr) - } - - return res, nil -} diff --git a/internal/proxy/reScorer_test.go b/internal/proxy/reScorer_test.go deleted file mode 100644 index 1f16768c47..0000000000 --- a/internal/proxy/reScorer_test.go +++ /dev/null @@ -1,146 +0,0 @@ -package proxy - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/json" -) - -func TestRescorer(t *testing.T) { - t.Run("default scorer", func(t *testing.T) { - rescorers, err := NewReScorers(context.TODO(), 2, nil) - assert.NoError(t, err) - assert.Equal(t, 2, len(rescorers)) - assert.Equal(t, rrfRankType, rescorers[0].scorerType()) - }) - - t.Run("rrf without param", func(t *testing.T) { - params := make(map[string]float64) - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "rrf"}, - {Key: RankParamsKey, Value: string(b)}, - } - - _, err = NewReScorers(context.TODO(), 2, rankParams) - assert.Error(t, err) - assert.Contains(t, err.Error(), "k not found in rank_params") - }) - - t.Run("rrf param out of range", func(t *testing.T) { - params := make(map[string]float64) - params[RRFParamsKey] = -1 - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "rrf"}, - {Key: RankParamsKey, Value: string(b)}, - } - - _, err = NewReScorers(context.TODO(), 2, rankParams) - assert.Error(t, err) - - params[RRFParamsKey] = maxRRFParamsValue + 1 - b, err = json.Marshal(params) - assert.NoError(t, err) - rankParams = []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "rrf"}, - {Key: RankParamsKey, Value: string(b)}, - } - - _, err = NewReScorers(context.TODO(), 2, rankParams) - assert.Error(t, err) - }) - - t.Run("rrf", func(t *testing.T) { - params := make(map[string]float64) - params[RRFParamsKey] = 61 - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "rrf"}, - {Key: RankParamsKey, Value: string(b)}, - } - - rescorers, err := NewReScorers(context.TODO(), 2, rankParams) - assert.NoError(t, err) - assert.Equal(t, 2, len(rescorers)) - assert.Equal(t, rrfRankType, rescorers[0].scorerType()) - assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k) - }) - - t.Run("weights without param", func(t *testing.T) { - params := make(map[string][]float64) - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "weighted"}, - {Key: RankParamsKey, Value: string(b)}, - } - - _, err = NewReScorers(context.TODO(), 2, rankParams) - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found in rank_params") - }) - - t.Run("weights out of range", func(t *testing.T) { - weights := []float64{1.2, 2.3} - params := make(map[string][]float64) - params[WeightsParamsKey] = weights - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "weighted"}, - {Key: RankParamsKey, Value: string(b)}, - } - - _, err = NewReScorers(context.TODO(), 2, rankParams) - assert.Error(t, err) - assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]") - }) - - t.Run("weights with norm_score false", func(t *testing.T) { - weights := []float64{0.5, 0.2} - params := make(map[string]interface{}) - params[WeightsParamsKey] = weights - params[NormScoreKey] = false - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "weighted"}, - {Key: RankParamsKey, Value: string(b)}, - } - - rescorers, err := NewReScorers(context.TODO(), 2, rankParams) - assert.NoError(t, err) - assert.Equal(t, 2, len(rescorers)) - assert.Equal(t, weightedRankType, rescorers[0].scorerType()) - assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight) - assert.False(t, rescorers[0].(*weightedScorer).normScore) - }) - - t.Run("weights", func(t *testing.T) { - weights := []float64{0.5, 0.2} - params := make(map[string]interface{}) - params[WeightsParamsKey] = weights - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "weighted"}, - {Key: RankParamsKey, Value: string(b)}, - } - - rescorers, err := NewReScorers(context.TODO(), 2, rankParams) - assert.NoError(t, err) - assert.Equal(t, 2, len(rescorers)) - assert.Equal(t, weightedRankType, rescorers[0].scorerType()) - assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight) - // normalize scores by default - assert.True(t, rescorers[0].(*weightedScorer).normScore) - }) -} diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go index ca50927cbe..a12b6d2102 100644 --- a/internal/proxy/search_reduce_util.go +++ b/internal/proxy/search_reduce_util.go @@ -3,8 +3,6 @@ package proxy import ( "context" "fmt" - "math" - "sort" "github.com/cockroachdb/errors" "go.uber.org/zap" @@ -432,21 +430,6 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData [] return ret, nil } -func rankSearchResultData(ctx context.Context, - nq int64, - params *rankParams, - pkType schemapb.DataType, - searchResults []*milvuspb.SearchResults, - groupByFieldID int64, - groupSize int64, - groupScorer func(group *Group) error, -) (*milvuspb.SearchResults, error) { - if groupByFieldID > 0 { - return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize) - } - return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults) -} - func compareKey(keyI interface{}, keyJ interface{}) bool { switch keyI.(type) { case int64: @@ -457,213 +440,6 @@ func compareKey(keyI interface{}, keyJ interface{}) bool { return false } -func GetGroupScorer(scorerType string) (func(group *Group) error, error) { - switch scorerType { - case MaxScorer: - return func(group *Group) error { - group.finalScore = group.maxScore - return nil - }, nil - case SumScorer: - return func(group *Group) error { - group.finalScore = group.sumScore - return nil - }, nil - case AvgScorer: - return func(group *Group) error { - if len(group.idList) == 0 { - return merr.WrapErrParameterInvalid(1, len(group.idList), - "input group for score must have at least one id, must be sth wrong within code") - } - group.finalScore = group.sumScore / float32(len(group.idList)) - return nil - }, nil - default: - return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType) - } -} - -type Group struct { - idList []interface{} - scoreList []float32 - groupVal interface{} - maxScore float32 - sumScore float32 - finalScore float32 -} - -func rankSearchResultDataByGroup(ctx context.Context, - nq int64, - params *rankParams, - pkType schemapb.DataType, - searchResults []*milvuspb.SearchResults, - groupScorer func(group *Group) error, - groupSize int64, -) (*milvuspb.SearchResults, error) { - tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup") - defer func() { - tr.CtxElapse(ctx, "done") - }() - offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal - // in the context of group by, the meaning for offset/limit/top refers to related numbers of group - groupTopK := limit + offset - log.Ctx(ctx).Debug("rankSearchResultDataByGroup", - zap.Int("len(searchResults)", len(searchResults)), - zap.Int64("nq", nq), - zap.Int64("offset", offset), - zap.Int64("limit", limit)) - - var ret *milvuspb.SearchResults - if ret = initSearchResults(nq, limit); len(searchResults) == 0 { - return ret, nil - } - - totalCount := limit * groupSize - if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil { - return ret, err - } - - type accumulateIDGroupVal struct { - accumulatedScore float32 - groupVal interface{} - } - - accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq) - for i := int64(0); i < nq; i++ { - accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal) - } - groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType() - for _, result := range searchResults { - scores := result.GetResults().GetScores() - start := 0 - // milvus has limits for the value range of nq and limit - // no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe - groupByValIterator := typeutil.GetDataIterator(result.GetResults().GetGroupByFieldValue()) - for i := 0; i < int(nq); i++ { - realTopK := int(result.GetResults().Topks[i]) - for j := start; j < start+realTopK; j++ { - id := typeutil.GetPK(result.GetResults().GetIds(), int64(j)) - groupByVal := groupByValIterator(j) - if accumulatedScores[i][id] != nil { - accumulatedScores[i][id].accumulatedScore += scores[j] - } else { - accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal} - } - } - start += realTopK - } - } - - gpFieldBuilder, err := typeutil.NewFieldDataBuilder(groupByDataType, true, int(limit)) - if err != nil { - return ret, err - } - for i := int64(0); i < nq; i++ { - idSet := accumulatedScores[i] - keys := make([]interface{}, 0) - for key := range idSet { - keys = append(keys, key) - } - - // sort id by score - big := func(i, j int) bool { - scoreItemI := idSet[keys[i]] - scoreItemJ := idSet[keys[j]] - if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore { - return compareKey(keys[i], keys[j]) - } - return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore - } - sort.Slice(keys, big) - - // separate keys into buckets according to groupVal - buckets := make(map[interface{}]*Group) - for _, key := range keys { - scoreItem := idSet[key] - groupVal := scoreItem.groupVal - if buckets[groupVal] == nil { - buckets[groupVal] = &Group{ - idList: make([]interface{}, 0), - scoreList: make([]float32, 0), - groupVal: groupVal, - } - } - if int64(len(buckets[groupVal].idList)) >= groupSize { - // only consider group size results in each group - continue - } - buckets[groupVal].idList = append(buckets[groupVal].idList, key) - buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore) - if scoreItem.accumulatedScore > buckets[groupVal].maxScore { - buckets[groupVal].maxScore = scoreItem.accumulatedScore - } - buckets[groupVal].sumScore += scoreItem.accumulatedScore - } - if int64(len(buckets)) <= offset { - ret.Results.Topks = append(ret.Results.Topks, 0) - continue - } - - groupList := make([]*Group, len(buckets)) - idx := 0 - for _, group := range buckets { - groupScorer(group) - groupList[idx] = group - idx += 1 - } - sort.Slice(groupList, func(i, j int) bool { - if groupList[i].finalScore == groupList[j].finalScore { - if len(groupList[i].idList) == len(groupList[j].idList) { - // if final score and size of group are both equal - // choose the group with smaller first key - // here, it's guaranteed all group having at least one id in the idList - return compareKey(groupList[i].idList[0], groupList[j].idList[0]) - } - // choose the larger group when scores are equal - return len(groupList[i].idList) > len(groupList[j].idList) - } - return groupList[i].finalScore > groupList[j].finalScore - }) - - if int64(len(groupList)) > groupTopK { - groupList = groupList[:groupTopK] - } - returnedRowNum := 0 - for index := int(offset); index < len(groupList); index++ { - group := groupList[index] - for i, score := range group.scoreList { - // idList and scoreList must have same length - typeutil.AppendPKs(ret.Results.Ids, group.idList[i]) - if roundDecimal != -1 { - multiplier := math.Pow(10.0, float64(roundDecimal)) - score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier) - } - ret.Results.Scores = append(ret.Results.Scores, score) - gpFieldBuilder.Add(group.groupVal) - } - returnedRowNum += len(group.idList) - } - ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum)) - } - - ret.Results.GroupByFieldValue = gpFieldBuilder.Build() - return ret, nil -} - -func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults { - return &milvuspb.SearchResults{ - Status: merr.Success(), - Results: &schemapb.SearchResultData{ - NumQueries: nq, - TopK: limit, - FieldsData: make([]*schemapb.FieldData, 0), - Scores: []float32{}, - Ids: &schemapb.IDs{}, - Topks: []int64{}, - }, - } -} - func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error { switch pkType { case schemapb.DataType_Int64: @@ -684,94 +460,6 @@ func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType sch return nil } -func rankSearchResultDataByPk(ctx context.Context, - nq int64, - params *rankParams, - pkType schemapb.DataType, - searchResults []*milvuspb.SearchResults, -) (*milvuspb.SearchResults, error) { - tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk") - defer func() { - tr.CtxElapse(ctx, "done") - }() - - offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal - topk := limit + offset - log.Ctx(ctx).Debug("rankSearchResultDataByPk", - zap.Int("len(searchResults)", len(searchResults)), - zap.Int64("nq", nq), - zap.Int64("offset", offset), - zap.Int64("limit", limit)) - - var ret *milvuspb.SearchResults - if ret = initSearchResults(nq, limit); len(searchResults) == 0 { - return ret, nil - } - - if err := setupIdListForSearchResult(ret, pkType, limit); err != nil { - return ret, nil - } - - // []map[id]score - accumulatedScores := make([]map[interface{}]float32, nq) - for i := int64(0); i < nq; i++ { - accumulatedScores[i] = make(map[interface{}]float32) - } - - for _, result := range searchResults { - scores := result.GetResults().GetScores() - start := int64(0) - for i := int64(0); i < nq; i++ { - realTopk := result.GetResults().Topks[i] - for j := start; j < start+realTopk; j++ { - id := typeutil.GetPK(result.GetResults().GetIds(), j) - accumulatedScores[i][id] += scores[j] - } - start += realTopk - } - } - - for i := int64(0); i < nq; i++ { - idSet := accumulatedScores[i] - keys := make([]interface{}, 0) - for key := range idSet { - keys = append(keys, key) - } - if int64(len(keys)) <= offset { - ret.Results.Topks = append(ret.Results.Topks, 0) - continue - } - - // sort id by score - big := func(i, j int) bool { - if idSet[keys[i]] == idSet[keys[j]] { - return compareKey(keys[i], keys[j]) - } - return idSet[keys[i]] > idSet[keys[j]] - } - - sort.Slice(keys, big) - - if int64(len(keys)) > topk { - keys = keys[:topk] - } - - // set real topk - ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset) - // append id and score - for index := offset; index < int64(len(keys)); index++ { - typeutil.AppendPKs(ret.Results.Ids, keys[index]) - score := idSet[keys[index]] - if roundDecimal != -1 { - multiplier := math.Pow(10.0, float64(roundDecimal)) - score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier) - } - ret.Results.Scores = append(ret.Results.Scores, score) - } - } - return ret, nil -} - func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults { return &milvuspb.SearchResults{ Status: merr.Success("search result is empty"), diff --git a/internal/proxy/search_reduce_util_test.go b/internal/proxy/search_reduce_util_test.go index 3f56b524c1..084a8f0079 100644 --- a/internal/proxy/search_reduce_util_test.go +++ b/internal/proxy/search_reduce_util_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -52,86 +51,6 @@ func genTestDataSearchResultsData() []*schemapb.SearchResultData { return []*schemapb.SearchResultData{searchResultData1, searchResultData2} } -func (struts *SearchReduceUtilTestSuite) TestRankByGroup() { - data := genTestDataSearchResultsData() - searchResults := []*milvuspb.SearchResults{ - {Results: data[0]}, - {Results: data[1]}, - } - - nq := int64(1) - limit := int64(3) - offset := int64(0) - roundDecimal := int64(1) - groupSize := int64(3) - groupByFieldId := int64(101) - rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal} - - { - // test for sum group scorer - scorerType := "sum" - groupScorer, _ := GetGroupScorer(scorerType) - rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) - struts.NoError(err) - struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data) - struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores()) - struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) - } - - { - // test for max group scorer - scorerType := "max" - groupScorer, _ := GetGroupScorer(scorerType) - rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) - struts.NoError(err) - struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data) - struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores()) - struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) - } - - { - // test for avg group scorer - scorerType := "avg" - groupScorer, _ := GetGroupScorer(scorerType) - rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) - struts.NoError(err) - struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data) - struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores()) - struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) - } - - { - // test for offset for ranking group - scorerType := "avg" - groupScorer, _ := GetGroupScorer(scorerType) - rankParams.offset = 2 - rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) - struts.NoError(err) - struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data) - struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores()) - struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data) - } - - { - // test for offset exceeding the count of final groups - scorerType := "avg" - groupScorer, _ := GetGroupScorer(scorerType) - rankParams.offset = 4 - rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer) - struts.NoError(err) - struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data) - struts.Equal([]float32{}, rankedRes.GetResults().GetScores()) - } - - { - // test for invalid group scorer - scorerType := "xxx" - groupScorer, err := GetGroupScorer(scorerType) - struts.Error(err) - struts.Nil(groupScorer) - } -} - func (struts *SearchReduceUtilTestSuite) TestReduceSearchResult() { data := genTestDataSearchResultsData() diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 3261868303..706e372dcb 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -537,6 +537,14 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C }, nil } +func getGroupScorerStr(params []*commonpb.KeyValuePair) string { + groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, params) + if err != nil { + groupScorerStr = MaxScorer + } + return groupScorerStr +} + func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest { ret := &milvuspb.SearchRequest{ Base: req.GetBase(), diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 759804f431..6368a26ac3 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -88,10 +88,6 @@ type searchTask struct { queryInfos []*planpb.QueryInfo relatedDataSize int64 - // Will be deprecated, use functionScore after milvus 2.6 - reScorers []reScorer - groupScorer func(group *Group) error - // New reranker functions functionScore *rerank.FunctionScore rankParams *rankParams @@ -378,22 +374,10 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { return err } } else { - t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams()) - if err != nil { - log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err)) + if t.functionScore, err = rerank.NewFunctionScoreWithlegacy(t.schema.CollectionSchema, t.request.GetSearchParams()); err != nil { + log.Warn("Failed to create function by legacy info", zap.Error(err)) return err } - - // set up groupScorer for hybridsearch+groupBy - groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams()) - if err != nil { - groupScorerStr = MaxScorer - } - groupScorer, err := GetGroupScorer(groupScorerStr) - if err != nil { - return err - } - t.groupScorer = groupScorer } t.needRequery = len(t.request.OutputFields) > 0 || len(t.functionScore.GetAllInputFieldNames()) > 0 @@ -544,22 +528,12 @@ func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, t return err } - if t.functionScore == nil { - t.reScorers[index].setMetricType(subMetricType) - t.reScorers[index].reScore(result) - } searchMetrics = append(searchMetrics, subMetricType) multipleMilvusResults[index] = result } - if t.functionScore == nil { - if err := t.rank(ctx, span, multipleMilvusResults); err != nil { - return err - } - } else { - if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil { - return err - } + if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil { + return err } t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool { @@ -583,43 +557,6 @@ func (t *searchTask) fillResult() { t.fillInFieldInfo() } -// TODO: Old version rerank: rrf/weighted, subsequent unified rerank implementation -func (t *searchTask) rank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error { - primaryFieldSchema, err := t.schema.GetPkField() - if err != nil { - log.Warn("failed to get primary field schema", zap.Error(err)) - return err - } - if t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(), - t.rankParams, - primaryFieldSchema.GetDataType(), - multipleMilvusResults, - t.SearchRequest.GetGroupByFieldId(), - t.SearchRequest.GetGroupSize(), - t.groupScorer); err != nil { - log.Warn("rank search result failed", zap.Error(err)) - return err - } - - if t.needRequery { - if t.requeryFunc == nil { - t.requeryFunc = requeryImpl - } - queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields) - if err != nil { - log.Warn("failed to requery", zap.Error(err)) - return err - } - fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}) - if err != nil { - return err - } - t.result.Results.FieldsData = fields[0] - } - - return nil -} - func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) { uniqueIDs := &schemapb.IDs{} count := 0 @@ -659,10 +596,10 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf") defer sp.End() - + groupScorerStr := getGroupScorerStr(t.request.GetSearchParams()) params := rerank.NewSearchParams( t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal, - t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics, + t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, groupScorerStr, searchMetrics, ) return t.functionScore.Process(ctx, params, results) } @@ -703,6 +640,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult for i := 0; i < len(multipleMilvusResults); i++ { multipleMilvusResults[i].Results.FieldsData = fields[i] } + if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil { return err } @@ -838,8 +776,9 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf") defer sp.End() + groupScorerStr := getGroupScorerStr(t.request.GetSearchParams()) params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), - t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType}) + t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, groupScorerStr, []string{metricType}) // rank only returns id and score if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil { return err diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 9b4c6dc5fa..a6b24084ff 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1000,7 +1000,7 @@ func TestSearchTask_PreExecute(t *testing.T) { require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp) enqueueTs := uint64(100000) st.SetTs(enqueueTs) - assert.ErrorContains(t, st.PreExecute(ctx), "Current rerank does not support grouping search") + assert.NoError(t, st.PreExecute(ctx)) }) } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index d4b3c7daf5..664a8e414d 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -77,9 +77,6 @@ const ( // DefaultStringIndexType name of default index type for varChar/string field DefaultStringIndexType = indexparamcheck.IndexINVERTED - - defaultRRFParamsValue = 60 - maxRRFParamsValue = 16384 ) var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole))) @@ -427,7 +424,6 @@ func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchem if !exist { return fmt.Errorf("type param(max_capacity) should be specified for array field %s of collection %s", field.GetName(), collectionName) } - return nil } diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index ea28303859..f6ed8d236b 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -1,3 +1,6 @@ +//go:build test +// +build test + /* * # Licensed to the LF AI & Data foundation under one * # or more contributor license agreements. See the NOTICE file @@ -22,11 +25,13 @@ import ( "context" "encoding/json" "io" + "math" "net/http" "net/http/httptest" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/function/models/ali" "github.com/milvus-io/milvus/internal/util/function/models/cohere" "github.com/milvus-io/milvus/internal/util/function/models/openai" @@ -34,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/util/function/models/tei" "github.com/milvus-io/milvus/internal/util/function/models/vertexai" "github.com/milvus-io/milvus/internal/util/function/models/voyageai" + "github.com/milvus-io/milvus/pkg/v2/util/testutils" ) const TestModel string = "TestModel" @@ -247,3 +253,56 @@ func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockrunt body, _ := json.Marshal(resp) return &bedrockruntime.InvokeModelOutput{Body: body}, nil } + +func GenSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData { + tops := make([]int64, nq) + for i := 0; i < int(nq); i++ { + tops[i] = topk + } + fieldsData := []*schemapb.FieldData{} + if fieldName != "" { + fieldsData = []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))} + fieldsData[0].FieldId = fieldId + } + + data := &schemapb.SearchResultData{ + NumQueries: nq, + TopK: topk, + Scores: testutils.GenerateFloat32Array(int(nq * topk)), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: testutils.GenerateInt64Array(int(nq * topk)), + }, + }, + }, + Topks: tops, + FieldsData: fieldsData, + } + return data +} + +func GenSearchResultDataWithGrouping(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64, groupingName string, groupingId int64, groupSize int64) *schemapb.SearchResultData { + data := GenSearchResultData(nq, topk*groupSize, dType, fieldName, fieldId) + values := make([]int64, 0) + for i := int64(0); i < nq*topk*groupSize; i += groupSize { + for j := int64(0); j < groupSize; j++ { + values = append(values, i) + } + } + groupingField := testutils.GenerateScalarFieldDataWithValue(schemapb.DataType_Int64, groupingName, groupingId, values) + data.GroupByFieldValue = groupingField + return data +} + +func FloatsAlmostEqual(a, b []float32, epsilon float32) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if float32(math.Abs(float64(a[i]-b[i]))) > epsilon { + return false + } + } + return true +} diff --git a/internal/util/function/rerank/decay_function.go b/internal/util/function/rerank/decay_function.go index 7ae9db1a72..97fb1b4508 100644 --- a/internal/util/function/rerank/decay_function.go +++ b/internal/util/function/rerank/decay_function.go @@ -55,7 +55,7 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct { } func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false) + base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func toGreaterScore(score float32, metricType string) float32 { } } -func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns) *IDScores[T] { +func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) { srcScores := maxMerge[T](cols) decayScores := map[T]float32{} for _, col := range cols { @@ -186,7 +186,10 @@ func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, sear for id := range decayScores { decayScores[id] = decayScores[id] * srcScores[id] } - return newIDScores(decayScores, searchParams) + if searchParams.isGrouping() { + return newGroupingIDScores(decayScores, searchParams, idGroup) + } + return newIDScores(decayScores, searchParams), nil } func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) { @@ -198,7 +201,10 @@ func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *Sea col.scores[j] = toGreaterScore(score, metricType) } } - idScore := decay.processOneSearchData(ctx, searchParams, cols) + idScore, err := decay.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue) + if err != nil { + return nil, err + } appendResult(outputs, idScore.ids, idScore.scores) } return outputs, nil diff --git a/internal/util/function/rerank/decay_function_test.go b/internal/util/function/rerank/decay_function_test.go index deae1fdd02..3641032aff 100644 --- a/internal/util/function/rerank/decay_function_test.go +++ b/internal/util/function/rerank/decay_function_test.go @@ -27,7 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/v2/util/testutils" + "github.com/milvus-io/milvus/internal/util/function" ) func TestDecayFunction(t *testing.T) { @@ -260,8 +260,8 @@ func (s *DecayFunctionSuite) TestRerankProcess() { nq := int64(1) f, err := newDecayFunction(schema, functionSchema) s.NoError(err) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.NoError(err) s.Equal(int64(3), ret.searchResultData.TopK) s.Equal([]int64{}, ret.searchResultData.Topks) @@ -271,10 +271,10 @@ func (s *DecayFunctionSuite) TestRerankProcess() { { nq := int64(1) f, err := newDecayFunction(schema, functionSchema) - data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000) s.NoError(err) - _, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) + _, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) s.ErrorContains(err, "Search reaults mismatch rerank inputs") } @@ -289,9 +289,9 @@ func (s *DecayFunctionSuite) TestRerankProcess() { nq := int64(1) f, err := newDecayFunction(schema, functionSchema) s.NoError(err) - data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -302,9 +302,9 @@ func (s *DecayFunctionSuite) TestRerankProcess() { nq := int64(3) f, err := newDecayFunction(schema, functionSchema) s.NoError(err) - data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -329,11 +329,11 @@ func (s *DecayFunctionSuite) TestRerankProcess() { f, err := newDecayFunction(schema, functionSchema2) s.NoError(err) // ts/id data: 0 - 9 - data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) // empty - data2 := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs) + data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -345,11 +345,12 @@ func (s *DecayFunctionSuite) TestRerankProcess() { f, err := newDecayFunction(schema, functionSchema2) s.NoError(err) // ts/id data: 0 - 9 - data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) // ts/id data: 0 - 3 - data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs) + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -363,13 +364,13 @@ func (s *DecayFunctionSuite) TestRerankProcess() { // nq1 ts/id data: 0 - 9 // nq2 ts/id data: 10 - 19 // nq3 ts/id data: 20 - 29 - data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) // nq1 ts/id data: 0 - 3 // nq2 ts/id data: 4 - 7 // nq3 ts/id data: 8 - 11 - data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs) + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -390,26 +391,3 @@ func (s *DecayFunctionSuite) TestDecay() { s.Equal(linearDecay(0, 1, 0.5, 5, 5), 1.0) s.Less(linearDecay(0, 1, 0.5, 5, 6), 1.0) } - -func genSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData { - tops := make([]int64, nq) - for i := 0; i < int(nq); i++ { - tops[i] = topk - } - data := &schemapb.SearchResultData{ - NumQueries: nq, - TopK: topk, - Scores: testutils.GenerateFloat32Array(int(nq * topk)), - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: testutils.GenerateInt64Array(int(nq * topk)), - }, - }, - }, - Topks: tops, - FieldsData: []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))}, - } - data.FieldsData[0].FieldId = fieldId - return data -} diff --git a/internal/util/function/rerank/function_score.go b/internal/util/function/rerank/function_score.go index ea15fdd8c3..de3065a89a 100644 --- a/internal/util/function/rerank/function_score.go +++ b/internal/util/function/rerank/function_score.go @@ -20,38 +20,79 @@ package rerank import ( "context" + "encoding/json" "fmt" + "reflect" + "strconv" "strings" "github.com/samber/lo" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" ) const ( decayFunctionName string = "decay" modelFunctionName string = "model" + rrfName string = "rrf" + weightedName string = "weighted" ) +const ( + maxScorer string = "max" + sumScorer string = "sum" + avgScorer string = "avg" +) + +// legacy rrf/weighted rerank configs + +const ( + legacyRankTypeKey = "strategy" + legacyRankParamsKey = "params" +) + +type rankType int + +const ( + invalidRankType rankType = iota // invalidRankType = 0 + rrfRankType // rrfRankType = 1 + weightedRankType // weightedRankType = 2 +) + +var rankTypeMap = map[string]rankType{ + "invalid": invalidRankType, + "rrf": rrfRankType, + "weighted": weightedRankType, +} + type SearchParams struct { nq int64 limit int64 offset int64 roundDecimal int64 - // TODO: supports group search groupByFieldId int64 groupSize int64 strictGroupSize bool + groupScore string searchMetrics []string } -func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, searchMetrics []string) *SearchParams { +func (s *SearchParams) isGrouping() bool { + return s.groupByFieldId > 0 +} + +func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, groupScore string, searchMetrics []string) *SearchParams { + if groupScore == "" { + groupScore = maxScorer + } return &SearchParams{ - nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, searchMetrics, + nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, groupScore, searchMetrics, } } @@ -95,6 +136,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb. rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema) case modelFunctionName: rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema) + case rrfName: + rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema) + case weightedName: + rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema) default: return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName) } @@ -117,6 +162,68 @@ func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *sc return funcScore, nil } +func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) { + var params map[string]interface{} + rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankTypeKey, rankParams) + if err != nil { + rankTypeStr = "rrf" + params = make(map[string]interface{}, 0) + } else { + if _, ok := rankTypeMap[rankTypeStr]; !ok { + return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr) + } + paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankParamsKey, rankParams) + if err != nil { + return nil, fmt.Errorf("params" + " not found in rank_params") + } + err = json.Unmarshal([]byte(paramStr), ¶ms) + if err != nil { + return nil, fmt.Errorf("Parse rerank params failed, err: %s", err) + } + } + fSchema := schemapb.FunctionSchema{ + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{}, + OutputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{}, + } + switch rankTypeMap[rankTypeStr] { + case rrfRankType: + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: rrfName}) + if v, ok := params[RRFParamsKey]; ok { + if reflect.ValueOf(params[RRFParamsKey]).CanFloat() { + k := reflect.ValueOf(v).Float() + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: RRFParamsKey, Value: strconv.FormatFloat(k, 'f', -1, 64)}) + } else { + return nil, fmt.Errorf("The type of rank param k should be float") + } + } + case weightedRankType: + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: weightedName}) + if v, ok := params[WeightsParamsKey]; ok { + if d, err := json.Marshal(v); err != nil { + return nil, fmt.Errorf("The weights param should be an array") + } else { + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: string(d)}) + } + } + if normScore, ok := params[NormScoreKey]; ok { + if ns, ok := normScore.(bool); ok { + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: NormScoreKey, Value: strconv.FormatBool(ns)}) + } else { + return nil, fmt.Errorf("Weighted rerank err, norm_score should been bool type, but [norm_score:%s]'s type is %T", normScore, normScore) + } + } + default: + return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr) + } + funcScore := &FunctionScore{} + if funcScore.reranker, err = createFunction(collSchema, &fSchema); err != nil { + return nil, err + } + return funcScore, nil +} + func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) { if len(multipleMilvusResults) == 0 { return &milvuspb.SearchResults{ @@ -137,7 +244,7 @@ func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchPa }) // rankResult only has scores - inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs()) + inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs(), searchParams.isGrouping()) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/function_score_test.go b/internal/util/function/rerank/function_score_test.go index 7243dcddf5..09cb7b2d01 100644 --- a/internal/util/function/rerank/function_score_test.go +++ b/internal/util/function/rerank/function_score_test.go @@ -20,6 +20,7 @@ package rerank import ( "context" + "math" "testing" "github.com/stretchr/testify/suite" @@ -27,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function" ) func TestFunctionScore(t *testing.T) { @@ -75,7 +77,7 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() { s.NoError(err) s.Equal([]string{"ts"}, f.GetAllInputFieldNames()) s.Equal([]int64{102}, f.GetAllInputFieldIDs()) - s.Equal(false, f.IsSupportGroup()) + s.Equal(true, f.IsSupportGroup()) s.Equal("decay", f.reranker.GetRankName()) { @@ -152,7 +154,7 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { // empty inputs { nq := int64(1) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal(0, len(ret.Results.FieldsData)) @@ -162,11 +164,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { // nq = 1 { nq := int64(1) - data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) searchData := &milvuspb.SearchResults{ Results: data, } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{3}, ret.Results.Topks) @@ -174,11 +176,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { // nq=1, input is empty { nq := int64(1) - data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102) + data := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102) searchData := &milvuspb.SearchResults{ Results: data, } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{0}, ret.Results.Topks) @@ -186,11 +188,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { // nq=3 { nq := int64(3) - data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102) searchData := &milvuspb.SearchResults{ Results: data, } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{3, 3, 3}, ret.Results.Topks) @@ -198,11 +200,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { // nq=3, all input is empty { nq := int64(3) - data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102) + data := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102) searchData := &milvuspb.SearchResults{ Results: data, } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{0, 0, 0}, ret.Results.Topks) @@ -213,13 +215,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { { nq := int64(1) searchData1 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), } searchData2 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102), } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{3}, ret.Results.Topks) @@ -228,13 +230,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { { nq := int64(1) searchData1 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), } searchData2 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{0}, ret.Results.Topks) @@ -243,13 +245,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { { nq := int64(1) searchData1 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), } searchData2 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{3}, ret.Results.Topks) @@ -258,13 +260,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { { nq := int64(3) searchData1 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), } searchData2 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102), } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{3, 3, 3}, ret.Results.Topks) @@ -273,13 +275,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { { nq := int64(3) searchData1 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), } searchData2 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{0, 0, 0}, ret.Results.Topks) @@ -288,15 +290,131 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { { nq := int64(3) searchData1 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102), } searchData2 := &milvuspb.SearchResults{ - Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), + Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102), } - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2}) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2}) s.NoError(err) s.Equal(int64(3), ret.Results.TopK) s.Equal([]int64{3, 3, 3}, ret.Results.Topks) } } + +func (s *FunctionScoreSuite) TestlegacyFunction() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + { + rankParams := []*commonpb.KeyValuePair{} + f, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.NoError(err) + s.Equal(f.reranker.GetRankName(), rrfName) + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "invalid"}, + {Key: legacyRankParamsKey, Value: `{"k": "v"}`}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.ErrorContains(err, "unsupported rank type") + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "rrf"}, + {Key: legacyRankParamsKey, Value: "invalid"}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.ErrorContains(err, "Parse rerank params failed") + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "rrf"}, + {Key: legacyRankParamsKey, Value: `{"k": "invalid"}`}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.ErrorContains(err, "The type of rank param k should be float") + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "rrf"}, + {Key: legacyRankParamsKey, Value: `{"k": 1.0}`}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.NoError(err) + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "weighted"}, + {Key: legacyRankParamsKey, Value: `{"weights": [1.0]}`}, + } + f, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.NoError(err) + s.Equal(f.reranker.GetRankName(), weightedName) + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "weighted"}, + {Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": "Invalid"}`}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.ErrorContains(err, "Weighted rerank err, norm_score should been bool type") + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "weighted"}, + {Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": false}`}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.NoError(err) + } + { + rankParams := []*commonpb.KeyValuePair{ + {Key: legacyRankTypeKey, Value: "weighted"}, + {Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": "false"}`}, + } + _, err := NewFunctionScoreWithlegacy(schema, rankParams) + s.ErrorContains(err, "Weighted rerank err, norm_score should been bool type") + } +} + +func (s *FunctionScoreSuite) TestFunctionUtil() { + g1 := &Group[int64]{ + idList: []int64{1, 2, 3}, + scoreList: []float32{1.0, 2.0, 3.0}, + groupVal: 3, + maxScore: 3.0, + sumScore: 6.0, + } + s1, err := groupScore(g1, maxScorer) + s.NoError(err) + s.True(math.Abs(float64(s1-3.0)) < 0.001) + + s2, err := groupScore(g1, sumScorer) + s.NoError(err) + s.True(math.Abs(float64(s2-6.0)) < 0.001) + + s3, err := groupScore(g1, avgScorer) + s.NoError(err) + s.True(math.Abs(float64(s3-2.0)) < 0.001) + + _, err = groupScore(g1, "NotSupported") + s.ErrorContains(err, "is not supported") + + g1.idList = []int64{} + _, err = groupScore(g1, avgScorer) + s.ErrorContains(err, "input group for score must have at least one id, must be sth wrong within code") +} diff --git a/internal/util/function/rerank/model_function.go b/internal/util/function/rerank/model_function.go index 79e0205196..00d26eb9a4 100644 --- a/internal/util/function/rerank/model_function.go +++ b/internal/util/function/rerank/model_function.go @@ -296,7 +296,7 @@ type ModelFunction[T PKType] struct { } func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false) + base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true) if err != nil { return nil, err } @@ -333,7 +333,7 @@ func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemap } } -func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) { +func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns, idGroup map[any]any) (*IDScores[T], error) { uniqueData := make(map[T]string) for _, col := range cols { texts := col.data[0].([]string) @@ -359,6 +359,9 @@ func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchP for idx, id := range ids { rerankScores[id] = scores[idx] } + if searchParams.isGrouping() { + return newGroupingIDScores(rerankScores, searchParams, idGroup) + } return newIDScores(rerankScores, searchParams), nil } @@ -368,7 +371,7 @@ func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *Search } outputs := newRerankOutputs(searchParams) for idx, cols := range inputs.data { - idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols) + idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols, inputs.idGroupValue) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/model_function_test.go b/internal/util/function/rerank/model_function_test.go index 4e3f709f8e..762d952fa9 100644 --- a/internal/util/function/rerank/model_function_test.go +++ b/internal/util/function/rerank/model_function_test.go @@ -375,8 +375,8 @@ func (s *RerankModelSuite) TestRerankProcess() { nq := int64(1) f, err := newModelFunction(schema, functionSchema) s.NoError(err) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.NoError(err) s.Equal(int64(3), ret.searchResultData.TopK) s.Equal([]int64{}, ret.searchResultData.Topks) @@ -386,10 +386,10 @@ func (s *RerankModelSuite) TestRerankProcess() { { nq := int64(1) f, err := newModelFunction(schema, functionSchema) - data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000) s.NoError(err) - _, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) + _, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) s.ErrorContains(err, "Search reaults mismatch rerank inputs") } } @@ -430,18 +430,18 @@ func (s *RerankModelSuite) TestRerankProcess() { { f, err := newModelFunction(schema, functionSchema) s.NoError(err) - data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) - _, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + _, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]") } { functionSchema.Params[2].Value = `["q1"]` f, err := newModelFunction(schema, functionSchema) s.NoError(err) - data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 0, -1, -1, 1, false, []string{"COSINE"}}, inputs) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -452,9 +452,9 @@ func (s *RerankModelSuite) TestRerankProcess() { functionSchema.Params[2].Value = `["q1", "q2", "q3"]` f, err := newModelFunction(schema, functionSchema) s.NoError(err) - data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -468,11 +468,11 @@ func (s *RerankModelSuite) TestRerankProcess() { functionSchema.Params[2].Value = `["q1"]` f, err := newModelFunction(schema, functionSchema) s.NoError(err) - data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) // empty - data2 := genSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -483,11 +483,11 @@ func (s *RerankModelSuite) TestRerankProcess() { f, err := newModelFunction(schema, functionSchema) s.NoError(err) // ts/id data: 0 - 9 - data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) // ts/id data: 0 - 3 - data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs) + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) @@ -498,10 +498,10 @@ func (s *RerankModelSuite) TestRerankProcess() { functionSchema.Params[2].Value = `["q1", "q2", "q3"]` f, err := newModelFunction(schema, functionSchema) s.NoError(err) - data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) - data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101) - inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) - ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs) + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs) s.NoError(err) s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) s.Equal(int64(3), ret.searchResultData.TopK) diff --git a/internal/util/function/rerank/rrf_function.go b/internal/util/function/rerank/rrf_function.go new file mode 100644 index 0000000000..9df02832be --- /dev/null +++ b/internal/util/function/rerank/rrf_function.go @@ -0,0 +1,101 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +const ( + RRFParamsKey string = "k" + + defaultRRFParamsValue float64 = 60 +) + +type RRFFunction[T PKType] struct { + RerankBase + + k float32 +} + +func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, rrfName, true) + if err != nil { + return nil, err + } + + if len(base.GetInputFieldNames()) != 0 { + return nil, fmt.Errorf("The rrf function does not support input parameters, but got %s", base.GetInputFieldNames()) + } + + k := float64(defaultRRFParamsValue) + for _, param := range funcSchema.Params { + if strings.ToLower(param.Key) == RRFParamsKey { + if k, err = strconv.ParseFloat(param.Value, 64); err != nil { + return nil, fmt.Errorf("Param k:%s is not a number", param.Value) + } + } + } + if k <= 0 || k >= 16384 { + return nil, fmt.Errorf("The rank params k should be in range (0, %d)", 16384) + } + if base.pkType == schemapb.DataType_Int64 { + return &RRFFunction[int64]{RerankBase: *base, k: float32(k)}, nil + } else { + return &RRFFunction[string]{RerankBase: *base, k: float32(k)}, nil + } +} + +func (rrf *RRFFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) { + rrfScores := map[T]float32{} + for _, col := range cols { + if col.size == 0 { + continue + } + ids := col.ids.([]T) + for idx, id := range ids { + if score, ok := rrfScores[id]; !ok { + rrfScores[id] = 1 / (rrf.k + float32(idx+1)) + } else { + rrfScores[id] = score + 1/(rrf.k+float32(idx+1)) + } + } + } + if searchParams.isGrouping() { + return newGroupingIDScores(rrfScores, searchParams, idGroup) + } + return newIDScores(rrfScores, searchParams), nil +} + +func (rrf *RRFFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) { + outputs := newRerankOutputs(searchParams) + for _, cols := range inputs.data { + idScore, err := rrf.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue) + if err != nil { + return nil, err + } + appendResult(outputs, idScore.ids, idScore.scores) + } + return outputs, nil +} diff --git a/internal/util/function/rerank/rrf_function_test.go b/internal/util/function/rerank/rrf_function_test.go new file mode 100644 index 0000000000..0d08abbd35 --- /dev/null +++ b/internal/util/function/rerank/rrf_function_test.go @@ -0,0 +1,250 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function" +) + +func TestRRFFunction(t *testing.T) { + suite.Run(t, new(RRFFunctionSuite)) +} + +type RRFFunctionSuite struct { + suite.Suite +} + +func (s *RRFFunctionSuite) TestNewRRFFuction() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: RRFParamsKey, Value: "70"}, + }, + } + + { + _, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + } + { + schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true} + _, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + } + { + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "NotNum"} + _, err := newRRFFunction(schema, functionSchema) + s.ErrorContains(err, "is not a number") + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "-1"} + _, err = newRRFFunction(schema, functionSchema) + s.ErrorContains(err, "he rank params k should be in range") + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "100"} + } + { + functionSchema.InputFieldNames = []string{"ts"} + _, err := newRRFFunction(schema, functionSchema) + s.ErrorContains(err, "The rrf function does not support input parameters") + } +} + +func (s *RRFFunctionSuite) TestRRFFuctionProcess() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{}, + } + + // empty + { + nq := int64(1) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) + s.NoError(err) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{}, ret.searchResultData.Topks) + } + + // singleSearchResultData + // nq = 1 + { + nq := int64(1) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{2, 3, 4}, ret.searchResultData.Ids.GetIntId().Data) + } + // nq = 3 + { + nq := int64(3) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{2, 3, 4, 12, 13, 14, 22, 23, 24}, ret.searchResultData.Ids.GetIntId().Data) + } + + // has empty inputs + { + nq := int64(1) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + // id data: 0 - 9 + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + // empty + data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{0, 1, 2}, ret.searchResultData.Ids.GetIntId().Data) + } + // nq = 1 + { + nq := int64(1) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + // id data: 0 - 9 + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + // id data: 0 - 3 + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{2, 3, 4}, ret.searchResultData.Ids.GetIntId().Data) + } + // // nq = 3 + { + nq := int64(3) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + // nq1 id data: 0 - 9 + // nq2 id data: 10 - 19 + // nq3 id data: 20 - 29 + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + // nq1 id data: 0 - 3 + // nq2 id data: 4 - 7 + // nq3 id data: 8 - 11 + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{2, 3, 4, 5, 11, 6, 9, 21, 10}, ret.searchResultData.Ids.GetIntId().Data) + } + // // nq = 3, grouping = true, grouping size = 1 + { + nq := int64(3) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + // nq1 id data: 0 - 9 + // nq2 id data: 10 - 19 + // nq3 id data: 20 - 29 + data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 1) + // nq1 id data: 0 - 3 + // nq2 id data: 4 - 7 + // nq3 id data: 8 - 11 + data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 1) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, 102, 1, true, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{2, 3, 4, 5, 11, 6, 9, 21, 10}, ret.searchResultData.Ids.GetIntId().Data) + } + + // // nq = 3, grouping = true, grouping size = 3 + { + nq := int64(3) + f, err := newRRFFunction(schema, functionSchema) + s.NoError(err) + + // nq1 id data: 0 - 29, group value: 0,0,0,1,1,1, ... , 9,9,9 + // nq2 id data: 30 - 59, group value: 10,10,10,11,11,11, ... , 19,19,19 + // nq3 id data: 60 - 99, group value: 20,20,20,21,21,21, ... , 29,29,29 + data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 3) + // nq1 id data: 0 - 11, group value: 0,0,0,1,1,1,2,2,2,3,3,3, + // nq2 id data: 12 - 23, group value: 4,4,4,5,5,5,6,6,6,7,7,7 + // nq3 id data: 24 - 35, group value: 8,8,8,9,9,9,10,10,10,11,11,11 + data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 3) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 3, 102, 3, true, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{9, 9, 9}, ret.searchResultData.Topks) + s.Equal(int64(9), ret.searchResultData.TopK) + s.Equal([]int64{ + 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 33, 34, 35, 18, 19, 20, + 27, 28, 29, 63, 64, 65, 30, 31, 32, + }, + ret.searchResultData.Ids.GetIntId().Data) + } +} diff --git a/internal/util/function/rerank/util.go b/internal/util/function/rerank/util.go index 6fe8b1d6a1..f075b8cc99 100644 --- a/internal/util/function/rerank/util.go +++ b/internal/util/function/rerank/util.go @@ -24,6 +24,8 @@ import ( "sort" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) type PKType interface { @@ -40,9 +42,9 @@ type columns struct { type rerankInputs struct { // nqs,searchResultsIndex - data [][]*columns - - nq int64 + data [][]*columns + idGroupValue map[any]any + nq int64 // There is only fieldId in schemapb.SearchResultData, but no fieldName inputFieldIds []int64 @@ -69,7 +71,7 @@ func organizeFieldIdData(multipSearchResultData []*schemapb.SearchResultData, in return multipIdField, nil } -func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputFieldIds []int64) (*rerankInputs, error) { +func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputFieldIds []int64, isGrouping bool) (*rerankInputs, error) { if len(multipSearchResultData) == 0 { return &rerankInputs{}, nil } @@ -84,27 +86,34 @@ func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputF cols[i] = make([]*columns, len(multipSearchResultData)) } for retIdx, searchResult := range multipSearchResultData { - for _, fieldId := range inputFieldIds { - fieldData := multipIdField[retIdx][fieldId] - start := int64(0) - for i := int64(0); i < nq; i++ { - size := searchResult.Topks[i] + start := int64(0) + for i := int64(0); i < nq; i++ { + size := searchResult.Topks[i] + if cols[i][retIdx] == nil { + cols[i][retIdx] = &columns{} + cols[i][retIdx].size = size + cols[i][retIdx].ids = getIds(searchResult.Ids, start, size) + cols[i][retIdx].scores = searchResult.Scores[start : start+size] + } + for _, fieldId := range inputFieldIds { + fieldData := multipIdField[retIdx][fieldId] d, err := getField(fieldData, start, size) if err != nil { return nil, err } - if cols[i][retIdx] == nil { - cols[i][retIdx] = &columns{} - cols[i][retIdx].size = size - cols[i][retIdx].ids = getIds(searchResult.Ids, start, size) - cols[i][retIdx].scores = searchResult.Scores[start : start+size] - } cols[i][retIdx].data = append(cols[i][retIdx].data, d) - start += size } + start += size } } - return &rerankInputs{cols, nq, inputFieldIds}, nil + if isGrouping { + idGroup, err := genIdGroupingMap(multipSearchResultData) + if err != nil { + return nil, err + } + return &rerankInputs{cols, idGroup, nq, inputFieldIds}, nil + } + return &rerankInputs{cols, nil, nq, inputFieldIds}, nil } func (inputs *rerankInputs) numOfQueries() int64 { @@ -116,9 +125,13 @@ type rerankOutputs struct { } func newRerankOutputs(searchParams *SearchParams) *rerankOutputs { + topk := searchParams.limit + if searchParams.isGrouping() { + topk = topk * searchParams.groupSize + } ret := &schemapb.SearchResultData{ NumQueries: searchParams.nq, - TopK: searchParams.limit, + TopK: topk, FieldsData: make([]*schemapb.FieldData, 0), Scores: []float32{}, Ids: &schemapb.IDs{}, @@ -153,28 +166,11 @@ func appendResult[T PKType](outputs *rerankOutputs, ids []T, scores []float32) { } type IDScores[T PKType] struct { - // idScores map[T]float32 ids []T scores []float32 size int64 } -// func (s *IDScores[T]) GetSortedIdScores() ([]T, []float32) { -// ids := make([]T, 0, s.size) -// big := func(i, j int) bool { -// if s.idScores[ids[i]] == s.idScores[ids[j]] { -// return ids[i] < ids[j] -// } -// return s.idScores[ids[i]] > s.idScores[ids[j]] -// } -// sort.Slice(ids, big) -// scores := make([]float32, 0, s.size) -// for _, id := range ids { -// scores = append(scores, s.idScores[id]) -// } -// return ids, scores -// } - func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *IDScores[T] { ids := make([]T, 0, len(idScores)) for id := range idScores { @@ -209,6 +205,120 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) * return &ret } +func genIDGroupValueMap[T PKType]() map[T]any { + return nil +} + +func groupScore[T PKType](group *Group[T], scorerType string) (float32, error) { + switch scorerType { + case maxScorer: + return group.maxScore, nil + case sumScorer: + return group.sumScore, nil + case avgScorer: + if len(group.idList) == 0 { + return 0, merr.WrapErrParameterInvalid(1, len(group.idList), + "input group for score must have at least one id, must be sth wrong within code") + } + return group.sumScore / float32(len(group.idList)), nil + default: + return 0, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType) + } +} + +type Group[T PKType] struct { + idList []T + scoreList []float32 + groupVal any + maxScore float32 + sumScore float32 + finalScore float32 +} + +func newGroupingIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams, idGroup map[any]any) (*IDScores[T], error) { + ids := make([]T, 0, len(idScores)) + for id := range idScores { + ids = append(ids, id) + } + + sort.Slice(ids, func(i, j int) bool { + if idScores[ids[i]] == idScores[ids[j]] { + return ids[i] < ids[j] + } + return idScores[ids[i]] > idScores[ids[j]] + }) + + buckets := make(map[interface{}]*Group[T]) + for _, id := range ids { + score := idScores[id] + groupVal := idGroup[id] + if buckets[groupVal] == nil { + buckets[groupVal] = &Group[T]{ + idList: make([]T, 0), + scoreList: make([]float32, 0), + groupVal: groupVal, + } + } + if int64(len(buckets[groupVal].idList)) >= searchParams.groupSize { + continue + } + buckets[groupVal].idList = append(buckets[groupVal].idList, id) + buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, idScores[id]) + if score > buckets[groupVal].maxScore { + buckets[groupVal].maxScore = score + } + buckets[groupVal].sumScore += score + } + + groupList := make([]*Group[T], len(buckets)) + idx := 0 + var err error + for _, group := range buckets { + if group.finalScore, err = groupScore(group, searchParams.groupScore); err != nil { + return nil, err + } + groupList[idx] = group + idx += 1 + } + sort.Slice(groupList, func(i, j int) bool { + if groupList[i].finalScore == groupList[j].finalScore { + if len(groupList[i].idList) == len(groupList[j].idList) { + // if final score and size of group are both equal + // choose the group with smaller first key + // here, it's guaranteed all group having at least one id in the idList + return groupList[i].idList[0] < groupList[j].idList[0] + } + // choose the larger group when scores are equal + return len(groupList[i].idList) > len(groupList[j].idList) + } + return groupList[i].finalScore > groupList[j].finalScore + }) + + if int64(len(groupList)) > searchParams.limit+searchParams.offset { + groupList = groupList[:searchParams.limit+searchParams.offset] + } + + ret := IDScores[T]{ + make([]T, 0, searchParams.limit), + make([]float32, 0, searchParams.limit), + 0, + } + for index := int(searchParams.offset); index < len(groupList); index++ { + group := groupList[index] + for i, score := range group.scoreList { + // idList and scoreList must have same length + if searchParams.roundDecimal != -1 { + multiplier := math.Pow(10.0, float64(searchParams.roundDecimal)) + score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier) + } + ret.scores = append(ret.scores, score) + ret.ids = append(ret.ids, group.idList[i]) + } + } + ret.size = int64(len(ret.ids)) + return &ret, nil +} + func getField(inputField *schemapb.FieldData, start int64, size int64) (any, error) { switch inputField.Type { case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: @@ -299,3 +409,22 @@ func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error) } return pkType, nil } + +func genIdGroupingMap(multipSearchResultData []*schemapb.SearchResultData) (map[any]any, error) { + idGroupValue := map[any]any{} + for _, result := range multipSearchResultData { + if result.GetGroupByFieldValue() == nil { + return nil, fmt.Errorf("Group value is nil") + } + size := typeutil.GetSizeOfIDs(result.Ids) + groupIter := typeutil.GetDataIterator(result.GetGroupByFieldValue()) + for i := 0; i < size; i++ { + groupByVal := groupIter(i) + id := typeutil.GetPK(result.Ids, int64(i)) + if _, exist := idGroupValue[id]; !exist { + idGroupValue[id] = groupByVal + } + } + } + return idGroupValue, nil +} diff --git a/internal/util/function/rerank/weighted_function.go b/internal/util/function/rerank/weighted_function.go new file mode 100644 index 0000000000..3aa03c0c64 --- /dev/null +++ b/internal/util/function/rerank/weighted_function.go @@ -0,0 +1,154 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/metric" +) + +const ( + WeightsParamsKey string = "weights" + NormScoreKey string = "norm_score" +) + +type WeightedFunction[T PKType] struct { + RerankBase + + weight []float32 + needNorm bool +} + +func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, weightedName, true) + if err != nil { + return nil, err + } + + if len(base.GetInputFieldNames()) != 0 { + return nil, fmt.Errorf("The weighted function does not support input parameters, but got %s", base.GetInputFieldNames()) + } + + var weights []float32 + needNorm := false + for _, param := range funcSchema.Params { + switch strings.ToLower(param.Key) { + case WeightsParamsKey: + if err := json.Unmarshal([]byte(param.Value), &weights); err != nil { + return nil, fmt.Errorf("Parse %s param failed, weight should be []float, bug got: %s", WeightsParamsKey, param.Value) + } + for _, weight := range weights { + if weight < 0 || weight > 1 { + return nil, fmt.Errorf("rank param weight should be in range [0, 1]") + } + } + case NormScoreKey: + if needNorm, err = strconv.ParseBool(param.Value); err != nil { + return nil, fmt.Errorf("%s params must be true/false, bug got %s", NormScoreKey, param.Value) + } + } + } + if len(weights) == 0 { + return nil, fmt.Errorf(WeightsParamsKey + " not found") + } + if base.pkType == schemapb.DataType_Int64 { + return &WeightedFunction[int64]{RerankBase: *base, weight: weights, needNorm: needNorm}, nil + } else { + return &WeightedFunction[string]{RerankBase: *base, weight: weights, needNorm: needNorm}, nil + } +} + +func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) { + if len(cols) != len(weighted.weight) { + return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(cols)), fmt.Sprint(len(weighted.weight)), "the length of weights param mismatch with ann search requests") + } + weightedScores := map[T]float32{} + for i, col := range cols { + if col.size == 0 { + continue + } + normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i]) + ids := col.ids.([]T) + for j, id := range ids { + if score, ok := weightedScores[id]; !ok { + weightedScores[id] = weighted.weight[i] * normFunc(col.scores[j]) + } else { + weightedScores[id] = score + weighted.weight[i]*normFunc(col.scores[j]) + } + } + } + if searchParams.isGrouping() { + return newGroupingIDScores(weightedScores, searchParams, idGroup) + } + return newIDScores(weightedScores, searchParams), nil +} + +func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) { + outputs := newRerankOutputs(searchParams) + for _, cols := range inputs.data { + for i, col := range cols { + metricType := searchParams.searchMetrics[i] + for j, score := range col.scores { + col.scores[j] = toGreaterScore(score, metricType) + } + } + idScore, err := weighted.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue) + if err != nil { + return nil, err + } + appendResult(outputs, idScore.ids, idScore.scores) + } + return outputs, nil +} + +type normalizeFunc func(float32) float32 + +func getNormalizeFunc(normScore bool, metrics string) normalizeFunc { + if !normScore { + return func(distance float32) float32 { + return distance + } + } + switch metrics { + case metric.COSINE: + return func(distance float32) float32 { + return (1 + distance) * 0.5 + } + case metric.IP: + return func(distance float32) float32 { + return 0.5 + float32(math.Atan(float64(distance)))/math.Pi + } + case metric.BM25: + return func(distance float32) float32 { + return 2 * float32(math.Atan(float64(distance))) / math.Pi + } + default: + return func(distance float32) float32 { + return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi + } + } +} diff --git a/internal/util/function/rerank/weighted_function_test.go b/internal/util/function/rerank/weighted_function_test.go new file mode 100644 index 0000000000..7947a1ee68 --- /dev/null +++ b/internal/util/function/rerank/weighted_function_test.go @@ -0,0 +1,298 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function" + "github.com/milvus-io/milvus/pkg/v2/util/metric" +) + +func TestWeightedFunction(t *testing.T) { + suite.Run(t, new(WeightedFunctionSuite)) +} + +type WeightedFunctionSuite struct { + suite.Suite +} + +func (s *WeightedFunctionSuite) TestNewWeightedFuction() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: WeightsParamsKey, Value: `[0.1, 0.9]`}, + }, + } + + { + _, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + } + { + schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true} + _, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + } + { + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: "NotNum"} + _, err := newWeightedFunction(schema, functionSchema) + s.ErrorContains(err, "param failed, weight should be []float") + } + { + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[10]`} + _, err := newWeightedFunction(schema, functionSchema) + s.ErrorContains(err, "rank param weight should be in range [0, 1]") + } + { + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `[10]`} + _, err := newWeightedFunction(schema, functionSchema) + s.ErrorContains(err, "not found") + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[0.1, 0.9]`} + } + { + functionSchema.InputFieldNames = []string{"ts"} + _, err := newWeightedFunction(schema, functionSchema) + s.ErrorContains(err, "The weighted function does not support input parameters,") + } +} + +func (s *WeightedFunctionSuite) TestWeightedFuctionProcess() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: WeightsParamsKey, Value: `[0.1]`}, + }, + } + + // empty + { + nq := int64(1) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) + s.NoError(err) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{}, ret.searchResultData.Topks) + } + + // singleSearchResultData + // nq = 1 + { + nq := int64(1) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{7, 6, 5}, ret.searchResultData.Ids.GetIntId().Data) + } + // nq = 3 + { + nq := int64(3) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{9, 8, 7, 19, 18, 17, 29, 28, 27}, ret.searchResultData.Ids.GetIntId().Data) + } + + // number of weigts not equal to search data + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[0.1, 0.9]`} + { + nq := int64(1) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false) + _, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.ErrorContains(err, "the length of weights param mismatch with ann search requests") + } + // has empty inputs + { + nq := int64(1) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + // id data: 0 - 9 + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + // empty + data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{9, 8, 7}, ret.searchResultData.Ids.GetIntId().Data) + s.True(function.FloatsAlmostEqual([]float32{0.9, 0.8, 0.7}, ret.searchResultData.Scores, 0.001)) + } + // nq = 1 + { + nq := int64(1) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + // id data: 0 - 9 + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + // id data: 0 - 3 + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{1, 9, 8}, ret.searchResultData.Ids.GetIntId().Data) + s.True(function.FloatsAlmostEqual([]float32{1, 0.9, 0.8}, ret.searchResultData.Scores, 0.001)) + } + // // nq = 3 + { + nq := int64(3) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + // nq1 id data: 0 - 9 + // nq2 id data: 10 - 19 + // nq3 id data: 20 - 29 + data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0) + // nq1 id data: 0 - 3 + // nq2 id data: 4 - 7 + // nq3 id data: 8 - 11 + data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, 1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{3, 2, 1, 7, 6, 5, 11, 10, 9}, ret.searchResultData.Ids.GetIntId().Data) + s.True(function.FloatsAlmostEqual([]float32{3, 2, 1, 6.3, 5.4, 4.5, 9.9, 9, 8.1}, ret.searchResultData.Scores, 0.001)) + } + // // nq = 3, grouping = true, grouping size = 1 + { + nq := int64(3) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + // nq1 id data: 0 - 9 + // nq2 id data: 10 - 19 + // nq3 id data: 20 - 29 + data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 1) + // nq1 id data: 0 - 3 + // nq2 id data: 4 - 7 + // nq3 id data: 8 - 11 + data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 1) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, 1, 102, 1, true, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal([]int64{3, 2, 1, 7, 6, 5, 11, 10, 9}, ret.searchResultData.Ids.GetIntId().Data) + s.True(function.FloatsAlmostEqual([]float32{3, 2, 1, 6.3, 5.4, 4.5, 9.9, 9, 8.1}, ret.searchResultData.Scores, 0.001)) + } + + // // nq = 3, grouping = true, grouping size = 3 + { + nq := int64(3) + f, err := newWeightedFunction(schema, functionSchema) + s.NoError(err) + + // nq1 id data: 0 - 29, group value: 0,0,0,1,1,1, ... , 9,9,9 + // nq2 id data: 30 - 59, group value: 10,10,10,11,11,11, ... , 19,19,19 + // nq3 id data: 60 - 99, group value: 20,20,20,21,21,21, ... , 29,29,29 + data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 3) + // nq1 id data: 0 - 11, group value: 0,0,0,1,1,1,2,2,2,3,3,3, + // nq2 id data: 12 - 23, group value: 4,4,4,5,5,5,6,6,6,7,7,7 + // nq3 id data: 24 - 35, group value: 8,8,8,9,9,9,10,10,10,11,11,11 + data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 3) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true) + ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, 102, 3, true, "", []string{"COSINE", "COSINE"}), inputs) + s.NoError(err) + s.Equal([]int64{9, 9, 9}, ret.searchResultData.Topks) + s.Equal(int64(9), ret.searchResultData.TopK) + s.Equal([]int64{ + 5, 4, 3, 29, 28, 27, 26, 25, 24, + 17, 16, 15, 14, 13, 12, 59, 58, 57, + 29, 28, 27, 26, 25, 24, 89, 88, 87, + }, + ret.searchResultData.Ids.GetIntId().Data) + } +} + +func (s *WeightedFunctionSuite) TestWeightedFuctionNormalize() { + { + f := getNormalizeFunc(false, metric.COSINE) + s.Equal(float32(1.0), f(1.0)) + } + { + f := getNormalizeFunc(true, metric.COSINE) + s.Equal(float32((1+1.0)*0.5), f(1)) + } + { + f := getNormalizeFunc(true, metric.IP) + s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1)) + } + { + f := getNormalizeFunc(true, metric.BM25) + s.Equal(float32(2*math.Atan(float64(1.0)))/math.Pi, f(1.0)) + } + { + f := getNormalizeFunc(true, metric.L2) + s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0)) + } +} diff --git a/tests/integration/partialsearch/partial_result_on_node_down_test.go b/tests/integration/partialsearch/partial_result_on_node_down_test.go index eb61da56c2..e7d2536814 100644 --- a/tests/integration/partialsearch/partial_result_on_node_down_test.go +++ b/tests/integration/partialsearch/partial_result_on_node_down_test.go @@ -213,6 +213,7 @@ func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() { for _, qn := range s.Cluster.GetAllQueryNodes() { qn.Stop() } + time.Sleep(2 * time.Second) s.Cluster.AddQueryNode() time.Sleep(20 * time.Second) diff --git a/tests/python_client/milvus_client/test_milvus_client_search.py b/tests/python_client/milvus_client/test_milvus_client_search.py index da89988872..4f2f99c650 100644 --- a/tests/python_client/milvus_client/test_milvus_client_search.py +++ b/tests/python_client/milvus_client/test_milvus_client_search.py @@ -1588,11 +1588,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): } ) vectors_to_search = rng.random((1, dim)) - error = {ct.err_code: 1100, - ct.err_msg: f"Current rerank does not support grouping search: invalid parameter"} - self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, - group_by_field=ct.default_reranker_field_name, - check_task=CheckTasks.err_res, check_items=error) + self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, group_by_field=ct.default_reranker_field_name) @pytest.mark.tags(CaseLabel.L1) def test_milvus_client_search_with_reranker_on_dynamic_fields(self):