milvus/internal/util/function/embedding/zilliz_embedding_provider_test.go
junjiejiangjjj d3164e8030
feat: add configurable batch factor and runtime check bypass for embedding functions (#45592)
https://github.com/milvus-io/milvus/issues/45544
- Add batch_factor configuration parameter (default: 5) to control
embedding provider batch sizes
- Add disable_func_runtime_check property to bypass function validation
during collection creation
- Add database interceptor support for AddCollectionFunction,
AlterCollectionFunction, and DropCollectionFunction requests

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
2025-11-20 19:55:04 +08:00

560 lines
16 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package embedding
// This file contains unit tests for the ZillizEmbeddingProvider.
// Due to the dependency on the ZillizClient which requires gRPC connections,
// these tests focus on testing the logic that can be tested in isolation:
// - Parameter extraction and validation
// - Method behavior (MaxBatch, FieldDim)
// - Batching logic
// - Input type parameter setting
// - Edge cases and constants
import (
"context"
"testing"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models"
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
)
func TestZillizEmbeddingProvider(t *testing.T) {
suite.Run(t, new(ZillizEmbeddingProviderSuite))
}
type ZillizEmbeddingProviderSuite struct {
suite.Suite
fieldSchema *schemapb.FieldSchema
functionSchema *schemapb.FunctionSchema
params map[string]string
extraInfo *models.ModelExtraInfo
}
func (s *ZillizEmbeddingProviderSuite) SetupTest() {
s.fieldSchema = &schemapb.FieldSchema{
FieldID: 102,
Name: "vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
}
s.functionSchema = &schemapb.FunctionSchema{
Name: "test_zilliz_embedding",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment-id"},
{Key: "custom_param", Value: "custom_value"},
},
}
s.params = map[string]string{
"api_key": "test-api-key",
}
s.extraInfo = &models.ModelExtraInfo{
ClusterID: "test-cluster-id",
DBName: "test-db",
BatchFactor: 5,
}
}
func (s *ZillizEmbeddingProviderSuite) TestParameterExtraction() {
// Test parameter extraction logic from function schema
functionSchema := &schemapb.FunctionSchema{
Params: []*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
{Key: "model_param1", Value: "value1"},
{Key: "model_param2", Value: "value2"},
},
}
// Test parameter extraction logic (same as in NewZillizEmbeddingProvider)
var modelDeploymentID string
modelParams := map[string]string{}
for _, param := range functionSchema.Params {
switch param.Key {
case models.ModelDeploymentIDKey:
modelDeploymentID = param.Value
default:
modelParams[param.Key] = param.Value
}
}
s.Equal("test-deployment", modelDeploymentID)
s.Equal("value1", modelParams["model_param1"])
s.Equal("value2", modelParams["model_param2"])
s.NotContains(modelParams, models.ModelDeploymentIDKey)
}
func (s *ZillizEmbeddingProviderSuite) TestNewZillizEmbeddingProvider_InvalidDimension() {
// Test with invalid dimension
invalidFieldSchema := &schemapb.FieldSchema{
FieldID: 102,
Name: "vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "invalid"},
},
}
provider, err := NewZillizEmbeddingProvider(invalidFieldSchema, s.functionSchema, s.params, s.extraInfo)
s.Error(err)
s.Nil(provider)
}
func (s *ZillizEmbeddingProviderSuite) TestMaxBatch() {
// Test MaxBatch method with a provider that has default maxBatch
provider := &ZillizEmbeddingProvider{
maxBatch: 64,
extraInfo: &models.ModelExtraInfo{BatchFactor: 5},
}
maxBatch := provider.MaxBatch()
s.Equal(5*64, maxBatch) // 5 * provider.maxBatch
}
func (s *ZillizEmbeddingProviderSuite) TestFieldDim() {
// Test FieldDim method
provider := &ZillizEmbeddingProvider{
fieldDim: 128,
}
fieldDim := provider.FieldDim()
s.Equal(int64(128), fieldDim)
}
func (s *ZillizEmbeddingProviderSuite) TestBatchingLogic() {
// Test the batching logic used in CallEmbedding
numRows := 25
maxBatch := 10
// Simulate the batching loop from CallEmbedding
batches := []struct{ start, end int }{}
for i := 0; i < numRows; i += maxBatch {
end := i + maxBatch
if end > numRows {
end = numRows
}
batches = append(batches, struct{ start, end int }{i, end})
}
// Should have 3 batches: [0,10), [10,20), [20,25)
s.Len(batches, 3)
s.Equal(0, batches[0].start)
s.Equal(10, batches[0].end)
s.Equal(10, batches[1].start)
s.Equal(20, batches[1].end)
s.Equal(20, batches[2].start)
s.Equal(25, batches[2].end)
}
func (s *ZillizEmbeddingProviderSuite) TestInputTypeParameterSetting() {
// Test that input_type parameter is set correctly for different modes
provider := &ZillizEmbeddingProvider{
modelParams: make(map[string]string),
}
// Simulate the parameter setting logic from CallEmbedding
// For SearchMode
provider.modelParams["input_type"] = "query"
s.Equal("query", provider.modelParams["input_type"])
// For InsertMode (non-SearchMode)
provider.modelParams["input_type"] = "document"
s.Equal("document", provider.modelParams["input_type"])
}
func (s *ZillizEmbeddingProviderSuite) TestDefaultValues() {
// Test that default values are set correctly in NewZillizEmbeddingProvider
// We can't test the full constructor due to the zilliz client dependency,
// but we can test the default values that should be set
expectedMaxBatch := 64
expectedTimeoutSec := int64(30)
// These are the default values that should be set in the constructor
s.Equal(64, expectedMaxBatch)
s.Equal(int64(30), expectedTimeoutSec)
}
func (s *ZillizEmbeddingProviderSuite) TestEdgeCases() {
// Test edge cases for batching logic
// Test with zero texts
numRows := 0
maxBatch := 10
batchCount := 0
for i := 0; i < numRows; i += maxBatch {
batchCount++
}
s.Equal(0, batchCount)
// Test with exactly one batch
numRows = 10
maxBatch = 10
batchCount = 0
for i := 0; i < numRows; i += maxBatch {
batchCount++
}
s.Equal(1, batchCount)
// Test with one more than batch size
numRows = 11
maxBatch = 10
batchCount = 0
for i := 0; i < numRows; i += maxBatch {
batchCount++
}
s.Equal(2, batchCount)
}
func (s *ZillizEmbeddingProviderSuite) TestModeConstants() {
// Test that the embedding modes are correctly defined
s.Equal(models.TextEmbeddingMode(0), models.InsertMode)
s.Equal(models.TextEmbeddingMode(1), models.SearchMode)
}
func (s *ZillizEmbeddingProviderSuite) TestConstantValues() {
// Test the constant values used in the provider
s.Equal("test-cluster-id", s.extraInfo.ClusterID)
s.Equal("test-db", s.extraInfo.DBName)
}
func (s *ZillizEmbeddingProviderSuite) TestNewZillizEmbeddingProviderWithInvalidFieldSchema() {
fieldSchema := &schemapb.FieldSchema{
FieldID: 102,
Name: "vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
}
_, err := NewZillizEmbeddingProvider(fieldSchema, s.functionSchema, s.params, s.extraInfo)
s.Error(err)
}
// CallEmbedding tests
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_SearchMode() {
// Create a provider with mock client
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
fieldDim: 4,
maxBatch: 10,
modelParams: make(map[string]string),
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{"hello", "world"}
mode := models.SearchMode
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
// Verify that input_type is set to "query" for SearchMode
s.Equal("query", params["input_type"])
// Return mock embeddings
embeddings := make([][]float32, len(texts))
for i := range texts {
embeddings[i] = []float32{1.0, 2.0, 3.0, 4.0}
}
return embeddings, nil
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.NoError(err)
s.NotNil(result)
// Verify result type and content
embeddings, ok := result.([][]float32)
s.True(ok)
s.Len(embeddings, 2)
s.Equal([]float32{1.0, 2.0, 3.0, 4.0}, embeddings[0])
s.Equal([]float32{1.0, 2.0, 3.0, 4.0}, embeddings[1])
// Verify that input_type was set correctly
s.Equal("query", provider.modelParams["input_type"])
}
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_InsertMode() {
// Create a provider with mock client
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
fieldDim: 4,
maxBatch: 10,
modelParams: make(map[string]string),
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{"document1", "document2"}
mode := models.InsertMode
// Set up mock to verify parameters
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
// Verify that input_type is set to "document" for InsertMode
s.Equal("document", params["input_type"])
// Return mock embeddings
embeddings := make([][]float32, len(texts))
for i := range texts {
embeddings[i] = []float32{2.0, 3.0, 4.0, 5.0}
}
return embeddings, nil
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.NoError(err)
s.NotNil(result)
// Verify result type and content
embeddings, ok := result.([][]float32)
s.True(ok)
s.Len(embeddings, 2)
s.Equal([]float32{2.0, 3.0, 4.0, 5.0}, embeddings[0])
s.Equal([]float32{2.0, 3.0, 4.0, 5.0}, embeddings[1])
// Verify that input_type was set correctly
s.Equal("document", provider.modelParams["input_type"])
}
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_Batching() {
// Create a provider with small batch size to test batching
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
maxBatch: 3, // Small batch size to force batching
modelParams: make(map[string]string),
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{"text1", "text2", "text3", "text4", "text5"} // 5 texts, batch size 3
mode := models.InsertMode
callCount := 0
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
callCount++
// First batch should have 3 texts, second batch should have 2 texts
if callCount == 1 {
s.Len(texts, 3)
s.Equal([]string{"text1", "text2", "text3"}, texts)
} else if callCount == 2 {
s.Len(texts, 2)
s.Equal([]string{"text4", "text5"}, texts)
}
// Return mock embeddings for this batch
embeddings := make([][]float32, len(texts))
for i := range texts {
embeddings[i] = []float32{float32(callCount), float32(i), 0.0, 0.0}
}
return embeddings, nil
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.NoError(err)
s.NotNil(result)
// Verify that client was called twice (batching worked)
s.Equal(2, callCount)
// Verify result
embeddings, ok := result.([][]float32)
s.True(ok)
s.Len(embeddings, 5) // All 5 embeddings should be returned
// Verify embeddings from first batch
s.Equal([]float32{1.0, 0.0, 0.0, 0.0}, embeddings[0])
s.Equal([]float32{1.0, 1.0, 0.0, 0.0}, embeddings[1])
s.Equal([]float32{1.0, 2.0, 0.0, 0.0}, embeddings[2])
// Verify embeddings from second batch
s.Equal([]float32{2.0, 0.0, 0.0, 0.0}, embeddings[3])
s.Equal([]float32{2.0, 1.0, 0.0, 0.0}, embeddings[4])
}
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_Error() {
// Create a provider with mock client that returns error
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
fieldDim: 4,
maxBatch: 10,
modelParams: make(map[string]string),
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{"hello", "world"}
mode := models.SearchMode
expectedError := errors.New("embedding service error")
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
return nil, expectedError
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.Error(err)
s.Nil(result)
s.Equal(expectedError, err)
}
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_EmptyTexts() {
// Create a provider with mock client
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
fieldDim: 4,
maxBatch: 10,
modelParams: make(map[string]string),
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{} // Empty texts
mode := models.InsertMode
callCount := 0
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
callCount++
return [][]float32{}, nil
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.NoError(err)
s.NotNil(result)
// Verify that client was not called for empty texts
s.Equal(0, callCount)
// Verify result
embeddings, ok := result.([][]float32)
s.True(ok)
s.Len(embeddings, 0)
}
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_SingleBatch() {
// Test with texts that fit exactly in one batch
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
fieldDim: 4,
maxBatch: 5, // Batch size 5
modelParams: make(map[string]string),
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{"text1", "text2", "text3", "text4", "text5"} // Exactly 5 texts
mode := models.SearchMode
callCount := 0
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
callCount++
s.Len(texts, 5)
s.Equal("query", params["input_type"])
// Return mock embeddings
embeddings := make([][]float32, len(texts))
for i := range texts {
embeddings[i] = []float32{float32(i), 1.0, 2.0, 3.0}
}
return embeddings, nil
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.NoError(err)
s.NotNil(result)
// Verify that client was called exactly once
s.Equal(1, callCount)
// Verify result
embeddings, ok := result.([][]float32)
s.True(ok)
s.Len(embeddings, 5)
for i := 0; i < 5; i++ {
s.Equal([]float32{float32(i), 1.0, 2.0, 3.0}, embeddings[i])
}
}
func (s *ZillizEmbeddingProviderSuite) TestCallEmbedding_ModelParamsPreservation() {
// Test that existing model params are preserved and input_type is added
client := &zilliz.ZillizClient{}
provider := &ZillizEmbeddingProvider{
client: client,
fieldDim: 4,
maxBatch: 10,
modelParams: map[string]string{
"existing_param": "existing_value",
"another_param": "another_value",
},
extraInfo: s.extraInfo,
}
ctx := context.Background()
texts := []string{"test"}
mode := models.SearchMode
mock := mockey.Mock((*zilliz.ZillizClient).Embedding).To(func(_ *zilliz.ZillizClient, ctx context.Context, texts []string, params map[string]string) ([][]float32, error) {
// Verify that existing params are preserved and input_type is added
s.Equal("existing_value", params["existing_param"])
s.Equal("another_value", params["another_param"])
s.Equal("query", params["input_type"])
s.Len(params, 3) // Should have 3 parameters total
return [][]float32{{1.0, 2.0, 3.0, 4.0}}, nil
}).Build()
defer mock.UnPatch()
result, err := provider.CallEmbedding(ctx, texts, mode)
s.NoError(err)
s.NotNil(result)
// Verify that the provider's modelParams were updated
s.Equal("query", provider.modelParams["input_type"])
s.Equal("existing_value", provider.modelParams["existing_param"])
s.Equal("another_value", provider.modelParams["another_param"])
}