From 71dc135289edf4c69daa463af483bf1ba4697966 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Tue, 19 Aug 2025 14:57:48 +0800 Subject: [PATCH] test: add rerank function testcases in go client (#43891) /kind improvement Signed-off-by: zhuwenxing --- client/entity/function.go | 13 +- client/entity/function_test.go | 30 + .../go_client/testcases/helper/test_setup.go | 7 +- .../testcases/reranker_function_test.go | 683 ++++++++++++++++++ 4 files changed, 731 insertions(+), 2 deletions(-) create mode 100644 tests/go_client/testcases/reranker_function_test.go diff --git a/client/entity/function.go b/client/entity/function.go index 1c172bbffd..26e691347f 100644 --- a/client/entity/function.go +++ b/client/entity/function.go @@ -17,7 +17,9 @@ package entity import ( + "encoding/json" "fmt" + "reflect" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -74,7 +76,16 @@ func (f *Function) WithType(funcType FunctionType) *Function { } func (f *Function) WithParam(key string, value any) *Function { - f.Params[key] = fmt.Sprintf("%v", value) + // Handle slices by converting to JSON format + if reflect.TypeOf(value).Kind() == reflect.Slice { + if jsonBytes, err := json.Marshal(value); err == nil { + f.Params[key] = string(jsonBytes) + } else { + f.Params[key] = fmt.Sprintf("%v", value) + } + } else { + f.Params[key] = fmt.Sprintf("%v", value) + } return f } diff --git a/client/entity/function_test.go b/client/entity/function_test.go index 6e73596bcd..5c7a73305f 100644 --- a/client/entity/function_test.go +++ b/client/entity/function_test.go @@ -46,3 +46,33 @@ func TestFunctionSchema(t *testing.T) { assert.Equal(t, function.Params, nf.Params) } } + +func TestFunctionWithParamSliceHandling(t *testing.T) { + // Test slice value handling - the main change in this PR + f := NewFunction().WithParam("slice_key", []string{"a", "b", "c"}) + assert.Equal(t, `["a","b","c"]`, f.Params["slice_key"]) + + // Test int slice + f = NewFunction().WithParam("int_slice_key", []int{1, 2, 3}) + assert.Equal(t, "[1,2,3]", f.Params["int_slice_key"]) + + // Test float slice + f = NewFunction().WithParam("float_slice_key", []float64{1.1, 2.2, 3.3}) + assert.Equal(t, "[1.1,2.2,3.3]", f.Params["float_slice_key"]) + + // Test empty slice + f = NewFunction().WithParam("empty_slice_key", []string{}) + assert.Equal(t, "[]", f.Params["empty_slice_key"]) + + // Test the JSON marshal error fallback path + type unmarshalableType struct { + Channel chan int + } + f = NewFunction().WithParam("complex_slice_key", []unmarshalableType{{Channel: make(chan int)}}) + // This should fallback to fmt.Sprintf since channels can't be JSON marshaled + assert.Contains(t, f.Params["complex_slice_key"], "0x") + + // Test non-slice value (existing behavior should remain unchanged) + f = NewFunction().WithParam("string_key", "string_value") + assert.Equal(t, "string_value", f.Params["string_key"]) +} diff --git a/tests/go_client/testcases/helper/test_setup.go b/tests/go_client/testcases/helper/test_setup.go index c27b7488a9..40e6c9224b 100644 --- a/tests/go_client/testcases/helper/test_setup.go +++ b/tests/go_client/testcases/helper/test_setup.go @@ -15,11 +15,12 @@ import ( ) var ( - addr = flag.String("addr", "localhost:19530", "server host and port") + addr = flag.String("addr", "http://localhost:19530", "server host and port") user = flag.String("user", "root", "user") password = flag.String("password", "Milvus", "password") logLevel = flag.String("log.level", "info", "log level for test") teiEndpoint = flag.String("tei_endpoint", "http://text-embeddings-service.milvus-ci.svc.cluster.local:80", "TEI service endpoint for text embedding tests") + teiRerankerEndpoint = flag.String("tei_reranker_uri", "http://text-rerank-service.milvus-ci.svc.cluster.local:80", "TEI reranker service endpoint") teiModelDim = flag.Int("tei_model_dim", 768, "Vector dimension for text embedding model") defaultClientConfig *client.ClientConfig ) @@ -48,6 +49,10 @@ func GetTEIEndpoint() string { return *teiEndpoint } +func GetTEIRerankerEndpoint() string { + return *teiRerankerEndpoint +} + func GetTEIModelDim() int { return *teiModelDim } diff --git a/tests/go_client/testcases/reranker_function_test.go b/tests/go_client/testcases/reranker_function_test.go new file mode 100644 index 0000000000..7a0c79efab --- /dev/null +++ b/tests/go_client/testcases/reranker_function_test.go @@ -0,0 +1,683 @@ +package testcases + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus/client/v2/entity" + client "github.com/milvus-io/milvus/client/v2/milvusclient" + "github.com/milvus-io/milvus/tests/go_client/base" + "github.com/milvus-io/milvus/tests/go_client/common" + hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" +) + +const ( + defaultTimestamp = int64(1700000000) +) + +func createRerankFunctionTestCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient, enableText bool) (*hp.CollectionPrepare, *entity.Schema) { + fields := hp.AllFields + if enableText { + fields = hp.FullTextSearch + } + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(fields), + hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), hp.TWithConsistencyLevel(entity.ClStrong)) + + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + + return prepare, schema +} + +func generateTestQueries() []string { + return []string{ + "machine learning algorithms for time series forecasting", + "deep neural networks and artificial intelligence", + "data science and statistical analysis methods", + "computer vision and image recognition", + "natural language processing techniques", + } +} + +func validateRerankFunctionResults(t *testing.T, results []client.ResultSet, expectedLimit int) { + require.Greater(t, len(results), 0, "Should have search results") + for _, res := range results { + require.LessOrEqual(t, res.ResultCount, expectedLimit, "Result count should not exceed limit") + require.Equal(t, res.IDs.Len(), res.ResultCount, "IDs length should match result count") + require.Equal(t, len(res.Scores), res.ResultCount, "Scores length should match result count") + } +} + +func TestRerankFunctionWeighted(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloat16Vector) + + testCases := []struct { + name string + weights []float64 + normScore bool + }{ + {"equal_weights", []float64{0.5, 0.5}, true}, + {"prefer_first", []float64{0.8, 0.2}, true}, + {"prefer_second", []float64{0.3, 0.7}, true}, + {"no_normalization", []float64{0.6, 0.4}, false}, + {"sum_not_one", []float64{0.3, 0.4}, true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + weightedReranker := entity.NewFunction(). + WithName("test_weighted_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "weighted"). + WithParam("weights", tc.weights). + WithParam("norm_score", tc.normScore) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + results, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(weightedReranker).WithOutputFields("*")) + + common.CheckErr(t, err, true) + validateRerankFunctionResults(t, results, common.DefaultLimit) + }) + } +} + +func TestRerankFunctionDecay(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloat16Vector) + + testCases := []struct { + name string + function string + origin int64 + scale int + decay float64 + }{ + {"linear_decay", "linear", defaultTimestamp + 3600, 3600, 0.1}, + {"exp_decay", "exp", defaultTimestamp + 7200, 7200, 0.2}, + {"gauss_decay", "gauss", defaultTimestamp + 1800, 1800, 0.15}, + {"large_scale", "linear", defaultTimestamp, 86400, 0.05}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decayReranker := entity.NewFunction(). + WithName("test_decay_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultInt64FieldName). + WithParam("reranker", "decay"). + WithParam("function", tc.function). + WithParam("origin", tc.origin). + WithParam("scale", tc.scale). + WithParam("decay", tc.decay) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + results, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(decayReranker).WithOutputFields("*")) + + common.CheckErr(t, err, true) + validateRerankFunctionResults(t, results, common.DefaultLimit) + + for _, res := range results { + require.Greater(t, res.ResultCount, 0, "Should have results for decay reranker") + timestampCol := res.GetColumn(common.DefaultInt64FieldName) + require.NotNil(t, timestampCol, "Should have timestamp field in results") + } + }) + } +} + +func TestRerankFunctionModel(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*3) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, true) + + queryVec1 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector) + queryVec2 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector) + + queries := generateTestQueries() + + testCases := []struct { + name string + provider string + endpoint string + queries []string + }{ + {"tei_provider", "tei", hp.GetTEIRerankerEndpoint(), queries[:common.DefaultNq]}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + modelReranker := entity.NewFunction(). + WithName("test_model_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultTextFieldName). + WithParam("reranker", "model"). + WithParam("provider", tc.provider). + WithParam("queries", tc.queries). + WithParam("endpoint", tc.endpoint) + + annReq1 := client.NewAnnRequest(common.DefaultTextSparseVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultTextSparseVecFieldName, common.DefaultLimit, queryVec2...) + + results, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(modelReranker).WithOutputFields("*")) + + common.CheckErr(t, err, true) + validateRerankFunctionResults(t, results, common.DefaultLimit) + + for _, res := range results { + textCol := res.GetColumn(common.DefaultTextFieldName) + require.NotNil(t, textCol, "Should have text field in results") + } + }) + } +} + +func TestRerankFunctionInvalidParams(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + testCases := []struct { + name string + function *entity.Function + expectedError string + }{ + { + "invalid_reranker_type", + entity.NewFunction(). + WithName("invalid_type"). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "invalid_type"), + "Unsupported rerank function", + }, + { + "weighted_invalid_weights", + entity.NewFunction(). + WithName("invalid_weights"). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "weighted"). + WithParam("weights", "invalid_format"), + "Parse weights param failed", + }, + { + "decay_missing_params", + entity.NewFunction(). + WithName("missing_params"). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultInt64FieldName). + WithParam("reranker", "decay"), + "Decay function lost param", + }, + { + "model_missing_endpoint", + entity.NewFunction(). + WithName("missing_endpoint"). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultVarcharFieldName). + WithParam("reranker", "model"). + WithParam("provider", "tei"), + "Rerank function lost params", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(tc.function)) + + common.CheckErr(t, err, false, tc.expectedError) + }) + } +} + +func TestRerankFunctionMissingFields(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + testCases := []struct { + name string + function *entity.Function + }{ + { + "decay_nonexistent_field", + entity.NewFunction(). + WithName("nonexistent_field"). + WithType(entity.FunctionTypeRerank). + WithInputFields("nonexistent_field"). + WithParam("reranker", "decay"). + WithParam("function", "linear"). + WithParam("origin", "1700000000"). + WithParam("scale", "3600"). + WithParam("decay", "0.1"), + }, + { + "model_nonexistent_field", + entity.NewFunction(). + WithName("nonexistent_text"). + WithType(entity.FunctionTypeRerank). + WithInputFields("nonexistent_text"). + WithParam("reranker", "model"). + WithParam("provider", "tei"). + WithParam("queries", []string{"test query"}). + WithParam("endpoint", hp.GetTEIRerankerEndpoint()), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(tc.function)) + + common.CheckErr(t, err, false, "field not found", "nonexistent") + }) + } +} + +func TestRerankFunctionRRF(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloat16Vector) + + testCases := []struct { + name string + k int + }{ + {"default_k", 60}, + {"small_k", 10}, + {"large_k", 100}, + {"very_large_k", 1000}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rrfReranker := entity.NewFunction(). + WithName("test_rrf_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "rrf"). + WithParam("k", tc.k) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + results, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(rrfReranker).WithOutputFields("*")) + + common.CheckErr(t, err, true) + validateRerankFunctionResults(t, results, common.DefaultLimit) + }) + } +} + +func TestRerankFunctionDecaySingleVector(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + + testCases := []struct { + name string + function string + origin int64 + scale int + decay float64 + }{ + {"linear_decay_single", "linear", defaultTimestamp + 3600, 3600, 0.1}, + {"exp_decay_single", "exp", defaultTimestamp + 7200, 7200, 0.2}, + {"gauss_decay_single", "gauss", defaultTimestamp + 1800, 1800, 0.15}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decayReranker := entity.NewFunction(). + WithName("test_decay_single_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultInt64FieldName). + WithParam("reranker", "decay"). + WithParam("function", tc.function). + WithParam("origin", tc.origin). + WithParam("scale", tc.scale). + WithParam("decay", tc.decay) + + results, err := mc.Search(ctx, client.NewSearchOption( + schema.CollectionName, common.DefaultLimit, queryVec, + ).WithANNSField(common.DefaultFloatVecFieldName). + WithFunctionReranker(decayReranker). + WithOutputFields("*")) + + common.CheckErr(t, err, true) + require.Greater(t, len(results), 0, "Should have search results") + + for _, res := range results { + require.LessOrEqual(t, res.ResultCount, common.DefaultLimit, "Result count should not exceed limit") + require.Greater(t, res.ResultCount, 0, "Should have results for decay reranker") + timestampCol := res.GetColumn(common.DefaultInt64FieldName) + require.NotNil(t, timestampCol, "Should have timestamp field in results") + } + }) + } +} + +func TestRerankFunctionModelSingleVector(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*3) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, true) + + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector) + queries := generateTestQueries() + + testCases := []struct { + name string + provider string + endpoint string + queries []string + }{ + {"tei_provider_single", "tei", hp.GetTEIRerankerEndpoint(), queries[:common.DefaultNq]}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + modelReranker := entity.NewFunction(). + WithName("test_model_single_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultTextFieldName). + WithParam("reranker", "model"). + WithParam("provider", tc.provider). + WithParam("queries", tc.queries). + WithParam("endpoint", tc.endpoint) + + results, err := mc.Search(ctx, client.NewSearchOption( + schema.CollectionName, common.DefaultLimit, queryVec, + ).WithANNSField(common.DefaultTextSparseVecFieldName). + WithFunctionReranker(modelReranker). + WithOutputFields("*")) + + common.CheckErr(t, err, true) + require.Greater(t, len(results), 0, "Should have search results") + + for _, res := range results { + require.LessOrEqual(t, res.ResultCount, common.DefaultLimit, "Result count should not exceed limit") + textCol := res.GetColumn(common.DefaultTextFieldName) + require.NotNil(t, textCol, "Should have text field in results") + } + }) + } +} + +func TestRerankFunctionEmptyResults(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloat16Vector) + + impossibleFilter := fmt.Sprintf("%s > %d", common.DefaultInt64FieldName, common.DefaultNb*10) + + weightedReranker := entity.NewFunction(). + WithName("test_empty_results"). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "weighted"). + WithParam("weights", []float64{0.5, 0.5}). + WithParam("norm_score", true) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...).WithFilter(impossibleFilter) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...).WithFilter(impossibleFilter) + + results, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(weightedReranker)) + + common.CheckErr(t, err, true) + require.Len(t, results, common.DefaultNq) + for _, res := range results { + require.Equal(t, 0, res.ResultCount, "Should have no results with impossible filter") + } +} + +func TestRerankFunctionWeightedNegative(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + testCases := []struct { + name string + weights interface{} + normScore bool + expectedError string + }{ + {"invalid_weights_format", "invalid_format", true, "Parse weights param failed"}, + {"empty_weights", []float64{}, true, "weights not found"}, + {"negative_weights", []float64{-0.5, 0.5}, true, "rank param weight should be in range [0, 1]"}, + {"mismatched_weights_count", []float64{0.3}, true, "the length of weights param mismatch with ann search requests"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + weightedReranker := entity.NewFunction(). + WithName("test_weighted_negative_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "weighted"). + WithParam("weights", tc.weights). + WithParam("norm_score", tc.normScore) + + _, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(weightedReranker)) + + common.CheckErr(t, err, false, tc.expectedError) + }) + } +} + +func TestRerankFunctionRRFNegative(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + testCases := []struct { + name string + k interface{} + expectedError string + }{ + {"negative_k", -10, "k should be in range (0, 16384)"}, + {"zero_k", 0, "k should be in range (0, 16384)"}, + {"too_large_k", 20000, "k should be in range (0, 16384)"}, + {"invalid_k_format", "invalid", "is not a number"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rrfReranker := entity.NewFunction(). + WithName("test_rrf_negative_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(). + WithParam("reranker", "rrf"). + WithParam("k", tc.k) + + _, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(rrfReranker)) + + common.CheckErr(t, err, false, tc.expectedError) + }) + } +} + +func TestRerankFunctionDecayNegative(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, false) + + queryVec1 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := hp.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) + + annReq1 := client.NewAnnRequest(common.DefaultFloatVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultFloat16VecFieldName, common.DefaultLimit, queryVec2...) + + testCases := []struct { + name string + function interface{} + origin interface{} + scale interface{} + decay interface{} + expectedError string + }{ + {"invalid_function_type", "invalid", defaultTimestamp, 3600, 0.1, "Invaild decay function"}, + {"negative_scale", "linear", defaultTimestamp, -3600, 0.1, "scale must > 0"}, + {"invalid_origin_format", "linear", "invalid", 3600, 0.1, "is not a number"}, + {"invalid_decay_range", "linear", defaultTimestamp, 3600, 1.5, "decay must 0 < decay < 1"}, + {"zero_decay", "linear", defaultTimestamp, 3600, 0.0, "decay must 0 < decay < 1"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decayReranker := entity.NewFunction(). + WithName("test_decay_negative_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultInt64FieldName). + WithParam("reranker", "decay"). + WithParam("function", tc.function). + WithParam("origin", tc.origin). + WithParam("scale", tc.scale). + WithParam("decay", tc.decay) + + _, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(decayReranker)) + + common.CheckErr(t, err, false, tc.expectedError) + }) + } +} + +func TestRerankFunctionModelNegative(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + _, schema := createRerankFunctionTestCollection(ctx, t, mc, true) + + queryVec1 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector) + queryVec2 := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector) + + queries := generateTestQueries() + + testCases := []struct { + name string + provider string + endpoint string + queries []string + expectedError string + }{ + {"invalid_endpoint", "tei", "http://invalid:8080", queries[:common.DefaultNq], "Call rerank model failed"}, + {"empty_endpoint", "tei", "", queries[:common.DefaultNq], "is not a valid http/https link"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + modelReranker := entity.NewFunction(). + WithName("test_model_negative_"+tc.name). + WithType(entity.FunctionTypeRerank). + WithInputFields(common.DefaultTextFieldName). + WithParam("reranker", "model"). + WithParam("provider", tc.provider). + WithParam("queries", tc.queries). + WithParam("endpoint", tc.endpoint) + + annReq1 := client.NewAnnRequest(common.DefaultTextSparseVecFieldName, common.DefaultLimit, queryVec1...) + annReq2 := client.NewAnnRequest(common.DefaultTextSparseVecFieldName, common.DefaultLimit, queryVec2...) + + _, err := mc.HybridSearch(ctx, client.NewHybridSearchOption( + schema.CollectionName, common.DefaultLimit, annReq1, annReq2, + ).WithFunctionRerankers(modelReranker).WithOutputFields("*")) + + common.CheckErr(t, err, false, tc.expectedError) + }) + } +}