mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: refactor rrf and weighted rerank (#42154)
https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
f3fe117840
commit
f1a4526bac
@ -1,246 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
type rankType int
|
||||
|
||||
const (
|
||||
invalidRankType rankType = iota // invalidRankType = 0
|
||||
rrfRankType // rrfRankType = 1
|
||||
weightedRankType // weightedRankType = 2
|
||||
udfExprRankType // udfExprRankType = 3
|
||||
)
|
||||
|
||||
var rankTypeMap = map[string]rankType{
|
||||
"invalid": invalidRankType,
|
||||
"rrf": rrfRankType,
|
||||
"weighted": weightedRankType,
|
||||
"expr": udfExprRankType,
|
||||
}
|
||||
|
||||
type reScorer interface {
|
||||
name() string
|
||||
scorerType() rankType
|
||||
reScore(input *milvuspb.SearchResults)
|
||||
setMetricType(metricType string)
|
||||
getMetricType() string
|
||||
}
|
||||
|
||||
type baseScorer struct {
|
||||
scorerName string
|
||||
metricType string
|
||||
}
|
||||
|
||||
func (bs *baseScorer) name() string {
|
||||
return bs.scorerName
|
||||
}
|
||||
|
||||
func (bs *baseScorer) setMetricType(metricType string) {
|
||||
bs.metricType = metricType
|
||||
}
|
||||
|
||||
func (bs *baseScorer) getMetricType() string {
|
||||
return bs.metricType
|
||||
}
|
||||
|
||||
type rrfScorer struct {
|
||||
baseScorer
|
||||
k float32
|
||||
}
|
||||
|
||||
func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) {
|
||||
index := 0
|
||||
for _, topk := range input.Results.GetTopks() {
|
||||
for i := int64(0); i < topk; i++ {
|
||||
input.Results.Scores[index] = 1 / (rs.k + float32(i+1))
|
||||
index++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *rrfScorer) scorerType() rankType {
|
||||
return rrfRankType
|
||||
}
|
||||
|
||||
type weightedScorer struct {
|
||||
baseScorer
|
||||
weight float32
|
||||
normScore bool
|
||||
}
|
||||
|
||||
type activateFunc func(float32) float32
|
||||
|
||||
func (ws *weightedScorer) getActivateFunc() activateFunc {
|
||||
if !ws.normScore {
|
||||
return func(distance float32) float32 {
|
||||
return distance
|
||||
}
|
||||
}
|
||||
mUpper := strings.ToUpper(ws.getMetricType())
|
||||
isCosine := mUpper == strings.ToUpper(metric.COSINE)
|
||||
isIP := mUpper == strings.ToUpper(metric.IP)
|
||||
isBM25 := mUpper == strings.ToUpper(metric.BM25)
|
||||
if isCosine {
|
||||
f := func(distance float32) float32 {
|
||||
return (1 + distance) * 0.5
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
if isIP {
|
||||
f := func(distance float32) float32 {
|
||||
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
if isBM25 {
|
||||
f := func(distance float32) float32 {
|
||||
return 2 * float32(math.Atan(float64(distance))) / math.Pi
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
f := func(distance float32) float32 {
|
||||
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
|
||||
activateF := ws.getActivateFunc()
|
||||
for i, distance := range input.Results.GetScores() {
|
||||
input.Results.Scores[i] = ws.weight * activateF(distance)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *weightedScorer) scorerType() rankType {
|
||||
return weightedRankType
|
||||
}
|
||||
|
||||
func NewReScorers(ctx context.Context, reqCnt int, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
|
||||
if reqCnt == 0 {
|
||||
return []reScorer{}, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
res := make([]reScorer, reqCnt)
|
||||
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
|
||||
if err != nil {
|
||||
log.Info("rank strategy not specified, use rrf instead")
|
||||
// if not set rank strategy, use rrf rank as default
|
||||
for i := 0; i < reqCnt; i++ {
|
||||
res[i] = &rrfScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "rrf",
|
||||
},
|
||||
k: float32(defaultRRFParamsValue),
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
if _, ok := rankTypeMap[rankTypeStr]; !ok {
|
||||
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
|
||||
}
|
||||
|
||||
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams)
|
||||
if err != nil {
|
||||
return nil, errors.New(RankParamsKey + " not found in rank_params")
|
||||
}
|
||||
|
||||
var params map[string]interface{}
|
||||
err = json.Unmarshal([]byte(paramStr), ¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch rankTypeMap[rankTypeStr] {
|
||||
case rrfRankType:
|
||||
_, ok := params[RRFParamsKey]
|
||||
if !ok {
|
||||
return nil, errors.New(RRFParamsKey + " not found in rank_params")
|
||||
}
|
||||
var k float64
|
||||
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
|
||||
k = reflect.ValueOf(params[RRFParamsKey]).Float()
|
||||
} else {
|
||||
return nil, errors.New("The type of rank param k should be float")
|
||||
}
|
||||
if k <= 0 || k >= maxRRFParamsValue {
|
||||
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
|
||||
}
|
||||
log.Debug("rrf params", zap.Float64("k", k))
|
||||
for i := 0; i < reqCnt; i++ {
|
||||
res[i] = &rrfScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "rrf",
|
||||
},
|
||||
k: float32(k),
|
||||
}
|
||||
}
|
||||
case weightedRankType:
|
||||
if _, ok := params[WeightsParamsKey]; !ok {
|
||||
return nil, errors.New(WeightsParamsKey + " not found in rank_params")
|
||||
}
|
||||
// normalize scores by default
|
||||
normScore := true
|
||||
if _, ok := params[NormScoreKey]; ok {
|
||||
normScore = params[NormScoreKey].(bool)
|
||||
}
|
||||
weights := make([]float32, 0)
|
||||
switch reflect.TypeOf(params[WeightsParamsKey]).Kind() {
|
||||
case reflect.Slice:
|
||||
rs := reflect.ValueOf(params[WeightsParamsKey])
|
||||
for i := 0; i < rs.Len(); i++ {
|
||||
v := rs.Index(i).Elem()
|
||||
if v.CanFloat() {
|
||||
weight := v.Float()
|
||||
if weight < 0 || weight > 1 {
|
||||
return nil, errors.New("rank param weight should be in range [0, 1]")
|
||||
}
|
||||
weights = append(weights, float32(weight))
|
||||
} else {
|
||||
return nil, errors.New("The type of rank param weight should be float")
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("The weights param should be an array")
|
||||
}
|
||||
|
||||
log.Debug("weights params", zap.Any("weights", weights), zap.Bool("norm_score", normScore))
|
||||
if reqCnt != len(weights) {
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
for i := 0; i < reqCnt; i++ {
|
||||
res[i] = &weightedScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "weighted",
|
||||
},
|
||||
weight: weights[i],
|
||||
normScore: normScore,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported rank type %s", rankTypeStr)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
)
|
||||
|
||||
func TestRescorer(t *testing.T) {
|
||||
t.Run("default scorer", func(t *testing.T) {
|
||||
rescorers, err := NewReScorers(context.TODO(), 2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
})
|
||||
|
||||
t.Run("rrf without param", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "k not found in rank_params")
|
||||
})
|
||||
|
||||
t.Run("rrf param out of range", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = -1
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.Error(t, err)
|
||||
|
||||
params[RRFParamsKey] = maxRRFParamsValue + 1
|
||||
b, err = json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams = []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("rrf", func(t *testing.T) {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = 61
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "rrf"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k)
|
||||
})
|
||||
|
||||
t.Run("weights without param", func(t *testing.T) {
|
||||
params := make(map[string][]float64)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found in rank_params")
|
||||
})
|
||||
|
||||
t.Run("weights out of range", func(t *testing.T) {
|
||||
weights := []float64{1.2, 2.3}
|
||||
params := make(map[string][]float64)
|
||||
params[WeightsParamsKey] = weights
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
|
||||
})
|
||||
|
||||
t.Run("weights with norm_score false", func(t *testing.T) {
|
||||
weights := []float64{0.5, 0.2}
|
||||
params := make(map[string]interface{})
|
||||
params[WeightsParamsKey] = weights
|
||||
params[NormScoreKey] = false
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
|
||||
assert.False(t, rescorers[0].(*weightedScorer).normScore)
|
||||
})
|
||||
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
weights := []float64{0.5, 0.2}
|
||||
params := make(map[string]interface{})
|
||||
params[WeightsParamsKey] = weights
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: RankTypeKey, Value: "weighted"},
|
||||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorers(context.TODO(), 2, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
||||
assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight)
|
||||
// normalize scores by default
|
||||
assert.True(t, rescorers[0].(*weightedScorer).normScore)
|
||||
})
|
||||
}
|
||||
@ -3,8 +3,6 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
@ -432,21 +430,6 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func rankSearchResultData(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
groupByFieldID int64,
|
||||
groupSize int64,
|
||||
groupScorer func(group *Group) error,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
if groupByFieldID > 0 {
|
||||
return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize)
|
||||
}
|
||||
return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults)
|
||||
}
|
||||
|
||||
func compareKey(keyI interface{}, keyJ interface{}) bool {
|
||||
switch keyI.(type) {
|
||||
case int64:
|
||||
@ -457,213 +440,6 @@ func compareKey(keyI interface{}, keyJ interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func GetGroupScorer(scorerType string) (func(group *Group) error, error) {
|
||||
switch scorerType {
|
||||
case MaxScorer:
|
||||
return func(group *Group) error {
|
||||
group.finalScore = group.maxScore
|
||||
return nil
|
||||
}, nil
|
||||
case SumScorer:
|
||||
return func(group *Group) error {
|
||||
group.finalScore = group.sumScore
|
||||
return nil
|
||||
}, nil
|
||||
case AvgScorer:
|
||||
return func(group *Group) error {
|
||||
if len(group.idList) == 0 {
|
||||
return merr.WrapErrParameterInvalid(1, len(group.idList),
|
||||
"input group for score must have at least one id, must be sth wrong within code")
|
||||
}
|
||||
group.finalScore = group.sumScore / float32(len(group.idList))
|
||||
return nil
|
||||
}, nil
|
||||
default:
|
||||
return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
|
||||
}
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
idList []interface{}
|
||||
scoreList []float32
|
||||
groupVal interface{}
|
||||
maxScore float32
|
||||
sumScore float32
|
||||
finalScore float32
|
||||
}
|
||||
|
||||
func rankSearchResultDataByGroup(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
groupScorer func(group *Group) error,
|
||||
groupSize int64,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
|
||||
// in the context of group by, the meaning for offset/limit/top refers to related numbers of group
|
||||
groupTopK := limit + offset
|
||||
log.Ctx(ctx).Debug("rankSearchResultDataByGroup",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
|
||||
var ret *milvuspb.SearchResults
|
||||
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
totalCount := limit * groupSize
|
||||
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
type accumulateIDGroupVal struct {
|
||||
accumulatedScore float32
|
||||
groupVal interface{}
|
||||
}
|
||||
|
||||
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
|
||||
}
|
||||
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
|
||||
for _, result := range searchResults {
|
||||
scores := result.GetResults().GetScores()
|
||||
start := 0
|
||||
// milvus has limits for the value range of nq and limit
|
||||
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
|
||||
groupByValIterator := typeutil.GetDataIterator(result.GetResults().GetGroupByFieldValue())
|
||||
for i := 0; i < int(nq); i++ {
|
||||
realTopK := int(result.GetResults().Topks[i])
|
||||
for j := start; j < start+realTopK; j++ {
|
||||
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
||||
groupByVal := groupByValIterator(j)
|
||||
if accumulatedScores[i][id] != nil {
|
||||
accumulatedScores[i][id].accumulatedScore += scores[j]
|
||||
} else {
|
||||
accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal}
|
||||
}
|
||||
}
|
||||
start += realTopK
|
||||
}
|
||||
}
|
||||
|
||||
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(groupByDataType, true, int(limit))
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
for i := int64(0); i < nq; i++ {
|
||||
idSet := accumulatedScores[i]
|
||||
keys := make([]interface{}, 0)
|
||||
for key := range idSet {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
big := func(i, j int) bool {
|
||||
scoreItemI := idSet[keys[i]]
|
||||
scoreItemJ := idSet[keys[j]]
|
||||
if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore {
|
||||
return compareKey(keys[i], keys[j])
|
||||
}
|
||||
return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore
|
||||
}
|
||||
sort.Slice(keys, big)
|
||||
|
||||
// separate keys into buckets according to groupVal
|
||||
buckets := make(map[interface{}]*Group)
|
||||
for _, key := range keys {
|
||||
scoreItem := idSet[key]
|
||||
groupVal := scoreItem.groupVal
|
||||
if buckets[groupVal] == nil {
|
||||
buckets[groupVal] = &Group{
|
||||
idList: make([]interface{}, 0),
|
||||
scoreList: make([]float32, 0),
|
||||
groupVal: groupVal,
|
||||
}
|
||||
}
|
||||
if int64(len(buckets[groupVal].idList)) >= groupSize {
|
||||
// only consider group size results in each group
|
||||
continue
|
||||
}
|
||||
buckets[groupVal].idList = append(buckets[groupVal].idList, key)
|
||||
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore)
|
||||
if scoreItem.accumulatedScore > buckets[groupVal].maxScore {
|
||||
buckets[groupVal].maxScore = scoreItem.accumulatedScore
|
||||
}
|
||||
buckets[groupVal].sumScore += scoreItem.accumulatedScore
|
||||
}
|
||||
if int64(len(buckets)) <= offset {
|
||||
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||
continue
|
||||
}
|
||||
|
||||
groupList := make([]*Group, len(buckets))
|
||||
idx := 0
|
||||
for _, group := range buckets {
|
||||
groupScorer(group)
|
||||
groupList[idx] = group
|
||||
idx += 1
|
||||
}
|
||||
sort.Slice(groupList, func(i, j int) bool {
|
||||
if groupList[i].finalScore == groupList[j].finalScore {
|
||||
if len(groupList[i].idList) == len(groupList[j].idList) {
|
||||
// if final score and size of group are both equal
|
||||
// choose the group with smaller first key
|
||||
// here, it's guaranteed all group having at least one id in the idList
|
||||
return compareKey(groupList[i].idList[0], groupList[j].idList[0])
|
||||
}
|
||||
// choose the larger group when scores are equal
|
||||
return len(groupList[i].idList) > len(groupList[j].idList)
|
||||
}
|
||||
return groupList[i].finalScore > groupList[j].finalScore
|
||||
})
|
||||
|
||||
if int64(len(groupList)) > groupTopK {
|
||||
groupList = groupList[:groupTopK]
|
||||
}
|
||||
returnedRowNum := 0
|
||||
for index := int(offset); index < len(groupList); index++ {
|
||||
group := groupList[index]
|
||||
for i, score := range group.scoreList {
|
||||
// idList and scoreList must have same length
|
||||
typeutil.AppendPKs(ret.Results.Ids, group.idList[i])
|
||||
if roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
gpFieldBuilder.Add(group.groupVal)
|
||||
}
|
||||
returnedRowNum += len(group.idList)
|
||||
}
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
|
||||
}
|
||||
|
||||
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
@ -684,94 +460,6 @@ func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType sch
|
||||
return nil
|
||||
}
|
||||
|
||||
func rankSearchResultDataByPk(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
|
||||
topk := limit + offset
|
||||
log.Ctx(ctx).Debug("rankSearchResultDataByPk",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
|
||||
var ret *milvuspb.SearchResults
|
||||
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// []map[id]score
|
||||
accumulatedScores := make([]map[interface{}]float32, nq)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
accumulatedScores[i] = make(map[interface{}]float32)
|
||||
}
|
||||
|
||||
for _, result := range searchResults {
|
||||
scores := result.GetResults().GetScores()
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
realTopk := result.GetResults().Topks[i]
|
||||
for j := start; j < start+realTopk; j++ {
|
||||
id := typeutil.GetPK(result.GetResults().GetIds(), j)
|
||||
accumulatedScores[i][id] += scores[j]
|
||||
}
|
||||
start += realTopk
|
||||
}
|
||||
}
|
||||
|
||||
for i := int64(0); i < nq; i++ {
|
||||
idSet := accumulatedScores[i]
|
||||
keys := make([]interface{}, 0)
|
||||
for key := range idSet {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if int64(len(keys)) <= offset {
|
||||
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||
continue
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
big := func(i, j int) bool {
|
||||
if idSet[keys[i]] == idSet[keys[j]] {
|
||||
return compareKey(keys[i], keys[j])
|
||||
}
|
||||
return idSet[keys[i]] > idSet[keys[j]]
|
||||
}
|
||||
|
||||
sort.Slice(keys, big)
|
||||
|
||||
if int64(len(keys)) > topk {
|
||||
keys = keys[:topk]
|
||||
}
|
||||
|
||||
// set real topk
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
|
||||
// append id and score
|
||||
for index := offset; index < int64(len(keys)); index++ {
|
||||
typeutil.AppendPKs(ret.Results.Ids, keys[index])
|
||||
score := idSet[keys[index]]
|
||||
if roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success("search result is empty"),
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
@ -52,86 +51,6 @@ func genTestDataSearchResultsData() []*schemapb.SearchResultData {
|
||||
return []*schemapb.SearchResultData{searchResultData1, searchResultData2}
|
||||
}
|
||||
|
||||
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||
data := genTestDataSearchResultsData()
|
||||
searchResults := []*milvuspb.SearchResults{
|
||||
{Results: data[0]},
|
||||
{Results: data[1]},
|
||||
}
|
||||
|
||||
nq := int64(1)
|
||||
limit := int64(3)
|
||||
offset := int64(0)
|
||||
roundDecimal := int64(1)
|
||||
groupSize := int64(3)
|
||||
groupByFieldId := int64(101)
|
||||
rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal}
|
||||
|
||||
{
|
||||
// test for sum group scorer
|
||||
scorerType := "sum"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for max group scorer
|
||||
scorerType := "max"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for avg group scorer
|
||||
scorerType := "avg"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for offset for ranking group
|
||||
scorerType := "avg"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankParams.offset = 2
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores())
|
||||
struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||
}
|
||||
|
||||
{
|
||||
// test for offset exceeding the count of final groups
|
||||
scorerType := "avg"
|
||||
groupScorer, _ := GetGroupScorer(scorerType)
|
||||
rankParams.offset = 4
|
||||
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||
struts.Equal([]float32{}, rankedRes.GetResults().GetScores())
|
||||
}
|
||||
|
||||
{
|
||||
// test for invalid group scorer
|
||||
scorerType := "xxx"
|
||||
groupScorer, err := GetGroupScorer(scorerType)
|
||||
struts.Error(err)
|
||||
struts.Nil(groupScorer)
|
||||
}
|
||||
}
|
||||
|
||||
func (struts *SearchReduceUtilTestSuite) TestReduceSearchResult() {
|
||||
data := genTestDataSearchResultsData()
|
||||
|
||||
|
||||
@ -537,6 +537,14 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getGroupScorerStr(params []*commonpb.KeyValuePair) string {
|
||||
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, params)
|
||||
if err != nil {
|
||||
groupScorerStr = MaxScorer
|
||||
}
|
||||
return groupScorerStr
|
||||
}
|
||||
|
||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
||||
ret := &milvuspb.SearchRequest{
|
||||
Base: req.GetBase(),
|
||||
|
||||
@ -88,10 +88,6 @@ type searchTask struct {
|
||||
queryInfos []*planpb.QueryInfo
|
||||
relatedDataSize int64
|
||||
|
||||
// Will be deprecated, use functionScore after milvus 2.6
|
||||
reScorers []reScorer
|
||||
groupScorer func(group *Group) error
|
||||
|
||||
// New reranker functions
|
||||
functionScore *rerank.FunctionScore
|
||||
rankParams *rankParams
|
||||
@ -378,22 +374,10 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||
if t.functionScore, err = rerank.NewFunctionScoreWithlegacy(t.schema.CollectionSchema, t.request.GetSearchParams()); err != nil {
|
||||
log.Warn("Failed to create function by legacy info", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// set up groupScorer for hybridsearch+groupBy
|
||||
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
groupScorerStr = MaxScorer
|
||||
}
|
||||
groupScorer, err := GetGroupScorer(groupScorerStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.groupScorer = groupScorer
|
||||
}
|
||||
|
||||
t.needRequery = len(t.request.OutputFields) > 0 || len(t.functionScore.GetAllInputFieldNames()) > 0
|
||||
@ -544,22 +528,12 @@ func (t *searchTask) advancedPostProcess(ctx context.Context, span trace.Span, t
|
||||
return err
|
||||
}
|
||||
|
||||
if t.functionScore == nil {
|
||||
t.reScorers[index].setMetricType(subMetricType)
|
||||
t.reScorers[index].reScore(result)
|
||||
}
|
||||
searchMetrics = append(searchMetrics, subMetricType)
|
||||
multipleMilvusResults[index] = result
|
||||
}
|
||||
|
||||
if t.functionScore == nil {
|
||||
if err := t.rank(ctx, span, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.hybridSearchRank(ctx, span, multipleMilvusResults, searchMetrics); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
@ -583,43 +557,6 @@ func (t *searchTask) fillResult() {
|
||||
t.fillInFieldInfo()
|
||||
}
|
||||
|
||||
// TODO: Old version rerank: rrf/weighted, subsequent unified rerank implementation
|
||||
func (t *searchTask) rank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults) error {
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
multipleMilvusResults,
|
||||
t.SearchRequest.GetGroupByFieldId(),
|
||||
t.SearchRequest.GetGroupSize(),
|
||||
t.groupScorer); err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if t.needRequery {
|
||||
if t.requeryFunc == nil {
|
||||
t.requeryFunc = requeryImpl
|
||||
}
|
||||
queryResult, err := t.requeryFunc(t, span, t.result.Results.Ids, t.translatedOutputFields)
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
|
||||
uniqueIDs := &schemapb.IDs{}
|
||||
count := 0
|
||||
@ -659,10 +596,10 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
|
||||
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
|
||||
defer sp.End()
|
||||
|
||||
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, groupScorerStr, searchMetrics,
|
||||
)
|
||||
return t.functionScore.Process(ctx, params, results)
|
||||
}
|
||||
@ -703,6 +640,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
|
||||
for i := 0; i < len(multipleMilvusResults); i++ {
|
||||
multipleMilvusResults[i].Results.FieldsData = fields[i]
|
||||
}
|
||||
|
||||
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -838,8 +776,9 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR
|
||||
{
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
|
||||
defer sp.End()
|
||||
groupScorerStr := getGroupScorerStr(t.request.GetSearchParams())
|
||||
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
|
||||
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
|
||||
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, groupScorerStr, []string{metricType})
|
||||
// rank only returns id and score
|
||||
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
|
||||
return err
|
||||
|
||||
@ -1000,7 +1000,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
enqueueTs := uint64(100000)
|
||||
st.SetTs(enqueueTs)
|
||||
assert.ErrorContains(t, st.PreExecute(ctx), "Current rerank does not support grouping search")
|
||||
assert.NoError(t, st.PreExecute(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -77,9 +77,6 @@ const (
|
||||
|
||||
// DefaultStringIndexType name of default index type for varChar/string field
|
||||
DefaultStringIndexType = indexparamcheck.IndexINVERTED
|
||||
|
||||
defaultRRFParamsValue = 60
|
||||
maxRRFParamsValue = 16384
|
||||
)
|
||||
|
||||
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
|
||||
@ -427,7 +424,6 @@ func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchem
|
||||
if !exist {
|
||||
return fmt.Errorf("type param(max_capacity) should be specified for array field %s of collection %s", field.GetName(), collectionName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
//go:build test
|
||||
// +build test
|
||||
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
@ -22,11 +25,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/ali"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/cohere"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/openai"
|
||||
@ -34,6 +39,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/tei"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
)
|
||||
|
||||
const TestModel string = "TestModel"
|
||||
@ -247,3 +253,56 @@ func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockrunt
|
||||
body, _ := json.Marshal(resp)
|
||||
return &bedrockruntime.InvokeModelOutput{Body: body}, nil
|
||||
}
|
||||
|
||||
func GenSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData {
|
||||
tops := make([]int64, nq)
|
||||
for i := 0; i < int(nq); i++ {
|
||||
tops[i] = topk
|
||||
}
|
||||
fieldsData := []*schemapb.FieldData{}
|
||||
if fieldName != "" {
|
||||
fieldsData = []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))}
|
||||
fieldsData[0].FieldId = fieldId
|
||||
}
|
||||
|
||||
data := &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: testutils.GenerateInt64Array(int(nq * topk)),
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: tops,
|
||||
FieldsData: fieldsData,
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func GenSearchResultDataWithGrouping(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64, groupingName string, groupingId int64, groupSize int64) *schemapb.SearchResultData {
|
||||
data := GenSearchResultData(nq, topk*groupSize, dType, fieldName, fieldId)
|
||||
values := make([]int64, 0)
|
||||
for i := int64(0); i < nq*topk*groupSize; i += groupSize {
|
||||
for j := int64(0); j < groupSize; j++ {
|
||||
values = append(values, i)
|
||||
}
|
||||
}
|
||||
groupingField := testutils.GenerateScalarFieldDataWithValue(schemapb.DataType_Int64, groupingName, groupingId, values)
|
||||
data.GroupByFieldValue = groupingField
|
||||
return data
|
||||
}
|
||||
|
||||
func FloatsAlmostEqual(a, b []float32, epsilon float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if float32(math.Abs(float64(a[i]-b[i]))) > epsilon {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
|
||||
}
|
||||
|
||||
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -168,7 +168,7 @@ func toGreaterScore(score float32, metricType string) float32 {
|
||||
}
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns) *IDScores[T] {
|
||||
func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
|
||||
srcScores := maxMerge[T](cols)
|
||||
decayScores := map[T]float32{}
|
||||
for _, col := range cols {
|
||||
@ -186,7 +186,10 @@ func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, sear
|
||||
for id := range decayScores {
|
||||
decayScores[id] = decayScores[id] * srcScores[id]
|
||||
}
|
||||
return newIDScores(decayScores, searchParams)
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(decayScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(decayScores, searchParams), nil
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
@ -198,7 +201,10 @@ func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *Sea
|
||||
col.scores[j] = toGreaterScore(score, metricType)
|
||||
}
|
||||
}
|
||||
idScore := decay.processOneSearchData(ctx, searchParams, cols)
|
||||
idScore, err := decay.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appendResult(outputs, idScore.ids, idScore.scores)
|
||||
}
|
||||
return outputs, nil
|
||||
|
||||
@ -27,7 +27,7 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
)
|
||||
|
||||
func TestDecayFunction(t *testing.T) {
|
||||
@ -260,8 +260,8 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{}, ret.searchResultData.Topks)
|
||||
@ -271,10 +271,10 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
||||
s.NoError(err)
|
||||
|
||||
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
|
||||
}
|
||||
|
||||
@ -289,9 +289,9 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
nq := int64(1)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -302,9 +302,9 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
nq := int64(3)
|
||||
f, err := newDecayFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -329,11 +329,11 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
f, err := newDecayFunction(schema, functionSchema2)
|
||||
s.NoError(err)
|
||||
// ts/id data: 0 - 9
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
// empty
|
||||
data2 := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
|
||||
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -345,11 +345,12 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
f, err := newDecayFunction(schema, functionSchema2)
|
||||
s.NoError(err)
|
||||
// ts/id data: 0 - 9
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
// ts/id data: 0 - 3
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -363,13 +364,13 @@ func (s *DecayFunctionSuite) TestRerankProcess() {
|
||||
// nq1 ts/id data: 0 - 9
|
||||
// nq2 ts/id data: 10 - 19
|
||||
// nq3 ts/id data: 20 - 29
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
// nq1 ts/id data: 0 - 3
|
||||
// nq2 ts/id data: 4 - 7
|
||||
// nq3 ts/id data: 8 - 11
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "ts", 102)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -390,26 +391,3 @@ func (s *DecayFunctionSuite) TestDecay() {
|
||||
s.Equal(linearDecay(0, 1, 0.5, 5, 5), 1.0)
|
||||
s.Less(linearDecay(0, 1, 0.5, 5, 6), 1.0)
|
||||
}
|
||||
|
||||
func genSearchResultData(nq int64, topk int64, dType schemapb.DataType, fieldName string, fieldId int64) *schemapb.SearchResultData {
|
||||
tops := make([]int64, nq)
|
||||
for i := 0; i < int(nq); i++ {
|
||||
tops[i] = topk
|
||||
}
|
||||
data := &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
Scores: testutils.GenerateFloat32Array(int(nq * topk)),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: testutils.GenerateInt64Array(int(nq * topk)),
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: tops,
|
||||
FieldsData: []*schemapb.FieldData{testutils.GenerateScalarFieldData(dType, fieldName, int(nq*topk))},
|
||||
}
|
||||
data.FieldsData[0].FieldId = fieldId
|
||||
return data
|
||||
}
|
||||
|
||||
@ -20,38 +20,79 @@ package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
const (
|
||||
decayFunctionName string = "decay"
|
||||
modelFunctionName string = "model"
|
||||
rrfName string = "rrf"
|
||||
weightedName string = "weighted"
|
||||
)
|
||||
|
||||
const (
|
||||
maxScorer string = "max"
|
||||
sumScorer string = "sum"
|
||||
avgScorer string = "avg"
|
||||
)
|
||||
|
||||
// legacy rrf/weighted rerank configs
|
||||
|
||||
const (
|
||||
legacyRankTypeKey = "strategy"
|
||||
legacyRankParamsKey = "params"
|
||||
)
|
||||
|
||||
type rankType int
|
||||
|
||||
const (
|
||||
invalidRankType rankType = iota // invalidRankType = 0
|
||||
rrfRankType // rrfRankType = 1
|
||||
weightedRankType // weightedRankType = 2
|
||||
)
|
||||
|
||||
var rankTypeMap = map[string]rankType{
|
||||
"invalid": invalidRankType,
|
||||
"rrf": rrfRankType,
|
||||
"weighted": weightedRankType,
|
||||
}
|
||||
|
||||
type SearchParams struct {
|
||||
nq int64
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
|
||||
// TODO: supports group search
|
||||
groupByFieldId int64
|
||||
groupSize int64
|
||||
strictGroupSize bool
|
||||
groupScore string
|
||||
|
||||
searchMetrics []string
|
||||
}
|
||||
|
||||
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, searchMetrics []string) *SearchParams {
|
||||
func (s *SearchParams) isGrouping() bool {
|
||||
return s.groupByFieldId > 0
|
||||
}
|
||||
|
||||
func NewSearchParams(nq, limit, offset, roundDecimal, groupByFieldId, groupSize int64, strictGroupSize bool, groupScore string, searchMetrics []string) *SearchParams {
|
||||
if groupScore == "" {
|
||||
groupScore = maxScorer
|
||||
}
|
||||
return &SearchParams{
|
||||
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, searchMetrics,
|
||||
nq, limit, offset, roundDecimal, groupByFieldId, groupSize, strictGroupSize, groupScore, searchMetrics,
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,6 +136,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
|
||||
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
|
||||
case modelFunctionName:
|
||||
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
|
||||
case rrfName:
|
||||
rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema)
|
||||
case weightedName:
|
||||
rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
|
||||
}
|
||||
@ -117,6 +162,68 @@ func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *sc
|
||||
return funcScore, nil
|
||||
}
|
||||
|
||||
func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) {
|
||||
var params map[string]interface{}
|
||||
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankTypeKey, rankParams)
|
||||
if err != nil {
|
||||
rankTypeStr = "rrf"
|
||||
params = make(map[string]interface{}, 0)
|
||||
} else {
|
||||
if _, ok := rankTypeMap[rankTypeStr]; !ok {
|
||||
return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr)
|
||||
}
|
||||
paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankParamsKey, rankParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("params" + " not found in rank_params")
|
||||
}
|
||||
err = json.Unmarshal([]byte(paramStr), ¶ms)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Parse rerank params failed, err: %s", err)
|
||||
}
|
||||
}
|
||||
fSchema := schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{},
|
||||
OutputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{},
|
||||
}
|
||||
switch rankTypeMap[rankTypeStr] {
|
||||
case rrfRankType:
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: rrfName})
|
||||
if v, ok := params[RRFParamsKey]; ok {
|
||||
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
|
||||
k := reflect.ValueOf(v).Float()
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: RRFParamsKey, Value: strconv.FormatFloat(k, 'f', -1, 64)})
|
||||
} else {
|
||||
return nil, fmt.Errorf("The type of rank param k should be float")
|
||||
}
|
||||
}
|
||||
case weightedRankType:
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: weightedName})
|
||||
if v, ok := params[WeightsParamsKey]; ok {
|
||||
if d, err := json.Marshal(v); err != nil {
|
||||
return nil, fmt.Errorf("The weights param should be an array")
|
||||
} else {
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: string(d)})
|
||||
}
|
||||
}
|
||||
if normScore, ok := params[NormScoreKey]; ok {
|
||||
if ns, ok := normScore.(bool); ok {
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: NormScoreKey, Value: strconv.FormatBool(ns)})
|
||||
} else {
|
||||
return nil, fmt.Errorf("Weighted rerank err, norm_score should been bool type, but [norm_score:%s]'s type is %T", normScore, normScore)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr)
|
||||
}
|
||||
funcScore := &FunctionScore{}
|
||||
if funcScore.reranker, err = createFunction(collSchema, &fSchema); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return funcScore, nil
|
||||
}
|
||||
|
||||
func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchParams, multipleMilvusResults []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
||||
if len(multipleMilvusResults) == 0 {
|
||||
return &milvuspb.SearchResults{
|
||||
@ -137,7 +244,7 @@ func (fScore *FunctionScore) Process(ctx context.Context, searchParams *SearchPa
|
||||
})
|
||||
|
||||
// rankResult only has scores
|
||||
inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs())
|
||||
inputs, err := newRerankInputs(allSearchResultData, fScore.reranker.GetInputFieldIDs(), searchParams.isGrouping())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
@ -27,6 +28,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
)
|
||||
|
||||
func TestFunctionScore(t *testing.T) {
|
||||
@ -75,7 +77,7 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() {
|
||||
s.NoError(err)
|
||||
s.Equal([]string{"ts"}, f.GetAllInputFieldNames())
|
||||
s.Equal([]int64{102}, f.GetAllInputFieldIDs())
|
||||
s.Equal(false, f.IsSupportGroup())
|
||||
s.Equal(true, f.IsSupportGroup())
|
||||
s.Equal("decay", f.reranker.GetRankName())
|
||||
|
||||
{
|
||||
@ -152,7 +154,7 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
// empty inputs
|
||||
{
|
||||
nq := int64(1)
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal(0, len(ret.Results.FieldsData))
|
||||
@ -162,11 +164,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3}, ret.Results.Topks)
|
||||
@ -174,11 +176,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
// nq=1, input is empty
|
||||
{
|
||||
nq := int64(1)
|
||||
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
data := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0}, ret.Results.Topks)
|
||||
@ -186,11 +188,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
// nq=3
|
||||
{
|
||||
nq := int64(3)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
|
||||
@ -198,11 +200,11 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
// nq=3, all input is empty
|
||||
{
|
||||
nq := int64(3)
|
||||
data := genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
data := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102)
|
||||
searchData := &milvuspb.SearchResults{
|
||||
Results: data,
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, []*milvuspb.SearchResults{searchData})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), []*milvuspb.SearchResults{searchData})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
|
||||
@ -213,13 +215,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
{
|
||||
nq := int64(1)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3}, ret.Results.Topks)
|
||||
@ -228,13 +230,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
{
|
||||
nq := int64(1)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0}, ret.Results.Topks)
|
||||
@ -243,13 +245,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
{
|
||||
nq := int64(1)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3}, ret.Results.Topks)
|
||||
@ -258,13 +260,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
{
|
||||
nq := int64(3)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 20, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
|
||||
@ -273,13 +275,13 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
{
|
||||
nq := int64(3)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{0, 0, 0}, ret.Results.Topks)
|
||||
@ -288,15 +290,131 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
{
|
||||
nq := int64(3)
|
||||
searchData1 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
|
||||
searchData2 := &milvuspb.SearchResults{
|
||||
Results: genSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
Results: function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "ts", 102),
|
||||
}
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), []*milvuspb.SearchResults{searchData1, searchData2})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.Results.TopK)
|
||||
s.Equal([]int64{3, 3, 3}, ret.Results.Topks)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FunctionScoreSuite) TestlegacyFunction() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{}
|
||||
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
s.Equal(f.reranker.GetRankName(), rrfName)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "invalid"},
|
||||
{Key: legacyRankParamsKey, Value: `{"k": "v"}`},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.ErrorContains(err, "unsupported rank type")
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "rrf"},
|
||||
{Key: legacyRankParamsKey, Value: "invalid"},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.ErrorContains(err, "Parse rerank params failed")
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "rrf"},
|
||||
{Key: legacyRankParamsKey, Value: `{"k": "invalid"}`},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.ErrorContains(err, "The type of rank param k should be float")
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "rrf"},
|
||||
{Key: legacyRankParamsKey, Value: `{"k": 1.0}`},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "weighted"},
|
||||
{Key: legacyRankParamsKey, Value: `{"weights": [1.0]}`},
|
||||
}
|
||||
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
s.Equal(f.reranker.GetRankName(), weightedName)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "weighted"},
|
||||
{Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": "Invalid"}`},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.ErrorContains(err, "Weighted rerank err, norm_score should been bool type")
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "weighted"},
|
||||
{Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": false}`},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: legacyRankTypeKey, Value: "weighted"},
|
||||
{Key: legacyRankParamsKey, Value: `{"weights": [1.0], "norm_score": "false"}`},
|
||||
}
|
||||
_, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.ErrorContains(err, "Weighted rerank err, norm_score should been bool type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FunctionScoreSuite) TestFunctionUtil() {
|
||||
g1 := &Group[int64]{
|
||||
idList: []int64{1, 2, 3},
|
||||
scoreList: []float32{1.0, 2.0, 3.0},
|
||||
groupVal: 3,
|
||||
maxScore: 3.0,
|
||||
sumScore: 6.0,
|
||||
}
|
||||
s1, err := groupScore(g1, maxScorer)
|
||||
s.NoError(err)
|
||||
s.True(math.Abs(float64(s1-3.0)) < 0.001)
|
||||
|
||||
s2, err := groupScore(g1, sumScorer)
|
||||
s.NoError(err)
|
||||
s.True(math.Abs(float64(s2-6.0)) < 0.001)
|
||||
|
||||
s3, err := groupScore(g1, avgScorer)
|
||||
s.NoError(err)
|
||||
s.True(math.Abs(float64(s3-2.0)) < 0.001)
|
||||
|
||||
_, err = groupScore(g1, "NotSupported")
|
||||
s.ErrorContains(err, "is not supported")
|
||||
|
||||
g1.idList = []int64{}
|
||||
_, err = groupScore(g1, avgScorer)
|
||||
s.ErrorContains(err, "input group for score must have at least one id, must be sth wrong within code")
|
||||
}
|
||||
|
||||
@ -296,7 +296,7 @@ type ModelFunction[T PKType] struct {
|
||||
}
|
||||
|
||||
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -333,7 +333,7 @@ func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemap
|
||||
}
|
||||
}
|
||||
|
||||
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) {
|
||||
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
|
||||
uniqueData := make(map[T]string)
|
||||
for _, col := range cols {
|
||||
texts := col.data[0].([]string)
|
||||
@ -359,6 +359,9 @@ func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchP
|
||||
for idx, id := range ids {
|
||||
rerankScores[id] = scores[idx]
|
||||
}
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(rerankScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(rerankScores, searchParams), nil
|
||||
}
|
||||
|
||||
@ -368,7 +371,7 @@ func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *Search
|
||||
}
|
||||
outputs := newRerankOutputs(searchParams)
|
||||
for idx, cols := range inputs.data {
|
||||
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols)
|
||||
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols, inputs.idGroupValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -375,8 +375,8 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
nq := int64(1)
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{}, ret.searchResultData.Topks)
|
||||
@ -386,10 +386,10 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
||||
s.NoError(err)
|
||||
|
||||
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
|
||||
}
|
||||
}
|
||||
@ -430,18 +430,18 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
{
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
_, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
_, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[2].Value = `["q1"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 0, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -452,9 +452,9 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -468,11 +468,11 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
functionSchema.Params[2].Value = `["q1"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
// empty
|
||||
data2 := genSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -483,11 +483,11 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// ts/id data: 0 - 9
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
// ts/id data: 0 - 3
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
@ -498,10 +498,10 @@ func (s *RerankModelSuite) TestRerankProcess() {
|
||||
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
|
||||
101
internal/util/function/rerank/rrf_function.go
Normal file
101
internal/util/function/rerank/rrf_function.go
Normal file
@ -0,0 +1,101 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
const (
|
||||
RRFParamsKey string = "k"
|
||||
|
||||
defaultRRFParamsValue float64 = 60
|
||||
)
|
||||
|
||||
type RRFFunction[T PKType] struct {
|
||||
RerankBase
|
||||
|
||||
k float32
|
||||
}
|
||||
|
||||
func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, rrfName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(base.GetInputFieldNames()) != 0 {
|
||||
return nil, fmt.Errorf("The rrf function does not support input parameters, but got %s", base.GetInputFieldNames())
|
||||
}
|
||||
|
||||
k := float64(defaultRRFParamsValue)
|
||||
for _, param := range funcSchema.Params {
|
||||
if strings.ToLower(param.Key) == RRFParamsKey {
|
||||
if k, err = strconv.ParseFloat(param.Value, 64); err != nil {
|
||||
return nil, fmt.Errorf("Param k:%s is not a number", param.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if k <= 0 || k >= 16384 {
|
||||
return nil, fmt.Errorf("The rank params k should be in range (0, %d)", 16384)
|
||||
}
|
||||
if base.pkType == schemapb.DataType_Int64 {
|
||||
return &RRFFunction[int64]{RerankBase: *base, k: float32(k)}, nil
|
||||
} else {
|
||||
return &RRFFunction[string]{RerankBase: *base, k: float32(k)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rrf *RRFFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
|
||||
rrfScores := map[T]float32{}
|
||||
for _, col := range cols {
|
||||
if col.size == 0 {
|
||||
continue
|
||||
}
|
||||
ids := col.ids.([]T)
|
||||
for idx, id := range ids {
|
||||
if score, ok := rrfScores[id]; !ok {
|
||||
rrfScores[id] = 1 / (rrf.k + float32(idx+1))
|
||||
} else {
|
||||
rrfScores[id] = score + 1/(rrf.k+float32(idx+1))
|
||||
}
|
||||
}
|
||||
}
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(rrfScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(rrfScores, searchParams), nil
|
||||
}
|
||||
|
||||
func (rrf *RRFFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
outputs := newRerankOutputs(searchParams)
|
||||
for _, cols := range inputs.data {
|
||||
idScore, err := rrf.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appendResult(outputs, idScore.ids, idScore.scores)
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
250
internal/util/function/rerank/rrf_function_test.go
Normal file
250
internal/util/function/rerank/rrf_function_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
)
|
||||
|
||||
func TestRRFFunction(t *testing.T) {
|
||||
suite.Run(t, new(RRFFunctionSuite))
|
||||
}
|
||||
|
||||
type RRFFunctionSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *RRFFunctionSuite) TestNewRRFFuction() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: RRFParamsKey, Value: "70"},
|
||||
},
|
||||
}
|
||||
|
||||
{
|
||||
_, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
|
||||
_, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "NotNum"}
|
||||
_, err := newRRFFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "is not a number")
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "-1"}
|
||||
_, err = newRRFFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "he rank params k should be in range")
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: RRFParamsKey, Value: "100"}
|
||||
}
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts"}
|
||||
_, err := newRRFFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "The rrf function does not support input parameters")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RRFFunctionSuite) TestRRFFuctionProcess() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{},
|
||||
}
|
||||
|
||||
// empty
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{}, ret.searchResultData.Topks)
|
||||
}
|
||||
|
||||
// singleSearchResultData
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{2, 3, 4}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
// nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{2, 3, 4, 12, 13, 14, 22, 23, 24}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
|
||||
// has empty inputs
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// id data: 0 - 9
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
// empty
|
||||
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{0, 1, 2}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// id data: 0 - 9
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
// id data: 0 - 3
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{2, 3, 4}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
// // nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// nq1 id data: 0 - 9
|
||||
// nq2 id data: 10 - 19
|
||||
// nq3 id data: 20 - 29
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
// nq1 id data: 0 - 3
|
||||
// nq2 id data: 4 - 7
|
||||
// nq3 id data: 8 - 11
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{2, 3, 4, 5, 11, 6, 9, 21, 10}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
// // nq = 3, grouping = true, grouping size = 1
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// nq1 id data: 0 - 9
|
||||
// nq2 id data: 10 - 19
|
||||
// nq3 id data: 20 - 29
|
||||
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
|
||||
// nq1 id data: 0 - 3
|
||||
// nq2 id data: 4 - 7
|
||||
// nq3 id data: 8 - 11
|
||||
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, 102, 1, true, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{2, 3, 4, 5, 11, 6, 9, 21, 10}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
|
||||
// // nq = 3, grouping = true, grouping size = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newRRFFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
|
||||
// nq1 id data: 0 - 29, group value: 0,0,0,1,1,1, ... , 9,9,9
|
||||
// nq2 id data: 30 - 59, group value: 10,10,10,11,11,11, ... , 19,19,19
|
||||
// nq3 id data: 60 - 99, group value: 20,20,20,21,21,21, ... , 29,29,29
|
||||
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
|
||||
// nq1 id data: 0 - 11, group value: 0,0,0,1,1,1,2,2,2,3,3,3,
|
||||
// nq2 id data: 12 - 23, group value: 4,4,4,5,5,5,6,6,6,7,7,7
|
||||
// nq3 id data: 24 - 35, group value: 8,8,8,9,9,9,10,10,10,11,11,11
|
||||
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 3, 102, 3, true, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{9, 9, 9}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(9), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{
|
||||
6, 7, 8, 9, 10, 11, 12, 13, 14,
|
||||
15, 16, 17, 33, 34, 35, 18, 19, 20,
|
||||
27, 28, 29, 63, 64, 65, 30, 31, 32,
|
||||
},
|
||||
ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
}
|
||||
@ -24,6 +24,8 @@ import (
|
||||
"sort"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type PKType interface {
|
||||
@ -40,9 +42,9 @@ type columns struct {
|
||||
|
||||
type rerankInputs struct {
|
||||
// nqs,searchResultsIndex
|
||||
data [][]*columns
|
||||
|
||||
nq int64
|
||||
data [][]*columns
|
||||
idGroupValue map[any]any
|
||||
nq int64
|
||||
|
||||
// There is only fieldId in schemapb.SearchResultData, but no fieldName
|
||||
inputFieldIds []int64
|
||||
@ -69,7 +71,7 @@ func organizeFieldIdData(multipSearchResultData []*schemapb.SearchResultData, in
|
||||
return multipIdField, nil
|
||||
}
|
||||
|
||||
func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputFieldIds []int64) (*rerankInputs, error) {
|
||||
func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputFieldIds []int64, isGrouping bool) (*rerankInputs, error) {
|
||||
if len(multipSearchResultData) == 0 {
|
||||
return &rerankInputs{}, nil
|
||||
}
|
||||
@ -84,27 +86,34 @@ func newRerankInputs(multipSearchResultData []*schemapb.SearchResultData, inputF
|
||||
cols[i] = make([]*columns, len(multipSearchResultData))
|
||||
}
|
||||
for retIdx, searchResult := range multipSearchResultData {
|
||||
for _, fieldId := range inputFieldIds {
|
||||
fieldData := multipIdField[retIdx][fieldId]
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
size := searchResult.Topks[i]
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
size := searchResult.Topks[i]
|
||||
if cols[i][retIdx] == nil {
|
||||
cols[i][retIdx] = &columns{}
|
||||
cols[i][retIdx].size = size
|
||||
cols[i][retIdx].ids = getIds(searchResult.Ids, start, size)
|
||||
cols[i][retIdx].scores = searchResult.Scores[start : start+size]
|
||||
}
|
||||
for _, fieldId := range inputFieldIds {
|
||||
fieldData := multipIdField[retIdx][fieldId]
|
||||
d, err := getField(fieldData, start, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cols[i][retIdx] == nil {
|
||||
cols[i][retIdx] = &columns{}
|
||||
cols[i][retIdx].size = size
|
||||
cols[i][retIdx].ids = getIds(searchResult.Ids, start, size)
|
||||
cols[i][retIdx].scores = searchResult.Scores[start : start+size]
|
||||
}
|
||||
cols[i][retIdx].data = append(cols[i][retIdx].data, d)
|
||||
start += size
|
||||
}
|
||||
start += size
|
||||
}
|
||||
}
|
||||
return &rerankInputs{cols, nq, inputFieldIds}, nil
|
||||
if isGrouping {
|
||||
idGroup, err := genIdGroupingMap(multipSearchResultData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rerankInputs{cols, idGroup, nq, inputFieldIds}, nil
|
||||
}
|
||||
return &rerankInputs{cols, nil, nq, inputFieldIds}, nil
|
||||
}
|
||||
|
||||
func (inputs *rerankInputs) numOfQueries() int64 {
|
||||
@ -116,9 +125,13 @@ type rerankOutputs struct {
|
||||
}
|
||||
|
||||
func newRerankOutputs(searchParams *SearchParams) *rerankOutputs {
|
||||
topk := searchParams.limit
|
||||
if searchParams.isGrouping() {
|
||||
topk = topk * searchParams.groupSize
|
||||
}
|
||||
ret := &schemapb.SearchResultData{
|
||||
NumQueries: searchParams.nq,
|
||||
TopK: searchParams.limit,
|
||||
TopK: topk,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
@ -153,28 +166,11 @@ func appendResult[T PKType](outputs *rerankOutputs, ids []T, scores []float32) {
|
||||
}
|
||||
|
||||
type IDScores[T PKType] struct {
|
||||
// idScores map[T]float32
|
||||
ids []T
|
||||
scores []float32
|
||||
size int64
|
||||
}
|
||||
|
||||
// func (s *IDScores[T]) GetSortedIdScores() ([]T, []float32) {
|
||||
// ids := make([]T, 0, s.size)
|
||||
// big := func(i, j int) bool {
|
||||
// if s.idScores[ids[i]] == s.idScores[ids[j]] {
|
||||
// return ids[i] < ids[j]
|
||||
// }
|
||||
// return s.idScores[ids[i]] > s.idScores[ids[j]]
|
||||
// }
|
||||
// sort.Slice(ids, big)
|
||||
// scores := make([]float32, 0, s.size)
|
||||
// for _, id := range ids {
|
||||
// scores = append(scores, s.idScores[id])
|
||||
// }
|
||||
// return ids, scores
|
||||
// }
|
||||
|
||||
func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *IDScores[T] {
|
||||
ids := make([]T, 0, len(idScores))
|
||||
for id := range idScores {
|
||||
@ -209,6 +205,120 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *
|
||||
return &ret
|
||||
}
|
||||
|
||||
func genIDGroupValueMap[T PKType]() map[T]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
func groupScore[T PKType](group *Group[T], scorerType string) (float32, error) {
|
||||
switch scorerType {
|
||||
case maxScorer:
|
||||
return group.maxScore, nil
|
||||
case sumScorer:
|
||||
return group.sumScore, nil
|
||||
case avgScorer:
|
||||
if len(group.idList) == 0 {
|
||||
return 0, merr.WrapErrParameterInvalid(1, len(group.idList),
|
||||
"input group for score must have at least one id, must be sth wrong within code")
|
||||
}
|
||||
return group.sumScore / float32(len(group.idList)), nil
|
||||
default:
|
||||
return 0, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
|
||||
}
|
||||
}
|
||||
|
||||
type Group[T PKType] struct {
|
||||
idList []T
|
||||
scoreList []float32
|
||||
groupVal any
|
||||
maxScore float32
|
||||
sumScore float32
|
||||
finalScore float32
|
||||
}
|
||||
|
||||
func newGroupingIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams, idGroup map[any]any) (*IDScores[T], error) {
|
||||
ids := make([]T, 0, len(idScores))
|
||||
for id := range idScores {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
sort.Slice(ids, func(i, j int) bool {
|
||||
if idScores[ids[i]] == idScores[ids[j]] {
|
||||
return ids[i] < ids[j]
|
||||
}
|
||||
return idScores[ids[i]] > idScores[ids[j]]
|
||||
})
|
||||
|
||||
buckets := make(map[interface{}]*Group[T])
|
||||
for _, id := range ids {
|
||||
score := idScores[id]
|
||||
groupVal := idGroup[id]
|
||||
if buckets[groupVal] == nil {
|
||||
buckets[groupVal] = &Group[T]{
|
||||
idList: make([]T, 0),
|
||||
scoreList: make([]float32, 0),
|
||||
groupVal: groupVal,
|
||||
}
|
||||
}
|
||||
if int64(len(buckets[groupVal].idList)) >= searchParams.groupSize {
|
||||
continue
|
||||
}
|
||||
buckets[groupVal].idList = append(buckets[groupVal].idList, id)
|
||||
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, idScores[id])
|
||||
if score > buckets[groupVal].maxScore {
|
||||
buckets[groupVal].maxScore = score
|
||||
}
|
||||
buckets[groupVal].sumScore += score
|
||||
}
|
||||
|
||||
groupList := make([]*Group[T], len(buckets))
|
||||
idx := 0
|
||||
var err error
|
||||
for _, group := range buckets {
|
||||
if group.finalScore, err = groupScore(group, searchParams.groupScore); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupList[idx] = group
|
||||
idx += 1
|
||||
}
|
||||
sort.Slice(groupList, func(i, j int) bool {
|
||||
if groupList[i].finalScore == groupList[j].finalScore {
|
||||
if len(groupList[i].idList) == len(groupList[j].idList) {
|
||||
// if final score and size of group are both equal
|
||||
// choose the group with smaller first key
|
||||
// here, it's guaranteed all group having at least one id in the idList
|
||||
return groupList[i].idList[0] < groupList[j].idList[0]
|
||||
}
|
||||
// choose the larger group when scores are equal
|
||||
return len(groupList[i].idList) > len(groupList[j].idList)
|
||||
}
|
||||
return groupList[i].finalScore > groupList[j].finalScore
|
||||
})
|
||||
|
||||
if int64(len(groupList)) > searchParams.limit+searchParams.offset {
|
||||
groupList = groupList[:searchParams.limit+searchParams.offset]
|
||||
}
|
||||
|
||||
ret := IDScores[T]{
|
||||
make([]T, 0, searchParams.limit),
|
||||
make([]float32, 0, searchParams.limit),
|
||||
0,
|
||||
}
|
||||
for index := int(searchParams.offset); index < len(groupList); index++ {
|
||||
group := groupList[index]
|
||||
for i, score := range group.scoreList {
|
||||
// idList and scoreList must have same length
|
||||
if searchParams.roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(searchParams.roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.scores = append(ret.scores, score)
|
||||
ret.ids = append(ret.ids, group.idList[i])
|
||||
}
|
||||
}
|
||||
ret.size = int64(len(ret.ids))
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func getField(inputField *schemapb.FieldData, start int64, size int64) (any, error) {
|
||||
switch inputField.Type {
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
@ -299,3 +409,22 @@ func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error)
|
||||
}
|
||||
return pkType, nil
|
||||
}
|
||||
|
||||
func genIdGroupingMap(multipSearchResultData []*schemapb.SearchResultData) (map[any]any, error) {
|
||||
idGroupValue := map[any]any{}
|
||||
for _, result := range multipSearchResultData {
|
||||
if result.GetGroupByFieldValue() == nil {
|
||||
return nil, fmt.Errorf("Group value is nil")
|
||||
}
|
||||
size := typeutil.GetSizeOfIDs(result.Ids)
|
||||
groupIter := typeutil.GetDataIterator(result.GetGroupByFieldValue())
|
||||
for i := 0; i < size; i++ {
|
||||
groupByVal := groupIter(i)
|
||||
id := typeutil.GetPK(result.Ids, int64(i))
|
||||
if _, exist := idGroupValue[id]; !exist {
|
||||
idGroupValue[id] = groupByVal
|
||||
}
|
||||
}
|
||||
}
|
||||
return idGroupValue, nil
|
||||
}
|
||||
|
||||
154
internal/util/function/rerank/weighted_function.go
Normal file
154
internal/util/function/rerank/weighted_function.go
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
const (
|
||||
WeightsParamsKey string = "weights"
|
||||
NormScoreKey string = "norm_score"
|
||||
)
|
||||
|
||||
type WeightedFunction[T PKType] struct {
|
||||
RerankBase
|
||||
|
||||
weight []float32
|
||||
needNorm bool
|
||||
}
|
||||
|
||||
func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, weightedName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(base.GetInputFieldNames()) != 0 {
|
||||
return nil, fmt.Errorf("The weighted function does not support input parameters, but got %s", base.GetInputFieldNames())
|
||||
}
|
||||
|
||||
var weights []float32
|
||||
needNorm := false
|
||||
for _, param := range funcSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case WeightsParamsKey:
|
||||
if err := json.Unmarshal([]byte(param.Value), &weights); err != nil {
|
||||
return nil, fmt.Errorf("Parse %s param failed, weight should be []float, bug got: %s", WeightsParamsKey, param.Value)
|
||||
}
|
||||
for _, weight := range weights {
|
||||
if weight < 0 || weight > 1 {
|
||||
return nil, fmt.Errorf("rank param weight should be in range [0, 1]")
|
||||
}
|
||||
}
|
||||
case NormScoreKey:
|
||||
if needNorm, err = strconv.ParseBool(param.Value); err != nil {
|
||||
return nil, fmt.Errorf("%s params must be true/false, bug got %s", NormScoreKey, param.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(weights) == 0 {
|
||||
return nil, fmt.Errorf(WeightsParamsKey + " not found")
|
||||
}
|
||||
if base.pkType == schemapb.DataType_Int64 {
|
||||
return &WeightedFunction[int64]{RerankBase: *base, weight: weights, needNorm: needNorm}, nil
|
||||
} else {
|
||||
return &WeightedFunction[string]{RerankBase: *base, weight: weights, needNorm: needNorm}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
|
||||
if len(cols) != len(weighted.weight) {
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(cols)), fmt.Sprint(len(weighted.weight)), "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
weightedScores := map[T]float32{}
|
||||
for i, col := range cols {
|
||||
if col.size == 0 {
|
||||
continue
|
||||
}
|
||||
normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i])
|
||||
ids := col.ids.([]T)
|
||||
for j, id := range ids {
|
||||
if score, ok := weightedScores[id]; !ok {
|
||||
weightedScores[id] = weighted.weight[i] * normFunc(col.scores[j])
|
||||
} else {
|
||||
weightedScores[id] = score + weighted.weight[i]*normFunc(col.scores[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(weightedScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(weightedScores, searchParams), nil
|
||||
}
|
||||
|
||||
func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
outputs := newRerankOutputs(searchParams)
|
||||
for _, cols := range inputs.data {
|
||||
for i, col := range cols {
|
||||
metricType := searchParams.searchMetrics[i]
|
||||
for j, score := range col.scores {
|
||||
col.scores[j] = toGreaterScore(score, metricType)
|
||||
}
|
||||
}
|
||||
idScore, err := weighted.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appendResult(outputs, idScore.ids, idScore.scores)
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
type normalizeFunc func(float32) float32
|
||||
|
||||
func getNormalizeFunc(normScore bool, metrics string) normalizeFunc {
|
||||
if !normScore {
|
||||
return func(distance float32) float32 {
|
||||
return distance
|
||||
}
|
||||
}
|
||||
switch metrics {
|
||||
case metric.COSINE:
|
||||
return func(distance float32) float32 {
|
||||
return (1 + distance) * 0.5
|
||||
}
|
||||
case metric.IP:
|
||||
return func(distance float32) float32 {
|
||||
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
case metric.BM25:
|
||||
return func(distance float32) float32 {
|
||||
return 2 * float32(math.Atan(float64(distance))) / math.Pi
|
||||
}
|
||||
default:
|
||||
return func(distance float32) float32 {
|
||||
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
}
|
||||
}
|
||||
298
internal/util/function/rerank/weighted_function_test.go
Normal file
298
internal/util/function/rerank/weighted_function_test.go
Normal file
@ -0,0 +1,298 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
func TestWeightedFunction(t *testing.T) {
|
||||
suite.Run(t, new(WeightedFunctionSuite))
|
||||
}
|
||||
|
||||
type WeightedFunctionSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *WeightedFunctionSuite) TestNewWeightedFuction() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: WeightsParamsKey, Value: `[0.1, 0.9]`},
|
||||
},
|
||||
}
|
||||
|
||||
{
|
||||
_, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
|
||||
_, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: "NotNum"}
|
||||
_, err := newWeightedFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "param failed, weight should be []float")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[10]`}
|
||||
_, err := newWeightedFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "rank param weight should be in range [0, 1]")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `[10]`}
|
||||
_, err := newWeightedFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "not found")
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[0.1, 0.9]`}
|
||||
}
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts"}
|
||||
_, err := newWeightedFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "The weighted function does not support input parameters,")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WeightedFunctionSuite) TestWeightedFuctionProcess() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: WeightsParamsKey, Value: `[0.1]`},
|
||||
},
|
||||
}
|
||||
|
||||
// empty
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{}, ret.searchResultData.Topks)
|
||||
}
|
||||
|
||||
// singleSearchResultData
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{7, 6, 5}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
// nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{9, 8, 7, 19, 18, 17, 29, 28, 27}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
|
||||
// number of weigts not equal to search data
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: WeightsParamsKey, Value: `[0.1, 0.9]`}
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
||||
_, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.ErrorContains(err, "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
// has empty inputs
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// id data: 0 - 9
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
// empty
|
||||
data2 := function.GenSearchResultData(nq, 0, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{9, 8, 7}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
s.True(function.FloatsAlmostEqual([]float32{0.9, 0.8, 0.7}, ret.searchResultData.Scores, 0.001))
|
||||
}
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// id data: 0 - 9
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
// id data: 0 - 3
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{1, 9, 8}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
s.True(function.FloatsAlmostEqual([]float32{1, 0.9, 0.8}, ret.searchResultData.Scores, 0.001))
|
||||
}
|
||||
// // nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// nq1 id data: 0 - 9
|
||||
// nq2 id data: 10 - 19
|
||||
// nq3 id data: 20 - 29
|
||||
data1 := function.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "", 0)
|
||||
// nq1 id data: 0 - 3
|
||||
// nq2 id data: 4 - 7
|
||||
// nq3 id data: 8 - 11
|
||||
data2 := function.GenSearchResultData(nq, 4, schemapb.DataType_Int64, "", 0)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, 1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{3, 2, 1, 7, 6, 5, 11, 10, 9}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
s.True(function.FloatsAlmostEqual([]float32{3, 2, 1, 6.3, 5.4, 4.5, 9.9, 9, 8.1}, ret.searchResultData.Scores, 0.001))
|
||||
}
|
||||
// // nq = 3, grouping = true, grouping size = 1
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// nq1 id data: 0 - 9
|
||||
// nq2 id data: 10 - 19
|
||||
// nq3 id data: 20 - 29
|
||||
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
|
||||
// nq1 id data: 0 - 3
|
||||
// nq2 id data: 4 - 7
|
||||
// nq3 id data: 8 - 11
|
||||
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 1)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, 1, 102, 1, true, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal([]int64{3, 2, 1, 7, 6, 5, 11, 10, 9}, ret.searchResultData.Ids.GetIntId().Data)
|
||||
s.True(function.FloatsAlmostEqual([]float32{3, 2, 1, 6.3, 5.4, 4.5, 9.9, 9, 8.1}, ret.searchResultData.Scores, 0.001))
|
||||
}
|
||||
|
||||
// // nq = 3, grouping = true, grouping size = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
f, err := newWeightedFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
|
||||
// nq1 id data: 0 - 29, group value: 0,0,0,1,1,1, ... , 9,9,9
|
||||
// nq2 id data: 30 - 59, group value: 10,10,10,11,11,11, ... , 19,19,19
|
||||
// nq3 id data: 60 - 99, group value: 20,20,20,21,21,21, ... , 29,29,29
|
||||
data1 := function.GenSearchResultDataWithGrouping(nq, 10, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
|
||||
// nq1 id data: 0 - 11, group value: 0,0,0,1,1,1,2,2,2,3,3,3,
|
||||
// nq2 id data: 12 - 23, group value: 4,4,4,5,5,5,6,6,6,7,7,7
|
||||
// nq3 id data: 24 - 35, group value: 8,8,8,9,9,9,10,10,10,11,11,11
|
||||
data2 := function.GenSearchResultDataWithGrouping(nq, 4, schemapb.DataType_Int64, "", 0, "ts", 102, 3)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), true)
|
||||
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, 102, 3, true, "", []string{"COSINE", "COSINE"}), inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{9, 9, 9}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(9), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{
|
||||
5, 4, 3, 29, 28, 27, 26, 25, 24,
|
||||
17, 16, 15, 14, 13, 12, 59, 58, 57,
|
||||
29, 28, 27, 26, 25, 24, 89, 88, 87,
|
||||
},
|
||||
ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WeightedFunctionSuite) TestWeightedFuctionNormalize() {
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.COSINE)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.COSINE)
|
||||
s.Equal(float32((1+1.0)*0.5), f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.IP)
|
||||
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.BM25)
|
||||
s.Equal(float32(2*math.Atan(float64(1.0)))/math.Pi, f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.L2)
|
||||
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
|
||||
}
|
||||
}
|
||||
@ -213,6 +213,7 @@ func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() {
|
||||
for _, qn := range s.Cluster.GetAllQueryNodes() {
|
||||
qn.Stop()
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
s.Cluster.AddQueryNode()
|
||||
|
||||
time.Sleep(20 * time.Second)
|
||||
|
||||
@ -1588,11 +1588,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
|
||||
}
|
||||
)
|
||||
vectors_to_search = rng.random((1, dim))
|
||||
error = {ct.err_code: 1100,
|
||||
ct.err_msg: f"Current rerank does not support grouping search: invalid parameter"}
|
||||
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
|
||||
group_by_field=ct.default_reranker_field_name,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, group_by_field=ct.default_reranker_field_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_search_with_reranker_on_dynamic_fields(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user