enhance: refactor embedding credentials manager (#41442)

https://github.com/milvus-io/milvus/issues/35856

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-04-24 14:34:38 +08:00 committed by GitHub
parent dbe54c2df8
commit e56adc121b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 733 additions and 294 deletions

View File

@ -359,6 +359,20 @@ func WriteYaml(w io.Writer) {
name: "knowhere", name: "knowhere",
header: ` header: `
# Any configuration related to the knowhere vector search engine`, # Any configuration related to the knowhere vector search engine`,
},
{
name: "credential",
header: `
# credential configs, support apikey, AKSK, gcp credential
# examples:
# credential:
# your_apikey_crendential_name:
# apikey: # Your apikey credential
# your_aksk_crendential_name:
# access_key_id:
# secret_access_key:
# your_gcp_credential_name:
# credential_json:`,
}, },
{ {
name: "function", name: "function",

View File

@ -1219,32 +1219,50 @@ knowhere:
search: search:
beam_width_ratio: 4 # Ratio between the maximum number of IO requests per search iteration and CPU number beam_width_ratio: 4 # Ratio between the maximum number of IO requests per search iteration and CPU number
# credential configs, support apikey, AKSK, gcp credential
# examples:
# credential:
# your_apikey_crendential_name:
# apikey: # Your apikey credential
# your_aksk_crendential_name:
# access_key_id:
# secret_access_key:
# your_gcp_credential_name:
# credential_json:
credential:
aksk1:
access_key_id: # Your access_key_id
secret_access_key: # Your secret_access_key
apikey1:
apikey: # Your apikey credential
gcp1:
credential_json: # base64 based gcp credential data
# Any configuration related to functions # Any configuration related to functions
function: function:
textEmbedding: textEmbedding:
enableVerifiInfoInParams: true # Controls whether to allow configuration of apikey and model service url on function parameters
providers: providers:
azure_openai: azure_openai:
api_key: # Your azure openai embedding url, Default is the official embedding url credential: # The name in the crendential configuration item
resource_name: # Your azure openai resource name resource_name: # Your azure openai resource name
url: # Your azure openai api key url: # Your azure openai embedding url, Default is the official embedding url
bedrock: bedrock:
aws_access_key_id: # Your aws_access_key_id credential: # The name in the crendential configuration item
aws_secret_access_key: # Your aws_secret_access_key
cohere: cohere:
api_key: # Your cohere embedding url, Default is the official embedding url credential: # The name in the crendential configuration item
url: # Your cohere api key url: # Your cohere embedding url, Default is the official embedding url
dashscope: dashscope:
api_key: # Your dashscope embedding url, Default is the official embedding url credential: # The name in the crendential configuration item
url: # Your dashscope api key url: # Your dashscope embedding url, Default is the official embedding url
openai: openai:
api_key: # Your openai embedding url, Default is the official embedding url credential: # The name in the crendential configuration item
url: # Your openai api key url: # Your openai embedding url, Default is the official embedding url
siliconflow: siliconflow:
api_key: # Your siliconflow api key credential: # The name in the crendential configuration item
url: # Your siliconflow embedding url, Default is the official embedding url url: # Your siliconflow embedding url, Default is the official embedding url
tei: tei:
credential: # The name in the crendential configuration item
enable: true # Whether to enable TEI model service enable: true # Whether to enable TEI model service
vertexai: vertexai:
credentials_file_path: # Path to your google application credentials, change the file path to refresh the configuration credential: # The name in the crendential configuration item
url: # Your VertexAI embedding url url: # Your VertexAI embedding url

View File

@ -437,6 +437,13 @@ func (s *SchedulerSuite) TestScheduler_ImportFile() {
} }
func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) (*conc.Future[struct{}], error) { s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) (*conc.Future[struct{}], error) {
future := conc.Go(func() (struct{}, error) { future := conc.Go(func() (struct{}, error) {
return struct{}{}, nil return struct{}{}, nil
@ -445,6 +452,11 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
}) })
ts := function.CreateOpenAIEmbeddingServer() ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
schema := &schemapb.CollectionSchema{ schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{ Fields: []*schemapb.FieldSchema{
{ {
@ -484,8 +496,7 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"}, {Key: "dim", Value: "4"},
}, },
}, },

View File

@ -314,8 +314,20 @@ func TestMaxInsertSize(t *testing.T) {
} }
func TestInsertTask_Function(t *testing.T) { func TestInsertTask_Function(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := function.CreateOpenAIEmbeddingServer() ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
data := []*schemapb.FieldData{} data := []*schemapb.FieldData{}
f := schemapb.FieldData{ f := schemapb.FieldData{
Type: schemapb.DataType_VarChar, Type: schemapb.DataType_VarChar,
@ -365,8 +377,7 @@ func TestInsertTask_Function(t *testing.T) {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"}, {Key: "dim", Value: "4"},
}, },
}, },

View File

@ -974,8 +974,19 @@ func TestSearchTask_PreExecute(t *testing.T) {
} }
func TestSearchTask_WithFunctions(t *testing.T) { func TestSearchTask_WithFunctions(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := function.CreateOpenAIEmbeddingServer() ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
collectionName := "TestSearchTask_function" collectionName := "TestSearchTask_function"
schema := &schemapb.CollectionSchema{ schema := &schemapb.CollectionSchema{
Name: collectionName, Name: collectionName,
@ -1016,8 +1027,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"}, {Key: "dim", Value: "4"},
}, },
}, },
@ -1031,8 +1041,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"}, {Key: "dim", Value: "4"},
}, },
}, },

View File

@ -1299,8 +1299,19 @@ func TestCreateCollectionTask(t *testing.T) {
}) })
t.Run("collection with embedding function ", func(t *testing.T) { t.Run("collection with embedding function ", func(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := function.CreateOpenAIEmbeddingServer() ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
schema.Functions = []*schemapb.FunctionSchema{ schema.Functions = []*schemapb.FunctionSchema{
{ {
Name: "test", Name: "test",
@ -1310,8 +1321,7 @@ func TestCreateCollectionTask(t *testing.T) {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "128"}, {Key: "dim", Value: "128"},
}, },
}, },

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/testutils" "github.com/milvus-io/milvus/pkg/v2/util/testutils"
) )
@ -367,8 +368,20 @@ func TestUpsertTaskForReplicate(t *testing.T) {
} }
func TestUpsertTask_Function(t *testing.T) { func TestUpsertTask_Function(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := function.CreateOpenAIEmbeddingServer() ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
data := []*schemapb.FieldData{} data := []*schemapb.FieldData{}
f1 := schemapb.FieldData{ f1 := schemapb.FieldData{
Type: schemapb.DataType_Int64, Type: schemapb.DataType_Int64,
@ -434,8 +447,7 @@ func TestUpsertTask_Function(t *testing.T) {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"}, {Key: "dim", Value: "4"},
}, },
}, },

View File

@ -2851,8 +2851,19 @@ func TestValidateFunction(t *testing.T) {
func TestValidateModelFunction(t *testing.T) { func TestValidateModelFunction(t *testing.T) {
t.Run("Valid model function schema", func(t *testing.T) { t.Run("Valid model function schema", func(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := function.CreateOpenAIEmbeddingServer() ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
schema := &schemapb.CollectionSchema{ schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{ Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}}, {Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
@ -2879,8 +2890,7 @@ func TestValidateModelFunction(t *testing.T) {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"}, {Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"}, {Key: "credential", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"}, {Key: "dim", Value: "4"},
}, },
}, },

View File

@ -0,0 +1,84 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package credentials
import (
"encoding/base64"
"fmt"
)
const (
APIKey string = "apikey"
AccessKeyId string = "access_key_id"
SecretAccessKey string = "secret_access_key"
// #nosec G101
CredentialJSON string = "credential_json"
)
// The current version only supports plain text, and cipher text will be supported later.
type CredentialsManager struct {
// key formats:
// {credentialName}.api_key
// {credentialName}.access_key_id
// {credentialName}.secret_access_key
// {credentialName}.credential_json
confMap map[string]string
}
func NewCredentialsManager(conf map[string]string) *CredentialsManager {
return &CredentialsManager{conf}
}
func (c *CredentialsManager) GetAPIKeyCredential(name string) (string, error) {
k := name + "." + APIKey
apikey, exist := c.confMap[k]
if !exist {
return "", fmt.Errorf("%s is not a apikey crediential, can not find key: %s", name, k)
}
return apikey, nil
}
func (c *CredentialsManager) GetAKSKCredential(name string) (string, string, error) {
IdKey := name + "." + AccessKeyId
accessKeyId, exist := c.confMap[IdKey]
if !exist {
return "", "", fmt.Errorf("%s is not a aksk crediential, can not find key: %s", name, IdKey)
}
AccessKey := name + "." + SecretAccessKey
secretAccessKey, exist := c.confMap[AccessKey]
if !exist {
return "", "", fmt.Errorf("%s is not a aksk crediential, can not find key: %s", name, AccessKey)
}
return accessKeyId, secretAccessKey, nil
}
func (c *CredentialsManager) GetGcpCredential(name string) ([]byte, error) {
k := name + "." + CredentialJSON
jsonByte, exist := c.confMap[k]
if !exist {
return nil, fmt.Errorf("%s is not a gcp crediential, can not find key: %s ", name, k)
}
decode, err := base64.StdEncoding.DecodeString(jsonByte)
if err != nil {
return nil, fmt.Errorf("Parse gcp credential:%s faild, err: %s", name, err)
}
return decode, nil
}

View File

@ -23,6 +23,7 @@ import (
"strings" "strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/ali" "github.com/milvus-io/milvus/internal/util/function/models/ali"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -41,7 +42,7 @@ type AliEmbeddingProvider struct {
func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
if apiKey == "" { if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
} }
if url == "" { if url == "" {
@ -51,12 +52,15 @@ func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbed
return c, nil return c, nil
} }
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*AliEmbeddingProvider, error) { func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*AliEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
apiKey, url := parseAKAndURL(functionSchema.Params, params, dashscopeAKEnvStr) apiKey, url, err := parseAKAndURL(credentials, functionSchema.Params, params, dashscopeAKEnvStr)
if err != nil {
return nil, err
}
var modelName string var modelName string
var dim int64 var dim int64

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/ali" "github.com/milvus-io/milvus/internal/util/function/models/ali"
) )
@ -69,14 +70,13 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: embeddingURLParamKey, Value: url}, {Key: credentialParamKey, Value: "mock"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
}, },
} }
switch providerName { switch providerName {
case aliDashScopeProvider: case aliDashScopeProvider:
return NewAliDashScopeEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewAliDashScopeEmbeddingProvider(schema, functionSchema, map[string]string{embeddingURLParamKey: url}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }
@ -181,12 +181,11 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
}, },
} }
// invalid dim // invalid dim
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} functionSchema.Params[1] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }

View File

@ -33,6 +33,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
milvusCredentials "github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -54,10 +55,10 @@ type BedrockEmbeddingProvider struct {
func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) { func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) {
if awsAccessKeyId == "" { if awsAccessKeyId == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId)
} }
if awsSecretAccessKey == "" { if awsSecretAccessKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr)
} }
if region == "" { if region == "" {
return nil, errors.New("Missing AWS Service region. Please pass `region` param.") return nil, errors.New("Missing AWS Service region. Please pass `region` param.")
@ -74,28 +75,29 @@ func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey stri
return bedrockruntime.NewFromConfig(cfg), nil return bedrockruntime.NewFromConfig(cfg), nil
} }
func parseAccessInfo(params []*commonpb.KeyValuePair, confParams map[string]string) (string, string) { func parseAKSKInfo(credentials *milvusCredentials.CredentialsManager, params []*commonpb.KeyValuePair, confParams map[string]string) (string, string, error) {
// function param > env > yaml // function param > yaml > env
var awsAccessKeyId, awsSecretAccessKey string var awsAccessKeyId, awsSecretAccessKey string
var err error
// from function params
if isEnableVerifiInfoInParamsKey(confParams) {
for _, param := range params { for _, param := range params {
switch strings.ToLower(param.Key) { switch strings.ToLower(param.Key) {
case awsAKIdParamKey: case credentialParamKey:
awsAccessKeyId = param.Value credentialName := param.Value
case awsSAKParamKey: if awsAccessKeyId, awsSecretAccessKey, err = credentials.GetAKSKCredential(credentialName); err != nil {
awsSecretAccessKey = param.Value return "", "", err
} }
} }
} }
// from milvus.yaml // from milvus.yaml
if awsAccessKeyId == "" { if awsAccessKeyId == "" && awsSecretAccessKey == "" {
awsAccessKeyId = confParams[awsAKIdParamKey] credentialName := confParams[credentialParamKey]
if credentialName != "" {
if awsAccessKeyId, awsSecretAccessKey, err = credentials.GetAKSKCredential(credentialName); err != nil {
return "", "", err
}
} }
if awsSecretAccessKey == "" {
awsSecretAccessKey = confParams[awsSAKParamKey]
} }
// from env // from env
@ -106,10 +108,10 @@ func parseAccessInfo(params []*commonpb.KeyValuePair, confParams map[string]stri
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr) awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
} }
return awsAccessKeyId, awsSecretAccessKey return awsAccessKeyId, awsSecretAccessKey, nil
} }
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient, params map[string]string) (*BedrockEmbeddingProvider, error) { func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient, params map[string]string, credentials *milvusCredentials.CredentialsManager) (*BedrockEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -142,7 +144,10 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
} }
} }
awsAccessKeyId, awsSecretAccessKey := parseAccessInfo(functionSchema.Params, params) awsAccessKeyId, awsSecretAccessKey, err := parseAKSKInfo(credentials, functionSchema.Params, params)
if err != nil {
return nil, err
}
var client BedrockClient var client BedrockClient
if c == nil { if c == nil {

View File

@ -19,6 +19,7 @@
package function package function
import ( import (
"os"
"testing" "testing"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -26,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
) )
func TestBedrockTextEmbeddingProvider(t *testing.T) { func TestBedrockTextEmbeddingProvider(t *testing.T) {
@ -65,13 +67,12 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
}, },
} }
switch providerName { switch providerName {
case bedrockProvider: case bedrockProvider:
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}, map[string]string{}) return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }
@ -110,6 +111,38 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
} }
} }
func (s *BedrockTextEmbeddingProviderSuite) TestParseCredentail() {
{
cred := credentials.NewCredentialsManager(map[string]string{})
ak, sk, err := parseAKSKInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{})
s.Equal(ak, "")
s.Equal(sk, "")
s.NoError(err)
}
{
cred := credentials.NewCredentialsManager(map[string]string{})
_, _, err := parseAKSKInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "NotExist"})
s.ErrorContains(err, "is not a aksk crediential, can not find key")
}
{
cred := credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"})
_, _, err := parseAKSKInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "mock"})
s.ErrorContains(err, "is not a aksk crediential, can not find key")
}
{
cred := credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"})
_, _, err := parseAKSKInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "mock"})
s.NoError(err)
}
{
os.Setenv(bedrockAccessKeyId, "mock")
os.Setenv(bedrockSAKEnvStr, "mock")
cred := credentials.NewCredentialsManager(map[string]string{})
_, _, err := parseAKSKInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{})
s.NoError(err)
}
}
func (s *BedrockTextEmbeddingProviderSuite) TestCreateBedrockClient() { func (s *BedrockTextEmbeddingProviderSuite) TestCreateBedrockClient() {
_, err := createBedRockEmbeddingClient("", "", "") _, err := createBedRockEmbeddingClient("", "", "")
s.Error(err) s.Error(err)
@ -144,32 +177,31 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: awsAKIdParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: awsSAKParamKey, Value: "mock"},
{Key: regionParamKey, Value: "mock"}, {Key: regionParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: normalizeParamKey, Value: "false"}, {Key: normalizeParamKey, Value: "false"},
}, },
} }
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"}))
s.NoError(err) s.NoError(err)
s.True(provider.MaxBatch() > 0) s.True(provider.MaxBatch() > 0)
s.Equal(provider.FieldDim(), int64(4)) s.Equal(provider.FieldDim(), int64(4))
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{awsAKIdParamKey: "mock", awsSAKParamKey: "mock"}) _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{credentialParamKey: "mock"}, credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"}))
s.NoError(err) s.NoError(err)
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"} functionSchema.Params[4] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"}))
s.NoError(err) s.NoError(err)
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"} functionSchema.Params[4] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"}))
s.Error(err) s.Error(err)
// invalid dim // invalid dim
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel} functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel}
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}) _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.access_key_id": "mock", "mock.secret_access_key": "mock"}))
s.Error(err) s.Error(err)
} }

View File

@ -23,6 +23,7 @@ import (
"strings" "strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/cohere" "github.com/milvus-io/milvus/internal/util/function/models/cohere"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -42,7 +43,7 @@ type CohereEmbeddingProvider struct {
func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) { func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) {
if apiKey == "" { if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr)
} }
if url == "" { if url == "" {
@ -53,12 +54,15 @@ func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbed
return c, nil return c, nil
} }
func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*CohereEmbeddingProvider, error) { func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*CohereEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
apiKey, url := parseAKAndURL(functionSchema.Params, params, cohereAIAKEnvStr) apiKey, url, err := parseAKAndURL(credentials, functionSchema.Params, params, cohereAIAKEnvStr)
if err != nil {
return nil, err
}
var modelName string var modelName string
truncate := "END" truncate := "END"
for _, param := range functionSchema.Params { for _, param := range functionSchema.Params {

View File

@ -29,7 +29,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/cohere" "github.com/milvus-io/milvus/internal/util/function/models/cohere"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
func TestCohereTextEmbeddingProvider(t *testing.T) { func TestCohereTextEmbeddingProvider(t *testing.T) {
@ -43,6 +45,12 @@ type CohereTextEmbeddingProviderSuite struct {
} }
func (s *CohereTextEmbeddingProviderSuite) SetupTest() { func (s *CohereTextEmbeddingProviderSuite) SetupTest() {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
s.schema = &schemapb.CollectionSchema{ s.schema = &schemapb.CollectionSchema{
Name: "test", Name: "test",
Fields: []*schemapb.FieldSchema{ Fields: []*schemapb.FieldSchema{
@ -69,13 +77,12 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
}, },
} }
switch providerName { switch providerName {
case cohereProvider: case cohereProvider:
return NewCohereEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewCohereEmbeddingProvider(schema, functionSchema, map[string]string{embeddingURLParamKey: url}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }
@ -259,22 +266,22 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
}, },
} }
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.truncate, "END") s.Equal(provider.truncate, "END")
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"}) functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"})
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.truncate, "START") s.Equal(provider.truncate, "START")
// Invalid truncateParam // Invalid truncateParam
functionSchema.Params[2].Value = "Unknow" functionSchema.Params[2].Value = "Unknow"
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }
@ -288,17 +295,17 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: "model-v2.0"}, {Key: modelNameParamKey, Value: "model-v2.0"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
}, },
} }
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.getInputType(InsertMode), "") s.Equal(provider.getInputType(InsertMode), "")
s.Equal(provider.getInputType(SearchMode), "") s.Equal(provider.getInputType(SearchMode), "")
functionSchema.Params[0].Value = "model-v3.0" functionSchema.Params[0].Value = "model-v3.0"
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.getInputType(InsertMode), "search_document") s.Equal(provider.getInputType(InsertMode), "search_document")
s.Equal(provider.getInputType(SearchMode), "search_query") s.Equal(provider.getInputType(SearchMode), "search_query")

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
) )
type TextEmbeddingMode int type TextEmbeddingMode int
@ -48,9 +49,8 @@ const (
modelNameParamKey string = "model_name" modelNameParamKey string = "model_name"
dimParamKey string = "dim" dimParamKey string = "dim"
embeddingURLParamKey string = "url" embeddingURLParamKey string = "url"
apiKeyParamKey string = "api_key" credentialParamKey string = "credential"
truncateParamKey string = "truncate" truncateParamKey string = "truncate"
enableVerifiInfoInParamsKey string = "enableVerifiInfoInParams"
) )
// ali text embedding // ali text embedding
@ -72,8 +72,8 @@ const (
// bedrock emebdding // bedrock emebdding
const ( const (
awsAKIdParamKey string = "aws_access_key_id" // awsAKIdParamKey string = "aws_access_key_id"
awsSAKParamKey string = "aws_secret_access_key" // awsSAKParamKey string = "aws_secret_access_key"
regionParamKey string = "region" regionParamKey string = "region"
normalizeParamKey string = "normalize" normalizeParamKey string = "normalize"
@ -121,41 +121,29 @@ const (
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI" enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
) )
const enableVerifiInfoInParams string = "ENABLE_VERIFI_INFO_IN_PARAMS" func parseAKAndURL(credentials *credentials.CredentialsManager, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {
// function param > yaml > env
func isEnableVerifiInfoInParamsKey(confParams map[string]string) bool { var err error
enable := true
if strings.ToLower(confParams[enableVerifiInfoInParamsKey]) != "" {
// If enableVerifiInfoInParamsKey is configured in milvus.yaml, the configuration in milvus.yaml will be used.
enable, _ = strconv.ParseBool(confParams[enableVerifiInfoInParamsKey])
} else {
// If enableVerifiInfoInParamsKey is not configured in milvus.yaml, the configuration in env will be used.
if strings.ToLower(os.Getenv(enableVerifiInfoInParams)) != "" {
enable, _ = strconv.ParseBool(confParams[enableVerifiInfoInParamsKey])
}
}
return enable
}
func parseAKAndURL(params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string) {
// function param > env > yaml
var apiKey, url string var apiKey, url string
// from function params
if isEnableVerifiInfoInParamsKey(confParams) {
for _, param := range params { for _, param := range params {
switch strings.ToLower(param.Key) { switch strings.ToLower(param.Key) {
case apiKeyParamKey: case credentialParamKey:
apiKey = param.Value credentialName := param.Value
case embeddingURLParamKey: if apiKey, err = credentials.GetAPIKeyCredential(credentialName); err != nil {
url = param.Value return "", "", err
} }
} }
} }
// from milvus.yaml // from milvus.yaml
if apiKey == "" { if apiKey == "" {
apiKey = confParams[apiKeyParamKey] credentialName := confParams[credentialParamKey]
if credentialName != "" {
if apiKey, err = credentials.GetAPIKeyCredential(credentialName); err != nil {
return "", "", err
}
}
} }
if url == "" { if url == "" {
@ -166,7 +154,7 @@ func parseAKAndURL(params []*commonpb.KeyValuePair, confParams map[string]string
if apiKey == "" { if apiKey == "" {
apiKey = os.Getenv(apiKeyEnv) apiKey = os.Getenv(apiKeyEnv)
} }
return apiKey, url return apiKey, url, nil
} }
func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {

View File

@ -49,9 +49,21 @@ type FunctionExecutorSuite struct {
func (s *FunctionExecutorSuite) SetupTest() { func (s *FunctionExecutorSuite) SetupTest() {
paramtable.Init() paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
} }
func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema { func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema {
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := openAIProvider + "." + embeddingURLParamKey
return map[string]string{
key: url,
}
}
return &schemapb.CollectionSchema{ return &schemapb.CollectionSchema{
Name: "test", Name: "test",
Fields: []*schemapb.FieldSchema{ Fields: []*schemapb.FieldSchema{
@ -83,8 +95,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
}, },
}, },
@ -98,8 +109,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "8"}, {Key: dimParamKey, Value: "8"},
}, },
}, },

View File

@ -24,6 +24,7 @@ import (
"strings" "strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/openai" "github.com/milvus-io/milvus/internal/util/function/models/openai"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -42,7 +43,7 @@ type OpenAIEmbeddingProvider struct {
func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) {
if apiKey == "" { if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service.", openaiAKEnvStr)
} }
if url == "" { if url == "" {
@ -55,7 +56,7 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbed
func createAzureOpenAIEmbeddingClient(apiKey string, url string, resourceName string) (*openai.AzureOpenAIEmbeddingClient, error) { func createAzureOpenAIEmbeddingClient(apiKey string, url string, resourceName string) (*openai.AzureOpenAIEmbeddingClient, error) {
if apiKey == "" { if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr)
} }
if url == "" { if url == "" {
@ -73,7 +74,7 @@ func createAzureOpenAIEmbeddingClient(apiKey string, url string, resourceName st
return c, nil return c, nil
} }
func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, isAzure bool) (*OpenAIEmbeddingProvider, error) { func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, isAzure bool, credentials *credentials.CredentialsManager) (*OpenAIEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -98,13 +99,19 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
var c openai.OpenAIEmbeddingInterface var c openai.OpenAIEmbeddingInterface
if !isAzure { if !isAzure {
apiKey, url := parseAKAndURL(functionSchema.Params, params, openaiAKEnvStr) apiKey, url, err := parseAKAndURL(credentials, functionSchema.Params, params, openaiAKEnvStr)
if err != nil {
return nil, err
}
c, err = createOpenAIEmbeddingClient(apiKey, url) c, err = createOpenAIEmbeddingClient(apiKey, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
apiKey, url := parseAKAndURL(functionSchema.Params, params, azureOpenaiAKEnvStr) apiKey, url, err := parseAKAndURL(credentials, functionSchema.Params, params, azureOpenaiAKEnvStr)
if err != nil {
return nil, err
}
resourceName := params["resource_name"] resourceName := params["resource_name"]
c, err = createAzureOpenAIEmbeddingClient(apiKey, url, resourceName) c, err = createAzureOpenAIEmbeddingClient(apiKey, url, resourceName)
if err != nil { if err != nil {
@ -124,12 +131,12 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
return &provider, nil return &provider, nil
} }
func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*OpenAIEmbeddingProvider, error) { func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*OpenAIEmbeddingProvider, error) {
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, false) return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, false, credentials)
} }
func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*OpenAIEmbeddingProvider, error) { func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*OpenAIEmbeddingProvider, error) {
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, true) return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, true, credentials)
} }
func (provider *OpenAIEmbeddingProvider) MaxBatch() int { func (provider *OpenAIEmbeddingProvider) MaxBatch() int {

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/openai" "github.com/milvus-io/milvus/internal/util/function/models/openai"
) )
@ -70,16 +71,15 @@ func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: embeddingURLParamKey, Value: url},
}, },
} }
switch providerName { switch providerName {
case openAIProvider: case openAIProvider:
return NewOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{embeddingURLParamKey: url}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
case azureOpenAIProvider: case azureOpenAIProvider:
return NewAzureOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewAzureOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{embeddingURLParamKey: url}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }

View File

@ -23,6 +23,7 @@ import (
"strings" "strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/siliconflow" "github.com/milvus-io/milvus/internal/util/function/models/siliconflow"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -40,7 +41,7 @@ type SiliconflowEmbeddingProvider struct {
func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.SiliconflowEmbedding, error) { func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.SiliconflowEmbedding, error) {
if apiKey == "" { if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", siliconflowAKEnvStr) return nil, fmt.Errorf("Missing credentials conifg or configure the %s environment variable in the Milvus service.", siliconflowAKEnvStr)
} }
if url == "" { if url == "" {
@ -51,12 +52,15 @@ func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.S
return c, nil return c, nil
} }
func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*SiliconflowEmbeddingProvider, error) { func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*SiliconflowEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
apiKey, url := parseAKAndURL(functionSchema.Params, params, siliconflowAKEnvStr) apiKey, url, err := parseAKAndURL(credentials, functionSchema.Params, params, siliconflowAKEnvStr)
if err != nil {
return nil, err
}
var modelName string var modelName string
for _, param := range functionSchema.Params { for _, param := range functionSchema.Params {

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/siliconflow" "github.com/milvus-io/milvus/internal/util/function/models/siliconflow"
) )
@ -69,13 +70,12 @@ func createSiliconflowProvider(url string, schema *schemapb.FieldSchema, provide
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
}, },
} }
switch providerName { switch providerName {
case siliconflowProvider: case siliconflowProvider:
return NewSiliconflowEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewSiliconflowEmbeddingProvider(schema, functionSchema, map[string]string{embeddingURLParamKey: url}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }
@ -182,11 +182,10 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
}, },
} }
provider, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{embeddingURLParamKey: "mock"}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.FieldDim(), int64(4)) s.Equal(provider.FieldDim(), int64(4))
s.True(provider.MaxBatch() > 0) s.True(provider.MaxBatch() > 0)

View File

@ -27,6 +27,7 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/tei" "github.com/milvus-io/milvus/internal/util/function/models/tei"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -54,12 +55,12 @@ func createTEIEmbeddingClient(apiKey string, endpoint string) (*tei.TEIEmbedding
return tei.NewTEIEmbeddingClient(apiKey, endpoint) return tei.NewTEIEmbeddingClient(apiKey, endpoint)
} }
func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*TeiEmbeddingProvider, error) { func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*TeiEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var apiKey, endpoint, ingestionPrompt, searchPrompt string var endpoint, ingestionPrompt, searchPrompt string
// TEI default client batch size // TEI default client batch size
maxBatch := 32 maxBatch := 32
truncate := false truncate := false
@ -68,8 +69,6 @@ func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *
for _, param := range functionSchema.Params { for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) { switch strings.ToLower(param.Key) {
case apiKeyParamKey:
apiKey = param.Value
case endpointParamKey: case endpointParamKey:
endpoint = param.Value endpoint = param.Value
case ingestionPromptParamKey: case ingestionPromptParamKey:
@ -92,6 +91,10 @@ func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *
} }
} }
apiKey, _, err := parseAKAndURL(credentials, functionSchema.Params, params, "")
if err != nil {
return nil, err
}
c, err := createTEIEmbeddingClient(apiKey, endpoint) c, err := createTEIEmbeddingClient(apiKey, endpoint)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
) )
func TestTEITextEmbeddingProvider(t *testing.T) { func TestTEITextEmbeddingProvider(t *testing.T) {
@ -68,7 +69,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st
InputFieldIds: []int64{101}, InputFieldIds: []int64{101},
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: endpointParamKey, Value: url}, {Key: endpointParamKey, Value: url},
{Key: ingestionPromptParamKey, Value: "doc:"}, {Key: ingestionPromptParamKey, Value: "doc:"},
{Key: searchPromptParamKey, Value: "query:"}, {Key: searchPromptParamKey, Value: "query:"},
@ -76,7 +77,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st
} }
switch providerName { switch providerName {
case teiProvider: case teiProvider:
return NewTEIEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewTEIEmbeddingProvider(schema, functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }
@ -168,11 +169,11 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() {
InputFieldIds: []int64{101}, InputFieldIds: []int64{101},
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: endpointParamKey, Value: "http://mymock.com"}, {Key: endpointParamKey, Value: "http://mymock.com"},
}, },
} }
provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.FieldDim(), int64(4)) s.Equal(provider.FieldDim(), int64(4))
s.True(provider.MaxBatch() == 32*5) s.True(provider.MaxBatch() == 32*5)
@ -180,35 +181,35 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() {
// Invalid truncate // Invalid truncate
{ {
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "Invalid"}) functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "Invalid"})
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }
// Invalid truncationDirection // Invalid truncationDirection
{ {
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: truncateParamKey, Value: "true"} functionSchema.Params[2] = &commonpb.KeyValuePair{Key: truncateParamKey, Value: "true"}
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Invalid"}) functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Invalid"})
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }
// truncationDirection // truncationDirection
{ {
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Left"} functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Left"}
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
} }
// Invalid max batch // Invalid max batch
{ {
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "Invalid"}) functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "Invalid"})
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }
// Valid max batch // Valid max batch
{ {
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "128"} functionSchema.Params[4] = &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "128"}
pv, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) pv, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.True(pv.MaxBatch() == 128*5) s.True(pv.MaxBatch() == 128*5)
} }

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
@ -98,25 +99,26 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
var embP textEmbeddingProvider var embP textEmbeddingProvider
var newProviderErr error var newProviderErr error
conf := paramtable.Get().FunctionCfg.GetTextEmbeddingProviderConfig(base.provider) conf := paramtable.Get().FunctionCfg.GetTextEmbeddingProviderConfig(base.provider)
credentials := credentials.NewCredentialsManager(paramtable.Get().CredentialCfg.GetCredentials())
switch base.provider { switch base.provider {
case openAIProvider: case openAIProvider:
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case azureOpenAIProvider: case azureOpenAIProvider:
embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case bedrockProvider: case bedrockProvider:
embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf) embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf, credentials)
case aliDashScopeProvider: case aliDashScopeProvider:
embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case vertexAIProvider: case vertexAIProvider:
embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf) embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf, credentials)
case voyageAIProvider: case voyageAIProvider:
embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case cohereProvider: case cohereProvider:
embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case siliconflowProvider: case siliconflowProvider:
embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case teiProvider: case teiProvider:
embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema, conf) embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
default: default:
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider) return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -45,6 +46,13 @@ type TextEmbeddingFunctionSuite struct {
func (s *TextEmbeddingFunctionSuite) SetupTest() { func (s *TextEmbeddingFunctionSuite) SetupTest() {
paramtable.Init() paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
"mock.access_key_id": "mock",
"mock.secret_access_key": "mock",
}
}
s.schema = &schemapb.CollectionSchema{ s.schema = &schemapb.CollectionSchema{
Name: "test", Name: "test",
Fields: []*schemapb.FieldSchema{ Fields: []*schemapb.FieldSchema{
@ -92,8 +100,7 @@ func (s *TextEmbeddingFunctionSuite) TestInvalidProvider() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
}, },
} }
providerName, err := getProvider(fSchema) providerName, err := getProvider(fSchema)
@ -110,6 +117,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
ts := CreateOpenAIEmbeddingServer() ts := CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
{ {
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := openAIProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -121,8 +134,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -142,6 +154,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
} }
} }
{ {
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := azureOpenAIProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -153,8 +171,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
{Key: Provider, Value: azureOpenAIProvider}, {Key: Provider, Value: azureOpenAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -178,6 +195,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
ts := CreateAliEmbeddingServer() ts := CreateAliEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := aliDashScopeProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
@ -190,8 +213,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
{Key: Provider, Value: aliDashScopeProvider}, {Key: Provider, Value: aliDashScopeProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -336,8 +358,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
}, },
}) })
s.Error(err) s.Error(err)
@ -375,8 +396,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
}, },
}) })
s.Error(err) s.Error(err)
@ -395,8 +415,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
}, },
}) })
s.Error(err) s.Error(err)
@ -432,8 +451,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: bedrockProvider}, {Key: Provider, Value: bedrockProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: awsAKIdParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: awsSAKParamKey, Value: "mock"},
{Key: regionParamKey, Value: "mock"}, {Key: regionParamKey, Value: "mock"},
}, },
} }
@ -456,7 +474,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: aliDashScopeProvider}, {Key: Provider, Value: aliDashScopeProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
}, },
} }
@ -478,7 +496,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: voyageAIProvider}, {Key: Provider, Value: voyageAIProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
}, },
} }
@ -500,7 +518,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: siliconflowProvider}, {Key: Provider, Value: siliconflowProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
}, },
} }
@ -522,7 +540,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: cohereProvider}, {Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
}, },
} }
@ -607,6 +625,12 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() { func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
ts := CreateOpenAIEmbeddingServer() ts := CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := openAIProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -618,8 +642,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -692,6 +715,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() {
}, },
}, },
} }
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := cohereProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -703,8 +732,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() {
{Key: Provider, Value: cohereProvider}, {Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -750,8 +778,8 @@ func (s *TextEmbeddingFunctionSuite) TestUnsupportedVec() {
{Key: Provider, Value: cohereProvider}, {Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"}, // {Key: embeddingURLParamKey, Value: "mock"},
}, },
}) })
s.Error(err) s.Error(err)
@ -774,6 +802,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
}, },
}, },
} }
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := cohereProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -785,8 +819,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
{Key: Provider, Value: cohereProvider}, {Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -845,6 +878,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() { func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() {
ts := CreateOpenAIEmbeddingServer() ts := CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := openAIProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -856,8 +895,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() {
{Key: Provider, Value: openAIProvider}, {Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)
@ -894,6 +932,26 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() {
} }
} }
func (s *TextEmbeddingFunctionSuite) TestParseCredentail() {
{
cred := credentials.NewCredentialsManager(map[string]string{})
ak, url, err := parseAKAndURL(cred, []*commonpb.KeyValuePair{}, map[string]string{}, "")
s.Equal(ak, "")
s.Equal(url, "")
s.NoError(err)
}
{
cred := credentials.NewCredentialsManager(map[string]string{})
_, _, err := parseAKAndURL(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "NotExist"}, "")
s.ErrorContains(err, "is not a apikey crediential, can not find key")
}
{
cred := credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"})
_, _, err := parseAKAndURL(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "mock"}, "")
s.NoError(err)
}
}
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() { func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
ts := CreateCohereEmbeddingServer[int8]() ts := CreateCohereEmbeddingServer[int8]()
defer ts.Close() defer ts.Close()
@ -910,6 +968,12 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
}, },
}, },
} }
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := cohereProvider + "." + embeddingURLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test", Name: "test",
Type: schemapb.FunctionType_TextEmbedding, Type: schemapb.FunctionType_TextEmbedding,
@ -921,8 +985,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
{Key: Provider, Value: cohereProvider}, {Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"}, {Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
}, },
}) })
s.NoError(err) s.NoError(err)

View File

@ -26,7 +26,9 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai" "github.com/milvus-io/milvus/internal/util/function/models/vertexai"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -39,22 +41,15 @@ type vertexAIJsonKey struct {
var vtxKey vertexAIJsonKey var vtxKey vertexAIJsonKey
func getVertexAIJsonKey(credentialsFilePath string) ([]byte, error) { func getVertexAIJsonKey() ([]byte, error) {
vtxKey.mu.Lock() vtxKey.mu.Lock()
defer vtxKey.mu.Unlock() defer vtxKey.mu.Unlock()
var jsonKeyPath string jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv)
if credentialsFilePath == "" {
jsonKeyPath = os.Getenv(vertexServiceAccountJSONEnv)
} else {
jsonKeyPath = credentialsFilePath
}
if jsonKeyPath == "" { if jsonKeyPath == "" {
return nil, errors.New("VetexAI credentials file path is empty") return nil, errors.New("VetexAI credentials file path is empty")
} }
if vtxKey.filePath == jsonKeyPath { if vtxKey.filePath == jsonKeyPath {
// The file path remains unchanged, using the data in the cache
return vtxKey.jsonKey, nil return vtxKey.jsonKey, nil
} }
@ -65,6 +60,7 @@ func getVertexAIJsonKey(credentialsFilePath string) ([]byte, error) {
vtxKey.jsonKey = jsonKey vtxKey.jsonKey = jsonKey
vtxKey.filePath = jsonKeyPath vtxKey.filePath = jsonKeyPath
return vtxKey.jsonKey, nil return vtxKey.jsonKey, nil
} }
@ -86,16 +82,47 @@ type VertexAIEmbeddingProvider struct {
timeoutSec int64 timeoutSec int64
} }
func createVertexAIEmbeddingClient(url string, credentialsFilePath string) (*vertexai.VertexAIEmbedding, error) { func createVertexAIEmbeddingClient(url string, credentialsJSON []byte) (*vertexai.VertexAIEmbedding, error) {
jsonKey, err := getVertexAIJsonKey(credentialsFilePath) c := vertexai.NewVertexAIEmbedding(url, credentialsJSON, "https://www.googleapis.com/auth/cloud-platform", "")
if err != nil {
return nil, err
}
c := vertexai.NewVertexAIEmbedding(url, jsonKey, "https://www.googleapis.com/auth/cloud-platform", "")
return c, nil return c, nil
} }
func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding, params map[string]string) (*VertexAIEmbeddingProvider, error) { func parseGcpCredentialInfo(credentials *credentials.CredentialsManager, params []*commonpb.KeyValuePair, confParams map[string]string) ([]byte, error) {
// function param > yaml > env
var credentialsJSON []byte
var err error
for _, param := range params {
switch strings.ToLower(param.Key) {
case credentialParamKey:
credentialName := param.Value
if credentialsJSON, err = credentials.GetGcpCredential(credentialName); err != nil {
return nil, err
}
}
}
// from milvus.yaml
if credentialsJSON == nil {
credentialName := confParams[credentialParamKey]
if credentialName != "" {
if credentialsJSON, err = credentials.GetGcpCredential(credentialName); err != nil {
return nil, err
}
}
}
// from env
if credentialsJSON == nil {
credentialsJSON, err = getVertexAIJsonKey()
if err != nil {
return nil, err
}
}
return credentialsJSON, nil
}
func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding, params map[string]string, credentials *credentials.CredentialsManager) (*VertexAIEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -136,7 +163,11 @@ func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
} }
var client *vertexai.VertexAIEmbedding var client *vertexai.VertexAIEmbedding
if c == nil { if c == nil {
client, err = createVertexAIEmbeddingClient(url, params["credentials_file_path"]) jsonKey, err := parseGcpCredentialInfo(credentials, functionSchema.Params, params)
if err != nil {
return nil, err
}
client, err = createVertexAIEmbeddingClient(url, jsonKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai" "github.com/milvus-io/milvus/internal/util/function/models/vertexai"
) )
@ -76,7 +77,7 @@ func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbed
}, },
} }
mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token") mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token")
return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient, map[string]string{}) return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "mock"}))
} }
func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() { func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() {
@ -177,7 +178,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() { func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath") os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
defer os.Unsetenv(vertexServiceAccountJSONEnv) defer os.Unsetenv(vertexServiceAccountJSONEnv)
_, err := getVertexAIJsonKey("") _, err := getVertexAIJsonKey()
s.Error(err) s.Error(err)
} }
@ -198,7 +199,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token") mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
{ {
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT") s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY") s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY")
@ -206,7 +207,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
{ {
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival}) functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival})
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT") s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY") s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY")
@ -214,20 +215,13 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
{ {
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS} functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS}
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY") s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY") s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
} }
} }
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
defer os.Unsetenv(vertexServiceAccountJSONEnv)
_, err := createVertexAIEmbeddingClient("https://mock_url.com", "")
s.Error(err)
}
func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() { func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() {
functionSchema := &schemapb.FunctionSchema{ functionSchema := &schemapb.FunctionSchema{
Name: "test", Name: "test",
@ -243,8 +237,40 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider()
}, },
} }
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token") mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}) provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "mock"}))
s.NoError(err) s.NoError(err)
s.True(provider.MaxBatch() > 0) s.True(provider.MaxBatch() > 0)
s.Equal(provider.FieldDim(), int64(4)) s.Equal(provider.FieldDim(), int64(4))
} }
func (s *VertexAITextEmbeddingProviderSuite) TestParseCredentail() {
{
cred := credentials.NewCredentialsManager(map[string]string{})
data, err := parseGcpCredentialInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{})
s.Nil(data)
s.ErrorContains(err, "VetexAI credentials file path is empty")
}
{
os.Setenv(vertexServiceAccountJSONEnv, "mock.json")
defer os.Unsetenv(vertexServiceAccountJSONEnv)
cred := credentials.NewCredentialsManager(map[string]string{})
data, err := parseGcpCredentialInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{})
s.Nil(data)
s.ErrorContains(err, "Vertexai: read credentials file failed")
}
{
cred := credentials.NewCredentialsManager(map[string]string{})
_, err := parseGcpCredentialInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "noExist"})
s.ErrorContains(err, "is not a gcp crediential, can not find key")
}
{
cred := credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "NotBase64"})
_, err := parseGcpCredentialInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "mock"})
s.ErrorContains(err, "Parse gcp credential")
}
{
cred := credentials.NewCredentialsManager(map[string]string{"mock.credential_json": "bW9jaw=="})
_, err := parseGcpCredentialInfo(cred, []*commonpb.KeyValuePair{}, map[string]string{"credential": "mock"})
s.NoError(err)
}
}

View File

@ -24,6 +24,7 @@ import (
"strings" "strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai" "github.com/milvus-io/milvus/internal/util/function/models/voyageai"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -44,7 +45,7 @@ type VoyageAIEmbeddingProvider struct {
func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) { func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) {
if apiKey == "" { if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr) return nil, fmt.Errorf("Missing credentials config or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr)
} }
if url == "" { if url == "" {
@ -55,12 +56,15 @@ func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageA
return c, nil return c, nil
} }
func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*VoyageAIEmbeddingProvider, error) { func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, credentials *credentials.CredentialsManager) (*VoyageAIEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema) fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
apiKey, url := parseAKAndURL(functionSchema.Params, params, voyageAIAKEnvStr) apiKey, url, err := parseAKAndURL(credentials, functionSchema.Params, params, voyageAIAKEnvStr)
if err != nil {
return nil, err
}
var modelName string var modelName string
dim := int64(0) dim := int64(0)
truncate := false truncate := false

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai" "github.com/milvus-io/milvus/internal/util/function/models/voyageai"
) )
@ -69,14 +70,13 @@ func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerNa
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "1024"}, {Key: dimParamKey, Value: "1024"},
}, },
} }
switch providerName { switch providerName {
case voyageAIProvider: case voyageAIProvider:
return NewVoyageAIEmbeddingProvider(schema, functionSchema, map[string]string{}) return NewVoyageAIEmbeddingProvider(schema, functionSchema, map[string]string{embeddingURLParamKey: url}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
default: default:
return nil, errors.New("Unknow provider") return nil, errors.New("Unknow provider")
} }
@ -293,36 +293,35 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
OutputFieldIds: []int64{102}, OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{ Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TestModel}, {Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"}, {Key: credentialParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
{Key: dimParamKey, Value: "1024"}, {Key: dimParamKey, Value: "1024"},
{Key: truncationParamKey, Value: "true"}, {Key: truncationParamKey, Value: "true"},
}, },
} }
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{embeddingURLParamKey: "mock"}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.NoError(err) s.NoError(err)
s.Equal(provider.FieldDim(), int64(1024)) s.Equal(provider.FieldDim(), int64(1024))
s.True(provider.MaxBatch() > 0) s.True(provider.MaxBatch() > 0)
// Invalid truncation // Invalid truncation
{ {
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"} functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"} functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"}
} }
// Invalid dim // Invalid dim
{ {
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"} functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }
// Invalid dim type // Invalid dim type
{ {
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"} functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}) _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentialsManager(map[string]string{"mock.apikey": "mock"}))
s.Error(err) s.Error(err)
} }
} }

View File

@ -86,6 +86,7 @@ type ComponentParam struct {
RbacConfig rbacConfig RbacConfig rbacConfig
StreamingCfg streamingConfig StreamingCfg streamingConfig
FunctionCfg functionConfig FunctionCfg functionConfig
CredentialCfg credentialConfig
InternalTLSCfg InternalTLSConfig InternalTLSCfg InternalTLSConfig
@ -142,6 +143,7 @@ func (p *ComponentParam) init(bt *BaseTable) {
p.GpuConfig.init(bt) p.GpuConfig.init(bt)
p.KnowhereConfig.init(bt) p.KnowhereConfig.init(bt)
p.FunctionCfg.init(bt) p.FunctionCfg.init(bt)
p.CredentialCfg.init(bt)
p.InternalTLSCfg.Init(bt) p.InternalTLSCfg.Init(bt)

View File

@ -0,0 +1,48 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package paramtable
type credentialConfig struct {
Credential ParamGroup `refreshable:"true"`
}
func (p *credentialConfig) init(base *BaseTable) {
p.Credential = ParamGroup{
KeyPrefix: "credential.",
Version: "2.6.0",
Export: true,
DocFunc: func(key string) string {
switch key {
case "apikey1.apikey":
return "Your apikey credential"
case "aksk1.access_key_id":
return "Your access_key_id"
case "aksk1.secret_access_key":
return "Your secret_access_key"
case "gcp1.credential_json":
return "base64 based gcp credential data"
default:
return ""
}
},
}
p.Credential.Init(base.mgr)
}
func (p *credentialConfig) GetCredentials() map[string]string {
return p.Credential.GetValue()
}

View File

@ -0,0 +1,39 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package paramtable
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCredentialConfig(t *testing.T) {
params := ComponentParam{}
params.Init(NewBaseTable(SkipRemote(true)))
cfg := &params.CredentialCfg
keys := []string{
"aksk1.access_key_id",
"aksk1.secret_access_key",
"apikey1.apikey",
"gcp1.credential_json",
}
for _, key := range keys {
assert.True(t, cfg.Credential.GetDoc(key) != "")
}
assert.True(t, cfg.Credential.GetDoc("Unknow") == "")
}

View File

@ -21,20 +21,10 @@ import (
) )
type functionConfig struct { type functionConfig struct {
TextEmbeddingEnableVerifiInfoInParams ParamItem `refreshable:"true"`
TextEmbeddingProviders ParamGroup `refreshable:"true"` TextEmbeddingProviders ParamGroup `refreshable:"true"`
} }
func (p *functionConfig) init(base *BaseTable) { func (p *functionConfig) init(base *BaseTable) {
p.TextEmbeddingEnableVerifiInfoInParams = ParamItem{
Key: "function.textEmbedding.enableVerifiInfoInParams",
Version: "2.6.0",
DefaultValue: "true",
Export: true,
Doc: "Controls whether to allow configuration of apikey and model service url on function parameters",
}
p.TextEmbeddingEnableVerifiInfoInParams.Init(base.mgr)
p.TextEmbeddingProviders = ParamGroup{ p.TextEmbeddingProviders = ParamGroup{
KeyPrefix: "function.textEmbedding.providers.", KeyPrefix: "function.textEmbedding.providers.",
Version: "2.6.0", Version: "2.6.0",
@ -43,40 +33,40 @@ func (p *functionConfig) init(base *BaseTable) {
switch key { switch key {
case "tei.enable": case "tei.enable":
return "Whether to enable TEI model service" return "Whether to enable TEI model service"
case "azure_openai.api_key": case "tei.credential":
return "Your azure openai embedding url, Default is the official embedding url" return "The name in the crendential configuration item"
case "azure_openai.credential":
return "The name in the crendential configuration item"
case "azure_openai.url": case "azure_openai.url":
return "Your azure openai api key" return "Your azure openai embedding url, Default is the official embedding url"
case "azure_openai.resource_name": case "azure_openai.resource_name":
return "Your azure openai resource name" return "Your azure openai resource name"
case "openai.api_key": case "openai.credential":
return "Your openai embedding url, Default is the official embedding url" return "The name in the crendential configuration item"
case "openai.url": case "openai.url":
return "Your openai api key" return "Your openai embedding url, Default is the official embedding url"
case "dashscope.api_key": case "dashscope.credential":
return "Your dashscope embedding url, Default is the official embedding url" return "The name in the crendential configuration item"
case "dashscope.url": case "dashscope.url":
return "Your dashscope api key" return "Your dashscope embedding url, Default is the official embedding url"
case "cohere.api_key": case "cohere.credential":
return "Your cohere embedding url, Default is the official embedding url" return "The name in the crendential configuration item"
case "cohere.url": case "cohere.url":
return "Your cohere api key" return "Your cohere embedding url, Default is the official embedding url"
case "voyageai.api_key": case "voyageai.credential":
return "Your voyageai embedding url, Default is the official embedding url" return "The name in the crendential configuration item"
case "voyageai.url": case "voyageai.url":
return "Your voyageai api key" return "Your voyageai embedding url, Default is the official embedding url"
case "siliconflow.url": case "siliconflow.url":
return "Your siliconflow embedding url, Default is the official embedding url" return "Your siliconflow embedding url, Default is the official embedding url"
case "siliconflow.api_key": case "siliconflow.credential":
return "Your siliconflow api key" return "The name in the crendential configuration item"
case "bedrock.aws_access_key_id": case "bedrock.credential":
return "Your aws_access_key_id" return "The name in the crendential configuration item"
case "bedrock.aws_secret_access_key":
return "Your aws_secret_access_key"
case "vertexai.url": case "vertexai.url":
return "Your VertexAI embedding url" return "Your VertexAI embedding url"
case "vertexai.credentials_file_path": case "vertexai.credential":
return "Path to your google application credentials, change the file path to refresh the configuration" return "The name in the crendential configuration item"
default: default:
return "" return ""
} }
@ -100,6 +90,5 @@ func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map
matchedParam[strings.TrimPrefix(k, prefix)] = v matchedParam[strings.TrimPrefix(k, prefix)] = v
} }
} }
matchedParam["enableVerifiInfoInParams"] = p.TextEmbeddingEnableVerifiInfoInParams.GetValue()
return matchedParam return matchedParam
} }

View File

@ -26,40 +26,34 @@ func TestFunctionConfig(t *testing.T) {
params := ComponentParam{} params := ComponentParam{}
params.Init(NewBaseTable(SkipRemote(true))) params.Init(NewBaseTable(SkipRemote(true)))
cfg := &params.FunctionCfg cfg := &params.FunctionCfg
notExistProvider := cfg.GetTextEmbeddingProviderConfig("notExist")
// Only has enableVerifiInfoInParams config
assert.Equal(t, len(notExistProvider), 1)
teiConf := cfg.GetTextEmbeddingProviderConfig("tei") teiConf := cfg.GetTextEmbeddingProviderConfig("tei")
assert.Equal(t, teiConf["enable"], "true") assert.Equal(t, teiConf["enable"], "true")
assert.Equal(t, teiConf["enableVerifiInfoInParams"], "true")
openaiConf := cfg.GetTextEmbeddingProviderConfig("openai") openaiConf := cfg.GetTextEmbeddingProviderConfig("openai")
assert.Equal(t, openaiConf["api_key"], "") assert.Equal(t, openaiConf["credential"], "")
assert.Equal(t, openaiConf["url"], "") assert.Equal(t, openaiConf["url"], "")
assert.Equal(t, openaiConf["enableVerifiInfoInParams"], "true")
keys := []string{ keys := []string{
"tei.enable", "tei.enable",
"azure_openai.api_key", "tei.credential",
"azure_openai.credential",
"azure_openai.url", "azure_openai.url",
"azure_openai.resource_name", "azure_openai.resource_name",
"openai.api_key", "openai.credential",
"openai.url", "openai.url",
"dashscope.api_key", "dashscope.credential",
"dashscope.url", "dashscope.url",
"cohere.api_key", "cohere.credential",
"cohere.url", "cohere.url",
"voyageai.api_key", "voyageai.credential",
"voyageai.url", "voyageai.url",
"siliconflow.url", "siliconflow.url",
"siliconflow.api_key", "siliconflow.credential",
"bedrock.aws_access_key_id", "bedrock.credential",
"bedrock.aws_secret_access_key",
"vertexai.url", "vertexai.url",
"vertexai.credentials_file_path", "vertexai.credential",
} }
for _, key := range keys { for _, key := range keys {
assert.True(t, cfg.TextEmbeddingProviders.GetDoc(key) != "") assert.True(t, cfg.TextEmbeddingProviders.GetDoc(key) != "")
} }
assert.True(t, cfg.TextEmbeddingProviders.GetDoc("Unknow") == "")
} }