enhance: optimize decay function with configurable score merging and … (#44066)

…normalization

- Add configurable score merge functions (max, avg, sum) for decay
reranking
- Introduce norm_score parameter to control score normalization behavior
- Refactor score normalization logic into reusable utility functions

#44051

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-09-23 14:18:06 +08:00 committed by GitHub
parent 1b7562a766
commit 71563d5d0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 402 additions and 86 deletions

View File

@ -26,7 +26,6 @@ import (
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
const (
@ -35,6 +34,9 @@ const (
offsetKey string = "offset"
decayKey string = "decay"
functionKey string = "function"
normsScorekey string = "norm_score"
scoreMode string = "score_mode"
)
const (
@ -51,6 +53,8 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
scale float64
offset float64
decay float64
needNorm bool
scoreFunc scoreMergeFunc[T]
reScorer decayReScorer
}
@ -97,7 +101,7 @@ func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemap
// T: PK Type, R: field type
func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
var err error
decayFunc := &DecayFunction[T, R]{RerankBase: *base, offset: 0, decay: 0.5}
decayFunc := &DecayFunction[T, R]{RerankBase: *base, offset: 0, decay: 0.5, needNorm: false, scoreFunc: maxMerge[T]}
orginInit := false
scaleInit := false
for _, param := range funcSchema.Params {
@ -122,6 +126,18 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
if decayFunc.decay, err = strconv.ParseFloat(param.Value, 64); err != nil {
return nil, fmt.Errorf("Param decay:%s is not a number", param.Value)
}
case normsScorekey:
if needNorm, err := strconv.ParseBool(param.Value); err != nil {
return nil, fmt.Errorf("%s params must be true/false, bug got %s", normsScorekey, param.Value)
} else {
decayFunc.needNorm = needNorm
}
case scoreMode:
if f, err := getMergeFunc[T](param.Value); err != nil {
return nil, err
} else {
decayFunc.scoreFunc = f
}
default:
}
}
@ -159,17 +175,8 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
return decayFunc, nil
}
func toGreaterScore(score float32, metricType string) float32 {
switch strings.ToUpper(metricType) {
case metric.COSINE, metric.IP, metric.BM25:
return score
default:
return 1.0 - 2*float32(math.Atan(float64(score)))/math.Pi
}
}
func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, searchParams *SearchParams, cols []*columns, idGroup map[any]any) (*IDScores[T], error) {
srcScores := maxMerge[T](cols)
srcScores := decay.scoreFunc(cols)
decayScores := map[T]float32{}
for _, col := range cols {
if col.size == 0 {
@ -189,16 +196,16 @@ func (decay *DecayFunction[T, R]) processOneSearchData(ctx context.Context, sear
if searchParams.isGrouping() {
return newGroupingIDScores(decayScores, searchParams, idGroup)
}
return newIDScores(decayScores, searchParams), nil
return newIDScores(decayScores, searchParams, true), nil
}
func (decay *DecayFunction[T, R]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
outputs := newRerankOutputs(searchParams)
for _, cols := range inputs.data {
for i, col := range cols {
metricType := searchParams.searchMetrics[i]
normFunc := getNormalizeFunc(decay.needNorm, searchParams.searchMetrics[i], true)
for j, score := range col.scores {
col.scores[j] = toGreaterScore(score, metricType)
col.scores[j] = normFunc(score)
}
}
idScore, err := decay.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)

View File

@ -73,6 +73,31 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Rerank function output field names should be empty")
}
{
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"ts"},
OutputFieldNames: []string{},
Params: []*commonpb.KeyValuePair{
{Key: originKey, Value: "4"},
{Key: scaleKey, Value: "4"},
{Key: offsetKey, Value: "4"},
{Key: decayKey, Value: "0.5"},
{Key: functionKey, Value: "gauss"},
{Key: normsScorekey, Value: "unknow"},
{Key: scoreMode, Value: "avg"},
},
}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "must be true/false")
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normsScorekey, Value: "true"}
functionSchema.Params[6] = &commonpb.KeyValuePair{Key: scoreMode, Value: "unknow"}
_, err = newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Unsupport score mode")
}
{
functionSchema.OutputFieldNames = []string{}

View File

@ -180,7 +180,7 @@ func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchP
if searchParams.isGrouping() {
return newGroupingIDScores(rerankScores, searchParams, idGroup)
}
return newIDScores(rerankScores, searchParams), nil
return newIDScores(rerankScores, searchParams, true), nil
}
func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {

View File

@ -85,7 +85,7 @@ func (rrf *RRFFunction[T]) processOneSearchData(ctx context.Context, searchParam
if searchParams.isGrouping() {
return newGroupingIDScores(rrfScores, searchParams, idGroup)
}
return newIDScores(rrfScores, searchParams), nil
return newIDScores(rrfScores, searchParams, true), nil
}
func (rrf *RRFFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {

View File

@ -22,9 +22,11 @@ import (
"fmt"
"math"
"sort"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -174,7 +176,7 @@ type IDScores[T PKType] struct {
size int64
}
func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *IDScores[T] {
func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams, descendingOrder bool) *IDScores[T] {
ids := make([]T, 0, len(idScores))
for id := range idScores {
ids = append(ids, id)
@ -184,7 +186,11 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *
if idScores[ids[i]] == idScores[ids[j]] {
return ids[i] < ids[j]
}
return idScores[ids[i]] > idScores[ids[j]]
if descendingOrder {
return idScores[ids[i]] > idScores[ids[j]]
} else {
return idScores[ids[i]] < idScores[ids[j]]
}
})
topk := searchParams.offset + searchParams.limit
if int64(len(ids)) > topk {
@ -208,10 +214,6 @@ func newIDScores[T PKType](idScores map[T]float32, searchParams *SearchParams) *
return &ret
}
func genIDGroupValueMap[T PKType]() map[T]any {
return nil
}
func groupScore[T PKType](group *Group[T], scorerType string) (float32, error) {
switch scorerType {
case maxScorer:
@ -383,6 +385,21 @@ func getIds(ids *schemapb.IDs, start int64, size int64) any {
return nil
}
type scoreMergeFunc[T PKType] func(cols []*columns) map[T]float32
func getMergeFunc[T PKType](name string) (scoreMergeFunc[T], error) {
switch strings.ToLower(name) {
case "max":
return maxMerge[T], nil
case "avg":
return avgMerge[T], nil
case "sum":
return sumMerge[T], nil
default:
return nil, fmt.Errorf("Unsupport score mode: [%s], only supports: [max, avg, sum]", name)
}
}
func maxMerge[T PKType](cols []*columns) map[T]float32 {
srcScores := make(map[T]float32)
@ -404,6 +421,54 @@ func maxMerge[T PKType](cols []*columns) map[T]float32 {
return srcScores
}
func avgMerge[T PKType](cols []*columns) map[T]float32 {
srcScores := make(map[T]*typeutil.Pair[float32, int32])
for _, col := range cols {
if col.size == 0 {
continue
}
scores := col.scores
ids := col.ids.([]T)
for idx, id := range ids {
if _, ok := srcScores[id]; !ok {
p := typeutil.NewPair[float32, int32](scores[idx], 1)
srcScores[id] = &p
} else {
srcScores[id].A += scores[idx]
srcScores[id].B += 1
}
}
}
retScores := make(map[T]float32, len(srcScores))
for id, item := range srcScores {
retScores[id] = item.A / float32(item.B)
}
return retScores
}
func sumMerge[T PKType](cols []*columns) map[T]float32 {
srcScores := make(map[T]float32)
for _, col := range cols {
if col.size == 0 {
continue
}
scores := col.scores
ids := col.ids.([]T)
for idx, id := range ids {
if _, ok := srcScores[id]; !ok {
srcScores[id] = scores[idx]
} else {
srcScores[id] += scores[idx]
}
}
}
return srcScores
}
func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error) {
pkType := schemapb.DataType_None
for _, field := range collSchema.Fields {
@ -436,3 +501,75 @@ func genIdGroupingMap(multipSearchResultData []*schemapb.SearchResultData) (map[
}
return idGroupValue, nil
}
type normalizeFunc func(float32) float32
func getNormalizeFunc(normScore bool, metrics string, toGreater bool) normalizeFunc {
if !normScore {
if !toGreater {
return func(distance float32) float32 {
return distance
}
}
switch strings.ToUpper(metrics) {
case metric.COSINE, metric.IP, metric.BM25:
return func(distance float32) float32 {
return distance
}
default:
return func(distance float32) float32 {
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
}
}
}
switch strings.ToUpper(metrics) {
case metric.COSINE:
return func(distance float32) float32 {
return (1 + distance) * 0.5
}
case metric.IP:
return func(distance float32) float32 {
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
}
case metric.BM25:
return func(distance float32) float32 {
return 2 * float32(math.Atan(float64(distance))) / math.Pi
}
default:
return func(distance float32) float32 {
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
}
}
}
// analyzeMetricsType inspects the given metrics and determines
// whether they contain mixed types and what the sorting order should be.
//
// Parameters:
//
// metrics - A list of metric names (e.g., COSINE, IP, BM25).
//
// Returns:
//
// mixed - true if the input contains both "larger-is-more-similar"
// and "smaller-is-more-similar" metrics; false otherwise.
// sortDescending - true if results should be sorted in descending order
// (larger value = more similar, e.g., COSINE, IP, BM25);
// false if results should be sorted in ascending order
// (smaller value = more similar, e.g., L2 distance).
func classifyMetricsOrder(metrics []string) (mixed bool, sortDescending bool) {
countLargerIsBetter := 0 // Larger value = more similar
countSmallerIsBetter := 0 // Smaller value = more similar
for _, m := range metrics {
switch strings.ToUpper(m) {
case metric.COSINE, metric.IP, metric.BM25:
countLargerIsBetter++
default:
countSmallerIsBetter++
}
}
if countLargerIsBetter > 0 && countSmallerIsBetter > 0 {
return true, true
}
return false, countSmallerIsBetter == 0
}

View File

@ -0,0 +1,204 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"math"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
func TestUtil(t *testing.T) {
suite.Run(t, new(UtilSuite))
}
type UtilSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func mockCols(num int) []*columns {
cols := []*columns{}
for i := 0; i < num; i++ {
c := columns{
size: 10,
ids: []int64{1, 2, 3, 4},
scores: []float32{1.0 + float32(i), 2.0 + float32(i), 3.0 + float32(i), 4.0 + float32(i)},
}
cols = append(cols, &c)
}
return cols
}
func (s *UtilSuite) TestScoreMode() {
{
_, err := getMergeFunc[int64]("test")
s.ErrorContains(err, "Unsupport score mode")
}
{
f, err := getMergeFunc[int64]("avg")
s.NoError(err)
cols := mockCols(0)
r := f(cols)
s.Equal(0, len(r))
cols = mockCols(1)
r = f(cols)
s.Equal(r, map[int64]float32{1: 1.0, 2: 2.0, 3: 3.0, 4: 4.0})
cols = mockCols(3)
r = f(cols)
s.Equal(r, map[int64]float32{1: 2.0, 2: 3.0, 3: 4.0, 4: 5.0})
}
{
f, err := getMergeFunc[int64]("max")
s.NoError(err)
cols := mockCols(0)
r := f(cols)
s.Equal(0, len(r))
cols = mockCols(1)
r = f(cols)
s.Equal(r, map[int64]float32{1: 1.0, 2: 2.0, 3: 3.0, 4: 4.0})
cols = mockCols(3)
r = f(cols)
s.Equal(r, map[int64]float32{1: 3.0, 2: 4.0, 3: 5.0, 4: 6.0})
}
{
f, err := getMergeFunc[int64]("sum")
s.NoError(err)
cols := mockCols(0)
r := f(cols)
s.Equal(0, len(r))
cols = mockCols(1)
r = f(cols)
s.Equal(r, map[int64]float32{1: 1.0, 2: 2.0, 3: 3.0, 4: 4.0})
cols = mockCols(3)
r = f(cols)
s.Equal(r, map[int64]float32{1: 6.0, 2: 9.0, 3: 12.0, 4: 15.0})
}
}
func (s *UtilSuite) TestFuctionNormalize() {
{
f := getNormalizeFunc(false, metric.COSINE, false)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.COSINE, true)
s.Equal(float32((1+1.0)*0.5), f(1))
}
{
f := getNormalizeFunc(false, metric.COSINE, true)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.COSINE, false)
s.Equal(float32((1+1.0)*0.5), f(1))
}
{
f := getNormalizeFunc(false, metric.IP, true)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.IP, false)
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
}
{
f := getNormalizeFunc(false, metric.IP, true)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.IP, false)
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
}
{
f := getNormalizeFunc(false, metric.BM25, false)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.BM25, false)
s.Equal(2*float32(math.Atan(float64(1.0)))/math.Pi, f(1.0))
}
{
f := getNormalizeFunc(false, metric.BM25, true)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.BM25, true)
s.Equal(2*float32(math.Atan(float64(1.0)))/math.Pi, f(1.0))
}
{
f := getNormalizeFunc(false, metric.L2, true)
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
}
{
f := getNormalizeFunc(true, metric.L2, true)
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
}
{
f := getNormalizeFunc(false, metric.L2, false)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.L2, false)
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
}
}
func (s *UtilSuite) TestIsCrossMetrics() {
{
metrics := []string{metric.BM25}
mixed, descending := classifyMetricsOrder(metrics)
s.False(mixed)
s.True(descending)
}
{
metrics := []string{metric.BM25, metric.COSINE, metric.IP}
mixed, descending := classifyMetricsOrder(metrics)
s.False(mixed)
s.True(descending)
}
{
metrics := []string{metric.L2}
mixed, descending := classifyMetricsOrder(metrics)
s.False(mixed)
s.False(descending)
}
{
metrics := []string{metric.L2, metric.BM25}
mixed, descending := classifyMetricsOrder(metrics)
s.True(mixed)
s.True(descending)
}
{
metrics := []string{metric.L2, metric.COSINE}
mixed, descending := classifyMetricsOrder(metrics)
s.True(mixed)
s.True(descending)
}
{
metrics := []string{metric.L2, metric.IP}
mixed, descending := classifyMetricsOrder(metrics)
s.True(mixed)
s.True(descending)
}
}

View File

@ -22,13 +22,11 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
const (
@ -87,11 +85,14 @@ func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, s
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(cols)), fmt.Sprint(len(weighted.weight)), "the length of weights param mismatch with ann search requests")
}
weightedScores := map[T]float32{}
isMixd, descendingOrder := classifyMetricsOrder(searchParams.searchMetrics)
for i, col := range cols {
if col.size == 0 {
continue
}
normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i])
// If it is a mixed metric (L2 + IP), with both large to small sorting and small to large sorting,
// force the small to large sorting scores to be converted to large to small sorting
normFunc := getNormalizeFunc(weighted.needNorm, searchParams.searchMetrics[i], isMixd)
ids := col.ids.([]T)
for j, id := range ids {
if score, ok := weightedScores[id]; !ok {
@ -104,18 +105,13 @@ func (weighted *WeightedFunction[T]) processOneSearchData(ctx context.Context, s
if searchParams.isGrouping() {
return newGroupingIDScores(weightedScores, searchParams, idGroup)
}
return newIDScores(weightedScores, searchParams), nil
// If normlize is set, the final result is sorted from largest to smallest, otherwise it depends on descendingOrder
return newIDScores(weightedScores, searchParams, weighted.needNorm || descendingOrder), nil
}
func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
outputs := newRerankOutputs(searchParams)
for _, cols := range inputs.data {
for i, col := range cols {
metricType := searchParams.searchMetrics[i]
for j, score := range col.scores {
col.scores[j] = toGreaterScore(score, metricType)
}
}
idScore, err := weighted.processOneSearchData(ctx, searchParams, cols, inputs.idGroupValue)
if err != nil {
return nil, err
@ -124,31 +120,3 @@ func (weighted *WeightedFunction[T]) Process(ctx context.Context, searchParams *
}
return outputs, nil
}
type normalizeFunc func(float32) float32
func getNormalizeFunc(normScore bool, metrics string) normalizeFunc {
if !normScore {
return func(distance float32) float32 {
return distance
}
}
switch metrics {
case metric.COSINE:
return func(distance float32) float32 {
return (1 + distance) * 0.5
}
case metric.IP:
return func(distance float32) float32 {
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
}
case metric.BM25:
return func(distance float32) float32 {
return 2 * float32(math.Atan(float64(distance))) / math.Pi
}
default:
return func(distance float32) float32 {
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
}
}
}

View File

@ -20,7 +20,6 @@ package rerank
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/suite"
@ -28,7 +27,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
)
func TestWeightedFunction(t *testing.T) {
@ -273,26 +271,3 @@ func (s *WeightedFunctionSuite) TestWeightedFuctionProcess() {
ret.searchResultData.Ids.GetIntId().Data)
}
}
func (s *WeightedFunctionSuite) TestWeightedFuctionNormalize() {
{
f := getNormalizeFunc(false, metric.COSINE)
s.Equal(float32(1.0), f(1.0))
}
{
f := getNormalizeFunc(true, metric.COSINE)
s.Equal(float32((1+1.0)*0.5), f(1))
}
{
f := getNormalizeFunc(true, metric.IP)
s.Equal(0.5+float32(math.Atan(float64(1.0)))/math.Pi, f(1))
}
{
f := getNormalizeFunc(true, metric.BM25)
s.Equal(float32(2*math.Atan(float64(1.0)))/math.Pi, f(1.0))
}
{
f := getNormalizeFunc(true, metric.L2)
s.Equal((1.0 - 2*float32(math.Atan(float64(1.0)))/math.Pi), f(1.0))
}
}