enhance: refactor rrf and weighted rerank (#42154)

https://github.com/milvus-io/milvus/issues/35856

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-06-10 18:08:35 +08:00 committed by GitHub
parent f3fe117840
commit f1a4526bac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1368 additions and 1010 deletions

View File

@ -1,246 +0,0 @@
package proxy
import (
"context"
"fmt"
"math"
"reflect"
"strings"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
type rankType int
const (
invalidRankType rankType = iota // invalidRankType = 0
rrfRankType // rrfRankType = 1
weightedRankType // weightedRankType = 2
udfExprRankType // udfExprRankType = 3
)
var rankTypeMap = map[string]rankType{
"invalid": invalidRankType,
"rrf": rrfRankType,
"weighted": weightedRankType,
"expr": udfExprRankType,
}
type reScorer interface {
name() string
scorerType() rankType
reScore(input *milvuspb.SearchResults)
setMetricType(metricType string)
getMetricType() string
}
type baseScorer struct {
scorerName string
metricType string
}
func (bs *baseScorer) name() string {
return bs.scorerName
}
func (bs *baseScorer) setMetricType(metricType string) {
bs.metricType = metricType
}
func (bs *baseScorer) getMetricType() string {
return bs.metricType
}
type rrfScorer struct {
baseScorer
k float32
}
func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) {
index := 0
for _, topk := range input.Results.GetTopks() {
for i := int64(0); i < topk; i++ {
input.Results.Scores[index] = 1 / (rs.k + float32(i+1))
index++
}
}
}
func (rs *rrfScorer) scorerType() rankType {
return rrfRankType
}
type weightedScorer struct {
baseScorer
weight float32
normScore bool
}
type activateFunc func(float32) float32
func (ws *weightedScorer) getActivateFunc() activateFunc {
if !ws.normScore {
return func(distance float32) float32 {
return distance
}
}
mUpper := strings.ToUpper(ws.getMetricType())
isCosine := mUpper == strings.ToUpper(metric.COSINE)
isIP := mUpper == strings.ToUpper(metric.IP)
isBM25 := mUpper == strings.ToUpper(metric.BM25)
if isCosine {
f := func(distance float32) float32 {
return (1 + distance) * 0.5
}
return f
}
if isIP {
f := func(distance float32) float32 {
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
}
return f
}
if isBM25 {
f := func(distance float32) float32 {
return 2 * float32(math.Atan(float64(distance))) / math.Pi
}
return f
}
f := func(distance float32) float32 {
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
}
return f
}
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
activateF := ws.getActivateFunc()
for i, distance := range input.Results.GetScores() {
input.Results.Scores[i] = ws.weight * activateF(distance)
}
}
func (ws *weightedScorer) scorerType() rankType {
return weightedRankType
}
func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
if reqCnt == 0 {
return []reScorer{}, nil
}
log := log.Ctx(ctx)
res := make([]reScorer, reqCnt)
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
if err != nil {
log.Info("rank strategy not specified, use rrf instead")
// if not set rank strategy, use rrf rank as default
for i := 0; i < reqCnt; i++ {
res[i] = &rrfScorer{
baseScorer: baseScorer{
scorerName: "rrf",
},
k: float32(defaultRRFParamsValue),
}
}
return res, nil
}
if _, ok := rankTypeMap[rankTypeStr]; !ok {
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
}
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams)
if err != nil {
return nil, errors.New(RankParamsKey + " not found in rank_params")
}
var params map[string]interface{}
err = json.Unmarshal([]byte(paramStr), &params)
if err != nil {
return nil, err
}
switch rankTypeMap[rankTypeStr] {
case rrfRankType:
_, ok := params[RRFParamsKey]
if !ok {
return nil, errors.New(RRFParamsKey + " not found in rank_params")
}
var k float64
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
k = reflect.ValueOf(params[RRFParamsKey]).Float()
} else {
return nil, errors.New("The type of rank param k should be float")
}
if k <= 0 || k >= maxRRFParamsValue {
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
}
log.Debug("rrf params", zap.Float64("k", k))
for i := 0; i < reqCnt; i++ {
res[i] = &rrfScorer{
baseScorer: baseScorer{
scorerName: "rrf",
},
k: float32(k),
}
}
case weightedRankType:
if _, ok := params[WeightsParamsKey]; !ok {
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
}
// normalize scores by default
normScore := true
if _, ok := params[NormScoreKey]; ok {
normScore = params[NormScoreKey].(bool)
}
weights := make([]float32, 0)
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
case reflect.Slice:
rs := reflect.ValueOf(params[WeightsParamsKey])
for i := 0; i < rs.Len(); i++ {
v := rs.Index(i).Elem()
if v.CanFloat() {
weight := v.Float()
if weight < 0 || weight > 1 {
return nil, errors.New("rank param weight should be in range [0, 1]")
}
weights = append(weights, float32(weight))
} else {
return nil, errors.New("The type of rank param weight should be float")
}
}
default:
return nil, errors.New("The weights param should be an array")
}
log.Debug("weights params", zap.Any("weights", weights), zap.Bool("norm_score", normScore))
if reqCnt != len(weights) {
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
}
for i := 0; i < reqCnt; i++ {
res[i] = &weightedScorer{
baseScorer: baseScorer{
scorerName: "weighted",
},
weight: weights[i],
normScore: normScore,
}
}
default:
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
}
return res, nil
}

View File

@ -1,146 +0,0 @@
package proxy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/json"
)
func TestRescorer(t *testing.T) {
t.Run("default scorer", func(t *testing.T) {
rescorers, err := NewReScorers(context.TODO(), 2, nil)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
})
t.Run("rrf without param", func(t *testing.T) {
params := make(map[string]float64)
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorers(context.TODO(), 2, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "k not found in rank_params")
})
t.Run("rrf param out of range", func(t *testing.T) {
params := make(map[string]float64)
params[RRFParamsKey] = -1
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorers(context.TODO(), 2, rankParams)
assert.Error(t, err)
params[RRFParamsKey] = maxRRFParamsValue + 1
b, err = json.Marshal(params)
assert.NoError(t, err)
rankParams = []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorers(context.TODO(), 2, rankParams)
assert.Error(t, err)
})
t.Run("rrf", func(t *testing.T) {
params := make(map[string]float64)
params[RRFParamsKey] = 61
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "rrf"},
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
})
t.Run("weights without param", func(t *testing.T) {
params := make(map[string][]float64)
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorers(context.TODO(), 2, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found in rank_params")
})
t.Run("weights out of range", func(t *testing.T) {
weights := []float64{1.2, 2.3}
params := make(map[string][]float64)
params[WeightsParamsKey] = weights
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorers(context.TODO(), 2, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
})
t.Run("weights with norm_score false", func(t *testing.T) {
weights := []float64{0.5, 0.2}
params := make(map[string]interface{})
params[WeightsParamsKey] = weights
params[NormScoreKey] = false
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
assert.False(t, rescorers[0].(*weightedScorer).normScore)
})
t.Run("weights", func(t *testing.T) {
weights := []float64{0.5, 0.2}
params := make(map[string]interface{})
params[WeightsParamsKey] = weights
b, err := json.Marshal(params)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: RankTypeKey, Value: "weighted"},
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
// normalize scores by default
assert.True(t, rescorers[0].(*weightedScorer).normScore)
})
}

View File

@ -3,8 +3,6 @@ package proxy
import (
"context"
"fmt"
"math"
"sort"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
@ -432,21 +430,6 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
return ret, nil
}
func rankSearchResultData(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
groupByFieldID int64,
groupSize int64,
groupScorer func(group *Group) error,
) (*milvuspb.SearchResults, error) {
if groupByFieldID > 0 {
return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize)
}
return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults)
}
func compareKey(keyI interface{}, keyJ interface{}) bool {
switch keyI.(type) {
case int64:
@ -457,213 +440,6 @@ func compareKey(keyI interface{}, keyJ interface{}) bool {
return false
}
func GetGroupScorer(scorerType string) (func(group *Group) error, error) {
switch scorerType {
case MaxScorer:
return func(group *Group) error {
group.finalScore = group.maxScore
return nil
}, nil
case SumScorer:
return func(group *Group) error {
group.finalScore = group.sumScore
return nil
}, nil
case AvgScorer:
return func(group *Group) error {
if len(group.idList) == 0 {
return merr.WrapErrParameterInvalid(1, len(group.idList),
"input group for score must have at least one id, must be sth wrong within code")
}
group.finalScore = group.sumScore / float32(len(group.idList))
return nil
}, nil
default:
return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
}
}
type Group struct {
idList []interface{}
scoreList []float32
groupVal interface{}
maxScore float32
sumScore float32
finalScore float32
}
func rankSearchResultDataByGroup(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
groupScorer func(group *Group) error,
groupSize int64,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
// in the context of group by, the meaning for offset/limit/top refers to related numbers of group
groupTopK := limit + offset
log.Ctx(ctx).Debug("rankSearchResultDataByGroup",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
var ret *milvuspb.SearchResults
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
return ret, nil
}
totalCount := limit * groupSize
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
return ret, err
}
type accumulateIDGroupVal struct {
accumulatedScore float32
groupVal interface{}
}
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
for i := int64(0); i < nq; i++ {
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
}
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
for _, result := range searchResults {
scores := result.GetResults().GetScores()
start := 0
// milvus has limits for the value range of nq and limit
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
groupByValIterator := typeutil.GetDataIterator(result.GetResults().GetGroupByFieldValue())
for i := 0; i < int(nq); i++ {
realTopK := int(result.GetResults().Topks[i])
for j := start; j < start+realTopK; j++ {
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
groupByVal := groupByValIterator(j)
if accumulatedScores[i][id] != nil {
accumulatedScores[i][id].accumulatedScore += scores[j]
} else {
accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal}
}
}
start += realTopK
}
}
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(groupByDataType, true, int(limit))
if err != nil {
return ret, err
}
for i := int64(0); i < nq; i++ {
idSet := accumulatedScores[i]
keys := make([]interface{}, 0)
for key := range idSet {
keys = append(keys, key)
}
// sort id by score
big := func(i, j int) bool {
scoreItemI := idSet[keys[i]]
scoreItemJ := idSet[keys[j]]
if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore {
return compareKey(keys[i], keys[j])
}
return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore
}
sort.Slice(keys, big)
// separate keys into buckets according to groupVal
buckets := make(map[interface{}]*Group)
for _, key := range keys {
scoreItem := idSet[key]
groupVal := scoreItem.groupVal
if buckets[groupVal] == nil {
buckets[groupVal] = &Group{
idList: make([]interface{}, 0),
scoreList: make([]float32, 0),
groupVal: groupVal,
}
}
if int64(len(buckets[groupVal].idList)) >= groupSize {
// only consider group size results in each group
continue
}
buckets[groupVal].idList = append(buckets[groupVal].idList, key)
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore)
if scoreItem.accumulatedScore > buckets[groupVal].maxScore {
buckets[groupVal].maxScore = scoreItem.accumulatedScore
}
buckets[groupVal].sumScore += scoreItem.accumulatedScore
}
if int64(len(buckets)) <= offset {
ret.Results.Topks = append(ret.Results.Topks, 0)
continue
}
groupList := make([]*Group, len(buckets))
idx := 0
for _, group := range buckets {
groupScorer(group)
groupList[idx] = group
idx += 1
}
sort.Slice(groupList, func(i, j int) bool {
if groupList[i].finalScore == groupList[j].finalScore {
if len(groupList[i].idList) == len(groupList[j].idList) {
// if final score and size of group are both equal
// choose the group with smaller first key
// here, it's guaranteed all group having at least one id in the idList
return compareKey(groupList[i].idList[0], groupList[j].idList[0])
}
// choose the larger group when scores are equal
return len(groupList[i].idList) > len(groupList[j].idList)
}
return groupList[i].finalScore > groupList[j].finalScore
})
if int64(len(groupList)) > groupTopK {
groupList = groupList[:groupTopK]
}
returnedRowNum := 0
for index := int(offset); index < len(groupList); index++ {
group := groupList[index]
for i, score := range group.scoreList {
// idList and scoreList must have same length
typeutil.AppendPKs(ret.Results.Ids, group.idList[i])
if roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Results.Scores = append(ret.Results.Scores, score)
gpFieldBuilder.Add(group.groupVal)
}
returnedRowNum += len(group.idList)
}
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
}
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
return ret, nil
}
func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
}
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
switch pkType {
case schemapb.DataType_Int64:
@ -684,94 +460,6 @@ func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType sch
return nil
}
func rankSearchResultDataByPk(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
topk := limit + offset
log.Ctx(ctx).Debug("rankSearchResultDataByPk",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
var ret *milvuspb.SearchResults
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
return ret, nil
}
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
return ret, nil
}
// []map[id]score
accumulatedScores := make([]map[interface{}]float32, nq)
for i := int64(0); i < nq; i++ {
accumulatedScores[i] = make(map[interface{}]float32)
}
for _, result := range searchResults {
scores := result.GetResults().GetScores()
start := int64(0)
for i := int64(0); i < nq; i++ {
realTopk := result.GetResults().Topks[i]
for j := start; j < start+realTopk; j++ {
id := typeutil.GetPK(result.GetResults().GetIds(), j)
accumulatedScores[i][id] += scores[j]
}
start += realTopk
}
}
for i := int64(0); i < nq; i++ {
idSet := accumulatedScores[i]
keys := make([]interface{}, 0)
for key := range idSet {
keys = append(keys, key)
}
if int64(len(keys)) <= offset {
ret.Results.Topks = append(ret.Results.Topks, 0)
continue
}
// sort id by score
big := func(i, j int) bool {
if idSet[keys[i]] == idSet[keys[j]] {
return compareKey(keys[i], keys[j])
}
return idSet[keys[i]] > idSet[keys[j]]
}
sort.Slice(keys, big)
if int64(len(keys)) > topk {
keys = keys[:topk]
}
// set real topk
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
// append id and score
for index := offset; index < int64(len(keys)); index++ {
typeutil.AppendPKs(ret.Results.Ids, keys[index])
score := idSet[keys[index]]
if roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Results.Scores = append(ret.Results.Scores, score)
}
}
return ret, nil
}
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
return &milvuspb.SearchResults{
Status: merr.Success("search result is empty"),

View File

@ -6,7 +6,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
@ -52,86 +51,6 @@ func genTestDataSearchResultsData() []*schemapb.SearchResultData {
return []*schemapb.SearchResultData{searchResultData1, searchResultData2}
}
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
data := genTestDataSearchResultsData()
searchResults := []*milvuspb.SearchResults{
{Results: data[0]},
{Results: data[1]},
}
nq := int64(1)
limit := int64(3)
offset := int64(0)
roundDecimal := int64(1)
groupSize := int64(3)
groupByFieldId := int64(101)
rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal}
{
// test for sum group scorer
scorerType := "sum"
groupScorer, _ := GetGroupScorer(scorerType)
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for max group scorer
scorerType := "max"
groupScorer, _ := GetGroupScorer(scorerType)
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for avg group scorer
scorerType := "avg"
groupScorer, _ := GetGroupScorer(scorerType)
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for offset for ranking group
scorerType := "avg"
groupScorer, _ := GetGroupScorer(scorerType)
rankParams.offset = 2
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores())
struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
}
{
// test for offset exceeding the count of final groups
scorerType := "avg"
groupScorer, _ := GetGroupScorer(scorerType)
rankParams.offset = 4
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
struts.NoError(err)
struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data)
struts.Equal([]float32{}, rankedRes.GetResults().GetScores())
}
{
// test for invalid group scorer
scorerType := "xxx"
groupScorer, err := GetGroupScorer(scorerType)
struts.Error(err)
struts.Nil(groupScorer)
}
}
func (struts *SearchReduceUtilTestSuite) TestReduceSearchResult() {
data := genTestDataSearchResultsData()

View File

@ -537,6 +537,14 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C
}, nil
}
func getGroupScorerStr(params []*commonpb.KeyValuePair) string {
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, params)
if err != nil {
groupScorerStr = MaxScorer
}
return groupScorerStr
}
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
ret := &milvuspb.SearchRequest{
Base: req.GetBase(),

View File

@ -88,10 +88,6 @@ type searchTask struct {
queryInfos []*planpb.QueryInfo
relatedDataSize int64
// Will be deprecated, use functionScore after milvus 2.6
reScorers []reScorer
groupScorer func(group *Group) error
// New reranker functions
functionScore *rerank.FunctionScore
rankParams *rankParams
@ -378,22 +374,10 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
return err
}
} else {
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
if t.functionScore, err = rerank.NewFunctionScoreWithlegacy(t.schema.CollectionSchema, t.request.GetSearchParams()); err != nil {
log.Warn("Failed to create function by legacy info", zap.Error(err))
return err
}
// set up groupScorer for hybridsearch+groupBy
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
if err != nil {
groupScorerStr = MaxScorer
}
groupScorer, err := GetGroupScorer(groupScorerStr)
if err != nil {
return err
}
t.groupScorer = groupScorer
}
t.needRequery = len(t.request.OutputFields) > 0 || len(t.functionScore.GetAllInputFieldNames()) > 0
@ -544,22 +528,12 @@ func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, t
return err
}
if t.functionScore == nil {
t.reScorers[index].setMetricType(subMetricType)
t.reScorers[index].reScore(result)
}
searchMetrics = append(searchMetrics, subMetricType)
multipleMilvusResults[index] = result
}
if t.functionScore == nil {
if err := t.rank(ctx, span, multipleMilvusResults); err != nil {
return err
}
} else {
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil {
return err
}
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil {
return err
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
@ -583,43 +557,6 @@ func (t *searchTask) fillResult() {
t.fillInFieldInfo()
}
// TODO: Old version rerank: rrf/weighted, subsequent unified rerank implementation
func (t *searchTask) rank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error {
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
if t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
t.rankParams,
primaryFieldSchema.GetDataType(),
multipleMilvusResults,
t.SearchRequest.GetGroupByFieldId(),
t.SearchRequest.GetGroupSize(),
t.groupScorer); err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
if t.needRequery {
if t.requeryFunc == nil {
t.requeryFunc = requeryImpl
}
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids})
if err != nil {
return err
}
t.result.Results.FieldsData = fields[0]
}
return nil
}
func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
uniqueIDs := &schemapb.IDs{}
count := 0
@ -659,10 +596,10 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, groupScorerStr, searchMetrics,
)
return t.functionScore.Process(ctx, params, results)
}
@ -703,6 +640,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
for i := 0; i < len(multipleMilvusResults); i++ {
multipleMilvusResults[i].Results.FieldsData = fields[i]
}
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
return err
}
@ -838,8 +776,9 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR
{
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, groupScorerStr, []string{metricType})
// rank only returns id and score
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err

View File

@ -1000,7 +1000,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.ErrorContains(t, st.PreExecute(ctx), "Current rerank does not support grouping search")
assert.NoError(t, st.PreExecute(ctx))
})
}

View File

@ -77,9 +77,6 @@ const (
// DefaultStringIndexType name of default index type for varChar/string field
DefaultStringIndexType = indexparamcheck.IndexINVERTED
defaultRRFParamsValue = 60
maxRRFParamsValue = 16384
)
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
@ -427,7 +424,6 @@ func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchem
if !exist {
return fmt.Errorf("type param(max_capacity) should be specified for array field %s of collection %s", field.GetName(), collectionName)
}
return nil
}

View File

@ -1,3 +1,6 @@
//go:build test
// +build test
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
@ -22,11 +25,13 @@ import (
"context"
"encoding/json"
"io"
"math"
"net/http"
"net/http/httptest"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/ali"
"github.com/milvus-io/milvus/internal/util/function/models/cohere"
"github.com/milvus-io/milvus/internal/util/function/models/openai"
@ -34,6 +39,7 @@ import (
"github.com/milvus-io/milvus/internal/util/function/models/tei"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
)
const TestModel string = "TestModel"
@ -247,3 +253,56 @@ func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockrunt
body, _ := json.Marshal(resp)
return &bedrockruntime.InvokeModelOutput{Body: body}, nil
}
func GenSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData {
tops := make([]int64, nq)
for i := 0; i < int(nq); i++ {
tops[i] = topk
}
fieldsData := []*schemapb.FieldData{}
if fieldName != "" {
fieldsData = []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))}
fieldsData[0].FieldId = fieldId
}
data := &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: testutils.GenerateInt64Array(int(nq * topk)),
},
},
},
Topks: tops,
FieldsData: fieldsData,
}
return data
}
func GenSearchResultDataWithGrouping(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64, groupingName string, groupingId int64, groupSize int64) *schemapb.SearchResultData {
data := GenSearchResultData(nq, topk*groupSize, dType, fieldName, fieldId)
values := make([]int64, 0)
for i := int64(0); i < nq*topk*groupSize; i += groupSize {
for j := int64(0); j < groupSize; j++ {
values = append(values, i)
}
}
groupingField := testutils.GenerateScalarFieldDataWithValue(schemapb.DataType_Int64, groupingName, groupingId, values)
data.GroupByFieldValue = groupingField
return data
}
func FloatsAlmostEqual(a, b []float32, epsilon float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if float32(math.Abs(float64(a[i]-b[i]))) > epsilon {
return false
}
}
return true
}

View File

@ -55,7 +55,7 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
}
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true)
if err != nil {
return nil, err
}
@ -168,7 +168,7 @@ func toGreaterScore(score float32, metricType string) float32 {
}
}
func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns) *IDScores[T] {
func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
srcScores := maxMerge[T](cols)
decayScores := map[T]float32{}
for _, col := range cols {
@ -186,7 +186,10 @@ func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, sear
for id := range decayScores {
decayScores[id] = decayScores[id] * srcScores[id]
}
return newIDScores(decayScores, searchParams)
if searchParams.isGrouping() {
return newGroupingIDScores(decayScores, searchParams, idGroup)
}
return newIDScores(decayScores, searchParams), nil
}
func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
@ -198,7 +201,10 @@ func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *Sea
col.scores[j] = toGreaterScore(score, metricType)
}
}
idScore := decay.processOneSearchData(ctx, searchParams, cols)
idScore, err := decay.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
if err != nil {
return nil, err
}
appendResult(outputs, idScore.ids, idScore.scores)
}
return outputs, nil

View File

@ -27,7 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
"github.com/milvus-io/milvus/internal/util/function"
)
func TestDecayFunction(t *testing.T) {
@ -260,8 +260,8 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{}, ret.searchResultData.Topks)
@ -271,10 +271,10 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
{
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
s.NoError(err)
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
}
@ -289,9 +289,9 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
nq := int64(1)
f, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -302,9 +302,9 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
nq := int64(3)
f, err := newDecayFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -329,11 +329,11 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
f, err := newDecayFunction(schema, functionSchema2)
s.NoError(err)
// ts/id data: 0 - 9
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
// empty
data2 := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -345,11 +345,12 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
f, err := newDecayFunction(schema, functionSchema2)
s.NoError(err)
// ts/id data: 0 - 9
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
// ts/id data: 0 - 3
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -363,13 +364,13 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
// nq1 ts/id data: 0 - 9
// nq2 ts/id data: 10 - 19
// nq3 ts/id data: 20 - 29
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
// nq1 ts/id data: 0 - 3
// nq2 ts/id data: 4 - 7
// nq3 ts/id data: 8 - 11
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -390,26 +391,3 @@ func (s *DecayFunctionSuite) TestDecay() {
s.Equal(linearDecay(0, 1, 0.5, 5, 5), 1.0)
s.Less(linearDecay(0, 1, 0.5, 5, 6), 1.0)
}
func genSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData {
tops := make([]int64, nq)
for i := 0; i < int(nq); i++ {
tops[i] = topk
}
data := &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: testutils.GenerateInt64Array(int(nq * topk)),
},
},
},
Topks: tops,
FieldsData: []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))},
}
data.FieldsData[0].FieldId = fieldId
return data
}

View File

@ -20,38 +20,79 @@ package rerank
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
const (
decayFunctionName string = "decay"
modelFunctionName string = "model"
rrfName string = "rrf"
weightedName string = "weighted"
)
const (
maxScorer string = "max"
sumScorer string = "sum"
avgScorer string = "avg"
)
// legacy rrf/weighted rerank configs
const (
legacyRankTypeKey = "strategy"
legacyRankParamsKey = "params"
)
type rankType int
const (
invalidRankType rankType = iota // invalidRankType = 0
rrfRankType // rrfRankType = 1
weightedRankType // weightedRankType = 2
)
var rankTypeMap = map[string]rankType{
"invalid": invalidRankType,
"rrf": rrfRankType,
"weighted": weightedRankType,
}
type SearchParams struct {
nq int64
limit int64
offset int64
roundDecimal int64
// TODO: supports group search
groupByFieldId int64
groupSize int64
strictGroupSize bool
groupScore string
searchMetrics []string
}
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, searchMetrics []string) *SearchParams {
func (s *SearchParams) isGrouping() bool {
return s.groupByFieldId > 0
}
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, groupScore string, searchMetrics []string) *SearchParams {
if groupScore == "" {
groupScore = maxScorer
}
return &SearchParams{
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, searchMetrics,
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, groupScore, searchMetrics,
}
}
@ -95,6 +136,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
case modelFunctionName:
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
case rrfName:
rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema)
case weightedName:
rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema)
default:
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
}
@ -117,6 +162,68 @@ func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *sc
return funcScore, nil
}
func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) {
var params map[string]interface{}
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankTypeKey, rankParams)
if err != nil {
rankTypeStr = "rrf"
params = make(map[string]interface{}, 0)
} else {
if _, ok := rankTypeMap[rankTypeStr]; !ok {
return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr)
}
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankParamsKey, rankParams)
if err != nil {
return nil, fmt.Errorf("params" + " not found in rank_params")
}
err = json.Unmarshal([]byte(paramStr), &params)
if err != nil {
return nil, fmt.Errorf("Parse rerank params failed, err: %s", err)
}
}
fSchema := schemapb.FunctionSchema{
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{},
OutputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{},
}
switch rankTypeMap[rankTypeStr] {
case rrfRankType:
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: rrfName})
if v, ok := params[RRFParamsKey]; ok {
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
k := reflect.ValueOf(v).Float()
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: RRFParamsKey, Value: strconv.FormatFloat(k, 'f', -1, 64)})
} else {
return nil, fmt.Errorf("The type of rank param k should be float")
}
}
case weightedRankType:
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: weightedName})
if v, ok := params[WeightsParamsKey]; ok {
if d, err := json.Marshal(v); err != nil {
return nil, fmt.Errorf("The weights param should be an array")
} else {
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: string(d)})
}
}
if normScore, ok := params[NormScoreKey]; ok {
if ns, ok := normScore.(bool); ok {
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: NormScoreKey, Value: strconv.FormatBool(ns)})
} else {
return nil, fmt.Errorf("Weighted rerank err, norm_score should been bool type, but [norm_score:%s]'s type is %T", normScore, normScore)
}
}
default:
return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr)
}
funcScore := &FunctionScore{}
if funcScore.reranker, err = createFunction(collSchema, &fSchema); err != nil {
return nil, err
}
return funcScore, nil
}
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
if len(multipleMilvusResults) == 0 {
return &milvuspb.SearchResults{
@ -137,7 +244,7 @@ func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchPa
})
// rankResult only has scores
inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs())
inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs(), searchParams.isGrouping())
if err != nil {
return nil, err
}

View File

@ -20,6 +20,7 @@ package rerank
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/suite"
@ -27,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function"
)
func TestFunctionScore(t *testing.T) {
@ -75,7 +77,7 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() {
s.NoError(err)
s.Equal([]string{"ts"}, f.GetAllInputFieldNames())
s.Equal([]int64{102}, f.GetAllInputFieldIDs())
s.Equal(false, f.IsSupportGroup())
s.Equal(true, f.IsSupportGroup())
s.Equal("decay", f.reranker.GetRankName())
{
@ -152,7 +154,7 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
// empty inputs
{
nq := int64(1)
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal(0, len(ret.Results.FieldsData))
@ -162,11 +164,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
// nq = 1
{
nq := int64(1)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3}, ret.Results.Topks)
@ -174,11 +176,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
// nq=1, input is empty
{
nq := int64(1)
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
data := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0}, ret.Results.Topks)
@ -186,11 +188,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
// nq=3
{
nq := int64(3)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
@ -198,11 +200,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
// nq=3, all input is empty
{
nq := int64(3)
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
data := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
searchData := &milvuspb.SearchResults{
Results: data,
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
@ -213,13 +215,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
{
nq := int64(1)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3}, ret.Results.Topks)
@ -228,13 +230,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
{
nq := int64(1)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0}, ret.Results.Topks)
@ -243,13 +245,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
{
nq := int64(1)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3}, ret.Results.Topks)
@ -258,13 +260,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
{
nq := int64(3)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
@ -273,13 +275,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
{
nq := int64(3)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
@ -288,15 +290,131 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
{
nq := int64(3)
searchData1 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
}
searchData2 := &milvuspb.SearchResults{
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
}
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
s.NoError(err)
s.Equal(int64(3), ret.Results.TopK)
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
}
}
func (s *FunctionScoreSuite) TestlegacyFunction() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
{
rankParams := []*commonpb.KeyValuePair{}
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.NoError(err)
s.Equal(f.reranker.GetRankName(), rrfName)
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "invalid"},
{Key: legacyRankParamsKey, Value: `{"k": "v"}`},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.ErrorContains(err, "unsupported rank type")
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "rrf"},
{Key: legacyRankParamsKey, Value: "invalid"},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.ErrorContains(err, "Parse rerank params failed")
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "rrf"},
{Key: legacyRankParamsKey, Value: `{"k": "invalid"}`},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.ErrorContains(err, "The type of rank param k should be float")
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "rrf"},
{Key: legacyRankParamsKey, Value: `{"k": 1.0}`},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.NoError(err)
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "weighted"},
{Key: legacyRankParamsKey, Value: `{"weights": [1.0]}`},
}
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.NoError(err)
s.Equal(f.reranker.GetRankName(), weightedName)
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "weighted"},
{Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": "Invalid"}`},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.ErrorContains(err, "Weighted rerank err, norm_score should been bool type")
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "weighted"},
{Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": false}`},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.NoError(err)
}
{
rankParams := []*commonpb.KeyValuePair{
{Key: legacyRankTypeKey, Value: "weighted"},
{Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": "false"}`},
}
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
s.ErrorContains(err, "Weighted rerank err, norm_score should been bool type")
}
}
func (s *FunctionScoreSuite) TestFunctionUtil() {
g1 := &Group[int64]{
idList: []int64{1, 2, 3},
scoreList: []float32{1.0, 2.0, 3.0},
groupVal: 3,
maxScore: 3.0,
sumScore: 6.0,
}
s1, err := groupScore(g1, maxScorer)
s.NoError(err)
s.True(math.Abs(float64(s1-3.0)) < 0.001)
s2, err := groupScore(g1, sumScorer)
s.NoError(err)
s.True(math.Abs(float64(s2-6.0)) < 0.001)
s3, err := groupScore(g1, avgScorer)
s.NoError(err)
s.True(math.Abs(float64(s3-2.0)) < 0.001)
_, err = groupScore(g1, "NotSupported")
s.ErrorContains(err, "is not supported")
g1.idList = []int64{}
_, err = groupScore(g1, avgScorer)
s.ErrorContains(err, "input group for score must have at least one id, must be sth wrong within code")
}

View File

@ -296,7 +296,7 @@ type ModelFunction[T PKType] struct {
}
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true)
if err != nil {
return nil, err
}
@ -333,7 +333,7 @@ func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemap
}
}
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) {
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
uniqueData := make(map[T]string)
for _, col := range cols {
texts := col.data[0].([]string)
@ -359,6 +359,9 @@ func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchP
for idx, id := range ids {
rerankScores[id] = scores[idx]
}
if searchParams.isGrouping() {
return newGroupingIDScores(rerankScores, searchParams, idGroup)
}
return newIDScores(rerankScores, searchParams), nil
}
@ -368,7 +371,7 @@ func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *Search
}
outputs := newRerankOutputs(searchParams)
for idx, cols := range inputs.data {
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols)
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols, inputs.idGroupValue)
if err != nil {
return nil, err
}

View File

@ -375,8 +375,8 @@ func (s *RerankModelSuite) TestRerankProcess() {
nq := int64(1)
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{}, ret.searchResultData.Topks)
@ -386,10 +386,10 @@ func (s *RerankModelSuite) TestRerankProcess() {
{
nq := int64(1)
f, err := newModelFunction(schema, functionSchema)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
s.NoError(err)
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
}
}
@ -430,18 +430,18 @@ func (s *RerankModelSuite) TestRerankProcess() {
{
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
_, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
_, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]")
}
{
functionSchema.Params[2].Value = `["q1"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 0, -1, -1, 1, false, []string{"COSINE"}}, inputs)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -452,9 +452,9 @@ func (s *RerankModelSuite) TestRerankProcess() {
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -468,11 +468,11 @@ func (s *RerankModelSuite) TestRerankProcess() {
functionSchema.Params[2].Value = `["q1"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
// empty
data2 := genSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -483,11 +483,11 @@ func (s *RerankModelSuite) TestRerankProcess() {
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
// ts/id data: 0 - 9
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
// ts/id data: 0 - 3
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
@ -498,10 +498,10 @@ func (s *RerankModelSuite) TestRerankProcess() {
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)

View File

@ -0,0 +1,101 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
const (
RRFParamsKey string = "k"
defaultRRFParamsValue float64 = 60
)
type RRFFunction[T PKType] struct {
RerankBase
k float32
}
func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
base, err := newRerankBase(collSchema, funcSchema, rrfName, true)
if err != nil {
return nil, err
}
if len(base.GetInputFieldNames()) != 0 {
return nil, fmt.Errorf("The rrf function does not support input parameters, but got %s", base.GetInputFieldNames())
}
k := float64(defaultRRFParamsValue)
for _, param := range funcSchema.Params {
if strings.ToLower(param.Key) == RRFParamsKey {
if k, err = strconv.ParseFloat(param.Value, 64); err != nil {
return nil, fmt.Errorf("Param k:%s is not a number", param.Value)
}
}
}
if k <= 0 || k >= 16384 {
return nil, fmt.Errorf("The rank params k should be in range (0, %d)", 16384)
}
if base.pkType == schemapb.DataType_Int64 {
return &RRFFunction[int64]{RerankBase: *base, k: float32(k)}, nil
} else {
return &RRFFunction[string]{RerankBase: *base, k: float32(k)}, nil
}
}
func (rrf *RRFFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
rrfScores := map[T]float32{}
for _, col := range cols {
if col.size == 0 {
continue
}
ids := col.ids.([]T)
for idx, id := range ids {
if score, ok := rrfScores[id]; !ok {
rrfScores[id] = 1 / (rrf.k + float32(idx+1))
} else {
rrfScores[id] = score + 1/(rrf.k+float32(idx+1))
}
}
}
if searchParams.isGrouping() {
return newGroupingIDScores(rrfScores, searchParams, idGroup)
}
return newIDScores(rrfScores, searchParams), nil
}
func (rrf *RRFFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
outputs := newRerankOutputs(searchParams)
for _, cols := range inputs.data {
idScore, err := rrf.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
if err != nil {
return nil, err
}
appendResult(outputs, idScore.ids, idScore.scores)
}
return outputs, nil
}

View File

@ -0,0 +1,250 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function"
)
func TestRRFFunction(t *testing.T) {
suite.Run(t, new(RRFFunctionSuite))
}
type RRFFunctionSuite struct {
suite.Suite
}
func (s *RRFFunctionSuite) TestNewRRFFuction() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: RRFParamsKey, Value: "70"},
},
}
{
_, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
}
{
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
_, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
}
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "NotNum"}
_, err := newRRFFunction(schema, functionSchema)
s.ErrorContains(err, "is not a number")
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "-1"}
_, err = newRRFFunction(schema, functionSchema)
s.ErrorContains(err, "he rank params k should be in range")
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "100"}
}
{
functionSchema.InputFieldNames = []string{"ts"}
_, err := newRRFFunction(schema, functionSchema)
s.ErrorContains(err, "The rrf function does not support input parameters")
}
}
func (s *RRFFunctionSuite) TestRRFFuctionProcess() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{},
}
// empty
{
nq := int64(1)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{}, ret.searchResultData.Topks)
}
// singleSearchResultData
// nq = 1
{
nq := int64(1)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{2, 3, 4}, ret.searchResultData.Ids.GetIntId().Data)
}
// nq = 3
{
nq := int64(3)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{2, 3, 4, 12, 13, 14, 22, 23, 24}, ret.searchResultData.Ids.GetIntId().Data)
}
// has empty inputs
{
nq := int64(1)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
// id data: 0 - 9
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
// empty
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{0, 1, 2}, ret.searchResultData.Ids.GetIntId().Data)
}
// nq = 1
{
nq := int64(1)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
// id data: 0 - 9
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
// id data: 0 - 3
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{2, 3, 4}, ret.searchResultData.Ids.GetIntId().Data)
}
// // nq = 3
{
nq := int64(3)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
// nq1 id data: 0 - 9
// nq2 id data: 10 - 19
// nq3 id data: 20 - 29
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
// nq1 id data: 0 - 3
// nq2 id data: 4 - 7
// nq3 id data: 8 - 11
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{2, 3, 4, 5, 11, 6, 9, 21, 10}, ret.searchResultData.Ids.GetIntId().Data)
}
// // nq = 3 grouping = true, grouping size = 1
{
nq := int64(3)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
// nq1 id data: 0 - 9
// nq2 id data: 10 - 19
// nq3 id data: 20 - 29
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
// nq1 id data: 0 - 3
// nq2 id data: 4 - 7
// nq3 id data: 8 - 11
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, 102, 1, true, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{2, 3, 4, 5, 11, 6, 9, 21, 10}, ret.searchResultData.Ids.GetIntId().Data)
}
// // nq = 3 grouping = true, grouping size = 3
{
nq := int64(3)
f, err := newRRFFunction(schema, functionSchema)
s.NoError(err)
// nq1 id data: 0 - 29, group value: 0,0,0,1,1,1, ... , 9,9,9
// nq2 id data: 30 - 59, group value: 10,10,10,11,11,11, ... , 19,19,19
// nq3 id data: 60 - 99, group value: 20,20,20,21,21,21, ... , 29,29,29
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
// nq1 id data: 0 - 11, group value: 0,0,0,1,1,1,2,2,2,3,3,3,
// nq2 id data: 12 - 23, group value: 4,4,4,5,5,5,6,6,6,7,7,7
// nq3 id data: 24 - 35, group value: 8,8,8,9,9,9,10,10,10,11,11,11
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 3, 102, 3, true, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{9, 9, 9}, ret.searchResultData.Topks)
s.Equal(int64(9), ret.searchResultData.TopK)
s.Equal([]int64{
6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 33, 34, 35, 18, 19, 20,
27, 28, 29, 63, 64, 65, 30, 31, 32,
},
ret.searchResultData.Ids.GetIntId().Data)
}
}

View File

@ -24,6 +24,8 @@ import (
"sort"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type PKType interface {
@ -40,9 +42,9 @@ type columns struct {
type rerankInputs struct {
// nqs,searchResultsIndex
data [][]*columns
nq int64
data [][]*columns
idGroupValue map[any]any
nq int64
// There is only fieldId in schemapb.SearchResultData, but no fieldName
inputFieldIds []int64
@ -69,7 +71,7 @@ func organizeFieldIdData(multipSearchResultData []*schemapb.SearchResultData, in
return multipIdField, nil
}
func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputFieldIds []int64) (*rerankInputs, error) {
func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputFieldIds []int64, isGrouping bool) (*rerankInputs, error) {
if len(multipSearchResultData) == 0 {
return &rerankInputs{}, nil
}
@ -84,27 +86,34 @@ func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputF
cols[i] = make([]*columns, len(multipSearchResultData))
}
for retIdx, searchResult := range multipSearchResultData {
for _, fieldId := range inputFieldIds {
fieldData := multipIdField[retIdx][fieldId]
start := int64(0)
for i := int64(0); i < nq; i++ {
size := searchResult.Topks[i]
start := int64(0)
for i := int64(0); i < nq; i++ {
size := searchResult.Topks[i]
if cols[i][retIdx] == nil {
cols[i][retIdx] = &columns{}
cols[i][retIdx].size = size
cols[i][retIdx].ids = getIds(searchResult.Ids, start, size)
cols[i][retIdx].scores = searchResult.Scores[start : start+size]
}
for _, fieldId := range inputFieldIds {
fieldData := multipIdField[retIdx][fieldId]
d, err := getField(fieldData, start, size)
if err != nil {
return nil, err
}
if cols[i][retIdx] == nil {
cols[i][retIdx] = &columns{}
cols[i][retIdx].size = size
cols[i][retIdx].ids = getIds(searchResult.Ids, start, size)
cols[i][retIdx].scores = searchResult.Scores[start : start+size]
}
cols[i][retIdx].data = append(cols[i][retIdx].data, d)
start += size
}
start += size
}
}
return &rerankInputs{cols, nq, inputFieldIds}, nil
if isGrouping {
idGroup, err := genIdGroupingMap(multipSearchResultData)
if err != nil {
return nil, err
}
return &rerankInputs{cols, idGroup, nq, inputFieldIds}, nil
}
return &rerankInputs{cols, nil, nq, inputFieldIds}, nil
}
func (inputs *rerankInputs) numOfQueries() int64 {
@ -116,9 +125,13 @@ type rerankOutputs struct {
}
func newRerankOutputs(searchParams *SearchParams) *rerankOutputs {
topk := searchParams.limit
if searchParams.isGrouping() {
topk = topk * searchParams.groupSize
}
ret := &schemapb.SearchResultData{
NumQueries: searchParams.nq,
TopK: searchParams.limit,
TopK: topk,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
@ -153,28 +166,11 @@ func appendResult[T PKType](outputs *rerankOutputs, ids []T, scores []float32) {
}
type IDScores[T PKType] struct {
// idScores map[T]float32
ids []T
scores []float32
size int64
}
// func (s *IDScores[T]) GetSortedIdScores() ([]T, []float32) {
// ids := make([]T, 0, s.size)
// big := func(i, j int) bool {
// if s.idScores[ids[i]] == s.idScores[ids[j]] {
// return ids[i] < ids[j]
// }
// return s.idScores[ids[i]] > s.idScores[ids[j]]
// }
// sort.Slice(ids, big)
// scores := make([]float32, 0, s.size)
// for _, id := range ids {
// scores = append(scores, s.idScores[id])
// }
// return ids, scores
// }
func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *IDScores[T] {
ids := make([]T, 0, len(idScores))
for id := range idScores {
@ -209,6 +205,120 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *
return &ret
}
func genIDGroupValueMap[T PKType]() map[T]any {
return nil
}
func groupScore[T PKType](group *Group[T], scorerType string) (float32, error) {
switch scorerType {
case maxScorer:
return group.maxScore, nil
case sumScorer:
return group.sumScore, nil
case avgScorer:
if len(group.idList) == 0 {
return 0, merr.WrapErrParameterInvalid(1, len(group.idList),
"input group for score must have at least one id, must be sth wrong within code")
}
return group.sumScore / float32(len(group.idList)), nil
default:
return 0, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
}
}
type Group[T PKType] struct {
idList []T
scoreList []float32
groupVal any
maxScore float32
sumScore float32
finalScore float32
}
func newGroupingIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams, idGroup map[any]any) (*IDScores[T], error) {
ids := make([]T, 0, len(idScores))
for id := range idScores {
ids = append(ids, id)
}
sort.Slice(ids, func(i, j int) bool {
if idScores[ids[i]] == idScores[ids[j]] {
return ids[i] < ids[j]
}
return idScores[ids[i]] > idScores[ids[j]]
})
buckets := make(map[interface{}]*Group[T])
for _, id := range ids {
score := idScores[id]
groupVal := idGroup[id]
if buckets[groupVal] == nil {
buckets[groupVal] = &Group[T]{
idList: make([]T, 0),
scoreList: make([]float32, 0),
groupVal: groupVal,
}
}
if int64(len(buckets[groupVal].idList)) >= searchParams.groupSize {
continue
}
buckets[groupVal].idList = append(buckets[groupVal].idList, id)
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, idScores[id])
if score > buckets[groupVal].maxScore {
buckets[groupVal].maxScore = score
}
buckets[groupVal].sumScore += score
}
groupList := make([]*Group[T], len(buckets))
idx := 0
var err error
for _, group := range buckets {
if group.finalScore, err = groupScore(group, searchParams.groupScore); err != nil {
return nil, err
}
groupList[idx] = group
idx += 1
}
sort.Slice(groupList, func(i, j int) bool {
if groupList[i].finalScore == groupList[j].finalScore {
if len(groupList[i].idList) == len(groupList[j].idList) {
// if final score and size of group are both equal
// choose the group with smaller first key
// here, it's guaranteed all group having at least one id in the idList
return groupList[i].idList[0] < groupList[j].idList[0]
}
// choose the larger group when scores are equal
return len(groupList[i].idList) > len(groupList[j].idList)
}
return groupList[i].finalScore > groupList[j].finalScore
})
if int64(len(groupList)) > searchParams.limit+searchParams.offset {
groupList = groupList[:searchParams.limit+searchParams.offset]
}
ret := IDScores[T]{
make([]T, 0, searchParams.limit),
make([]float32, 0, searchParams.limit),
0,
}
for index := int(searchParams.offset); index < len(groupList); index++ {
group := groupList[index]
for i, score := range group.scoreList {
// idList and scoreList must have same length
if searchParams.roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(searchParams.roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.scores = append(ret.scores, score)
ret.ids = append(ret.ids, group.idList[i])
}
}
ret.size = int64(len(ret.ids))
return &ret, nil
}
func getField(inputField *schemapb.FieldData, start int64, size int64) (any, error) {
switch inputField.Type {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
@ -299,3 +409,22 @@ func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error)
}
return pkType, nil
}
func genIdGroupingMap(multipSearchResultData []*schemapb.SearchResultData) (map[any]any, error) {
idGroupValue := map[any]any{}
for _, result := range multipSearchResultData {
if result.GetGroupByFieldValue() == nil {
return nil, fmt.Errorf("Group value is nil")
}
size := typeutil.GetSizeOfIDs(result.Ids)
groupIter := typeutil.GetDataIterator(result.GetGroupByFieldValue())
for i := 0; i < size; i++ {
groupByVal := groupIter(i)
id := typeutil.GetPK(result.Ids, int64(i))
if _, exist := idGroupValue[id]; !exist {
idGroupValue[id] = groupByVal
}
}
}
return idGroupValue, nil
}

View File

@ -0,0 +1,154 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"encoding/json"
"fmt"
"math"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
const (
WeightsParamsKey string = "weights"
NormScoreKey string = "norm_score"
)
type WeightedFunction[T PKType] struct {
RerankBase
weight []float32
needNorm bool
}
func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
base, err := newRerankBase(collSchema, funcSchema, weightedName, true)
if err != nil {
return nil, err
}
if len(base.GetInputFieldNames()) != 0 {
return nil, fmt.Errorf("The weighted function does not support input parameters, but got %s", base.GetInputFieldNames())
}
var weights []float32
needNorm := false
for _, param := range funcSchema.Params {
switch strings.ToLower(param.Key) {
case WeightsParamsKey:
if err := json.Unmarshal([]byte(param.Value), &weights); err != nil {
return nil, fmt.Errorf("Parse %s param failed, weight should be []float, bug got: %s", WeightsParamsKey, param.Value)
}
for _, weight := range weights {
if weight < 0 || weight > 1 {
return nil, fmt.Errorf("rank param weight should be in range [0, 1]")
}
}
case NormScoreKey:
if needNorm, err = strconv.ParseBool(param.Value); err != nil {
return nil, fmt.Errorf("%s params must be true/false, bug got %s", NormScoreKey, param.Value)
}
}
}
if len(weights) == 0 {
return nil, fmt.Errorf(WeightsParamsKey + " not found")
}
if base.pkType == schemapb.DataType_Int64 {
return &WeightedFunction[int64]{RerankBase: *base, weight: weights, needNorm: needNorm}, nil
} else {
return &WeightedFunction[string]{RerankBase: *base, weight: weights, needNorm: needNorm}, nil
}
}
func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
if len(cols) != len(weighted.weight) {
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(cols)), fmt.Sprint(len(weighted.weight)), "the length of weights param mismatch with ann search requests")
}
weightedScores := map[T]float32{}
for i, col := range cols {
if col.size == 0 {
continue
}
normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i])
ids := col.ids.([]T)
for j, id := range ids {
if score, ok := weightedScores[id]; !ok {
weightedScores[id] = weighted.weight[i] * normFunc(col.scores[j])
} else {
weightedScores[id] = score + weighted.weight[i]*normFunc(col.scores[j])
}
}
}
if searchParams.isGrouping() {
return newGroupingIDScores(weightedScores, searchParams, idGroup)
}
return newIDScores(weightedScores, searchParams), nil
}
func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
outputs := newRerankOutputs(searchParams)
for _, cols := range inputs.data {
for i, col := range cols {
metricType := searchParams.searchMetrics[i]
for j, score := range col.scores {
col.scores[j] = toGreaterScore(score, metricType)
}
}
idScore, err := weighted.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
if err != nil {
return nil, err
}
appendResult(outputs, idScore.ids, idScore.scores)
}
return outputs, nil
}
type normalizeFunc func(float32) float32
func getNormalizeFunc(normScore bool, metrics string) normalizeFunc {
if !normScore {
return func(distance float32) float32 {
return distance
}
}
switch metrics {
case metric.COSINE:
return func(distance float32) float32 {
return (1 + distance) * 0.5
}
case metric.IP:
return func(distance float32) float32 {
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
}
case metric.BM25:
return func(distance float32) float32 {
return 2 * float32(math.Atan(float64(distance))) / math.Pi
}
default:
return func(distance float32) float32 {
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
}
}
}

View File

@ -0,0 +1,298 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
func TestWeightedFunction(t *testing.T) {
suite.Run(t, new(WeightedFunctionSuite))
}
type WeightedFunctionSuite struct {
suite.Suite
}
func (s *WeightedFunctionSuite) TestNewWeightedFuction() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: WeightsParamsKey, Value: `[0.1, 0.9]`},
},
}
{
_, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
}
{
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
_, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
}
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: "NotNum"}
_, err := newWeightedFunction(schema, functionSchema)
s.ErrorContains(err, "param failed, weight should be []float")
}
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[10]`}
_, err := newWeightedFunction(schema, functionSchema)
s.ErrorContains(err, "rank param weight should be in range [0, 1]")
}
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `[10]`}
_, err := newWeightedFunction(schema, functionSchema)
s.ErrorContains(err, "not found")
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[0.1, 0.9]`}
}
{
functionSchema.InputFieldNames = []string{"ts"}
_, err := newWeightedFunction(schema, functionSchema)
s.ErrorContains(err, "The weighted function does not support input parameters,")
}
}
func (s *WeightedFunctionSuite) TestWeightedFuctionProcess() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: WeightsParamsKey, Value: `[0.1]`},
},
}
// empty
{
nq := int64(1)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{}, ret.searchResultData.Topks)
}
// singleSearchResultData
// nq = 1
{
nq := int64(1)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{7, 6, 5}, ret.searchResultData.Ids.GetIntId().Data)
}
// nq = 3
{
nq := int64(3)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{9, 8, 7, 19, 18, 17, 29, 28, 27}, ret.searchResultData.Ids.GetIntId().Data)
}
// number of weigts not equal to search data
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[0.1, 0.9]`}
{
nq := int64(1)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
_, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.ErrorContains(err, "the length of weights param mismatch with ann search requests")
}
// has empty inputs
{
nq := int64(1)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
// id data: 0 - 9
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
// empty
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{9, 8, 7}, ret.searchResultData.Ids.GetIntId().Data)
s.True(function.FloatsAlmostEqual([]float32{0.9, 0.8, 0.7}, ret.searchResultData.Scores, 0.001))
}
// nq = 1
{
nq := int64(1)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
// id data: 0 - 9
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
// id data: 0 - 3
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{1, 9, 8}, ret.searchResultData.Ids.GetIntId().Data)
s.True(function.FloatsAlmostEqual([]float32{1, 0.9, 0.8}, ret.searchResultData.Scores, 0.001))
}
// // nq = 3
{
nq := int64(3)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
// nq1 id data: 0 - 9
// nq2 id data: 10 - 19
// nq3 id data: 20 - 29
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
// nq1 id data: 0 - 3
// nq2 id data: 4 - 7
// nq3 id data: 8 - 11
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, 1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{3, 2, 1, 7, 6, 5, 11, 10, 9}, ret.searchResultData.Ids.GetIntId().Data)
s.True(function.FloatsAlmostEqual([]float32{3, 2, 1, 6.3, 5.4, 4.5, 9.9, 9, 8.1}, ret.searchResultData.Scores, 0.001))
}
// // nq = 3 grouping = true, grouping size = 1
{
nq := int64(3)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
// nq1 id data: 0 - 9
// nq2 id data: 10 - 19
// nq3 id data: 20 - 29
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
// nq1 id data: 0 - 3
// nq2 id data: 4 - 7
// nq3 id data: 8 - 11
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, 1, 102, 1, true, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal([]int64{3, 2, 1, 7, 6, 5, 11, 10, 9}, ret.searchResultData.Ids.GetIntId().Data)
s.True(function.FloatsAlmostEqual([]float32{3, 2, 1, 6.3, 5.4, 4.5, 9.9, 9, 8.1}, ret.searchResultData.Scores, 0.001))
}
// // nq = 3 grouping = true, grouping size = 3
{
nq := int64(3)
f, err := newWeightedFunction(schema, functionSchema)
s.NoError(err)
// nq1 id data: 0 - 29, group value: 0,0,0,1,1,1, ... , 9,9,9
// nq2 id data: 30 - 59, group value: 10,10,10,11,11,11, ... , 19,19,19
// nq3 id data: 60 - 99, group value: 20,20,20,21,21,21, ... , 29,29,29
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
// nq1 id data: 0 - 11, group value: 0,0,0,1,1,1,2,2,2,3,3,3,
// nq2 id data: 12 - 23, group value: 4,4,4,5,5,5,6,6,6,7,7,7
// nq3 id data: 24 - 35, group value: 8,8,8,9,9,9,10,10,10,11,11,11
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, 102, 3, true, "", []string{"COSINE", "COSINE"}), inputs)
s.NoError(err)
s.Equal([]int64{9, 9, 9}, ret.searchResultData.Topks)
s.Equal(int64(9), ret.searchResultData.TopK)
s.Equal([]int64{
5, 4, 3, 29, 28, 27, 26, 25, 24,
17, 16, 15, 14, 13, 12, 59, 58, 57,
29, 28, 27, 26, 25, 24, 89, 88, 87,
},
ret.searchResultData.Ids.GetIntId().Data)
}
}
func (s *WeightedFunctionSuite) TestWeightedFuctionNormalize() {
{
f := getNormalizeFunc(false, metric.COSINE)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.COSINE)
s.Equal(float32((1+1.0)*0.5), f(1))
}
{
f := getNormalizeFunc(true, metric.IP)
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
}
{
f := getNormalizeFunc(true, metric.BM25)
s.Equal(float32(2*math.Atan(float64(1.0)))/math.Pi, f(1.0))
}
{
f := getNormalizeFunc(true, metric.L2)
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
}
}

View File

@ -213,6 +213,7 @@ func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() {
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
time.Sleep(2 * time.Second)
s.Cluster.AddQueryNode()
time.Sleep(20 * time.Second)

View File

@ -1588,11 +1588,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
}
)
vectors_to_search = rng.random((1, dim))
error = {ct.err_code: 1100,
ct.err_msg: f"Current rerank does not support grouping search: invalid parameter"}
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
group_by_field=ct.default_reranker_field_name,
check_task=CheckTasks.err_res, check_items=error)
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, group_by_field=ct.default_reranker_field_name)
@pytest.mark.tags(CaseLabel.L1)
def test_milvus_client_search_with_reranker_on_dynamic_fields(self):