diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index e610a6894b..dd91cac2ee 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -163,13 +163,17 @@ func (it *insertTask) PreExecute(ctx context.Context) error { // Calculate embedding fields if function.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-call-function-udf") + defer sp.End() exec, err := function.NewFunctionExecutor(schema.CollectionSchema) if err != nil { return err } - if err := exec.ProcessInsert(it.insertMsg); err != nil { + sp.AddEvent("Create-function-udf") + if err := exec.ProcessInsert(ctx, it.insertMsg); err != nil { return err } + sp.AddEvent("Call-function-udf") } rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index f60cdaf35b..7a09a3a8a8 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -427,13 +427,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { var err error if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIds) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf") + defer sp.End() exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema) if err != nil { return err } - if err := exec.ProcessSearch(t.SearchRequest); err != nil { + sp.AddEvent("Create-function-udf") + if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil { return err } + sp.AddEvent("Call-function-udf") } t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId() @@ -516,13 +520,17 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { t.SearchRequest.GroupSize = queryInfo.GroupSize if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf") + defer sp.End() exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema) if err != nil { return err } - if err := exec.ProcessSearch(t.SearchRequest); err != nil { + sp.AddEvent("Create-function-udf") + if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil { return err } + sp.AddEvent("Call-function-udf") } log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index f663b99b36..49305f7759 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -20,7 +20,6 @@ import ( "bytes" "context" "encoding/binary" - "fmt" "math/rand" "strconv" "testing" @@ -40,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/proto/indexpb" @@ -1025,7 +1025,8 @@ func TestCreateCollectionTask(t *testing.T) { }) t.Run("collection with embedding function ", func(t *testing.T) { - fmt.Println(schema) + ts := function.CreateOpenAIEmbeddingServer() + defer ts.Close() schema.Functions = []*schemapb.FunctionSchema{ { Name: "test", @@ -1036,6 +1037,8 @@ func TestCreateCollectionTask(t *testing.T) { {Key: "provider", Value: "openai"}, {Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "128"}, }, }, } diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 537ac7a2ae..75552c7f4b 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -148,6 +148,8 @@ func (it *upsertTask) OnEnqueue() error { } func (it *upsertTask) insertPreExecute(ctx context.Context) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute") + defer sp.End() collectionName := it.upsertMsg.InsertMsg.CollectionName if err := validateCollectionName(collectionName); err != nil { log.Ctx(ctx).Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err)) @@ -156,13 +158,17 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { // Calculate embedding fields if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf") + defer sp.End() exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema) if err != nil { return err } - if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil { + sp.AddEvent("Create-function-udf") + if err := exec.ProcessInsert(ctx, it.upsertMsg.InsertMsg); err != nil { return err } + sp.AddEvent("Call-function-udf") } rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) // set upsertTask.insertRequest.rowIDs diff --git a/internal/proxy/util.go b/internal/proxy/util.go index fc53ed47c1..fa8ddce39e 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -720,8 +720,8 @@ func validateFunction(coll *schemapb.CollectionSchema) error { return nil } -func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error { - switch function.GetType() { +func checkFunctionOutputField(fSchema *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: if len(fields) != 1 { return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields)) @@ -731,8 +731,8 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String()) } case schemapb.FunctionType_TextEmbedding: - if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) { - return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field") + if err := function.TextEmbeddingOutputsCheck(fields); err != nil { + return err } default: return fmt.Errorf("check output field for unknown function type") diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 5449b1ff80..4eaebafb70 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" @@ -2850,6 +2851,8 @@ func TestValidateFunction(t *testing.T) { func TestValidateModelFunction(t *testing.T) { t.Run("Valid model function schema", func(t *testing.T) { + ts := function.CreateOpenAIEmbeddingServer() + defer ts.Close() schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ {Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}}, @@ -2877,7 +2880,7 @@ func TestValidateModelFunction(t *testing.T) { {Key: "provider", Value: "openai"}, {Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "api_key", Value: "mock"}, - {Key: "url", Value: "mock_url"}, + {Key: "url", Value: ts.URL}, {Key: "dim", Value: "4"}, }, }, diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index 5afa262845..7866ad795f 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -60,7 +60,8 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio if err != nil { return nil, err } - var apiKey, url, modelName string + apiKey, url := parseAKAndURL(functionSchema.Params) + var modelName string var dim int64 for _, param := range functionSchema.Params { @@ -72,26 +73,17 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio if err != nil { return nil, err } - case apiKeyParamKey: - apiKey = param.Value - case embeddingURLParamKey: - url = param.Value default: } } - if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", - modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3) - } - c, err := createAliEmbeddingClient(apiKey, url) if err != nil { return nil, err } maxBatch := 25 - if modelName == TextEmbeddingV3 { + if modelName == "text-embedding-v3" { maxBatch = 6 } diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 989d514093..368520c719 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -69,9 +69,9 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: TextEmbeddingV3}, - {Key: apiKeyParamKey, Value: "mock"}, + {Key: modelNameParamKey, Value: TestModel}, {Key: embeddingURLParamKey, Value: url}, + {Key: apiKeyParamKey, Value: "mock"}, {Key: dimParamKey, Value: "4"}, }, } @@ -85,8 +85,8 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { ts := CreateAliEmbeddingServer() - defer ts.Close() + for _, provderName := range s.providers { provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) s.NoError(err) @@ -186,18 +186,13 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() { InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: "UnkownModels"}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingURLParamKey, Value: "mock"}, {Key: dimParamKey, Value: "4"}, }, } + // invalid dim + functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} _, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema) s.Error(err) - - // invalid dim - functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TextEmbeddingV3} - functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} - _, err = NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema) - s.Error(err) } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index b46b2400ac..030eb8ad96 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -98,9 +98,13 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche return nil, err } case awsAKIdParamKey: - awsAccessKeyId = param.Value + if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" { + awsAccessKeyId = param.Value + } case awsSAKParamKey: - awsSecretAccessKey = param.Value + if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" { + awsSecretAccessKey = param.Value + } case regionParamKey: region = param.Value case normalizeParamKey: @@ -116,10 +120,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche } } - if modelName != BedRockTitanTextEmbeddingsV2 { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s]", - modelName, BedRockTitanTextEmbeddingsV2) - } var client BedrockClient if c == nil { client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region) diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index 64ac5462f0..a7478673a1 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -64,7 +64,7 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, {Key: dimParamKey, Value: "4"}, }, @@ -143,7 +143,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() { InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}, + {Key: modelNameParamKey, Value: TestModel}, {Key: awsAKIdParamKey, Value: "mock"}, {Key: awsSAKParamKey, Value: "mock"}, {Key: regionParamKey, Value: "mock"}, @@ -164,14 +164,8 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() { _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) s.Error(err) - // invalid model name - functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"} - functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"} - _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) - s.Error(err) - // invalid dim - functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2} + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel} functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"} _, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil) s.Error(err) diff --git a/internal/util/function/cohere_embedding_provider.go b/internal/util/function/cohere_embedding_provider.go index a9373b47a2..fd34923b22 100644 --- a/internal/util/function/cohere_embedding_provider.go +++ b/internal/util/function/cohere_embedding_provider.go @@ -62,16 +62,13 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem if err != nil { return nil, err } - var apiKey, url, modelName string + apiKey, url := parseAKAndURL(functionSchema.Params) + var modelName string truncate := "END" for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { case modelNameParamKey: modelName = param.Value - case apiKeyParamKey: - apiKey = param.Value - case embeddingURLParamKey: - url = param.Value case truncateParamKey: if param.Value != "NONE" && param.Value != "START" && param.Value != "END" { return nil, fmt.Errorf("Illegal parameters, %s only supports [NONE, START, END]", truncateParamKey) @@ -81,11 +78,6 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem } } - if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 && modelName != embedEnglishV20 && modelName != embedEnglishLightV20 && modelName != embedMultilingualV20 { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]", - modelName, embedEnglishV30, embedMultilingualV30, embedEnglishLightV30, embedMultilingualLightV30, embedEnglishV20, embedEnglishLightV20, embedMultilingualV20) - } - c, err := createCohereEmbeddingClient(apiKey, url) if err != nil { return nil, err @@ -103,12 +95,6 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem return "int8" }() - if outputType == "int8" { - if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 { - return nil, fmt.Errorf("Cohere text embedding model: [%s] doesn't supports int8. Valid for only v3 models.", modelName) - } - } - provider := CohereEmbeddingProvider{ client: c, fieldDim: fieldDim, @@ -132,7 +118,8 @@ func (provider *CohereEmbeddingProvider) FieldDim() int64 { // Specifies the type of input passed to the model. Required for embedding models v3 and higher. func (provider *CohereEmbeddingProvider) getInputType(mode TextEmbeddingMode) string { - if provider.modelName == embedEnglishV20 || provider.modelName == embedEnglishLightV20 || provider.modelName == embedMultilingualV20 { + // v2 models not support instructor + if strings.HasSuffix(provider.modelName, "v2.0") { return "" } if mode == InsertMode { diff --git a/internal/util/function/cohere_embedding_provider_test.go b/internal/util/function/cohere_embedding_provider_test.go index 0371cb203d..cfc5f10bd5 100644 --- a/internal/util/function/cohere_embedding_provider_test.go +++ b/internal/util/function/cohere_embedding_provider_test.go @@ -69,7 +69,7 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: embedEnglishLightV30}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: url}, }, @@ -143,25 +143,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingInt8() { s.Equal([][]int8{{0, 1, 2, 3}, {1, 2, 3, 4}, {2, 3, 4, 5}}, ret) } } - - // Invalid model name - { - functionSchema := &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldNames: []string{"text"}, - OutputFieldNames: []string{"vector"}, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, - Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: embedEnglishLightV20}, - {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingURLParamKey, Value: ts.URL}, - }, - } - _, err := NewCohereEmbeddingProvider(int8VecField, functionSchema) - s.Error(err) - } } func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { @@ -278,7 +259,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() { InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: embedEnglishLightV20}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, }, } @@ -296,12 +277,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() { functionSchema.Params[2].Value = "Unknow" _, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) s.Error(err) - - // Invalid ModelName - functionSchema.Params[2].Value = "END" - functionSchema.Params[0].Value = "Unknow" - _, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) - s.Error(err) } func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() { @@ -313,7 +288,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() { InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: embedEnglishLightV20}, + {Key: modelNameParamKey, Value: "model-v2.0"}, {Key: apiKeyParamKey, Value: "mock"}, }, } @@ -323,7 +298,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() { s.Equal(provider.getInputType(InsertMode), "") s.Equal(provider.getInputType(SearchMode), "") - functionSchema.Params[0].Value = embedEnglishLightV30 + functionSchema.Params[0].Value = "model-v3.0" provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema) s.NoError(err) s.Equal(provider.getInputType(InsertMode), "search_document") diff --git a/internal/util/function/common.go b/internal/util/function/common.go index dd836da32c..fd34a0d8a8 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -20,8 +20,11 @@ package function import ( "fmt" + "os" "strconv" + "strings" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -51,20 +54,12 @@ const ( // ali text embedding const ( - TextEmbeddingV1 string = "text-embedding-v1" - TextEmbeddingV2 string = "text-embedding-v2" - TextEmbeddingV3 string = "text-embedding-v3" - dashscopeAKEnvStr string = "MILVUSAI_DASHSCOPE_API_KEY" ) // openai/azure text embedding const ( - TextEmbeddingAda002 string = "text-embedding-ada-002" - TextEmbedding3Small string = "text-embedding-3-small" - TextEmbedding3Large string = "text-embedding-3-large" - openaiAKEnvStr string = "MILVUSAI_OPENAI_API_KEY" azureOpenaiAKEnvStr string = "MILVUSAI_AZURE_OPENAI_API_KEY" @@ -76,11 +71,10 @@ const ( // bedrock emebdding const ( - BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0" - awsAKIdParamKey string = "aws_access_key_id" - awsSAKParamKey string = "aws_secret_access_key" - regionParamKey string = "regin" - normalizeParamKey string = "normalize" + awsAKIdParamKey string = "aws_access_key_id" + awsSAKParamKey string = "aws_secret_access_key" + regionParamKey string = "regin" + normalizeParamKey string = "normalize" bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" bedrockSAKEnvStr string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY" @@ -93,48 +87,24 @@ const ( projectIDParamKey string = "projectid" taskTypeParamKey string = "task" - textEmbedding005 string = "text-embedding-005" - textMultilingualEmbedding002 string = "text-multilingual-embedding-002" - vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) // voyageAI const ( - voyage3Large string = "voyage-3-large" - voyage3 string = "voyage-3" - voyage3Lite string = "voyage-3-lite" - voyageCode3 string = "voyage-code-3" - voyageFinance2 string = "voyage-finance-2" - voyageLaw2 string = "voyage-law-2" - voyageCode2 string = "voyage-code-2" - - voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY" + truncationParamKey string = "truncation" + voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY" ) // cohere const ( - embedEnglishV30 string = "embed-english-v3.0" - embedMultilingualV30 string = "embed-multilingual-v3.0" - embedEnglishLightV30 string = "embed-english-light-v3.0" - embedMultilingualLightV30 string = "embed-multilingual-light-v3.0" - embedEnglishV20 string = "embed-english-v2.0" - embedEnglishLightV20 string = "embed-english-light-v2.0" - embedMultilingualV20 string = "embed-multilingual-v2.0" - cohereAIAKEnvStr string = "MILVUSAI_COHERE_API_KEY" ) // siliconflow const ( - bAAIBgeLargeZhV15 string = "BAAI/bge-large-zh-v1.5" - bAAIBgeLargeEhV15 string = "BAAI/bge-large-eh-v1.5" - neteaseYoudaoBceEmbeddingBasev1 string = "netease-youdao/bce-embedding-base_v1" - bAAIBgeM3 string = "BAAI/bge-m3" - proBAAIBgeM3 string = "Pro/BAAI/bge-m3 " - siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY" ) @@ -150,6 +120,23 @@ const ( enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI" ) +const enableConfigAKAndURL string = "ENABLE_CONFIG_AK_AND_URL" + +func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) { + var apiKey, url string + if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" { + for _, param := range params { + switch strings.ToLower(param.Key) { + case apiKeyParamKey: + apiKey = param.Value + case embeddingURLParamKey: + url = param.Value + } + } + } + return apiKey, url +} + func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { dim, err := strconv.ParseInt(dimStr, 10, 64) if err != nil { diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index aabcfdf5c0..9738c24f78 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -20,6 +20,7 @@ package function import ( "fmt" + "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -27,6 +28,22 @@ import ( type FunctionBase struct { schema *schemapb.FunctionSchema outputFields []*schemapb.FieldSchema + + collectionName string + functionTypeName string + functionName string + provider string +} + +func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) { + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case Provider: + return strings.ToLower(param.Value), nil + default: + } + } + return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider) } func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) { @@ -45,6 +62,15 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.Function return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", coll.Name, fSchema.Name) } + + provider, err := getProvider(fSchema) + if err != nil { + return nil, err + } + base.collectionName = coll.Name + base.functionName = fSchema.Name + base.provider = provider + base.functionTypeName = fSchema.GetType().String() return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index bfd12dedf8..66913d1dd8 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -19,7 +19,9 @@ package function import ( + "context" "fmt" + "strconv" "sync" "google.golang.org/protobuf/proto" @@ -27,18 +29,26 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/timerecord" ) type Runner interface { GetSchema() *schemapb.FunctionSchema GetOutputFields() []*schemapb.FieldSchema + GetCollectionName() string + GetFunctionTypeName() string + GetFunctionName() string + GetFunctionProvider() string + Check() error MaxBatch() int - ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) - ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) + ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) + ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) } @@ -64,9 +74,18 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc // Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here. func ValidateFunctions(schema *schemapb.CollectionSchema) error { for _, fSchema := range schema.Functions { - if _, err := createFunction(schema, fSchema); err != nil { + f, err := createFunction(schema, fSchema) + if err != nil { return err } + + // ignore bm25 function + if f == nil { + continue + } + if err := f.Check(); err != nil { + return fmt.Errorf("Check function [%s:%s] failed, the err is: %v", fSchema.Name, fSchema.GetType().String(), err) + } } return nil } @@ -87,7 +106,7 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, return executor, nil } -func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { +func (executor *FunctionExecutor) processSingleFunction(ctx context.Context, runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames())) for _, name := range runner.GetSchema().GetInputFieldNames() { for _, field := range msg.FieldsData { @@ -100,14 +119,18 @@ func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgs return nil, fmt.Errorf("Input field not found") } - outputs, err := runner.ProcessInsert(inputs) + tr := timerecord.NewTimeRecorder("function ProcessInsert") + outputs, err := runner.ProcessInsert(ctx, inputs) if err != nil { return nil, err } + + metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds())) + tr.CtxElapse(ctx, "function ProcessInsert done") return outputs, nil } -func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error { +func (executor *FunctionExecutor) ProcessInsert(ctx context.Context, msg *msgstream.InsertMsg) error { numRows := msg.NumRows for _, runner := range executor.runners { if numRows > uint64(runner.MaxBatch()) { @@ -122,7 +145,7 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error wg.Add(1) go func(runner Runner) { defer wg.Done() - data, err := executor.processSingleFunction(runner, msg) + data, err := executor.processSingleFunction(ctx, runner, msg) if err != nil { errChan <- err return @@ -149,7 +172,7 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error return nil } -func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) { +func (executor *FunctionExecutor) processSingleSearch(ctx context.Context, runner Runner, placeholderGroup []byte) ([]byte, error) { pb := &commonpb.PlaceholderGroup{} proto.Unmarshal(placeholderGroup, pb) if len(pb.Placeholders) != 1 { @@ -158,14 +181,18 @@ func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholder if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar { return placeholderGroup, nil } - res, err := runner.ProcessSearch(pb) + + tr := timerecord.NewTimeRecorder("function ProcessSearch") + res, err := runner.ProcessSearch(ctx, pb) if err != nil { return nil, err } + metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds())) + tr.CtxElapse(ctx, "function ProcessSearch done") return proto.Marshal(res) } -func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error { +func (executor *FunctionExecutor) prcessSearch(ctx context.Context, req *internalpb.SearchRequest) error { runner, exist := executor.runners[req.FieldId] if !exist { return fmt.Errorf("Can not found function in field %d", req.FieldId) @@ -173,7 +200,7 @@ func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) er if req.Nq > int64(runner.MaxBatch()) { return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch()) } - if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil { + if newHolder, err := executor.processSingleSearch(ctx, runner, req.GetPlaceholderGroup()); err != nil { return err } else { req.PlaceholderGroup = newHolder @@ -181,7 +208,7 @@ func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) er return nil } -func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error { +func (executor *FunctionExecutor) prcessAdvanceSearch(ctx context.Context, req *internalpb.SearchRequest) error { outputs := make(chan map[int64][]byte, len(req.GetSubReqs())) errChan := make(chan error, len(req.GetSubReqs())) var wg sync.WaitGroup @@ -193,7 +220,7 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ wg.Add(1) go func(runner Runner, idx int64, placeholderGroup []byte) { defer wg.Done() - if newHolder, err := executor.processSingleSearch(runner, placeholderGroup); err != nil { + if newHolder, err := executor.processSingleSearch(ctx, runner, placeholderGroup); err != nil { errChan <- err } else { outputs <- map[int64][]byte{idx: newHolder} @@ -216,11 +243,11 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ return nil } -func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error { +func (executor *FunctionExecutor) ProcessSearch(ctx context.Context, req *internalpb.SearchRequest) error { if !req.IsAdvanced { - return executor.prcessSearch(req) + return executor.prcessSearch(ctx, req) } - return executor.prcessAdvanceSearch(req) + return executor.prcessAdvanceSearch(ctx, req) } func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) { diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 6c21a97d80..3906a342f7 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -19,6 +19,7 @@ package function import ( + "context" "encoding/json" "io" "net/http" @@ -148,7 +149,7 @@ func (s *FunctionExecutorSuite) TestExecutor() { exec, err := NewFunctionExecutor(schema) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) - exec.ProcessInsert(msg) + exec.ProcessInsert(context.Background(), msg) s.Equal(len(msg.FieldsData), 3) } @@ -183,7 +184,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { exec, err := NewFunctionExecutor(schema) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) - err = exec.ProcessInsert(msg) + err = exec.ProcessInsert(context.Background(), msg) s.Error(err) } @@ -225,7 +226,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() { IsAdvanced: false, FieldId: 102, } - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.NoError(err) // No function found @@ -235,7 +236,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() { IsAdvanced: false, FieldId: 111, } - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.Error(err) // Large search nq @@ -245,7 +246,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() { IsAdvanced: false, FieldId: 102, } - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.Error(err) } @@ -277,12 +278,12 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() { IsAdvanced: true, SubReqs: []*internalpb.SubSearchRequest{subReq}, } - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.NoError(err) // Large nq subReq.Nq = 1000 - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.Error(err) } } @@ -318,7 +319,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() { IsAdvanced: false, FieldId: 102, } - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.Error(err) } // AdvanceSearch @@ -332,7 +333,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() { IsAdvanced: true, SubReqs: []*internalpb.SubSearchRequest{subReq}, } - err = exec.ProcessSearch(req) + err = exec.ProcessSearch(context.Background(), req) s.Error(err) } } diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index ff7a699dd3..ea28303859 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -36,18 +36,7 @@ import ( "github.com/milvus-io/milvus/internal/util/function/models/voyageai" ) -// func mockEmbedding(texts []string, dim int) [][]float32 { -// embeddings := make([][]float32, 0) -// for i := 0; i < len(texts); i++ { -// f := float32(i) -// emb := make([]float32, 0) -// for j := 0; j < dim; j++ { -// emb = append(emb, f+float32(j)*0.1) -// } -// embeddings = append(embeddings, emb) -// } -// return embeddings -// } +const TestModel string = "TestModel" func mockEmbedding[T int8 | float32](texts []string, dim int) [][]T { embeddings := make([][]T, 0) diff --git a/internal/util/function/models/ali/ali_dashscope_text_embedding.go b/internal/util/function/models/ali/ali_dashscope_text_embedding.go index f59b5973c8..119983e8cc 100644 --- a/internal/util/function/models/ali/ali_dashscope_text_embedding.go +++ b/internal/util/function/models/ali/ali_dashscope_text_embedding.go @@ -17,7 +17,6 @@ package ali import ( - "bytes" "context" "encoding/json" "fmt" @@ -135,14 +134,11 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) - if err != nil { - return nil, err + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - body, err := utils.RetrySend(req, 3) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) if err != nil { return nil, err } diff --git a/internal/util/function/models/cohere/cohere_text_embedding.go b/internal/util/function/models/cohere/cohere_text_embedding.go index 05958fe196..cca6c08172 100644 --- a/internal/util/function/models/cohere/cohere_text_embedding.go +++ b/internal/util/function/models/cohere/cohere_text_embedding.go @@ -17,7 +17,6 @@ package cohere import ( - "bytes" "context" "encoding/json" "fmt" @@ -100,15 +99,12 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) - if err != nil { - return nil, err + headers := map[string]string{ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("bearer %s", c.apiKey), } - - req.Header.Set("accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("bearer %s", c.apiKey)) - body, err := utils.RetrySend(req, 3) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) if err != nil { return nil, err } diff --git a/internal/util/function/models/openai/openai_embedding.go b/internal/util/function/models/openai/openai_embedding.go index c81604f25f..54bf01e9fb 100644 --- a/internal/util/function/models/openai/openai_embedding.go +++ b/internal/util/function/models/openai/openai_embedding.go @@ -17,7 +17,6 @@ package openai import ( - "bytes" "context" "encoding/json" "fmt" @@ -148,18 +147,10 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) + body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1) if err != nil { return nil, err } - for key, value := range headers { - req.Header.Set(key, value) - } - body, err := utils.RetrySend(req, 3) - if err != nil { - return nil, err - } - var res EmbeddingResponse err = json.Unmarshal(body, &res) if err != nil { diff --git a/internal/util/function/models/openai/openai_embedding_test.go b/internal/util/function/models/openai/openai_embedding_test.go index 87f44b4ea6..cb6f03a3e7 100644 --- a/internal/util/function/models/openai/openai_embedding_test.go +++ b/internal/util/function/models/openai/openai_embedding_test.go @@ -218,7 +218,8 @@ func TestEmbeddingFailed(t *testing.T) { func TestTimeout(t *testing.T) { var st int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(3 * time.Second) + // (Timeout 1s + Wait 1s) * Retry 3 + time.Sleep(6 * time.Second) atomic.AddInt32(&st, 1) w.WriteHeader(http.StatusUnauthorized) })) diff --git a/internal/util/function/models/siliconflow/siliconflow_text_embedding.go b/internal/util/function/models/siliconflow/siliconflow_text_embedding.go index f58b962efc..515dd7d1e3 100644 --- a/internal/util/function/models/siliconflow/siliconflow_text_embedding.go +++ b/internal/util/function/models/siliconflow/siliconflow_text_embedding.go @@ -17,7 +17,6 @@ package siliconflow import ( - "bytes" "context" "encoding/json" "fmt" @@ -115,14 +114,12 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - body, err := utils.RetrySend(req, 3) + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), + } + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) if err != nil { return nil, err } diff --git a/internal/util/function/models/tei/tei.go b/internal/util/function/models/tei/tei.go index e5e1f00553..31386ad76e 100644 --- a/internal/util/function/models/tei/tei.go +++ b/internal/util/function/models/tei/tei.go @@ -17,7 +17,6 @@ package tei import ( - "bytes" "context" "encoding/json" "fmt" @@ -46,10 +45,10 @@ func NewTEIEmbeddingClient(apiKey string, endpoint string) (*TEIEmbedding, error return nil, err } if base.Scheme != "http" && base.Scheme != "https" { - return nil, fmt.Errorf("%s is not a valid http/https link", endpoint) + return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint) } if base.Host == "" { - return nil, fmt.Errorf("%s is not a valid http/https link", endpoint) + return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint) } base.Path = "/embed" @@ -88,17 +87,13 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) - if err != nil { - return nil, err + headers := map[string]string{ + "Content-Type": "application/json", } - - req.Header.Set("Content-Type", "application/json") if c.apiKey != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey) } - - body, err := utils.RetrySend(req, 3) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) if err != nil { return nil, err } diff --git a/internal/util/function/models/utils/embedding_util.go b/internal/util/function/models/utils/embedding_util.go index 7ea94789e0..df854ac8bb 100644 --- a/internal/util/function/models/utils/embedding_util.go +++ b/internal/util/function/models/utils/embedding_util.go @@ -17,9 +17,12 @@ package utils import ( + "bytes" + "context" "fmt" "io" "net/http" + "time" ) const DefaultTimeout int64 = 30 @@ -33,23 +36,31 @@ func send(req *http.Request) ([]byte, error) { body, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, fmt.Errorf("Call service faild, read response failed, errs:[%v]", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Call %s faild, errs:[%s, %s]", req.URL, resp.Status, body) + return nil, fmt.Errorf("Call service faild, errs:[%s, %s]", resp.Status, body) } return body, nil } -func RetrySend(req *http.Request, maxRetries int) ([]byte, error) { +func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) { var err error - var res []byte + var body []byte for i := 0; i < maxRetries; i++ { - res, err = send(req) - if err == nil { - return res, nil + req, reqErr := http.NewRequestWithContext(ctx, httpMethod, url, bytes.NewBuffer(data)) + if reqErr != nil { + return nil, reqErr } + for k, v := range headers { + req.Header.Set(k, v) + } + body, err = send(req) + if err == nil { + return body, nil + } + time.Sleep(time.Duration(retryDelay) * time.Second) } return nil, err } diff --git a/internal/util/function/models/vertexai/vertexai_text_embedding.go b/internal/util/function/models/vertexai/vertexai_text_embedding.go index dc7d9ea503..6b1afd294a 100644 --- a/internal/util/function/models/vertexai/vertexai_text_embedding.go +++ b/internal/util/function/models/vertexai/vertexai_text_embedding.go @@ -17,7 +17,6 @@ package vertexai import ( - "bytes" "context" "encoding/json" "fmt" @@ -134,10 +133,6 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6 ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } var token string if c.token != "" { token = c.token @@ -148,9 +143,11 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6 } } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - body, err := utils.RetrySend(req, 3) + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", token), + } + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) if err != nil { return nil, err } diff --git a/internal/util/function/models/voyageai/voyageai_text_embedding.go b/internal/util/function/models/voyageai/voyageai_text_embedding.go index 0aea35427e..67a1400766 100644 --- a/internal/util/function/models/voyageai/voyageai_text_embedding.go +++ b/internal/util/function/models/voyageai/voyageai_text_embedding.go @@ -17,7 +17,6 @@ package voyageai import ( - "bytes" "context" "encoding/json" "fmt" @@ -108,7 +107,7 @@ func (c *VoyageAIEmbedding) Check() error { return nil } -func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (any, error) { +func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, truncation bool, timeoutSec int64) (any, error) { if outputType != "float" && outputType != "int8" { return nil, fmt.Errorf("Voyageai: unsupport output type: [%s], only support float and int8", outputType) } @@ -117,9 +116,11 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, r.Input = texts r.InputType = textType r.OutputDtype = outputType + r.Truncation = truncation if dim != 0 { r.OutputDimension = int64(dim) } + data, err := json.Marshal(r) if err != nil { return nil, err @@ -131,14 +132,11 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) - if err != nil { - return nil, err + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - body, err := utils.RetrySend(req, 3) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) if err != nil { return nil, err } diff --git a/internal/util/function/models/voyageai/voyageai_text_embedding_test.go b/internal/util/function/models/voyageai/voyageai_text_embedding_test.go index b44ba36acd..20b0b08b8a 100644 --- a/internal/util/function/models/voyageai/voyageai_text_embedding_test.go +++ b/internal/util/function/models/voyageai/voyageai_text_embedding_test.go @@ -98,7 +98,7 @@ func TestEmbeddingOK(t *testing.T) { c := NewVoyageAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) - r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0) + r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", true, 0) ret := r.(*EmbeddingResponse[float32]) assert.True(t, err == nil) assert.Equal(t, ret.Data[0].Index, 0) @@ -158,7 +158,7 @@ func TestEmbeddingInt8Embed(t *testing.T) { c := NewVoyageAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) - r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "int8", 0) + r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "int8", false, 0) ret := r.(*EmbeddingResponse[int8]) assert.True(t, err == nil) assert.Equal(t, ret.Data[0].Index, 0) @@ -169,7 +169,7 @@ func TestEmbeddingInt8Embed(t *testing.T) { assert.Equal(t, ret.Data[1].Embedding, []int8{3, 4}) assert.Equal(t, ret.Data[2].Embedding, []int8{5, 6}) - _, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "unknow", 0) + _, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "unknow", true, 0) assert.Error(t, err) } } @@ -186,7 +186,7 @@ func TestEmbeddingFailed(t *testing.T) { c := NewVoyageAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) - _, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0) + _, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", false, 0) assert.True(t, err != nil) } } diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index 9da72b6c9d..8dda40b53e 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -81,7 +81,8 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem if err != nil { return nil, err } - var apiKey, url, modelName, user string + apiKey, url := parseAKAndURL(functionSchema.Params) + var modelName, user string var dim int64 for _, param := range functionSchema.Params { @@ -95,21 +96,12 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem } case userParamKey: user = param.Value - case apiKeyParamKey: - apiKey = param.Value - case embeddingURLParamKey: - url = param.Value default: } } var c openai.OpenAIEmbeddingInterface if !isAzure { - if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", - modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) - } - c, err = createOpenAIEmbeddingClient(apiKey, url) if err != nil { return nil, err diff --git a/internal/util/function/siliconflow_embedding_provider.go b/internal/util/function/siliconflow_embedding_provider.go index a5b36fd953..32ee5ae3cf 100644 --- a/internal/util/function/siliconflow_embedding_provider.go +++ b/internal/util/function/siliconflow_embedding_provider.go @@ -60,25 +60,17 @@ func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, function if err != nil { return nil, err } - var apiKey, url, modelName string + apiKey, url := parseAKAndURL(functionSchema.Params) + var modelName string for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { case modelNameParamKey: modelName = param.Value - case apiKeyParamKey: - apiKey = param.Value - case embeddingURLParamKey: - url = param.Value default: } } - if modelName != bAAIBgeLargeZhV15 && modelName != bAAIBgeLargeEhV15 && modelName != neteaseYoudaoBceEmbeddingBasev1 && modelName != bAAIBgeM3 && modelName != proBAAIBgeM3 { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s]", - modelName, bAAIBgeLargeZhV15, bAAIBgeLargeEhV15, neteaseYoudaoBceEmbeddingBasev1, bAAIBgeM3, proBAAIBgeM3) - } - c, err := createSiliconflowEmbeddingClient(apiKey, url) if err != nil { return nil, err diff --git a/internal/util/function/siliconflow_embedding_provider_test.go b/internal/util/function/siliconflow_embedding_provider_test.go index bb294fc79f..cb278df669 100644 --- a/internal/util/function/siliconflow_embedding_provider_test.go +++ b/internal/util/function/siliconflow_embedding_provider_test.go @@ -69,7 +69,7 @@ func createSiliconflowProvider(url string, schema *schemapb.FieldSchema, provide InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: bAAIBgeLargeEhV15}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: url}, }, @@ -187,7 +187,7 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: bAAIBgeLargeEhV15}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: "mock"}, }, @@ -196,11 +196,4 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi s.NoError(err) s.Equal(provider.FieldDim(), int64(4)) s.True(provider.MaxBatch() > 0) - - // Invalid model - { - functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"} - _, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema) - s.Error(err) - } } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index c5ee65d2bd..510464796d 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -19,8 +19,9 @@ package function import ( + "context" "fmt" - "strings" + "reflect" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -44,6 +45,13 @@ const ( teiProvider string = "tei" ) +func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error { + if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) { + return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field") + } + return nil +} + // Text embedding for retrieval task type textEmbeddingProvider interface { MaxBatch() int @@ -51,17 +59,6 @@ type textEmbeddingProvider interface { FieldDim() int64 } -func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) { - for _, param := range functionSchema.Params { - switch strings.ToLower(param.Key) { - case Provider: - return strings.ToLower(param.Value), nil - default: - } - } - return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider) -} - type TextEmbeddingFunction struct { FunctionBase @@ -82,20 +79,13 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s return nil, err } - if base.outputFields[0].DataType != schemapb.DataType_FloatVector && base.outputFields[0].DataType != schemapb.DataType_Int8Vector { - return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s, %s], got [%s]", - schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], - schemapb.DataType_name[int32(schemapb.DataType_Int8Vector)], - schemapb.DataType_name[int32(base.outputFields[0].DataType)]) - } - - provider, err := getProvider(functionSchema) - if err != nil { + if err := TextEmbeddingOutputsCheck(base.outputFields); err != nil { return nil, err } + var embP textEmbeddingProvider var newProviderErr error - switch provider { + switch base.provider { case openAIProvider: embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) case azureOpenAIProvider: @@ -115,7 +105,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s case teiProvider: embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema) default: - return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider) + return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider) } if newProviderErr != nil { @@ -127,10 +117,46 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s }, nil } +func (runner *TextEmbeddingFunction) Check() error { + embds, err := runner.embProvider.CallEmbedding([]string{"check"}, InsertMode) + if err != nil { + return err + } + dim := 0 + switch embds := embds.(type) { + case [][]float32: + dim = len(embds[0]) + case [][]int8: + dim = len(embds[0]) + default: + return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String()) + } + if dim != int(runner.embProvider.FieldDim()) { + return fmt.Errorf("The dim set in the schema is inconsistent with the dim of the model, dim in schema is %d, dim of model is %d", runner.embProvider.FieldDim(), dim) + } + return nil +} + func (runner *TextEmbeddingFunction) MaxBatch() int { return runner.embProvider.MaxBatch() } +func (runner *TextEmbeddingFunction) GetCollectionName() string { + return runner.collectionName +} + +func (runner *TextEmbeddingFunction) GetFunctionProvider() string { + return runner.provider +} + +func (runner *TextEmbeddingFunction) GetFunctionTypeName() string { + return runner.functionTypeName +} + +func (runner *TextEmbeddingFunction) GetFunctionName() string { + return runner.functionName +} + func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.FieldData, error) { var outputField schemapb.FieldData outputField.FieldId = runner.GetOutputFields()[0].FieldID @@ -174,7 +200,7 @@ func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.Fie return []*schemapb.FieldData{&outputField}, nil } -func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { +func (runner *TextEmbeddingFunction) ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { if len(inputs) != 1 { return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs)) } @@ -199,7 +225,7 @@ func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) return runner.packToFieldData(embds) } -func (runner *TextEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { +func (runner *TextEmbeddingFunction) ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally numRows := len(texts) if numRows > runner.MaxBatch() { diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go index aee2e0f6ae..a34131679d 100644 --- a/internal/util/function/text_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -19,6 +19,7 @@ package function import ( + "context" "strings" "testing" @@ -126,7 +127,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() { { data := createData([]string{"sentence"}) - ret, err2 := runner.ProcessInsert(data) + ret, err2 := runner.ProcessInsert(context.Background(), data) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(int64(4), ret[0].GetVectors().Dim) @@ -134,7 +135,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() { } { data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) - ret, _ := runner.ProcessInsert(data) + ret, _ := runner.ProcessInsert(context.Background(), data) s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data) } } @@ -158,7 +159,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() { { data := createData([]string{"sentence"}) - ret, err2 := runner.ProcessInsert(data) + ret, err2 := runner.ProcessInsert(context.Background(), data) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(int64(4), ret[0].GetVectors().Dim) @@ -166,7 +167,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() { } { data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) - ret, _ := runner.ProcessInsert(data) + ret, _ := runner.ProcessInsert(context.Background(), data) s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data) } } @@ -185,7 +186,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: aliDashScopeProvider}, - {Key: modelNameParamKey, Value: TextEmbeddingV3}, + {Key: modelNameParamKey, Value: TestModel}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: ts.URL}, @@ -195,7 +196,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { { data := createData([]string{"sentence"}) - ret, err2 := runner.ProcessInsert(data) + ret, err2 := runner.ProcessInsert(context.Background(), data) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(int64(4), ret[0].GetVectors().Dim) @@ -203,7 +204,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { } { data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) - ret, _ := runner.ProcessInsert(data) + ret, _ := runner.ProcessInsert(context.Background(), data) s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data) } @@ -226,7 +227,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { } data = append(data, &f) data = append(data, &f) - _, err := runner.ProcessInsert(data) + _, err := runner.ProcessInsert(context.Background(), data) s.Error(err) } @@ -242,7 +243,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { }, } data = append(data, &f) - _, err := runner.ProcessInsert(data) + _, err := runner.ProcessInsert(context.Background(), data) s.Error(err) } // empty input @@ -257,7 +258,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { }, } data = append(data, &f) - _, err := runner.ProcessInsert(data) + _, err := runner.ProcessInsert(context.Background(), data) s.Error(err) } // large input data @@ -278,7 +279,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { }, } data = append(data, &f) - _, err := runner.ProcessInsert(data) + _, err := runner.ProcessInsert(context.Background(), data) s.Error(err) } } @@ -377,26 +378,6 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { s.Error(err) } - // error model name - { - _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldNames: []string{"text"}, - OutputFieldNames: []string{"vector"}, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, - Params: []*commonpb.KeyValuePair{ - {Key: Provider, Value: openAIProvider}, - {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, - {Key: dimParamKey, Value: "4"}, - {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingURLParamKey, Value: "mock"}, - }, - }) - s.Error(err) - } - // no openai api key { _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ @@ -426,7 +407,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: bedrockProvider}, - {Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}, + {Key: modelNameParamKey, Value: TestModel}, {Key: awsAKIdParamKey, Value: "mock"}, {Key: awsSAKParamKey, Value: "mock"}, {Key: regionParamKey, Value: "mock"}, @@ -450,7 +431,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: aliDashScopeProvider}, - {Key: modelNameParamKey, Value: TextEmbeddingV1}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, }, } @@ -472,7 +453,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: voyageAIProvider}, - {Key: modelNameParamKey, Value: voyage3}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, }, } @@ -494,7 +475,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: siliconflowProvider}, - {Key: modelNameParamKey, Value: bAAIBgeLargeZhV15}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, }, } @@ -516,7 +497,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: cohereProvider}, - {Key: modelNameParamKey, Value: embedEnglishLightV20}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, }, } @@ -640,7 +621,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() { s.NoError(err) placeholderGroup := commonpb.PlaceholderGroup{} proto.Unmarshal(placeholderGroupBytes, &placeholderGroup) - _, err = runner.ProcessSearch(&placeholderGroup) + _, err = runner.ProcessSearch(context.Background(), &placeholderGroup) s.Error(err) } @@ -665,7 +646,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() { s.NoError(err) placeholderGroup := commonpb.PlaceholderGroup{} proto.Unmarshal(placeholderGroupBytes, &placeholderGroup) - _, err = runner.ProcessSearch(&placeholderGroup) + _, err = runner.ProcessSearch(context.Background(), &placeholderGroup) s.NoError(err) } } @@ -696,7 +677,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: cohereProvider}, - {Key: modelNameParamKey, Value: embedEnglishV30}, + {Key: modelNameParamKey, Value: TestModel}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: ts.URL}, @@ -706,7 +687,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() { { data := createData([]string{"sentence"}) - ret, err2 := runner.ProcessInsert(data) + ret, err2 := runner.ProcessInsert(context.Background(), data) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(int64(4), ret[0].GetVectors().Dim) @@ -743,7 +724,7 @@ func (s *TextEmbeddingFunctionSuite) TestUnsupportedVec() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: cohereProvider}, - {Key: modelNameParamKey, Value: embedEnglishV30}, + {Key: modelNameParamKey, Value: TestModel}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: "mock"}, @@ -778,7 +759,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: cohereProvider}, - {Key: modelNameParamKey, Value: embedEnglishV30}, + {Key: modelNameParamKey, Value: TestModel}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: ts.URL}, @@ -807,7 +788,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() { s.NoError(err) placeholderGroup := commonpb.PlaceholderGroup{} proto.Unmarshal(placeholderGroupBytes, &placeholderGroup) - _, err = runner.ProcessSearch(&placeholderGroup) + _, err = runner.ProcessSearch(context.Background(), &placeholderGroup) s.NoError(err) } } @@ -880,7 +861,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: cohereProvider}, - {Key: modelNameParamKey, Value: embedEnglishV30}, + {Key: modelNameParamKey, Value: TestModel}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: ts.URL}, diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go index d00b7e1340..ab6566fc37 100644 --- a/internal/util/function/vertexai_embedding_provider.go +++ b/internal/util/function/vertexai_embedding_provider.go @@ -42,7 +42,7 @@ func getVertexAIJsonKey() ([]byte, error) { jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv) jsonKey, err := os.ReadFile(jsonKeyPath) if err != nil { - vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err) + vtxKey.initErr = fmt.Errorf("Vertexai: read service account json file failed, %v", err) return } vtxKey.jsonKey = jsonKey @@ -56,16 +56,6 @@ const ( vertexAISTS string = "STS" ) -func checkTask(modelName string, task string) error { - if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS { - return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS) - } - if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival { - return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival) - } - return nil -} - type VertexAIEmbeddingProvider struct { fieldDim int64 @@ -117,18 +107,11 @@ func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch if task == "" { task = vertexAIDocRetrival } - if err := checkTask(modelName, task); err != nil { - return nil, err - } if location == "" { location = "us-central1" } - if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]", - modelName, textEmbedding005, textMultilingualEmbedding002) - } url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName) var client *vertexai.VertexAIEmbedding if c == nil { diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go index 4a3e134782..4018db58eb 100644 --- a/internal/util/function/vertexai_embedding_provider_test.go +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -67,7 +67,7 @@ func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbed InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: textEmbedding005}, + {Key: modelNameParamKey, Value: TestModel}, {Key: locationParamKey, Value: "mock_local"}, {Key: projectIDParamKey, Value: "mock_id"}, {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, @@ -174,21 +174,6 @@ func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { s.Error(err2) } -func (s *VertexAITextEmbeddingProviderSuite) TestCheckVertexAITask() { - err := checkTask(textMultilingualEmbedding002, "UnkownTask") - s.Error(err) - - // textMultilingualEmbedding002 not support vertexAICodeRetrival task - err = checkTask(textMultilingualEmbedding002, vertexAICodeRetrival) - s.Error(err) - - err = checkTask(textEmbedding005, vertexAICodeRetrival) - s.NoError(err) - - err = checkTask(textMultilingualEmbedding002, vertexAISTS) - s.NoError(err) -} - func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() { os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath") defer os.Unsetenv(vertexServiceAccountJSONEnv) @@ -205,7 +190,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() { InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: textEmbedding005}, + {Key: modelNameParamKey, Value: TestModel}, {Key: projectIDParamKey, Value: "mock_id"}, {Key: dimParamKey, Value: "4"}, }, @@ -234,13 +219,6 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() { s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY") s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY") } - - // invalid task - { - functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: "UnkownTask"} - _, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient) - s.Error(err) - } } func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() { @@ -259,7 +237,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: textEmbedding005}, + {Key: modelNameParamKey, Value: TestModel}, {Key: projectIDParamKey, Value: "mock_id"}, {Key: dimParamKey, Value: "4"}, }, @@ -269,9 +247,4 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() s.NoError(err) s.True(provider.MaxBatch() > 0) s.Equal(provider.FieldDim(), int64(4)) - - // check model name - functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"} - _, err = NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient) - s.Error(err) } diff --git a/internal/util/function/voyageai_embedding_provider.go b/internal/util/function/voyageai_embedding_provider.go index 2537a3a560..3e3240712d 100644 --- a/internal/util/function/voyageai_embedding_provider.go +++ b/internal/util/function/voyageai_embedding_provider.go @@ -21,6 +21,7 @@ package function import ( "fmt" "os" + "strconv" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -34,6 +35,7 @@ type VoyageAIEmbeddingProvider struct { client *voyageai.VoyageAIEmbedding modelName string embedDimParam int64 + truncate bool embdType embeddingType outputType string @@ -62,8 +64,10 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch if err != nil { return nil, err } - var apiKey, url, modelName string + apiKey, url := parseAKAndURL(functionSchema.Params) + var modelName string dim := int64(0) + truncate := false for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { @@ -75,27 +79,14 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch if err != nil { return nil, err } - case apiKeyParamKey: - apiKey = param.Value - case embeddingURLParamKey: - url = param.Value + case truncationParamKey: + if truncate, err = strconv.ParseBool(param.Value); err != nil { + return nil, fmt.Errorf("[%s param's value: %s] is invalid, only supports: [true/false]", truncationParamKey, param.Value) + } default: } } - if modelName != voyage3Large && modelName != voyage3 && modelName != voyage3Lite && modelName != voyageCode3 && modelName != voyageFinance2 && modelName != voyageLaw2 && modelName != voyageCode2 { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]", - modelName, voyage3Large, voyage3, voyage3Lite, voyageCode3, voyageFinance2, voyageLaw2, voyageCode2) - } - - if dim != 0 { - if modelName != voyage3Large && modelName != voyageCode3 { - return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports dim parameter, only [%s, %s] support it.", modelName, voyage3, voyageCode3) - } - if dim != 1024 && dim != 256 && dim != 512 && dim != 2048 { - return nil, fmt.Errorf("VoyageAI text embedding model's dim only supports 2048, 1024 (default), 512, and 256.") - } - } c, err := createVoyageAIEmbeddingClient(apiKey, url) if err != nil { return nil, err @@ -113,16 +104,11 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch return "int8" }() - if outputType == "int8" { - if modelName != voyage3Large && modelName != voyageCode3 { - return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports int8 output_dtype, only [%s, %s] support it.", modelName, voyage3, voyageCode3) - } - } - provider := VoyageAIEmbeddingProvider{ client: c, fieldDim: fieldDim, modelName: modelName, + truncate: truncate, embedDimParam: dim, embdType: embdType, outputType: outputType, @@ -155,7 +141,7 @@ func (provider *VoyageAIEmbeddingProvider) CallEmbedding(texts []string, mode Te if end > numRows { end = numRows } - r, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec) + r, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.truncate, provider.timeoutSec) if err != nil { return nil, err } diff --git a/internal/util/function/voyageai_embedding_provider_test.go b/internal/util/function/voyageai_embedding_provider_test.go index 15100ccbfb..0f8ccab7e4 100644 --- a/internal/util/function/voyageai_embedding_provider_test.go +++ b/internal/util/function/voyageai_embedding_provider_test.go @@ -69,7 +69,7 @@ func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerNa InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: voyage3Large}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "1024"}, @@ -136,26 +136,6 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingIn8() { s.NoError(err) } } - - // Invalid model name - { - functionSchema := &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldNames: []string{"text"}, - OutputFieldNames: []string{"vector"}, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, - Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: voyage3Lite}, - {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingURLParamKey, Value: ts.URL}, - {Key: dimParamKey, Value: "1024"}, - }, - } - _, err := NewCohereEmbeddingProvider(int8VecField, functionSchema) - s.Error(err) - } } func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { @@ -318,10 +298,11 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: modelNameParamKey, Value: voyage3Large}, + {Key: modelNameParamKey, Value: TestModel}, {Key: apiKeyParamKey, Value: "mock"}, {Key: embeddingURLParamKey, Value: "mock"}, {Key: dimParamKey, Value: "1024"}, + {Key: truncationParamKey, Value: "true"}, }, } provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema) @@ -329,11 +310,12 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() s.Equal(provider.FieldDim(), int64(1024)) s.True(provider.MaxBatch() > 0) - // Invalid model + // Invalid truncation { - functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"} + functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"} _, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema) s.Error(err) + functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"} } // Invalid dim diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index ff0f064849..aea3edd4ea 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -127,6 +127,11 @@ const ( cgoTypeLabelName = `cgo_type` queueTypeLabelName = `queue_type` + // model function/UDF labels + functionTypeName = "function_type_name" + functionProvider = "function_provider" + functionName = "function_name" + // entities label LoadedLabel = "loaded" NumEntitiesAllLabel = "all" diff --git a/pkg/metrics/proxy_metrics.go b/pkg/metrics/proxy_metrics.go index 24e6688196..77e5020726 100644 --- a/pkg/metrics/proxy_metrics.go +++ b/pkg/metrics/proxy_metrics.go @@ -445,6 +445,15 @@ var ( Help: "the latency of parse expression", Buckets: buckets, }, []string{nodeIDLabelName, functionLabelName, statusLabelName}) + // ProxyFunctionlatency records the latency of function + ProxyFunctionlatency = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.ProxyRole, + Name: "function_udf_call_latency", + Help: "latency of function call", + Buckets: buckets, + }, []string{nodeIDLabelName, collectionName, functionTypeName, functionProvider, functionName}) ) // RegisterProxy registers Proxy metrics @@ -512,6 +521,8 @@ func RegisterProxy(registry *prometheus.Registry) { registry.MustRegister(ProxyParseExpressionLatency) + registry.MustRegister(ProxyFunctionlatency) + RegisterStreamingServiceClient(registry) }