From fe81c7baaefa577d40bc2a48253d21b2326b72a2 Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Tue, 25 Mar 2025 10:06:24 +0800 Subject: [PATCH] feat: Add function config (#40534) #35856 1. Add function-related configuration in milvus.yaml 2. Add null and empty value check to TextEmbeddingFunction Signed-off-by: junjie.jiang --- cmd/tools/config/generate.go | 5 + configs/milvus.yaml | 30 +++++ .../util/function/ali_embedding_provider.go | 8 +- .../alitext_embedding_provider_test.go | 10 +- .../function/bedrock_embedding_provider.go | 57 +++++++--- .../bedrock_text_embedding_provider_test.go | 13 ++- .../function/cohere_embedding_provider.go | 8 +- .../cohere_embedding_provider_test.go | 19 +--- internal/util/function/common.go | 50 +++++++-- .../util/function/function_executor_test.go | 5 + .../function/models/utils/embedding_util.go | 4 +- .../function/openai_embedding_provider.go | 31 +++--- .../openai_text_embedding_provider_test.go | 20 +--- .../siliconflow_embedding_provider.go | 8 +- .../siliconflow_embedding_provider_test.go | 10 +- .../util/function/tei_embedding_provider.go | 2 +- .../function/tei_embedding_provider_test.go | 14 +-- .../util/function/text_embedding_function.go | 44 ++++++-- .../function/text_embedding_function_test.go | 66 ++++++++++- .../function/vertexai_embedding_provider.go | 57 ++++++---- .../vertexai_embedding_provider_test.go | 14 +-- .../function/voyageai_embedding_provider.go | 8 +- .../voyageai_embedding_provider_test.go | 16 +-- pkg/util/paramtable/component_param.go | 2 + pkg/util/paramtable/function_param.go | 105 ++++++++++++++++++ pkg/util/paramtable/function_param_test.go | 65 +++++++++++ .../test_text_embedding_function_e2e.py | 3 + 27 files changed, 498 insertions(+), 176 deletions(-) create mode 100644 pkg/util/paramtable/function_param.go create mode 100644 pkg/util/paramtable/function_param_test.go diff --git a/cmd/tools/config/generate.go b/cmd/tools/config/generate.go index 7b1f176fd8..0e012347bb 100644 --- a/cmd/tools/config/generate.go +++ b/cmd/tools/config/generate.go @@ -351,6 +351,11 @@ func WriteYaml(w io.Writer) { header: ` # Any configuration related to the knowhere vector search engine`, }, + { + name: "function", + header: ` +# Any configuration related to functions`, + }, } marshller := YamlMarshaller{w, groups, result} marshller.writeYamlRecursive(lo.Filter(result, func(d DocContent, _ int) bool { diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 76042ab3d7..729dd5385f 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1157,3 +1157,33 @@ knowhere: search_list_size: 100 # Size of the candidate list during building graph search: beam_width_ratio: 4 # Ratio between the maximum number of IO requests per search iteration and CPU number + +# Any configuration related to functions +function: + textEmbedding: + enableVerifiInfoInParams: true # Controls whether to allow configuration of apikey and model service url on function parameters + providers: + azure_openai: + api_key: # Your azure openai embedding url, Default is the official embedding url + resource_name: # Your azure openai resource name + url: # Your azure openai api key + bedrock: + aws_access_key_id: # Your aws_access_key_id + aws_secret_access_key: # Your aws_secret_access_key + cohere: + api_key: # Your cohere embedding url, Default is the official embedding url + url: # Your cohere api key + dashscope: + api_key: # Your dashscope embedding url, Default is the official embedding url + url: # Your dashscope api key + openai: + api_key: # Your openai embedding url, Default is the official embedding url + url: # Your openai api key + siliconflow: + api_key: # Your siliconflow api key + url: # Your siliconflow embedding url, Default is the official embedding url + tei: + enable: true # Whether to enable TEI model service + vertexai: + credentials_file_path: # Path to your google application credentials, change the file path to refresh the configuration + url: # Your VertexAI embedding url diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index 7866ad795f..73c119a665 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -20,7 +20,6 @@ package function import ( "fmt" - "os" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -41,9 +40,6 @@ type AliEmbeddingProvider struct { } func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { - if apiKey == "" { - apiKey = os.Getenv(dashscopeAKEnvStr) - } if apiKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr) } @@ -55,12 +51,12 @@ func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbed return c, nil } -func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) { +func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*AliEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err } - apiKey, url := parseAKAndURL(functionSchema.Params) + apiKey, url := parseAKAndURL(functionSchema.Params, params, dashscopeAKEnvStr) var modelName string var dim int64 diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 368520c719..fff23bccb0 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -23,7 +23,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/suite" @@ -77,7 +76,7 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st } switch providerName { case aliDashScopeProvider: - return NewAliDashScopeEmbeddingProvider(schema, functionSchema) + return NewAliDashScopeEmbeddingProvider(schema, functionSchema, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -170,11 +169,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() { func (s *AliTextEmbeddingProviderSuite) TestCreateAliEmbeddingClient() { _, err := createAliEmbeddingClient("", "") s.Error(err) - - os.Setenv(dashscopeAKEnvStr, "mock_key") - defer os.Unsetenv(dashscopeAKEnvStr) - _, err = createAliEmbeddingClient("", "") - s.NoError(err) } func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() { @@ -193,6 +187,6 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() { } // invalid dim functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} - _, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index 030eb8ad96..a21b0b6f4a 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -51,16 +52,9 @@ type BedrockEmbeddingProvider struct { } func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) { - if awsAccessKeyId == "" { - awsAccessKeyId = os.Getenv(bedrockAccessKeyId) - } if awsAccessKeyId == "" { return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId) } - - if awsSecretAccessKey == "" { - awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr) - } if awsSecretAccessKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr) } @@ -79,12 +73,47 @@ func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey stri return bedrockruntime.NewFromConfig(cfg), nil } -func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) { +func parseAccessInfo(params []*commonpb.KeyValuePair, confParams map[string]string) (string, string) { + // function param > env > yaml + var awsAccessKeyId, awsSecretAccessKey string + + // from function params + if isEnableVerifiInfoInParamsKey(confParams) { + for _, param := range params { + switch strings.ToLower(param.Key) { + case awsAKIdParamKey: + awsAccessKeyId = param.Value + case awsSAKParamKey: + awsSecretAccessKey = param.Value + } + } + } + + // from milvus.yaml + if awsAccessKeyId == "" { + awsAccessKeyId = confParams[awsAKIdParamKey] + } + if awsSecretAccessKey == "" { + awsSecretAccessKey = confParams[awsSAKParamKey] + } + + // from env + if awsAccessKeyId == "" { + awsAccessKeyId = os.Getenv(bedrockAccessKeyId) + } + if awsSecretAccessKey == "" { + awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr) + } + + return awsAccessKeyId, awsSecretAccessKey +} + +func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient, params map[string]string) (*BedrockEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err } - var awsAccessKeyId, awsSecretAccessKey, region, modelName string + var region, modelName string var dim int64 normalize := true @@ -97,14 +126,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche if err != nil { return nil, err } - case awsAKIdParamKey: - if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" { - awsAccessKeyId = param.Value - } - case awsSAKParamKey: - if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" { - awsSecretAccessKey = param.Value - } case regionParamKey: region = param.Value case normalizeParamKey: @@ -120,6 +141,8 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche } } + awsAccessKeyId, awsSecretAccessKey := parseAccessInfo(functionSchema.Params, params) + var client BedrockClient if c == nil { client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region) diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index a7478673a1..6f381d554b 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -71,7 +71,7 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di } switch providerName { case bedrockProvider: - return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}) + return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -151,22 +151,25 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() { {Key: normalizeParamKey, Value: "false"}, }, } - provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) + provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) s.NoError(err) s.True(provider.MaxBatch() > 0) s.Equal(provider.FieldDim(), int64(4)) + _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{awsAKIdParamKey: "mock", awsSAKParamKey: "mock"}) + s.NoError(err) + functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"} - _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) + _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) s.NoError(err) functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"} - _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) + _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) s.Error(err) // invalid dim functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel} functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} - _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) + _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) s.Error(err) } diff --git a/internal/util/function/cohere_embedding_provider.go b/internal/util/function/cohere_embedding_provider.go index fd34923b22..e59b4ea174 100644 --- a/internal/util/function/cohere_embedding_provider.go +++ b/internal/util/function/cohere_embedding_provider.go @@ -20,7 +20,6 @@ package function import ( "fmt" - "os" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -42,9 +41,6 @@ type CohereEmbeddingProvider struct { } func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) { - if apiKey == "" { - apiKey = os.Getenv(cohereAIAKEnvStr) - } if apiKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr) } @@ -57,12 +53,12 @@ func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbed return c, nil } -func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*CohereEmbeddingProvider, error) { +func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*CohereEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err } - apiKey, url := parseAKAndURL(functionSchema.Params) + apiKey, url := parseAKAndURL(functionSchema.Params, params, cohereAIAKEnvStr) var modelName string truncate := "END" for _, param := range functionSchema.Params { diff --git a/internal/util/function/cohere_embedding_provider_test.go b/internal/util/function/cohere_embedding_provider_test.go index cfc5f10bd5..50f2e9f64f 100644 --- a/internal/util/function/cohere_embedding_provider_test.go +++ b/internal/util/function/cohere_embedding_provider_test.go @@ -23,7 +23,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/suite" @@ -76,7 +75,7 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName } switch providerName { case cohereProvider: - return NewCohereEmbeddingProvider(schema, functionSchema) + return NewCohereEmbeddingProvider(schema, functionSchema, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -264,18 +263,18 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() { }, } - provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.truncate, "END") functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"}) - provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.truncate, "START") // Invalid truncateParam functionSchema.Params[2].Value = "Unknow" - _, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } @@ -293,13 +292,13 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() { }, } - provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.getInputType(InsertMode), "") s.Equal(provider.getInputType(SearchMode), "") functionSchema.Params[0].Value = "model-v3.0" - provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.getInputType(InsertMode), "search_document") s.Equal(provider.getInputType(SearchMode), "search_query") @@ -308,12 +307,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() { func (s *CohereTextEmbeddingProviderSuite) TestCreateCohereEmbeddingClient() { _, err := createCohereEmbeddingClient("", "") s.Error(err) - - os.Setenv(cohereAIAKEnvStr, "mockKey") - defer os.Unsetenv(openaiAKEnvStr) - - _, err = createCohereEmbeddingClient("", "") - s.NoError(err) } func (s *CohereTextEmbeddingProviderSuite) TestRuntimeDimNotMatch() { diff --git a/internal/util/function/common.go b/internal/util/function/common.go index fd34a0d8a8..b9b80e13bc 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -45,11 +45,12 @@ const ( // common params const ( - modelNameParamKey string = "model_name" - dimParamKey string = "dim" - embeddingURLParamKey string = "url" - apiKeyParamKey string = "api_key" - truncateParamKey string = "truncate" + modelNameParamKey string = "model_name" + dimParamKey string = "dim" + embeddingURLParamKey string = "url" + apiKeyParamKey string = "api_key" + truncateParamKey string = "truncate" + enableVerifiInfoInParamsKey string = "enableVerifiInfoInParams" ) // ali text embedding @@ -73,7 +74,7 @@ const ( const ( awsAKIdParamKey string = "aws_access_key_id" awsSAKParamKey string = "aws_secret_access_key" - regionParamKey string = "regin" + regionParamKey string = "region" normalizeParamKey string = "normalize" bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" @@ -120,11 +121,28 @@ const ( enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI" ) -const enableConfigAKAndURL string = "ENABLE_CONFIG_AK_AND_URL" +const enableVerifiInfoInParams string = "ENABLE_VERIFI_INFO_IN_PARAMS" -func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) { +func isEnableVerifiInfoInParamsKey(confParams map[string]string) bool { + enable := true + if strings.ToLower(confParams[enableVerifiInfoInParamsKey]) != "" { + // If enableVerifiInfoInParamsKey is configured in milvus.yaml, the configuration in milvus.yaml will be used. + enable, _ = strconv.ParseBool(confParams[enableVerifiInfoInParamsKey]) + } else { + // If enableVerifiInfoInParamsKey is not configured in milvus.yaml, the configuration in env will be used. + if strings.ToLower(os.Getenv(enableVerifiInfoInParams)) != "" { + enable, _ = strconv.ParseBool(confParams[enableVerifiInfoInParamsKey]) + } + } + return enable +} + +func parseAKAndURL(params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string) { + // function param > env > yaml var apiKey, url string - if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" { + + // from function params + if isEnableVerifiInfoInParamsKey(confParams) { for _, param := range params { switch strings.ToLower(param.Key) { case apiKeyParamKey: @@ -134,6 +152,20 @@ func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) { } } } + + // from milvus.yaml + if apiKey == "" { + apiKey = confParams[apiKeyParamKey] + } + + if url == "" { + url = confParams[embeddingURLParamKey] + } + + // from env, url doesn't support configuration in in env + if apiKey == "" { + url = os.Getenv(apiKeyEnv) + } return apiKey, url } diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 3906a342f7..572c7ab0b0 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -36,6 +36,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) func TestFunctionExecutor(t *testing.T) { @@ -46,6 +47,10 @@ type FunctionExecutorSuite struct { suite.Suite } +func (s *FunctionExecutorSuite) SetupTest() { + paramtable.Init() +} + func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema { return &schemapb.CollectionSchema{ Name: "test", diff --git a/internal/util/function/models/utils/embedding_util.go b/internal/util/function/models/utils/embedding_util.go index df854ac8bb..d786842727 100644 --- a/internal/util/function/models/utils/embedding_util.go +++ b/internal/util/function/models/utils/embedding_util.go @@ -36,11 +36,11 @@ func send(req *http.Request) ([]byte, error) { body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("Call service faild, read response failed, errs:[%v]", err) + return nil, fmt.Errorf("Call service failed, read response failed, errs:[%v]", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Call service faild, errs:[%s, %s]", resp.Status, body) + return nil, fmt.Errorf("Call service failed, errs:[%s, %s]", resp.Status, body) } return body, nil } diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index 8dda40b53e..c5beaee409 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -41,9 +41,6 @@ type OpenAIEmbeddingProvider struct { } func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { - if apiKey == "" { - apiKey = os.Getenv(openaiAKEnvStr) - } if apiKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr) } @@ -56,16 +53,16 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbed return c, nil } -func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) { - if apiKey == "" { - apiKey = os.Getenv(azureOpenaiAKEnvStr) - } +func createAzureOpenAIEmbeddingClient(apiKey string, url string, resourceName string) (*openai.AzureOpenAIEmbeddingClient, error) { if apiKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr) } if url == "" { - if resourceName := os.Getenv(azureOpenaiResourceName); resourceName != "" { + if resourceName == "" { + resourceName = os.Getenv(azureOpenaiResourceName) + } + if resourceName != "" { url = fmt.Sprintf("https://%s.openai.azure.com", resourceName) } } @@ -76,15 +73,14 @@ func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureO return c, nil } -func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, isAzure bool) (*OpenAIEmbeddingProvider, error) { +func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, isAzure bool) (*OpenAIEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err } - apiKey, url := parseAKAndURL(functionSchema.Params) + var modelName, user string var dim int64 - for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { case modelNameParamKey: @@ -102,12 +98,15 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem var c openai.OpenAIEmbeddingInterface if !isAzure { + apiKey, url := parseAKAndURL(functionSchema.Params, params, openaiAKEnvStr) c, err = createOpenAIEmbeddingClient(apiKey, url) if err != nil { return nil, err } } else { - c, err = createAzureOpenAIEmbeddingClient(apiKey, url) + apiKey, url := parseAKAndURL(functionSchema.Params, params, azureOpenaiAKEnvStr) + resourceName := params["resource_name"] + c, err = createAzureOpenAIEmbeddingClient(apiKey, url, resourceName) if err != nil { return nil, err } @@ -125,12 +124,12 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem return &provider, nil } -func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) { - return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, false) +func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*OpenAIEmbeddingProvider, error) { + return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, false) } -func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) { - return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, true) +func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*OpenAIEmbeddingProvider, error) { + return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, true) } func (provider *OpenAIEmbeddingProvider) MaxBatch() int { diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 28c94901fd..2432bee688 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -77,9 +77,9 @@ func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName } switch providerName { case openAIProvider: - return NewOpenAIEmbeddingProvider(schema, functionSchema) + return NewOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{}) case azureOpenAIProvider: - return NewAzureOpenAIEmbeddingProvider(schema, functionSchema) + return NewAzureOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -181,27 +181,15 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { func (s *OpenAITextEmbeddingProviderSuite) TestCreateOpenAIEmbeddingClient() { _, err := createOpenAIEmbeddingClient("", "") s.Error(err) - - os.Setenv(openaiAKEnvStr, "mockKey") - defer os.Unsetenv(openaiAKEnvStr) - - _, err = createOpenAIEmbeddingClient("", "") - s.NoError(err) } func (s *OpenAITextEmbeddingProviderSuite) TestCreateAzureOpenAIEmbeddingClient() { - _, err := createAzureOpenAIEmbeddingClient("", "") - s.Error(err) - - os.Setenv(azureOpenaiAKEnvStr, "mockKey") - defer os.Unsetenv(azureOpenaiAKEnvStr) - - _, err = createAzureOpenAIEmbeddingClient("", "") + _, err := createAzureOpenAIEmbeddingClient("", "", "") s.Error(err) os.Setenv(azureOpenaiResourceName, "mockResource") defer os.Unsetenv(azureOpenaiResourceName) - _, err = createAzureOpenAIEmbeddingClient("", "") + _, err = createAzureOpenAIEmbeddingClient("mock", "", "") s.NoError(err) } diff --git a/internal/util/function/siliconflow_embedding_provider.go b/internal/util/function/siliconflow_embedding_provider.go index 32ee5ae3cf..ba25a754a1 100644 --- a/internal/util/function/siliconflow_embedding_provider.go +++ b/internal/util/function/siliconflow_embedding_provider.go @@ -20,7 +20,6 @@ package function import ( "fmt" - "os" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -40,9 +39,6 @@ type SiliconflowEmbeddingProvider struct { } func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.SiliconflowEmbedding, error) { - if apiKey == "" { - apiKey = os.Getenv(siliconflowAKEnvStr) - } if apiKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", siliconflowAKEnvStr) } @@ -55,12 +51,12 @@ func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.S return c, nil } -func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*SiliconflowEmbeddingProvider, error) { +func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*SiliconflowEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err } - apiKey, url := parseAKAndURL(functionSchema.Params) + apiKey, url := parseAKAndURL(functionSchema.Params, params, siliconflowAKEnvStr) var modelName string for _, param := range functionSchema.Params { diff --git a/internal/util/function/siliconflow_embedding_provider_test.go b/internal/util/function/siliconflow_embedding_provider_test.go index cb278df669..7ec4b204c7 100644 --- a/internal/util/function/siliconflow_embedding_provider_test.go +++ b/internal/util/function/siliconflow_embedding_provider_test.go @@ -23,7 +23,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/suite" @@ -76,7 +75,7 @@ func createSiliconflowProvider(url string, schema *schemapb.FieldSchema, provide } switch providerName { case siliconflowProvider: - return NewSiliconflowEmbeddingProvider(schema, functionSchema) + return NewSiliconflowEmbeddingProvider(schema, functionSchema, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -171,11 +170,6 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() { func (s *SiliconflowTextEmbeddingProviderSuite) TestCreateSiliconflowEmbeddingClient() { _, err := createSiliconflowEmbeddingClient("", "") s.Error(err) - - os.Setenv(siliconflowAKEnvStr, "mockKey") - defer os.Unsetenv(siliconflowAKEnvStr) - _, err = createSiliconflowEmbeddingClient("", "") - s.NoError(err) } func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvider() { @@ -192,7 +186,7 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi {Key: embeddingURLParamKey, Value: "mock"}, }, } - provider, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.FieldDim(), int64(4)) s.True(provider.MaxBatch() > 0) diff --git a/internal/util/function/tei_embedding_provider.go b/internal/util/function/tei_embedding_provider.go index 26cb719f8e..131e9c1839 100644 --- a/internal/util/function/tei_embedding_provider.go +++ b/internal/util/function/tei_embedding_provider.go @@ -52,7 +52,7 @@ func createTEIEmbeddingClient(apiKey string, endpoint string) (*tei.TEIEmbedding return tei.NewTEIEmbeddingClient(apiKey, endpoint) } -func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*TeiEmbeddingProvider, error) { +func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*TeiEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err diff --git a/internal/util/function/tei_embedding_provider_test.go b/internal/util/function/tei_embedding_provider_test.go index 79e1490180..5e31174df7 100644 --- a/internal/util/function/tei_embedding_provider_test.go +++ b/internal/util/function/tei_embedding_provider_test.go @@ -76,7 +76,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st } switch providerName { case teiProvider: - return NewTEIEmbeddingProvider(schema, functionSchema) + return NewTEIEmbeddingProvider(schema, functionSchema, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -172,7 +172,7 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() { {Key: endpointParamKey, Value: "http://mymock.com"}, }, } - provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.FieldDim(), int64(4)) s.True(provider.MaxBatch() == 32*5) @@ -180,35 +180,35 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() { // Invalid truncate { functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "Invalid"}) - _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } // Invalid truncationDirection { functionSchema.Params[2] = &commonpb.KeyValuePair{Key: truncateParamKey, Value: "true"} functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Invalid"}) - _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } // truncationDirection { functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Left"} - _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) } // Invalid max batch { functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "Invalid"}) - _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } // Valid max batch { functionSchema.Params[4] = &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "128"} - pv, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema) + pv, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.True(pv.MaxBatch() == 128*5) } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index 510464796d..6359e228e9 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) const ( @@ -45,6 +46,15 @@ const ( teiProvider string = "tei" ) +func hasEmptyString(texts []string) bool { + for _, text := range texts { + if text == "" { + return true + } + } + return false +} + func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error { if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) { return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field") @@ -85,25 +95,26 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s var embP textEmbeddingProvider var newProviderErr error + conf := paramtable.Get().FunctionCfg.GetTextEmbeddingProviderConfig(base.provider) switch base.provider { case openAIProvider: - embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf) case azureOpenAIProvider: - embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf) case bedrockProvider: - embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil) + embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf) case aliDashScopeProvider: - embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema, conf) case vertexAIProvider: - embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil) + embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf) case voyageAIProvider: - embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema, conf) case cohereProvider: - embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema, conf) case siliconflowProvider: - embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema, conf) case teiProvider: - embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema) + embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema, conf) default: return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider) } @@ -213,10 +224,16 @@ func (runner *TextEmbeddingFunction) ProcessInsert(ctx context.Context, inputs [ if texts == nil { return nil, fmt.Errorf("Input texts is empty") } + + // make sure all texts are not empty + if hasEmptyString(texts) { + return nil, fmt.Errorf("There is an empty string in the input data, TextEmbedding function does not support empty text") + } numRows := len(texts) if numRows > runner.MaxBatch() { return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } + embds, err := runner.embProvider.CallEmbedding(texts, InsertMode) if err != nil { return nil, err @@ -231,6 +248,10 @@ func (runner *TextEmbeddingFunction) ProcessSearch(ctx context.Context, placehol if numRows > runner.MaxBatch() { return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } + // make sure all texts are not empty + if hasEmptyString(texts) { + return nil, fmt.Errorf("There is an empty string in the queries, TextEmbedding function does not support empty text") + } embds, err := runner.embProvider.CallEmbedding(texts, SearchMode) if err != nil { return nil, err @@ -257,6 +278,11 @@ func (runner *TextEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldDat return nil, fmt.Errorf("Input texts is empty") } + // make sure all texts are not empty + // In storage.FieldData, null is also stored as an empty string + if hasEmptyString(texts) { + return nil, fmt.Errorf("There is an empty string in the input data, TextEmbedding function does not support empty text") + } embds, err := runner.embProvider.CallEmbedding(texts, InsertMode) if err != nil { return nil, err diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go index a34131679d..a60fd4d61b 100644 --- a/internal/util/function/text_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) func TestTextEmbeddingFunction(t *testing.T) { @@ -43,6 +44,7 @@ type TextEmbeddingFunctionSuite struct { } func (s *TextEmbeddingFunctionSuite) SetupTest() { + paramtable.Init() s.schema = &schemapb.CollectionSchema{ Name: "test", Fields: []*schemapb.FieldSchema{ @@ -272,7 +274,29 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_StringData{ StringData: &schemapb.StringArray{ - Data: strings.Split(strings.Repeat("Element,", 1000), ","), + Data: strings.Split(strings.Repeat("Element,", 1000), ",")[:999], + }, + }, + }, + }, + } + data = append(data, &f) + _, err := runner.ProcessInsert(context.Background(), data) + s.Error(err) + } + + // empty string + { + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: strings.Split(strings.Repeat("Element,", 10), ","), }, }, }, @@ -610,7 +634,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() { Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_StringData{ StringData: &schemapb.StringArray{ - Data: strings.Split(strings.Repeat("Element,", 1000), ","), + Data: strings.Split(strings.Repeat("Element,", 1000), ",")[0:999], }, }, }, @@ -635,7 +659,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() { Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_StringData{ StringData: &schemapb.StringArray{ - Data: strings.Split(strings.Repeat("Element,", 100), ","), + Data: strings.Split(strings.Repeat("Element,", 100), ",")[:99], }, }, }, @@ -777,7 +801,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() { Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_StringData{ StringData: &schemapb.StringArray{ - Data: strings.Split(strings.Repeat("Element,", 100), ","), + Data: strings.Split(strings.Repeat("Element,", 100), ",")[:99], }, }, }, @@ -791,6 +815,31 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() { _, err = runner.ProcessSearch(context.Background(), &placeholderGroup) s.NoError(err) } + + // empty text + { + f := &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: strings.Split(strings.Repeat("Element,", 100), ","), + }, + }, + }, + }, + } + + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f) + s.NoError(err) + placeholderGroup := commonpb.PlaceholderGroup{} + proto.Unmarshal(placeholderGroupBytes, &placeholderGroup) + _, err = runner.ProcessSearch(context.Background(), &placeholderGroup) + s.Error(err) + } } func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() { @@ -834,6 +883,15 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() { _, err := runner.ProcessBulkInsert(input) s.Error(err) } + + // empty texts + { + input := []storage.FieldData{data.Data[101]} + err := input[0].AppendRow("") + s.NoError(err) + _, err = runner.ProcessBulkInsert(input) + s.Error(err) + } } func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() { diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go index ab6566fc37..9b4abb2b9d 100644 --- a/internal/util/function/vertexai_embedding_provider.go +++ b/internal/util/function/vertexai_embedding_provider.go @@ -30,24 +30,40 @@ import ( ) type vertexAIJsonKey struct { - jsonKey []byte - once sync.Once - initErr error + mu sync.Mutex + filePath string + jsonKey []byte } var vtxKey vertexAIJsonKey -func getVertexAIJsonKey() ([]byte, error) { - vtxKey.once.Do(func() { - jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv) - jsonKey, err := os.ReadFile(jsonKeyPath) - if err != nil { - vtxKey.initErr = fmt.Errorf("Vertexai: read service account json file failed, %v", err) - return - } - vtxKey.jsonKey = jsonKey - }) - return vtxKey.jsonKey, vtxKey.initErr +func getVertexAIJsonKey(credentialsFilePath string) ([]byte, error) { + vtxKey.mu.Lock() + defer vtxKey.mu.Unlock() + + var jsonKeyPath string + if credentialsFilePath == "" { + jsonKeyPath = os.Getenv(vertexServiceAccountJSONEnv) + } else { + jsonKeyPath = credentialsFilePath + } + if jsonKeyPath == "" { + return nil, fmt.Errorf("VetexAI credentials file path is empty") + } + + if vtxKey.filePath == jsonKeyPath { + // The file path remains unchanged, using the data in the cache + return vtxKey.jsonKey, nil + } + + jsonKey, err := os.ReadFile(jsonKeyPath) + if err != nil { + return nil, fmt.Errorf("Vertexai: read credentials file failed, %v", err) + } + + vtxKey.jsonKey = jsonKey + vtxKey.filePath = jsonKeyPath + return vtxKey.jsonKey, nil } const ( @@ -68,8 +84,8 @@ type VertexAIEmbeddingProvider struct { timeoutSec int64 } -func createVertexAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) { - jsonKey, err := getVertexAIJsonKey() +func createVertexAIEmbeddingClient(url string, credentialsFilePath string) (*vertexai.VertexAIEmbedding, error) { + jsonKey, err := getVertexAIJsonKey(credentialsFilePath) if err != nil { return nil, err } @@ -77,7 +93,7 @@ func createVertexAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, err return c, nil } -func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertexAIEmbeddingProvider, error) { +func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding, params map[string]string) (*VertexAIEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err @@ -112,10 +128,13 @@ func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch location = "us-central1" } - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName) + url := params["url"] + if url == "" { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName) + } var client *vertexai.VertexAIEmbedding if c == nil { - client, err = createVertexAIEmbeddingClient(url) + client, err = createVertexAIEmbeddingClient(url, params["credentials_file_path"]) if err != nil { return nil, err } diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go index 4018db58eb..988531cb42 100644 --- a/internal/util/function/vertexai_embedding_provider_test.go +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -76,7 +76,7 @@ func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbed }, } mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token") - return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient) + return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient, map[string]string{}) } func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() { @@ -177,7 +177,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() { os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath") defer os.Unsetenv(vertexServiceAccountJSONEnv) - _, err := getVertexAIJsonKey() + _, err := getVertexAIJsonKey("") s.Error(err) } @@ -198,7 +198,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() { mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token") { - provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient) + provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) s.NoError(err) s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT") s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY") @@ -206,7 +206,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() { { functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival}) - provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient) + provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) s.NoError(err) s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT") s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY") @@ -214,7 +214,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() { { functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS} - provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient) + provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) s.NoError(err) s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY") s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY") @@ -224,7 +224,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() { func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() { os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath") defer os.Unsetenv(vertexServiceAccountJSONEnv) - _, err := createVertexAIEmbeddingClient("https://mock_url.com") + _, err := createVertexAIEmbeddingClient("https://mock_url.com", "") s.Error(err) } @@ -243,7 +243,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() }, } mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token") - provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient) + provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) s.NoError(err) s.True(provider.MaxBatch() > 0) s.Equal(provider.FieldDim(), int64(4)) diff --git a/internal/util/function/voyageai_embedding_provider.go b/internal/util/function/voyageai_embedding_provider.go index 3e3240712d..ad6d3b4f13 100644 --- a/internal/util/function/voyageai_embedding_provider.go +++ b/internal/util/function/voyageai_embedding_provider.go @@ -20,7 +20,6 @@ package function import ( "fmt" - "os" "strconv" "strings" @@ -44,9 +43,6 @@ type VoyageAIEmbeddingProvider struct { } func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) { - if apiKey == "" { - apiKey = os.Getenv(voyageAIAKEnvStr) - } if apiKey == "" { return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr) } @@ -59,12 +55,12 @@ func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageA return c, nil } -func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*VoyageAIEmbeddingProvider, error) { +func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*VoyageAIEmbeddingProvider, error) { fieldDim, err := typeutil.GetDim(fieldSchema) if err != nil { return nil, err } - apiKey, url := parseAKAndURL(functionSchema.Params) + apiKey, url := parseAKAndURL(functionSchema.Params, params, voyageAIAKEnvStr) var modelName string dim := int64(0) truncate := false diff --git a/internal/util/function/voyageai_embedding_provider_test.go b/internal/util/function/voyageai_embedding_provider_test.go index 0f8ccab7e4..84c594bc70 100644 --- a/internal/util/function/voyageai_embedding_provider_test.go +++ b/internal/util/function/voyageai_embedding_provider_test.go @@ -23,7 +23,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/suite" @@ -77,7 +76,7 @@ func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerNa } switch providerName { case voyageAIProvider: - return NewVoyageAIEmbeddingProvider(schema, functionSchema) + return NewVoyageAIEmbeddingProvider(schema, functionSchema, map[string]string{}) default: return nil, fmt.Errorf("Unknow provider") } @@ -282,11 +281,6 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() { func (s *VoyageAITextEmbeddingProviderSuite) TestCreateVoyageAIEmbeddingClient() { _, err := createVoyageAIEmbeddingClient("", "") s.Error(err) - - os.Setenv(voyageAIAKEnvStr, "mockKey") - defer os.Unsetenv(voyageAIAKEnvStr) - _, err = createVoyageAIEmbeddingClient("", "") - s.NoError(err) } func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() { @@ -305,7 +299,7 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() {Key: truncationParamKey, Value: "true"}, }, } - provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema) + provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.NoError(err) s.Equal(provider.FieldDim(), int64(1024)) s.True(provider.MaxBatch() > 0) @@ -313,7 +307,7 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() // Invalid truncation { functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"} - _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"} } @@ -321,14 +315,14 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() // Invalid dim { functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"} - _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } // Invalid dim type { functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"} - _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema) + _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) s.Error(err) } } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index dab6946511..1a13c599f3 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -84,6 +84,7 @@ type ComponentParam struct { RoleCfg roleConfig RbacConfig rbacConfig StreamingCfg streamingConfig + FunctionCfg functionConfig InternalTLSCfg InternalTLSConfig @@ -138,6 +139,7 @@ func (p *ComponentParam) init(bt *BaseTable) { p.RbacConfig.init(bt) p.GpuConfig.init(bt) p.KnowhereConfig.init(bt) + p.FunctionCfg.init(bt) p.InternalTLSCfg.Init(bt) diff --git a/pkg/util/paramtable/function_param.go b/pkg/util/paramtable/function_param.go new file mode 100644 index 0000000000..2e7650dec6 --- /dev/null +++ b/pkg/util/paramtable/function_param.go @@ -0,0 +1,105 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package paramtable + +import ( + "strings" +) + +type functionConfig struct { + TextEmbeddingEnableVerifiInfoInParams ParamItem `refreshable:"true"` + TextEmbeddingProviders ParamGroup `refreshable:"true"` +} + +func (p *functionConfig) init(base *BaseTable) { + p.TextEmbeddingEnableVerifiInfoInParams = ParamItem{ + Key: "function.textEmbedding.enableVerifiInfoInParams", + Version: "2.6.0", + DefaultValue: "true", + Export: true, + Doc: "Controls whether to allow configuration of apikey and model service url on function parameters", + } + p.TextEmbeddingEnableVerifiInfoInParams.Init(base.mgr) + + p.TextEmbeddingProviders = ParamGroup{ + KeyPrefix: "function.textEmbedding.providers.", + Version: "2.6.0", + Export: true, + DocFunc: func(key string) string { + switch key { + case "tei.enable": + return "Whether to enable TEI model service" + case "azure_openai.api_key": + return "Your azure openai embedding url, Default is the official embedding url" + case "azure_openai.url": + return "Your azure openai api key" + case "azure_openai.resource_name": + return "Your azure openai resource name" + case "openai.api_key": + return "Your openai embedding url, Default is the official embedding url" + case "openai.url": + return "Your openai api key" + case "dashscope.api_key": + return "Your dashscope embedding url, Default is the official embedding url" + case "dashscope.url": + return "Your dashscope api key" + case "cohere.api_key": + return "Your cohere embedding url, Default is the official embedding url" + case "cohere.url": + return "Your cohere api key" + case "voyageai.api_key": + return "Your voyageai embedding url, Default is the official embedding url" + case "voyageai.url": + return "Your voyageai api key" + case "siliconflow.url": + return "Your siliconflow embedding url, Default is the official embedding url" + case "siliconflow.api_key": + return "Your siliconflow api key" + case "bedrock.aws_access_key_id": + return "Your aws_access_key_id" + case "bedrock.aws_secret_access_key": + return "Your aws_secret_access_key" + case "vertexai.url": + return "Your VertexAI embedding url" + case "vertexai.credentials_file_path": + return "Path to your google application credentials, change the file path to refresh the configuration" + default: + return "" + } + }, + } + p.TextEmbeddingProviders.Init(base.mgr) +} + +const ( + textEmbeddingKey string = "textEmbedding" +) + +func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map[string]string { + matchedParam := make(map[string]string) + + params := p.TextEmbeddingProviders.GetValue() + prefix := providerName + "." + + for k, v := range params { + if strings.HasPrefix(k, prefix) { + matchedParam[strings.TrimPrefix(k, prefix)] = v + } + } + matchedParam["enableVerifiInfoInParams"] = p.TextEmbeddingEnableVerifiInfoInParams.GetValue() + return matchedParam +} diff --git a/pkg/util/paramtable/function_param_test.go b/pkg/util/paramtable/function_param_test.go new file mode 100644 index 0000000000..cf50a1c2ec --- /dev/null +++ b/pkg/util/paramtable/function_param_test.go @@ -0,0 +1,65 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package paramtable + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFunctionConfig(t *testing.T) { + params := ComponentParam{} + params.Init(NewBaseTable(SkipRemote(true))) + cfg := ¶ms.FunctionCfg + notExistProvider := cfg.GetTextEmbeddingProviderConfig("notExist") + + // Only has enableVerifiInfoInParams config + assert.Equal(t, len(notExistProvider), 1) + + teiConf := cfg.GetTextEmbeddingProviderConfig("tei") + assert.Equal(t, teiConf["enable"], "true") + assert.Equal(t, teiConf["enableVerifiInfoInParams"], "true") + openaiConf := cfg.GetTextEmbeddingProviderConfig("openai") + assert.Equal(t, openaiConf["api_key"], "") + assert.Equal(t, openaiConf["url"], "") + assert.Equal(t, openaiConf["enableVerifiInfoInParams"], "true") + + keys := []string{ + "tei.enable", + "azure_openai.api_key", + "azure_openai.url", + "azure_openai.resource_name", + "openai.api_key", + "openai.url", + "dashscope.api_key", + "dashscope.url", + "cohere.api_key", + "cohere.url", + "voyageai.api_key", + "voyageai.url", + "siliconflow.url", + "siliconflow.api_key", + "bedrock.aws_access_key_id", + "bedrock.aws_secret_access_key", + "vertexai.url", + "vertexai.credentials_file_path", + } + for _, key := range keys { + assert.True(t, cfg.TextEmbeddingProviders.GetDoc(key) != "") + } +} diff --git a/tests/python_client/testcases/test_text_embedding_function_e2e.py b/tests/python_client/testcases/test_text_embedding_function_e2e.py index 945af63797..723796d878 100644 --- a/tests/python_client/testcases/test_text_embedding_function_e2e.py +++ b/tests/python_client/testcases/test_text_embedding_function_e2e.py @@ -346,6 +346,7 @@ class TestInsertWithTextEmbeddingNegative(TestcaseBase): """ @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("not support empty document now") def test_insert_with_text_embedding_empty_document(self, tei_endpoint): """ target: test insert data with empty document @@ -389,6 +390,7 @@ class TestInsertWithTextEmbeddingNegative(TestcaseBase): assert collection_w.num_entities == 0 @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("TODO") def test_insert_with_text_embedding_long_document(self, tei_endpoint): """ target: test insert data with long document @@ -663,6 +665,7 @@ class TestSearchWithTextEmbeddingNegative(TestcaseBase): @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("query", ["empty_query", "long_query"]) + @pytest.mark.skip("not support empty query now") def test_search_with_text_embedding_negative_query(self, query, tei_endpoint): """ target: test search with empty query or long query