mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: optimize decay function with configurable score merging and … (#44066)
…normalization - Add configurable score merge functions (max, avg, sum) for decay reranking - Introduce norm_score parameter to control score normalization behavior - Refactor score normalization logic into reusable utility functions #44051 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
1b7562a766
commit
71563d5d0e
@ -26,7 +26,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -35,6 +34,9 @@ const (
|
||||
offsetKey string = "offset"
|
||||
decayKey string = "decay"
|
||||
functionKey string = "function"
|
||||
|
||||
normsScorekey string = "norm_score"
|
||||
scoreMode string = "score_mode"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -51,6 +53,8 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
|
||||
scale float64
|
||||
offset float64
|
||||
decay float64
|
||||
needNorm bool
|
||||
scoreFunc scoreMergeFunc[T]
|
||||
reScorer decayReScorer
|
||||
}
|
||||
|
||||
@ -97,7 +101,7 @@ func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemap
|
||||
// T: PK Type, R: field type
|
||||
func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
var err error
|
||||
decayFunc := &DecayFunction[T, R]{RerankBase: *base, offset: 0, decay: 0.5}
|
||||
decayFunc := &DecayFunction[T, R]{RerankBase: *base, offset: 0, decay: 0.5, needNorm: false, scoreFunc: maxMerge[T]}
|
||||
orginInit := false
|
||||
scaleInit := false
|
||||
for _, param := range funcSchema.Params {
|
||||
@ -122,6 +126,18 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
|
||||
if decayFunc.decay, err = strconv.ParseFloat(param.Value, 64); err != nil {
|
||||
return nil, fmt.Errorf("Param decay:%s is not a number", param.Value)
|
||||
}
|
||||
case normsScorekey:
|
||||
if needNorm, err := strconv.ParseBool(param.Value); err != nil {
|
||||
return nil, fmt.Errorf("%s params must be true/false, bug got %s", normsScorekey, param.Value)
|
||||
} else {
|
||||
decayFunc.needNorm = needNorm
|
||||
}
|
||||
case scoreMode:
|
||||
if f, err := getMergeFunc[T](param.Value); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
decayFunc.scoreFunc = f
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
@ -159,17 +175,8 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
|
||||
return decayFunc, nil
|
||||
}
|
||||
|
||||
func toGreaterScore(score float32, metricType string) float32 {
|
||||
switch strings.ToUpper(metricType) {
|
||||
case metric.COSINE, metric.IP, metric.BM25:
|
||||
return score
|
||||
default:
|
||||
return 1.0 - 2*float32(math.Atan(float64(score)))/math.Pi
|
||||
}
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
|
||||
srcScores := maxMerge[T](cols)
|
||||
srcScores := decay.scoreFunc(cols)
|
||||
decayScores := map[T]float32{}
|
||||
for _, col := range cols {
|
||||
if col.size == 0 {
|
||||
@ -189,16 +196,16 @@ func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, sear
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(decayScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(decayScores, searchParams), nil
|
||||
return newIDScores(decayScores, searchParams, true), nil
|
||||
}
|
||||
|
||||
func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
outputs := newRerankOutputs(searchParams)
|
||||
for _, cols := range inputs.data {
|
||||
for i, col := range cols {
|
||||
metricType := searchParams.searchMetrics[i]
|
||||
normFunc := getNormalizeFunc(decay.needNorm, searchParams.searchMetrics[i], true)
|
||||
for j, score := range col.scores {
|
||||
col.scores[j] = toGreaterScore(score, metricType)
|
||||
col.scores[j] = normFunc(score)
|
||||
}
|
||||
}
|
||||
idScore, err := decay.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
|
||||
|
||||
@ -73,6 +73,31 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Rerank function output field names should be empty")
|
||||
}
|
||||
{
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
OutputFieldNames: []string{},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
{Key: decayKey, Value: "0.5"},
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
{Key: normsScorekey, Value: "unknow"},
|
||||
{Key: scoreMode, Value: "avg"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "must be true/false")
|
||||
|
||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normsScorekey, Value: "true"}
|
||||
functionSchema.Params[6] = &commonpb.KeyValuePair{Key: scoreMode, Value: "unknow"}
|
||||
_, err = newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Unsupport score mode")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.OutputFieldNames = []string{}
|
||||
|
||||
@ -180,7 +180,7 @@ func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchP
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(rerankScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(rerankScores, searchParams), nil
|
||||
return newIDScores(rerankScores, searchParams, true), nil
|
||||
}
|
||||
|
||||
func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
|
||||
@ -85,7 +85,7 @@ func (rrf *RRFFunction[T]) processOneSearchData(ctx context.Context, searchParam
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(rrfScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(rrfScores, searchParams), nil
|
||||
return newIDScores(rrfScores, searchParams, true), nil
|
||||
}
|
||||
|
||||
func (rrf *RRFFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
|
||||
@ -22,9 +22,11 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -174,7 +176,7 @@ type IDScores[T PKType] struct {
|
||||
size int64
|
||||
}
|
||||
|
||||
func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *IDScores[T] {
|
||||
func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams, descendingOrder bool) *IDScores[T] {
|
||||
ids := make([]T, 0, len(idScores))
|
||||
for id := range idScores {
|
||||
ids = append(ids, id)
|
||||
@ -184,7 +186,11 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *
|
||||
if idScores[ids[i]] == idScores[ids[j]] {
|
||||
return ids[i] < ids[j]
|
||||
}
|
||||
return idScores[ids[i]] > idScores[ids[j]]
|
||||
if descendingOrder {
|
||||
return idScores[ids[i]] > idScores[ids[j]]
|
||||
} else {
|
||||
return idScores[ids[i]] < idScores[ids[j]]
|
||||
}
|
||||
})
|
||||
topk := searchParams.offset + searchParams.limit
|
||||
if int64(len(ids)) > topk {
|
||||
@ -208,10 +214,6 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *
|
||||
return &ret
|
||||
}
|
||||
|
||||
func genIDGroupValueMap[T PKType]() map[T]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
func groupScore[T PKType](group *Group[T], scorerType string) (float32, error) {
|
||||
switch scorerType {
|
||||
case maxScorer:
|
||||
@ -383,6 +385,21 @@ func getIds(ids *schemapb.IDs, start int64, size int64) any {
|
||||
return nil
|
||||
}
|
||||
|
||||
type scoreMergeFunc[T PKType] func(cols []*columns) map[T]float32
|
||||
|
||||
func getMergeFunc[T PKType](name string) (scoreMergeFunc[T], error) {
|
||||
switch strings.ToLower(name) {
|
||||
case "max":
|
||||
return maxMerge[T], nil
|
||||
case "avg":
|
||||
return avgMerge[T], nil
|
||||
case "sum":
|
||||
return sumMerge[T], nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupport score mode: [%s], only supports: [max, avg, sum]", name)
|
||||
}
|
||||
}
|
||||
|
||||
func maxMerge[T PKType](cols []*columns) map[T]float32 {
|
||||
srcScores := make(map[T]float32)
|
||||
|
||||
@ -404,6 +421,54 @@ func maxMerge[T PKType](cols []*columns) map[T]float32 {
|
||||
return srcScores
|
||||
}
|
||||
|
||||
func avgMerge[T PKType](cols []*columns) map[T]float32 {
|
||||
srcScores := make(map[T]*typeutil.Pair[float32, int32])
|
||||
|
||||
for _, col := range cols {
|
||||
if col.size == 0 {
|
||||
continue
|
||||
}
|
||||
scores := col.scores
|
||||
ids := col.ids.([]T)
|
||||
|
||||
for idx, id := range ids {
|
||||
if _, ok := srcScores[id]; !ok {
|
||||
p := typeutil.NewPair[float32, int32](scores[idx], 1)
|
||||
srcScores[id] = &p
|
||||
} else {
|
||||
srcScores[id].A += scores[idx]
|
||||
srcScores[id].B += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
retScores := make(map[T]float32, len(srcScores))
|
||||
for id, item := range srcScores {
|
||||
retScores[id] = item.A / float32(item.B)
|
||||
}
|
||||
return retScores
|
||||
}
|
||||
|
||||
func sumMerge[T PKType](cols []*columns) map[T]float32 {
|
||||
srcScores := make(map[T]float32)
|
||||
|
||||
for _, col := range cols {
|
||||
if col.size == 0 {
|
||||
continue
|
||||
}
|
||||
scores := col.scores
|
||||
ids := col.ids.([]T)
|
||||
|
||||
for idx, id := range ids {
|
||||
if _, ok := srcScores[id]; !ok {
|
||||
srcScores[id] = scores[idx]
|
||||
} else {
|
||||
srcScores[id] += scores[idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
return srcScores
|
||||
}
|
||||
|
||||
func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error) {
|
||||
pkType := schemapb.DataType_None
|
||||
for _, field := range collSchema.Fields {
|
||||
@ -436,3 +501,75 @@ func genIdGroupingMap(multipSearchResultData []*schemapb.SearchResultData) (map[
|
||||
}
|
||||
return idGroupValue, nil
|
||||
}
|
||||
|
||||
type normalizeFunc func(float32) float32
|
||||
|
||||
func getNormalizeFunc(normScore bool, metrics string, toGreater bool) normalizeFunc {
|
||||
if !normScore {
|
||||
if !toGreater {
|
||||
return func(distance float32) float32 {
|
||||
return distance
|
||||
}
|
||||
}
|
||||
switch strings.ToUpper(metrics) {
|
||||
case metric.COSINE, metric.IP, metric.BM25:
|
||||
return func(distance float32) float32 {
|
||||
return distance
|
||||
}
|
||||
default:
|
||||
return func(distance float32) float32 {
|
||||
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
}
|
||||
}
|
||||
switch strings.ToUpper(metrics) {
|
||||
case metric.COSINE:
|
||||
return func(distance float32) float32 {
|
||||
return (1 + distance) * 0.5
|
||||
}
|
||||
case metric.IP:
|
||||
return func(distance float32) float32 {
|
||||
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
case metric.BM25:
|
||||
return func(distance float32) float32 {
|
||||
return 2 * float32(math.Atan(float64(distance))) / math.Pi
|
||||
}
|
||||
default:
|
||||
return func(distance float32) float32 {
|
||||
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// analyzeMetricsType inspects the given metrics and determines
|
||||
// whether they contain mixed types and what the sorting order should be.
|
||||
//
|
||||
// Parameters:
|
||||
//
|
||||
// metrics - A list of metric names (e.g., COSINE, IP, BM25).
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// mixed - true if the input contains both "larger-is-more-similar"
|
||||
// and "smaller-is-more-similar" metrics; false otherwise.
|
||||
// sortDescending - true if results should be sorted in descending order
|
||||
// (larger value = more similar, e.g., COSINE, IP, BM25);
|
||||
// false if results should be sorted in ascending order
|
||||
// (smaller value = more similar, e.g., L2 distance).
|
||||
func classifyMetricsOrder(metrics []string) (mixed bool, sortDescending bool) {
|
||||
countLargerIsBetter := 0 // Larger value = more similar
|
||||
countSmallerIsBetter := 0 // Smaller value = more similar
|
||||
for _, m := range metrics {
|
||||
switch strings.ToUpper(m) {
|
||||
case metric.COSINE, metric.IP, metric.BM25:
|
||||
countLargerIsBetter++
|
||||
default:
|
||||
countSmallerIsBetter++
|
||||
}
|
||||
}
|
||||
if countLargerIsBetter > 0 && countSmallerIsBetter > 0 {
|
||||
return true, true
|
||||
}
|
||||
return false, countSmallerIsBetter == 0
|
||||
}
|
||||
|
||||
204
internal/util/function/rerank/util_test.go
Normal file
204
internal/util/function/rerank/util_test.go
Normal file
@ -0,0 +1,204 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
func TestUtil(t *testing.T) {
|
||||
suite.Run(t, new(UtilSuite))
|
||||
}
|
||||
|
||||
type UtilSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func mockCols(num int) []*columns {
|
||||
cols := []*columns{}
|
||||
for i := 0; i < num; i++ {
|
||||
c := columns{
|
||||
size: 10,
|
||||
ids: []int64{1, 2, 3, 4},
|
||||
scores: []float32{1.0 + float32(i), 2.0 + float32(i), 3.0 + float32(i), 4.0 + float32(i)},
|
||||
}
|
||||
cols = append(cols, &c)
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (s *UtilSuite) TestScoreMode() {
|
||||
{
|
||||
_, err := getMergeFunc[int64]("test")
|
||||
s.ErrorContains(err, "Unsupport score mode")
|
||||
}
|
||||
{
|
||||
f, err := getMergeFunc[int64]("avg")
|
||||
s.NoError(err)
|
||||
cols := mockCols(0)
|
||||
r := f(cols)
|
||||
s.Equal(0, len(r))
|
||||
cols = mockCols(1)
|
||||
r = f(cols)
|
||||
s.Equal(r, map[int64]float32{1: 1.0, 2: 2.0, 3: 3.0, 4: 4.0})
|
||||
cols = mockCols(3)
|
||||
r = f(cols)
|
||||
s.Equal(r, map[int64]float32{1: 2.0, 2: 3.0, 3: 4.0, 4: 5.0})
|
||||
}
|
||||
{
|
||||
f, err := getMergeFunc[int64]("max")
|
||||
s.NoError(err)
|
||||
cols := mockCols(0)
|
||||
r := f(cols)
|
||||
s.Equal(0, len(r))
|
||||
cols = mockCols(1)
|
||||
r = f(cols)
|
||||
s.Equal(r, map[int64]float32{1: 1.0, 2: 2.0, 3: 3.0, 4: 4.0})
|
||||
cols = mockCols(3)
|
||||
r = f(cols)
|
||||
s.Equal(r, map[int64]float32{1: 3.0, 2: 4.0, 3: 5.0, 4: 6.0})
|
||||
}
|
||||
{
|
||||
f, err := getMergeFunc[int64]("sum")
|
||||
s.NoError(err)
|
||||
cols := mockCols(0)
|
||||
r := f(cols)
|
||||
s.Equal(0, len(r))
|
||||
cols = mockCols(1)
|
||||
r = f(cols)
|
||||
s.Equal(r, map[int64]float32{1: 1.0, 2: 2.0, 3: 3.0, 4: 4.0})
|
||||
cols = mockCols(3)
|
||||
r = f(cols)
|
||||
s.Equal(r, map[int64]float32{1: 6.0, 2: 9.0, 3: 12.0, 4: 15.0})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UtilSuite) TestFuctionNormalize() {
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.COSINE, false)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.COSINE, true)
|
||||
s.Equal(float32((1+1.0)*0.5), f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.COSINE, true)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.COSINE, false)
|
||||
s.Equal(float32((1+1.0)*0.5), f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.IP, true)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.IP, false)
|
||||
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.IP, true)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.IP, false)
|
||||
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.BM25, false)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.BM25, false)
|
||||
s.Equal(2*float32(math.Atan(float64(1.0)))/math.Pi, f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.BM25, true)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.BM25, true)
|
||||
s.Equal(2*float32(math.Atan(float64(1.0)))/math.Pi, f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.L2, true)
|
||||
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.L2, true)
|
||||
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.L2, false)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.L2, false)
|
||||
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UtilSuite) TestIsCrossMetrics() {
|
||||
{
|
||||
metrics := []string{metric.BM25}
|
||||
mixed, descending := classifyMetricsOrder(metrics)
|
||||
s.False(mixed)
|
||||
s.True(descending)
|
||||
}
|
||||
{
|
||||
metrics := []string{metric.BM25, metric.COSINE, metric.IP}
|
||||
mixed, descending := classifyMetricsOrder(metrics)
|
||||
s.False(mixed)
|
||||
s.True(descending)
|
||||
}
|
||||
{
|
||||
metrics := []string{metric.L2}
|
||||
mixed, descending := classifyMetricsOrder(metrics)
|
||||
s.False(mixed)
|
||||
s.False(descending)
|
||||
}
|
||||
{
|
||||
metrics := []string{metric.L2, metric.BM25}
|
||||
mixed, descending := classifyMetricsOrder(metrics)
|
||||
s.True(mixed)
|
||||
s.True(descending)
|
||||
}
|
||||
{
|
||||
metrics := []string{metric.L2, metric.COSINE}
|
||||
mixed, descending := classifyMetricsOrder(metrics)
|
||||
s.True(mixed)
|
||||
s.True(descending)
|
||||
}
|
||||
{
|
||||
metrics := []string{metric.L2, metric.IP}
|
||||
mixed, descending := classifyMetricsOrder(metrics)
|
||||
s.True(mixed)
|
||||
s.True(descending)
|
||||
}
|
||||
}
|
||||
@ -22,13 +22,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -87,11 +85,14 @@ func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, s
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(cols)), fmt.Sprint(len(weighted.weight)), "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
weightedScores := map[T]float32{}
|
||||
isMixd, descendingOrder := classifyMetricsOrder(searchParams.searchMetrics)
|
||||
for i, col := range cols {
|
||||
if col.size == 0 {
|
||||
continue
|
||||
}
|
||||
normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i])
|
||||
// If it is a mixed metric (L2 + IP), with both large to small sorting and small to large sorting,
|
||||
// force the small to large sorting scores to be converted to large to small sorting
|
||||
normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i], isMixd)
|
||||
ids := col.ids.([]T)
|
||||
for j, id := range ids {
|
||||
if score, ok := weightedScores[id]; !ok {
|
||||
@ -104,18 +105,13 @@ func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, s
|
||||
if searchParams.isGrouping() {
|
||||
return newGroupingIDScores(weightedScores, searchParams, idGroup)
|
||||
}
|
||||
return newIDScores(weightedScores, searchParams), nil
|
||||
// If normlize is set, the final result is sorted from largest to smallest, otherwise it depends on descendingOrder
|
||||
return newIDScores(weightedScores, searchParams, weighted.needNorm || descendingOrder), nil
|
||||
}
|
||||
|
||||
func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
outputs := newRerankOutputs(searchParams)
|
||||
for _, cols := range inputs.data {
|
||||
for i, col := range cols {
|
||||
metricType := searchParams.searchMetrics[i]
|
||||
for j, score := range col.scores {
|
||||
col.scores[j] = toGreaterScore(score, metricType)
|
||||
}
|
||||
}
|
||||
idScore, err := weighted.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -124,31 +120,3 @@ func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
type normalizeFunc func(float32) float32
|
||||
|
||||
func getNormalizeFunc(normScore bool, metrics string) normalizeFunc {
|
||||
if !normScore {
|
||||
return func(distance float32) float32 {
|
||||
return distance
|
||||
}
|
||||
}
|
||||
switch metrics {
|
||||
case metric.COSINE:
|
||||
return func(distance float32) float32 {
|
||||
return (1 + distance) * 0.5
|
||||
}
|
||||
case metric.IP:
|
||||
return func(distance float32) float32 {
|
||||
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
case metric.BM25:
|
||||
return func(distance float32) float32 {
|
||||
return 2 * float32(math.Atan(float64(distance))) / math.Pi
|
||||
}
|
||||
default:
|
||||
return func(distance float32) float32 {
|
||||
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,7 +20,6 @@ package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
@ -28,7 +27,6 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
)
|
||||
|
||||
func TestWeightedFunction(t *testing.T) {
|
||||
@ -273,26 +271,3 @@ func (s *WeightedFunctionSuite) TestWeightedFuctionProcess() {
|
||||
ret.searchResultData.Ids.GetIntId().Data)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WeightedFunctionSuite) TestWeightedFuctionNormalize() {
|
||||
{
|
||||
f := getNormalizeFunc(false, metric.COSINE)
|
||||
s.Equal(float32(1.0), f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.COSINE)
|
||||
s.Equal(float32((1+1.0)*0.5), f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.IP)
|
||||
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.BM25)
|
||||
s.Equal(float32(2*math.Atan(float64(1.0)))/math.Pi, f(1.0))
|
||||
}
|
||||
{
|
||||
f := getNormalizeFunc(true, metric.L2)
|
||||
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user