mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
feat: Add function config (#40534)
#35856 1. Add function-related configuration in milvus.yaml 2. Add null and empty value check to TextEmbeddingFunction Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
16efcda5c4
commit
fe81c7baae
@ -351,6 +351,11 @@ func WriteYaml(w io.Writer) {
|
|||||||
header: `
|
header: `
|
||||||
# Any configuration related to the knowhere vector search engine`,
|
# Any configuration related to the knowhere vector search engine`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "function",
|
||||||
|
header: `
|
||||||
|
# Any configuration related to functions`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
marshller := YamlMarshaller{w, groups, result}
|
marshller := YamlMarshaller{w, groups, result}
|
||||||
marshller.writeYamlRecursive(lo.Filter(result, func(d DocContent, _ int) bool {
|
marshller.writeYamlRecursive(lo.Filter(result, func(d DocContent, _ int) bool {
|
||||||
|
|||||||
@ -1157,3 +1157,33 @@ knowhere:
|
|||||||
search_list_size: 100 # Size of the candidate list during building graph
|
search_list_size: 100 # Size of the candidate list during building graph
|
||||||
search:
|
search:
|
||||||
beam_width_ratio: 4 # Ratio between the maximum number of IO requests per search iteration and CPU number
|
beam_width_ratio: 4 # Ratio between the maximum number of IO requests per search iteration and CPU number
|
||||||
|
|
||||||
|
# Any configuration related to functions
|
||||||
|
function:
|
||||||
|
textEmbedding:
|
||||||
|
enableVerifiInfoInParams: true # Controls whether to allow configuration of apikey and model service url on function parameters
|
||||||
|
providers:
|
||||||
|
azure_openai:
|
||||||
|
api_key: # Your azure openai embedding url, Default is the official embedding url
|
||||||
|
resource_name: # Your azure openai resource name
|
||||||
|
url: # Your azure openai api key
|
||||||
|
bedrock:
|
||||||
|
aws_access_key_id: # Your aws_access_key_id
|
||||||
|
aws_secret_access_key: # Your aws_secret_access_key
|
||||||
|
cohere:
|
||||||
|
api_key: # Your cohere embedding url, Default is the official embedding url
|
||||||
|
url: # Your cohere api key
|
||||||
|
dashscope:
|
||||||
|
api_key: # Your dashscope embedding url, Default is the official embedding url
|
||||||
|
url: # Your dashscope api key
|
||||||
|
openai:
|
||||||
|
api_key: # Your openai embedding url, Default is the official embedding url
|
||||||
|
url: # Your openai api key
|
||||||
|
siliconflow:
|
||||||
|
api_key: # Your siliconflow api key
|
||||||
|
url: # Your siliconflow embedding url, Default is the official embedding url
|
||||||
|
tei:
|
||||||
|
enable: true # Whether to enable TEI model service
|
||||||
|
vertexai:
|
||||||
|
credentials_file_path: # Path to your google application credentials, change the file path to refresh the configuration
|
||||||
|
url: # Your VertexAI embedding url
|
||||||
|
|||||||
@ -20,7 +20,6 @@ package function
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
@ -41,9 +40,6 @@ type AliEmbeddingProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
|
func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
|
||||||
if apiKey == "" {
|
|
||||||
apiKey = os.Getenv(dashscopeAKEnvStr)
|
|
||||||
}
|
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
|
||||||
}
|
}
|
||||||
@ -55,12 +51,12 @@ func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbed
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) {
|
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*AliEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
apiKey, url := parseAKAndURL(functionSchema.Params, params, dashscopeAKEnvStr)
|
||||||
var modelName string
|
var modelName string
|
||||||
var dim int64
|
var dim int64
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@ -77,7 +76,7 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case aliDashScopeProvider:
|
case aliDashScopeProvider:
|
||||||
return NewAliDashScopeEmbeddingProvider(schema, functionSchema)
|
return NewAliDashScopeEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -170,11 +169,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
|
|||||||
func (s *AliTextEmbeddingProviderSuite) TestCreateAliEmbeddingClient() {
|
func (s *AliTextEmbeddingProviderSuite) TestCreateAliEmbeddingClient() {
|
||||||
_, err := createAliEmbeddingClient("", "")
|
_, err := createAliEmbeddingClient("", "")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
os.Setenv(dashscopeAKEnvStr, "mock_key")
|
|
||||||
defer os.Unsetenv(dashscopeAKEnvStr)
|
|
||||||
_, err = createAliEmbeddingClient("", "")
|
|
||||||
s.NoError(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
|
func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
|
||||||
@ -193,6 +187,6 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
|
|||||||
}
|
}
|
||||||
// invalid dim
|
// invalid dim
|
||||||
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||||
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,6 +30,7 @@ import (
|
|||||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
)
|
)
|
||||||
@ -51,16 +52,9 @@ type BedrockEmbeddingProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) {
|
func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) {
|
||||||
if awsAccessKeyId == "" {
|
|
||||||
awsAccessKeyId = os.Getenv(bedrockAccessKeyId)
|
|
||||||
}
|
|
||||||
if awsAccessKeyId == "" {
|
if awsAccessKeyId == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId)
|
return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if awsSecretAccessKey == "" {
|
|
||||||
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
|
|
||||||
}
|
|
||||||
if awsSecretAccessKey == "" {
|
if awsSecretAccessKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr)
|
||||||
}
|
}
|
||||||
@ -79,12 +73,47 @@ func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey stri
|
|||||||
return bedrockruntime.NewFromConfig(cfg), nil
|
return bedrockruntime.NewFromConfig(cfg), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) {
|
func parseAccessInfo(params []*commonpb.KeyValuePair, confParams map[string]string) (string, string) {
|
||||||
|
// function param > env > yaml
|
||||||
|
var awsAccessKeyId, awsSecretAccessKey string
|
||||||
|
|
||||||
|
// from function params
|
||||||
|
if isEnableVerifiInfoInParamsKey(confParams) {
|
||||||
|
for _, param := range params {
|
||||||
|
switch strings.ToLower(param.Key) {
|
||||||
|
case awsAKIdParamKey:
|
||||||
|
awsAccessKeyId = param.Value
|
||||||
|
case awsSAKParamKey:
|
||||||
|
awsSecretAccessKey = param.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// from milvus.yaml
|
||||||
|
if awsAccessKeyId == "" {
|
||||||
|
awsAccessKeyId = confParams[awsAKIdParamKey]
|
||||||
|
}
|
||||||
|
if awsSecretAccessKey == "" {
|
||||||
|
awsSecretAccessKey = confParams[awsSAKParamKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
// from env
|
||||||
|
if awsAccessKeyId == "" {
|
||||||
|
awsAccessKeyId = os.Getenv(bedrockAccessKeyId)
|
||||||
|
}
|
||||||
|
if awsSecretAccessKey == "" {
|
||||||
|
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return awsAccessKeyId, awsSecretAccessKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient, params map[string]string) (*BedrockEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var awsAccessKeyId, awsSecretAccessKey, region, modelName string
|
var region, modelName string
|
||||||
var dim int64
|
var dim int64
|
||||||
normalize := true
|
normalize := true
|
||||||
|
|
||||||
@ -97,14 +126,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case awsAKIdParamKey:
|
|
||||||
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
|
|
||||||
awsAccessKeyId = param.Value
|
|
||||||
}
|
|
||||||
case awsSAKParamKey:
|
|
||||||
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
|
|
||||||
awsSecretAccessKey = param.Value
|
|
||||||
}
|
|
||||||
case regionParamKey:
|
case regionParamKey:
|
||||||
region = param.Value
|
region = param.Value
|
||||||
case normalizeParamKey:
|
case normalizeParamKey:
|
||||||
@ -120,6 +141,8 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsAccessKeyId, awsSecretAccessKey := parseAccessInfo(functionSchema.Params, params)
|
||||||
|
|
||||||
var client BedrockClient
|
var client BedrockClient
|
||||||
if c == nil {
|
if c == nil {
|
||||||
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)
|
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)
|
||||||
|
|||||||
@ -71,7 +71,7 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case bedrockProvider:
|
case bedrockProvider:
|
||||||
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim})
|
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -151,22 +151,25 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
|
|||||||
{Key: normalizeParamKey, Value: "false"},
|
{Key: normalizeParamKey, Value: "false"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.True(provider.MaxBatch() > 0)
|
s.True(provider.MaxBatch() > 0)
|
||||||
s.Equal(provider.FieldDim(), int64(4))
|
s.Equal(provider.FieldDim(), int64(4))
|
||||||
|
|
||||||
|
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{awsAKIdParamKey: "mock", awsSAKParamKey: "mock"})
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
|
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
|
||||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
|
||||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"}
|
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"}
|
||||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
// invalid dim
|
// invalid dim
|
||||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel}
|
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel}
|
||||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,7 +20,6 @@ package function
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
@ -42,9 +41,6 @@ type CohereEmbeddingProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) {
|
func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) {
|
||||||
if apiKey == "" {
|
|
||||||
apiKey = os.Getenv(cohereAIAKEnvStr)
|
|
||||||
}
|
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr)
|
||||||
}
|
}
|
||||||
@ -57,12 +53,12 @@ func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbed
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*CohereEmbeddingProvider, error) {
|
func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*CohereEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
apiKey, url := parseAKAndURL(functionSchema.Params, params, cohereAIAKEnvStr)
|
||||||
var modelName string
|
var modelName string
|
||||||
truncate := "END"
|
truncate := "END"
|
||||||
for _, param := range functionSchema.Params {
|
for _, param := range functionSchema.Params {
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@ -76,7 +75,7 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case cohereProvider:
|
case cohereProvider:
|
||||||
return NewCohereEmbeddingProvider(schema, functionSchema)
|
return NewCohereEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -264,18 +263,18 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.truncate, "END")
|
s.Equal(provider.truncate, "END")
|
||||||
|
|
||||||
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"})
|
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"})
|
||||||
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.truncate, "START")
|
s.Equal(provider.truncate, "START")
|
||||||
|
|
||||||
// Invalid truncateParam
|
// Invalid truncateParam
|
||||||
functionSchema.Params[2].Value = "Unknow"
|
functionSchema.Params[2].Value = "Unknow"
|
||||||
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,13 +292,13 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.getInputType(InsertMode), "")
|
s.Equal(provider.getInputType(InsertMode), "")
|
||||||
s.Equal(provider.getInputType(SearchMode), "")
|
s.Equal(provider.getInputType(SearchMode), "")
|
||||||
|
|
||||||
functionSchema.Params[0].Value = "model-v3.0"
|
functionSchema.Params[0].Value = "model-v3.0"
|
||||||
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.getInputType(InsertMode), "search_document")
|
s.Equal(provider.getInputType(InsertMode), "search_document")
|
||||||
s.Equal(provider.getInputType(SearchMode), "search_query")
|
s.Equal(provider.getInputType(SearchMode), "search_query")
|
||||||
@ -308,12 +307,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
|
|||||||
func (s *CohereTextEmbeddingProviderSuite) TestCreateCohereEmbeddingClient() {
|
func (s *CohereTextEmbeddingProviderSuite) TestCreateCohereEmbeddingClient() {
|
||||||
_, err := createCohereEmbeddingClient("", "")
|
_, err := createCohereEmbeddingClient("", "")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
os.Setenv(cohereAIAKEnvStr, "mockKey")
|
|
||||||
defer os.Unsetenv(openaiAKEnvStr)
|
|
||||||
|
|
||||||
_, err = createCohereEmbeddingClient("", "")
|
|
||||||
s.NoError(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CohereTextEmbeddingProviderSuite) TestRuntimeDimNotMatch() {
|
func (s *CohereTextEmbeddingProviderSuite) TestRuntimeDimNotMatch() {
|
||||||
|
|||||||
@ -45,11 +45,12 @@ const (
|
|||||||
|
|
||||||
// common params
|
// common params
|
||||||
const (
|
const (
|
||||||
modelNameParamKey string = "model_name"
|
modelNameParamKey string = "model_name"
|
||||||
dimParamKey string = "dim"
|
dimParamKey string = "dim"
|
||||||
embeddingURLParamKey string = "url"
|
embeddingURLParamKey string = "url"
|
||||||
apiKeyParamKey string = "api_key"
|
apiKeyParamKey string = "api_key"
|
||||||
truncateParamKey string = "truncate"
|
truncateParamKey string = "truncate"
|
||||||
|
enableVerifiInfoInParamsKey string = "enableVerifiInfoInParams"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ali text embedding
|
// ali text embedding
|
||||||
@ -73,7 +74,7 @@ const (
|
|||||||
const (
|
const (
|
||||||
awsAKIdParamKey string = "aws_access_key_id"
|
awsAKIdParamKey string = "aws_access_key_id"
|
||||||
awsSAKParamKey string = "aws_secret_access_key"
|
awsSAKParamKey string = "aws_secret_access_key"
|
||||||
regionParamKey string = "regin"
|
regionParamKey string = "region"
|
||||||
normalizeParamKey string = "normalize"
|
normalizeParamKey string = "normalize"
|
||||||
|
|
||||||
bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID"
|
bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID"
|
||||||
@ -120,11 +121,28 @@ const (
|
|||||||
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
|
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
|
||||||
)
|
)
|
||||||
|
|
||||||
const enableConfigAKAndURL string = "ENABLE_CONFIG_AK_AND_URL"
|
const enableVerifiInfoInParams string = "ENABLE_VERIFI_INFO_IN_PARAMS"
|
||||||
|
|
||||||
func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) {
|
func isEnableVerifiInfoInParamsKey(confParams map[string]string) bool {
|
||||||
|
enable := true
|
||||||
|
if strings.ToLower(confParams[enableVerifiInfoInParamsKey]) != "" {
|
||||||
|
// If enableVerifiInfoInParamsKey is configured in milvus.yaml, the configuration in milvus.yaml will be used.
|
||||||
|
enable, _ = strconv.ParseBool(confParams[enableVerifiInfoInParamsKey])
|
||||||
|
} else {
|
||||||
|
// If enableVerifiInfoInParamsKey is not configured in milvus.yaml, the configuration in env will be used.
|
||||||
|
if strings.ToLower(os.Getenv(enableVerifiInfoInParams)) != "" {
|
||||||
|
enable, _ = strconv.ParseBool(confParams[enableVerifiInfoInParamsKey])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return enable
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAKAndURL(params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string) {
|
||||||
|
// function param > env > yaml
|
||||||
var apiKey, url string
|
var apiKey, url string
|
||||||
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
|
|
||||||
|
// from function params
|
||||||
|
if isEnableVerifiInfoInParamsKey(confParams) {
|
||||||
for _, param := range params {
|
for _, param := range params {
|
||||||
switch strings.ToLower(param.Key) {
|
switch strings.ToLower(param.Key) {
|
||||||
case apiKeyParamKey:
|
case apiKeyParamKey:
|
||||||
@ -134,6 +152,20 @@ func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// from milvus.yaml
|
||||||
|
if apiKey == "" {
|
||||||
|
apiKey = confParams[apiKeyParamKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
if url == "" {
|
||||||
|
url = confParams[embeddingURLParamKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
// from env, url doesn't support configuration in in env
|
||||||
|
if apiKey == "" {
|
||||||
|
url = os.Getenv(apiKeyEnv)
|
||||||
|
}
|
||||||
return apiKey, url
|
return apiKey, url
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFunctionExecutor(t *testing.T) {
|
func TestFunctionExecutor(t *testing.T) {
|
||||||
@ -46,6 +47,10 @@ type FunctionExecutorSuite struct {
|
|||||||
suite.Suite
|
suite.Suite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FunctionExecutorSuite) SetupTest() {
|
||||||
|
paramtable.Init()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema {
|
func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema {
|
||||||
return &schemapb.CollectionSchema{
|
return &schemapb.CollectionSchema{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
|
|||||||
@ -36,11 +36,11 @@ func send(req *http.Request) ([]byte, error) {
|
|||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Call service faild, read response failed, errs:[%v]", err)
|
return nil, fmt.Errorf("Call service failed, read response failed, errs:[%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("Call service faild, errs:[%s, %s]", resp.Status, body)
|
return nil, fmt.Errorf("Call service failed, errs:[%s, %s]", resp.Status, body)
|
||||||
}
|
}
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,9 +41,6 @@ type OpenAIEmbeddingProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) {
|
func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) {
|
||||||
if apiKey == "" {
|
|
||||||
apiKey = os.Getenv(openaiAKEnvStr)
|
|
||||||
}
|
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr)
|
||||||
}
|
}
|
||||||
@ -56,16 +53,16 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbed
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) {
|
func createAzureOpenAIEmbeddingClient(apiKey string, url string, resourceName string) (*openai.AzureOpenAIEmbeddingClient, error) {
|
||||||
if apiKey == "" {
|
|
||||||
apiKey = os.Getenv(azureOpenaiAKEnvStr)
|
|
||||||
}
|
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if url == "" {
|
if url == "" {
|
||||||
if resourceName := os.Getenv(azureOpenaiResourceName); resourceName != "" {
|
if resourceName == "" {
|
||||||
|
resourceName = os.Getenv(azureOpenaiResourceName)
|
||||||
|
}
|
||||||
|
if resourceName != "" {
|
||||||
url = fmt.Sprintf("https://%s.openai.azure.com", resourceName)
|
url = fmt.Sprintf("https://%s.openai.azure.com", resourceName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,15 +73,14 @@ func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureO
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, isAzure bool) (*OpenAIEmbeddingProvider, error) {
|
func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string, isAzure bool) (*OpenAIEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
|
||||||
var modelName, user string
|
var modelName, user string
|
||||||
var dim int64
|
var dim int64
|
||||||
|
|
||||||
for _, param := range functionSchema.Params {
|
for _, param := range functionSchema.Params {
|
||||||
switch strings.ToLower(param.Key) {
|
switch strings.ToLower(param.Key) {
|
||||||
case modelNameParamKey:
|
case modelNameParamKey:
|
||||||
@ -102,12 +98,15 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
|||||||
|
|
||||||
var c openai.OpenAIEmbeddingInterface
|
var c openai.OpenAIEmbeddingInterface
|
||||||
if !isAzure {
|
if !isAzure {
|
||||||
|
apiKey, url := parseAKAndURL(functionSchema.Params, params, openaiAKEnvStr)
|
||||||
c, err = createOpenAIEmbeddingClient(apiKey, url)
|
c, err = createOpenAIEmbeddingClient(apiKey, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c, err = createAzureOpenAIEmbeddingClient(apiKey, url)
|
apiKey, url := parseAKAndURL(functionSchema.Params, params, azureOpenaiAKEnvStr)
|
||||||
|
resourceName := params["resource_name"]
|
||||||
|
c, err = createAzureOpenAIEmbeddingClient(apiKey, url, resourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -125,12 +124,12 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
|||||||
return &provider, nil
|
return &provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) {
|
func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*OpenAIEmbeddingProvider, error) {
|
||||||
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, false)
|
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) {
|
func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*OpenAIEmbeddingProvider, error) {
|
||||||
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, true)
|
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, params, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (provider *OpenAIEmbeddingProvider) MaxBatch() int {
|
func (provider *OpenAIEmbeddingProvider) MaxBatch() int {
|
||||||
|
|||||||
@ -77,9 +77,9 @@ func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case openAIProvider:
|
case openAIProvider:
|
||||||
return NewOpenAIEmbeddingProvider(schema, functionSchema)
|
return NewOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
case azureOpenAIProvider:
|
case azureOpenAIProvider:
|
||||||
return NewAzureOpenAIEmbeddingProvider(schema, functionSchema)
|
return NewAzureOpenAIEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -181,27 +181,15 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
|
|||||||
func (s *OpenAITextEmbeddingProviderSuite) TestCreateOpenAIEmbeddingClient() {
|
func (s *OpenAITextEmbeddingProviderSuite) TestCreateOpenAIEmbeddingClient() {
|
||||||
_, err := createOpenAIEmbeddingClient("", "")
|
_, err := createOpenAIEmbeddingClient("", "")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
os.Setenv(openaiAKEnvStr, "mockKey")
|
|
||||||
defer os.Unsetenv(openaiAKEnvStr)
|
|
||||||
|
|
||||||
_, err = createOpenAIEmbeddingClient("", "")
|
|
||||||
s.NoError(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAITextEmbeddingProviderSuite) TestCreateAzureOpenAIEmbeddingClient() {
|
func (s *OpenAITextEmbeddingProviderSuite) TestCreateAzureOpenAIEmbeddingClient() {
|
||||||
_, err := createAzureOpenAIEmbeddingClient("", "")
|
_, err := createAzureOpenAIEmbeddingClient("", "", "")
|
||||||
s.Error(err)
|
|
||||||
|
|
||||||
os.Setenv(azureOpenaiAKEnvStr, "mockKey")
|
|
||||||
defer os.Unsetenv(azureOpenaiAKEnvStr)
|
|
||||||
|
|
||||||
_, err = createAzureOpenAIEmbeddingClient("", "")
|
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
os.Setenv(azureOpenaiResourceName, "mockResource")
|
os.Setenv(azureOpenaiResourceName, "mockResource")
|
||||||
defer os.Unsetenv(azureOpenaiResourceName)
|
defer os.Unsetenv(azureOpenaiResourceName)
|
||||||
|
|
||||||
_, err = createAzureOpenAIEmbeddingClient("", "")
|
_, err = createAzureOpenAIEmbeddingClient("mock", "", "")
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,7 +20,6 @@ package function
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
@ -40,9 +39,6 @@ type SiliconflowEmbeddingProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.SiliconflowEmbedding, error) {
|
func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.SiliconflowEmbedding, error) {
|
||||||
if apiKey == "" {
|
|
||||||
apiKey = os.Getenv(siliconflowAKEnvStr)
|
|
||||||
}
|
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", siliconflowAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", siliconflowAKEnvStr)
|
||||||
}
|
}
|
||||||
@ -55,12 +51,12 @@ func createSiliconflowEmbeddingClient(apiKey string, url string) (*siliconflow.S
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*SiliconflowEmbeddingProvider, error) {
|
func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*SiliconflowEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
apiKey, url := parseAKAndURL(functionSchema.Params, params, siliconflowAKEnvStr)
|
||||||
var modelName string
|
var modelName string
|
||||||
|
|
||||||
for _, param := range functionSchema.Params {
|
for _, param := range functionSchema.Params {
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@ -76,7 +75,7 @@ func createSiliconflowProvider(url string, schema *schemapb.FieldSchema, provide
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case siliconflowProvider:
|
case siliconflowProvider:
|
||||||
return NewSiliconflowEmbeddingProvider(schema, functionSchema)
|
return NewSiliconflowEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -171,11 +170,6 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
|
|||||||
func (s *SiliconflowTextEmbeddingProviderSuite) TestCreateSiliconflowEmbeddingClient() {
|
func (s *SiliconflowTextEmbeddingProviderSuite) TestCreateSiliconflowEmbeddingClient() {
|
||||||
_, err := createSiliconflowEmbeddingClient("", "")
|
_, err := createSiliconflowEmbeddingClient("", "")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
os.Setenv(siliconflowAKEnvStr, "mockKey")
|
|
||||||
defer os.Unsetenv(siliconflowAKEnvStr)
|
|
||||||
_, err = createSiliconflowEmbeddingClient("", "")
|
|
||||||
s.NoError(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvider() {
|
func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvider() {
|
||||||
@ -192,7 +186,7 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi
|
|||||||
{Key: embeddingURLParamKey, Value: "mock"},
|
{Key: embeddingURLParamKey, Value: "mock"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
provider, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.FieldDim(), int64(4))
|
s.Equal(provider.FieldDim(), int64(4))
|
||||||
s.True(provider.MaxBatch() > 0)
|
s.True(provider.MaxBatch() > 0)
|
||||||
|
|||||||
@ -52,7 +52,7 @@ func createTEIEmbeddingClient(apiKey string, endpoint string) (*tei.TEIEmbedding
|
|||||||
return tei.NewTEIEmbeddingClient(apiKey, endpoint)
|
return tei.NewTEIEmbeddingClient(apiKey, endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*TeiEmbeddingProvider, error) {
|
func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*TeiEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -76,7 +76,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case teiProvider:
|
case teiProvider:
|
||||||
return NewTEIEmbeddingProvider(schema, functionSchema)
|
return NewTEIEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -172,7 +172,7 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() {
|
|||||||
{Key: endpointParamKey, Value: "http://mymock.com"},
|
{Key: endpointParamKey, Value: "http://mymock.com"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.FieldDim(), int64(4))
|
s.Equal(provider.FieldDim(), int64(4))
|
||||||
s.True(provider.MaxBatch() == 32*5)
|
s.True(provider.MaxBatch() == 32*5)
|
||||||
@ -180,35 +180,35 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() {
|
|||||||
// Invalid truncate
|
// Invalid truncate
|
||||||
{
|
{
|
||||||
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "Invalid"})
|
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "Invalid"})
|
||||||
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
// Invalid truncationDirection
|
// Invalid truncationDirection
|
||||||
{
|
{
|
||||||
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: truncateParamKey, Value: "true"}
|
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: truncateParamKey, Value: "true"}
|
||||||
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Invalid"})
|
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Invalid"})
|
||||||
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncationDirection
|
// truncationDirection
|
||||||
{
|
{
|
||||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Left"}
|
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: truncationDirectionParamKey, Value: "Left"}
|
||||||
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalid max batch
|
// Invalid max batch
|
||||||
{
|
{
|
||||||
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "Invalid"})
|
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "Invalid"})
|
||||||
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid max batch
|
// Valid max batch
|
||||||
{
|
{
|
||||||
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "128"}
|
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: maxClientBatchSizeParamKey, Value: "128"}
|
||||||
pv, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
pv, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.True(pv.MaxBatch() == 128*5)
|
s.True(pv.MaxBatch() == 128*5)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -45,6 +46,15 @@ const (
|
|||||||
teiProvider string = "tei"
|
teiProvider string = "tei"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func hasEmptyString(texts []string) bool {
|
||||||
|
for _, text := range texts {
|
||||||
|
if text == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error {
|
func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error {
|
||||||
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
|
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
|
||||||
return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field")
|
return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field")
|
||||||
@ -85,25 +95,26 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
|
|||||||
|
|
||||||
var embP textEmbeddingProvider
|
var embP textEmbeddingProvider
|
||||||
var newProviderErr error
|
var newProviderErr error
|
||||||
|
conf := paramtable.Get().FunctionCfg.GetTextEmbeddingProviderConfig(base.provider)
|
||||||
switch base.provider {
|
switch base.provider {
|
||||||
case openAIProvider:
|
case openAIProvider:
|
||||||
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
case azureOpenAIProvider:
|
case azureOpenAIProvider:
|
||||||
embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
case bedrockProvider:
|
case bedrockProvider:
|
||||||
embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil)
|
embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf)
|
||||||
case aliDashScopeProvider:
|
case aliDashScopeProvider:
|
||||||
embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
case vertexAIProvider:
|
case vertexAIProvider:
|
||||||
embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil)
|
embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf)
|
||||||
case voyageAIProvider:
|
case voyageAIProvider:
|
||||||
embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
case cohereProvider:
|
case cohereProvider:
|
||||||
embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
case siliconflowProvider:
|
case siliconflowProvider:
|
||||||
embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
case teiProvider:
|
case teiProvider:
|
||||||
embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema)
|
embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema, conf)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
|
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
|
||||||
}
|
}
|
||||||
@ -213,10 +224,16 @@ func (runner *TextEmbeddingFunction) ProcessInsert(ctx context.Context, inputs [
|
|||||||
if texts == nil {
|
if texts == nil {
|
||||||
return nil, fmt.Errorf("Input texts is empty")
|
return nil, fmt.Errorf("Input texts is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make sure all texts are not empty
|
||||||
|
if hasEmptyString(texts) {
|
||||||
|
return nil, fmt.Errorf("There is an empty string in the input data, TextEmbedding function does not support empty text")
|
||||||
|
}
|
||||||
numRows := len(texts)
|
numRows := len(texts)
|
||||||
if numRows > runner.MaxBatch() {
|
if numRows > runner.MaxBatch() {
|
||||||
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
||||||
}
|
}
|
||||||
|
|
||||||
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -231,6 +248,10 @@ func (runner *TextEmbeddingFunction) ProcessSearch(ctx context.Context, placehol
|
|||||||
if numRows > runner.MaxBatch() {
|
if numRows > runner.MaxBatch() {
|
||||||
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
||||||
}
|
}
|
||||||
|
// make sure all texts are not empty
|
||||||
|
if hasEmptyString(texts) {
|
||||||
|
return nil, fmt.Errorf("There is an empty string in the queries, TextEmbedding function does not support empty text")
|
||||||
|
}
|
||||||
embds, err := runner.embProvider.CallEmbedding(texts, SearchMode)
|
embds, err := runner.embProvider.CallEmbedding(texts, SearchMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -257,6 +278,11 @@ func (runner *TextEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldDat
|
|||||||
return nil, fmt.Errorf("Input texts is empty")
|
return nil, fmt.Errorf("Input texts is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make sure all texts are not empty
|
||||||
|
// In storage.FieldData, null is also stored as an empty string
|
||||||
|
if hasEmptyString(texts) {
|
||||||
|
return nil, fmt.Errorf("There is an empty string in the input data, TextEmbedding function does not support empty text")
|
||||||
|
}
|
||||||
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
"github.com/milvus-io/milvus/internal/util/testutil"
|
"github.com/milvus-io/milvus/internal/util/testutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTextEmbeddingFunction(t *testing.T) {
|
func TestTextEmbeddingFunction(t *testing.T) {
|
||||||
@ -43,6 +44,7 @@ type TextEmbeddingFunctionSuite struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TextEmbeddingFunctionSuite) SetupTest() {
|
func (s *TextEmbeddingFunctionSuite) SetupTest() {
|
||||||
|
paramtable.Init()
|
||||||
s.schema = &schemapb.CollectionSchema{
|
s.schema = &schemapb.CollectionSchema{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Fields: []*schemapb.FieldSchema{
|
Fields: []*schemapb.FieldSchema{
|
||||||
@ -272,7 +274,29 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
|||||||
Scalars: &schemapb.ScalarField{
|
Scalars: &schemapb.ScalarField{
|
||||||
Data: &schemapb.ScalarField_StringData{
|
Data: &schemapb.ScalarField_StringData{
|
||||||
StringData: &schemapb.StringArray{
|
StringData: &schemapb.StringArray{
|
||||||
Data: strings.Split(strings.Repeat("Element,", 1000), ","),
|
Data: strings.Split(strings.Repeat("Element,", 1000), ",")[:999],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data = append(data, &f)
|
||||||
|
_, err := runner.ProcessInsert(context.Background(), data)
|
||||||
|
s.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// empty string
|
||||||
|
{
|
||||||
|
data := []*schemapb.FieldData{}
|
||||||
|
f := schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_VarChar,
|
||||||
|
FieldId: 101,
|
||||||
|
IsDynamic: false,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: strings.Split(strings.Repeat("Element,", 10), ","),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -610,7 +634,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
|
|||||||
Scalars: &schemapb.ScalarField{
|
Scalars: &schemapb.ScalarField{
|
||||||
Data: &schemapb.ScalarField_StringData{
|
Data: &schemapb.ScalarField_StringData{
|
||||||
StringData: &schemapb.StringArray{
|
StringData: &schemapb.StringArray{
|
||||||
Data: strings.Split(strings.Repeat("Element,", 1000), ","),
|
Data: strings.Split(strings.Repeat("Element,", 1000), ",")[0:999],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -635,7 +659,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
|
|||||||
Scalars: &schemapb.ScalarField{
|
Scalars: &schemapb.ScalarField{
|
||||||
Data: &schemapb.ScalarField_StringData{
|
Data: &schemapb.ScalarField_StringData{
|
||||||
StringData: &schemapb.StringArray{
|
StringData: &schemapb.StringArray{
|
||||||
Data: strings.Split(strings.Repeat("Element,", 100), ","),
|
Data: strings.Split(strings.Repeat("Element,", 100), ",")[:99],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -777,7 +801,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
|
|||||||
Scalars: &schemapb.ScalarField{
|
Scalars: &schemapb.ScalarField{
|
||||||
Data: &schemapb.ScalarField_StringData{
|
Data: &schemapb.ScalarField_StringData{
|
||||||
StringData: &schemapb.StringArray{
|
StringData: &schemapb.StringArray{
|
||||||
Data: strings.Split(strings.Repeat("Element,", 100), ","),
|
Data: strings.Split(strings.Repeat("Element,", 100), ",")[:99],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -791,6 +815,31 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
|
|||||||
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
|
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// empty text
|
||||||
|
{
|
||||||
|
f := &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_VarChar,
|
||||||
|
FieldId: 101,
|
||||||
|
IsDynamic: false,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: strings.Split(strings.Repeat("Element,", 100), ","),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
|
||||||
|
s.NoError(err)
|
||||||
|
placeholderGroup := commonpb.PlaceholderGroup{}
|
||||||
|
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
|
||||||
|
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
|
||||||
|
s.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() {
|
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() {
|
||||||
@ -834,6 +883,15 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertFloat32() {
|
|||||||
_, err := runner.ProcessBulkInsert(input)
|
_, err := runner.ProcessBulkInsert(input)
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// empty texts
|
||||||
|
{
|
||||||
|
input := []storage.FieldData{data.Data[101]}
|
||||||
|
err := input[0].AppendRow("")
|
||||||
|
s.NoError(err)
|
||||||
|
_, err = runner.ProcessBulkInsert(input)
|
||||||
|
s.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
|
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
|
||||||
|
|||||||
@ -30,24 +30,40 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type vertexAIJsonKey struct {
|
type vertexAIJsonKey struct {
|
||||||
jsonKey []byte
|
mu sync.Mutex
|
||||||
once sync.Once
|
filePath string
|
||||||
initErr error
|
jsonKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var vtxKey vertexAIJsonKey
|
var vtxKey vertexAIJsonKey
|
||||||
|
|
||||||
func getVertexAIJsonKey() ([]byte, error) {
|
func getVertexAIJsonKey(credentialsFilePath string) ([]byte, error) {
|
||||||
vtxKey.once.Do(func() {
|
vtxKey.mu.Lock()
|
||||||
jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv)
|
defer vtxKey.mu.Unlock()
|
||||||
jsonKey, err := os.ReadFile(jsonKeyPath)
|
|
||||||
if err != nil {
|
var jsonKeyPath string
|
||||||
vtxKey.initErr = fmt.Errorf("Vertexai: read service account json file failed, %v", err)
|
if credentialsFilePath == "" {
|
||||||
return
|
jsonKeyPath = os.Getenv(vertexServiceAccountJSONEnv)
|
||||||
}
|
} else {
|
||||||
vtxKey.jsonKey = jsonKey
|
jsonKeyPath = credentialsFilePath
|
||||||
})
|
}
|
||||||
return vtxKey.jsonKey, vtxKey.initErr
|
if jsonKeyPath == "" {
|
||||||
|
return nil, fmt.Errorf("VetexAI credentials file path is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if vtxKey.filePath == jsonKeyPath {
|
||||||
|
// The file path remains unchanged, using the data in the cache
|
||||||
|
return vtxKey.jsonKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonKey, err := os.ReadFile(jsonKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Vertexai: read credentials file failed, %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vtxKey.jsonKey = jsonKey
|
||||||
|
vtxKey.filePath = jsonKeyPath
|
||||||
|
return vtxKey.jsonKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -68,8 +84,8 @@ type VertexAIEmbeddingProvider struct {
|
|||||||
timeoutSec int64
|
timeoutSec int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func createVertexAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) {
|
func createVertexAIEmbeddingClient(url string, credentialsFilePath string) (*vertexai.VertexAIEmbedding, error) {
|
||||||
jsonKey, err := getVertexAIJsonKey()
|
jsonKey, err := getVertexAIJsonKey(credentialsFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -77,7 +93,7 @@ func createVertexAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, err
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertexAIEmbeddingProvider, error) {
|
func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding, params map[string]string) (*VertexAIEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -112,10 +128,13 @@ func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
|
|||||||
location = "us-central1"
|
location = "us-central1"
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName)
|
url := params["url"]
|
||||||
|
if url == "" {
|
||||||
|
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName)
|
||||||
|
}
|
||||||
var client *vertexai.VertexAIEmbedding
|
var client *vertexai.VertexAIEmbedding
|
||||||
if c == nil {
|
if c == nil {
|
||||||
client, err = createVertexAIEmbeddingClient(url)
|
client, err = createVertexAIEmbeddingClient(url, params["credentials_file_path"])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -76,7 +76,7 @@ func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbed
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token")
|
mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token")
|
||||||
return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient)
|
return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient, map[string]string{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() {
|
func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() {
|
||||||
@ -177,7 +177,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
|
|||||||
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
|
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
|
||||||
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
||||||
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
||||||
_, err := getVertexAIJsonKey()
|
_, err := getVertexAIJsonKey("")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,7 +198,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
|||||||
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
|
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
|
||||||
|
|
||||||
{
|
{
|
||||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
|
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
|
||||||
s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY")
|
s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY")
|
||||||
@ -206,7 +206,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
|||||||
|
|
||||||
{
|
{
|
||||||
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival})
|
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival})
|
||||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
|
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
|
||||||
s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY")
|
s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY")
|
||||||
@ -214,7 +214,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
|||||||
|
|
||||||
{
|
{
|
||||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS}
|
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS}
|
||||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
|
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
|
||||||
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
|
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
|
||||||
@ -224,7 +224,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
|||||||
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
|
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
|
||||||
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
||||||
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
||||||
_, err := createVertexAIEmbeddingClient("https://mock_url.com")
|
_, err := createVertexAIEmbeddingClient("https://mock_url.com", "")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider()
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
|
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
|
||||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.True(provider.MaxBatch() > 0)
|
s.True(provider.MaxBatch() > 0)
|
||||||
s.Equal(provider.FieldDim(), int64(4))
|
s.Equal(provider.FieldDim(), int64(4))
|
||||||
|
|||||||
@ -20,7 +20,6 @@ package function
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -44,9 +43,6 @@ type VoyageAIEmbeddingProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) {
|
func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) {
|
||||||
if apiKey == "" {
|
|
||||||
apiKey = os.Getenv(voyageAIAKEnvStr)
|
|
||||||
}
|
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr)
|
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr)
|
||||||
}
|
}
|
||||||
@ -59,12 +55,12 @@ func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageA
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*VoyageAIEmbeddingProvider, error) {
|
func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, params map[string]string) (*VoyageAIEmbeddingProvider, error) {
|
||||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
apiKey, url := parseAKAndURL(functionSchema.Params, params, voyageAIAKEnvStr)
|
||||||
var modelName string
|
var modelName string
|
||||||
dim := int64(0)
|
dim := int64(0)
|
||||||
truncate := false
|
truncate := false
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@ -77,7 +76,7 @@ func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerNa
|
|||||||
}
|
}
|
||||||
switch providerName {
|
switch providerName {
|
||||||
case voyageAIProvider:
|
case voyageAIProvider:
|
||||||
return NewVoyageAIEmbeddingProvider(schema, functionSchema)
|
return NewVoyageAIEmbeddingProvider(schema, functionSchema, map[string]string{})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknow provider")
|
return nil, fmt.Errorf("Unknow provider")
|
||||||
}
|
}
|
||||||
@ -282,11 +281,6 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
|
|||||||
func (s *VoyageAITextEmbeddingProviderSuite) TestCreateVoyageAIEmbeddingClient() {
|
func (s *VoyageAITextEmbeddingProviderSuite) TestCreateVoyageAIEmbeddingClient() {
|
||||||
_, err := createVoyageAIEmbeddingClient("", "")
|
_, err := createVoyageAIEmbeddingClient("", "")
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
|
|
||||||
os.Setenv(voyageAIAKEnvStr, "mockKey")
|
|
||||||
defer os.Unsetenv(voyageAIAKEnvStr)
|
|
||||||
_, err = createVoyageAIEmbeddingClient("", "")
|
|
||||||
s.NoError(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() {
|
func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() {
|
||||||
@ -305,7 +299,7 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
|
|||||||
{Key: truncationParamKey, Value: "true"},
|
{Key: truncationParamKey, Value: "true"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Equal(provider.FieldDim(), int64(1024))
|
s.Equal(provider.FieldDim(), int64(1024))
|
||||||
s.True(provider.MaxBatch() > 0)
|
s.True(provider.MaxBatch() > 0)
|
||||||
@ -313,7 +307,7 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
|
|||||||
// Invalid truncation
|
// Invalid truncation
|
||||||
{
|
{
|
||||||
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"}
|
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"}
|
||||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"}
|
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"}
|
||||||
}
|
}
|
||||||
@ -321,14 +315,14 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
|
|||||||
// Invalid dim
|
// Invalid dim
|
||||||
{
|
{
|
||||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"}
|
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"}
|
||||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalid dim type
|
// Invalid dim type
|
||||||
{
|
{
|
||||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"}
|
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"}
|
||||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{})
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -84,6 +84,7 @@ type ComponentParam struct {
|
|||||||
RoleCfg roleConfig
|
RoleCfg roleConfig
|
||||||
RbacConfig rbacConfig
|
RbacConfig rbacConfig
|
||||||
StreamingCfg streamingConfig
|
StreamingCfg streamingConfig
|
||||||
|
FunctionCfg functionConfig
|
||||||
|
|
||||||
InternalTLSCfg InternalTLSConfig
|
InternalTLSCfg InternalTLSConfig
|
||||||
|
|
||||||
@ -138,6 +139,7 @@ func (p *ComponentParam) init(bt *BaseTable) {
|
|||||||
p.RbacConfig.init(bt)
|
p.RbacConfig.init(bt)
|
||||||
p.GpuConfig.init(bt)
|
p.GpuConfig.init(bt)
|
||||||
p.KnowhereConfig.init(bt)
|
p.KnowhereConfig.init(bt)
|
||||||
|
p.FunctionCfg.init(bt)
|
||||||
|
|
||||||
p.InternalTLSCfg.Init(bt)
|
p.InternalTLSCfg.Init(bt)
|
||||||
|
|
||||||
|
|||||||
105
pkg/util/paramtable/function_param.go
Normal file
105
pkg/util/paramtable/function_param.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package paramtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type functionConfig struct {
|
||||||
|
TextEmbeddingEnableVerifiInfoInParams ParamItem `refreshable:"true"`
|
||||||
|
TextEmbeddingProviders ParamGroup `refreshable:"true"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *functionConfig) init(base *BaseTable) {
|
||||||
|
p.TextEmbeddingEnableVerifiInfoInParams = ParamItem{
|
||||||
|
Key: "function.textEmbedding.enableVerifiInfoInParams",
|
||||||
|
Version: "2.6.0",
|
||||||
|
DefaultValue: "true",
|
||||||
|
Export: true,
|
||||||
|
Doc: "Controls whether to allow configuration of apikey and model service url on function parameters",
|
||||||
|
}
|
||||||
|
p.TextEmbeddingEnableVerifiInfoInParams.Init(base.mgr)
|
||||||
|
|
||||||
|
p.TextEmbeddingProviders = ParamGroup{
|
||||||
|
KeyPrefix: "function.textEmbedding.providers.",
|
||||||
|
Version: "2.6.0",
|
||||||
|
Export: true,
|
||||||
|
DocFunc: func(key string) string {
|
||||||
|
switch key {
|
||||||
|
case "tei.enable":
|
||||||
|
return "Whether to enable TEI model service"
|
||||||
|
case "azure_openai.api_key":
|
||||||
|
return "Your azure openai embedding url, Default is the official embedding url"
|
||||||
|
case "azure_openai.url":
|
||||||
|
return "Your azure openai api key"
|
||||||
|
case "azure_openai.resource_name":
|
||||||
|
return "Your azure openai resource name"
|
||||||
|
case "openai.api_key":
|
||||||
|
return "Your openai embedding url, Default is the official embedding url"
|
||||||
|
case "openai.url":
|
||||||
|
return "Your openai api key"
|
||||||
|
case "dashscope.api_key":
|
||||||
|
return "Your dashscope embedding url, Default is the official embedding url"
|
||||||
|
case "dashscope.url":
|
||||||
|
return "Your dashscope api key"
|
||||||
|
case "cohere.api_key":
|
||||||
|
return "Your cohere embedding url, Default is the official embedding url"
|
||||||
|
case "cohere.url":
|
||||||
|
return "Your cohere api key"
|
||||||
|
case "voyageai.api_key":
|
||||||
|
return "Your voyageai embedding url, Default is the official embedding url"
|
||||||
|
case "voyageai.url":
|
||||||
|
return "Your voyageai api key"
|
||||||
|
case "siliconflow.url":
|
||||||
|
return "Your siliconflow embedding url, Default is the official embedding url"
|
||||||
|
case "siliconflow.api_key":
|
||||||
|
return "Your siliconflow api key"
|
||||||
|
case "bedrock.aws_access_key_id":
|
||||||
|
return "Your aws_access_key_id"
|
||||||
|
case "bedrock.aws_secret_access_key":
|
||||||
|
return "Your aws_secret_access_key"
|
||||||
|
case "vertexai.url":
|
||||||
|
return "Your VertexAI embedding url"
|
||||||
|
case "vertexai.credentials_file_path":
|
||||||
|
return "Path to your google application credentials, change the file path to refresh the configuration"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.TextEmbeddingProviders.Init(base.mgr)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
textEmbeddingKey string = "textEmbedding"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map[string]string {
|
||||||
|
matchedParam := make(map[string]string)
|
||||||
|
|
||||||
|
params := p.TextEmbeddingProviders.GetValue()
|
||||||
|
prefix := providerName + "."
|
||||||
|
|
||||||
|
for k, v := range params {
|
||||||
|
if strings.HasPrefix(k, prefix) {
|
||||||
|
matchedParam[strings.TrimPrefix(k, prefix)] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
matchedParam["enableVerifiInfoInParams"] = p.TextEmbeddingEnableVerifiInfoInParams.GetValue()
|
||||||
|
return matchedParam
|
||||||
|
}
|
||||||
65
pkg/util/paramtable/function_param_test.go
Normal file
65
pkg/util/paramtable/function_param_test.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package paramtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFunctionConfig(t *testing.T) {
|
||||||
|
params := ComponentParam{}
|
||||||
|
params.Init(NewBaseTable(SkipRemote(true)))
|
||||||
|
cfg := ¶ms.FunctionCfg
|
||||||
|
notExistProvider := cfg.GetTextEmbeddingProviderConfig("notExist")
|
||||||
|
|
||||||
|
// Only has enableVerifiInfoInParams config
|
||||||
|
assert.Equal(t, len(notExistProvider), 1)
|
||||||
|
|
||||||
|
teiConf := cfg.GetTextEmbeddingProviderConfig("tei")
|
||||||
|
assert.Equal(t, teiConf["enable"], "true")
|
||||||
|
assert.Equal(t, teiConf["enableVerifiInfoInParams"], "true")
|
||||||
|
openaiConf := cfg.GetTextEmbeddingProviderConfig("openai")
|
||||||
|
assert.Equal(t, openaiConf["api_key"], "")
|
||||||
|
assert.Equal(t, openaiConf["url"], "")
|
||||||
|
assert.Equal(t, openaiConf["enableVerifiInfoInParams"], "true")
|
||||||
|
|
||||||
|
keys := []string{
|
||||||
|
"tei.enable",
|
||||||
|
"azure_openai.api_key",
|
||||||
|
"azure_openai.url",
|
||||||
|
"azure_openai.resource_name",
|
||||||
|
"openai.api_key",
|
||||||
|
"openai.url",
|
||||||
|
"dashscope.api_key",
|
||||||
|
"dashscope.url",
|
||||||
|
"cohere.api_key",
|
||||||
|
"cohere.url",
|
||||||
|
"voyageai.api_key",
|
||||||
|
"voyageai.url",
|
||||||
|
"siliconflow.url",
|
||||||
|
"siliconflow.api_key",
|
||||||
|
"bedrock.aws_access_key_id",
|
||||||
|
"bedrock.aws_secret_access_key",
|
||||||
|
"vertexai.url",
|
||||||
|
"vertexai.credentials_file_path",
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
assert.True(t, cfg.TextEmbeddingProviders.GetDoc(key) != "")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -346,6 +346,7 @@ class TestInsertWithTextEmbeddingNegative(TestcaseBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
|
@pytest.mark.skip("not support empty document now")
|
||||||
def test_insert_with_text_embedding_empty_document(self, tei_endpoint):
|
def test_insert_with_text_embedding_empty_document(self, tei_endpoint):
|
||||||
"""
|
"""
|
||||||
target: test insert data with empty document
|
target: test insert data with empty document
|
||||||
@ -389,6 +390,7 @@ class TestInsertWithTextEmbeddingNegative(TestcaseBase):
|
|||||||
assert collection_w.num_entities == 0
|
assert collection_w.num_entities == 0
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
def test_insert_with_text_embedding_long_document(self, tei_endpoint):
|
def test_insert_with_text_embedding_long_document(self, tei_endpoint):
|
||||||
"""
|
"""
|
||||||
target: test insert data with long document
|
target: test insert data with long document
|
||||||
@ -663,6 +665,7 @@ class TestSearchWithTextEmbeddingNegative(TestcaseBase):
|
|||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
@pytest.mark.parametrize("query", ["empty_query", "long_query"])
|
@pytest.mark.parametrize("query", ["empty_query", "long_query"])
|
||||||
|
@pytest.mark.skip("not support empty query now")
|
||||||
def test_search_with_text_embedding_negative_query(self, query, tei_endpoint):
|
def test_search_with_text_embedding_negative_query(self, query, tei_endpoint):
|
||||||
"""
|
"""
|
||||||
target: test search with empty query or long query
|
target: test search with empty query or long query
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user