feat: Add function running monitoring (#40358)

#35856 
#40004 
1. Optimize model verification logic
2. Add profiling code

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-03-10 22:28:05 +08:00 committed by GitHub
parent 0a7e692b6f
commit 359e7efd8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 343 additions and 440 deletions

View File

@ -163,13 +163,17 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
// Calculate embedding fields
if function.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-call-function-udf")
defer sp.End()
exec, err := function.NewFunctionExecutor(schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessInsert(it.insertMsg); err != nil {
sp.AddEvent("Create-function-udf")
if err := exec.ProcessInsert(ctx, it.insertMsg); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs

View File

@ -427,13 +427,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
var err error
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIds) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
defer sp.End()
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
sp.AddEvent("Create-function-udf")
if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
@ -516,13 +520,17 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
t.SearchRequest.GroupSize = queryInfo.GroupSize
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf")
defer sp.End()
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
sp.AddEvent("Create-function-udf")
if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),

View File

@ -20,7 +20,6 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"math/rand"
"strconv"
"testing"
@ -40,6 +39,7 @@ import (
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
@ -1025,7 +1025,8 @@ func TestCreateCollectionTask(t *testing.T) {
})
t.Run("collection with embedding function ", func(t *testing.T) {
fmt.Println(schema)
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
schema.Functions = []*schemapb.FunctionSchema{
{
Name: "test",
@ -1036,6 +1037,8 @@ func TestCreateCollectionTask(t *testing.T) {
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "128"},
},
},
}

View File

@ -148,6 +148,8 @@ func (it *upsertTask) OnEnqueue() error {
}
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute")
defer sp.End()
collectionName := it.upsertMsg.InsertMsg.CollectionName
if err := validateCollectionName(collectionName); err != nil {
log.Ctx(ctx).Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
@ -156,13 +158,17 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
// Calculate embedding fields
if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf")
defer sp.End()
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil {
sp.AddEvent("Create-function-udf")
if err := exec.ProcessInsert(ctx, it.upsertMsg.InsertMsg); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs

View File

@ -720,8 +720,8 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
return nil
}
func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
switch function.GetType() {
func checkFunctionOutputField(fSchema *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
switch fSchema.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 {
return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields))
@ -731,8 +731,8 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
}
case schemapb.FunctionType_TextEmbedding:
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field")
if err := function.TextEmbeddingOutputsCheck(fields); err != nil {
return err
}
default:
return fmt.Errorf("check output field for unknown function type")

View File

@ -37,6 +37,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
@ -2850,6 +2851,8 @@ func TestValidateFunction(t *testing.T) {
func TestValidateModelFunction(t *testing.T) {
t.Run("Valid model function schema", func(t *testing.T) {
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
@ -2877,7 +2880,7 @@ func TestValidateModelFunction(t *testing.T) {
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: "mock_url"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"},
},
},

View File

@ -60,7 +60,8 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio
if err != nil {
return nil, err
}
var apiKey, url, modelName string
apiKey, url := parseAKAndURL(functionSchema.Params)
var modelName string
var dim int64
for _, param := range functionSchema.Params {
@ -72,26 +73,17 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio
if err != nil {
return nil, err
}
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
default:
}
}
if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3)
}
c, err := createAliEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
}
maxBatch := 25
if modelName == TextEmbeddingV3 {
if modelName == "text-embedding-v3" {
maxBatch = 6
}

View File

@ -69,9 +69,9 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TextEmbeddingV3},
{Key: apiKeyParamKey, Value: "mock"},
{Key: modelNameParamKey, Value: TestModel},
{Key: embeddingURLParamKey, Value: url},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
},
}
@ -85,8 +85,8 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
func (s *AliTextEmbeddingProviderSuite) TestEmbedding() {
ts := CreateAliEmbeddingServer()
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
@ -186,18 +186,13 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: "UnkownModels"},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
},
}
// invalid dim
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
// invalid dim
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TextEmbeddingV3}
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err = NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}

View File

@ -98,9 +98,13 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
return nil, err
}
case awsAKIdParamKey:
awsAccessKeyId = param.Value
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
awsAccessKeyId = param.Value
}
case awsSAKParamKey:
awsSecretAccessKey = param.Value
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
awsSecretAccessKey = param.Value
}
case regionParamKey:
region = param.Value
case normalizeParamKey:
@ -116,10 +120,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
}
}
if modelName != BedRockTitanTextEmbeddingsV2 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s]",
modelName, BedRockTitanTextEmbeddingsV2)
}
var client BedrockClient
if c == nil {
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)

View File

@ -64,7 +64,7 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
},
@ -143,7 +143,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: modelNameParamKey, Value: TestModel},
{Key: awsAKIdParamKey, Value: "mock"},
{Key: awsSAKParamKey, Value: "mock"},
{Key: regionParamKey, Value: "mock"},
@ -164,14 +164,8 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.Error(err)
// invalid model name
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.Error(err)
// invalid dim
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel}
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.Error(err)

View File

@ -62,16 +62,13 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
if err != nil {
return nil, err
}
var apiKey, url, modelName string
apiKey, url := parseAKAndURL(functionSchema.Params)
var modelName string
truncate := "END"
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
case truncateParamKey:
if param.Value != "NONE" && param.Value != "START" && param.Value != "END" {
return nil, fmt.Errorf("Illegal parameters, %s only supports [NONE, START, END]", truncateParamKey)
@ -81,11 +78,6 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
}
}
if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 && modelName != embedEnglishV20 && modelName != embedEnglishLightV20 && modelName != embedMultilingualV20 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
modelName, embedEnglishV30, embedMultilingualV30, embedEnglishLightV30, embedMultilingualLightV30, embedEnglishV20, embedEnglishLightV20, embedMultilingualV20)
}
c, err := createCohereEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
@ -103,12 +95,6 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
return "int8"
}()
if outputType == "int8" {
if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 {
return nil, fmt.Errorf("Cohere text embedding model: [%s] doesn't supports int8. Valid for only v3 models.", modelName)
}
}
provider := CohereEmbeddingProvider{
client: c,
fieldDim: fieldDim,
@ -132,7 +118,8 @@ func (provider *CohereEmbeddingProvider) FieldDim() int64 {
// Specifies the type of input passed to the model. Required for embedding models v3 and higher.
func (provider *CohereEmbeddingProvider) getInputType(mode TextEmbeddingMode) string {
if provider.modelName == embedEnglishV20 || provider.modelName == embedEnglishLightV20 || provider.modelName == embedMultilingualV20 {
// v2 models not support instructor
if strings.HasSuffix(provider.modelName, "v2.0") {
return ""
}
if mode == InsertMode {

View File

@ -69,7 +69,7 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV30},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
},
@ -143,25 +143,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingInt8() {
s.Equal([][]int8{{0, 1, 2, 3}, {1, 2, 3, 4}, {2, 3, 4, 5}}, ret)
}
}
// Invalid model name
{
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
},
}
_, err := NewCohereEmbeddingProvider(int8VecField, functionSchema)
s.Error(err)
}
}
func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
@ -278,7 +259,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
},
}
@ -296,12 +277,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
functionSchema.Params[2].Value = "Unknow"
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
// Invalid ModelName
functionSchema.Params[2].Value = "END"
functionSchema.Params[0].Value = "Unknow"
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}
func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
@ -313,7 +288,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: modelNameParamKey, Value: "model-v2.0"},
{Key: apiKeyParamKey, Value: "mock"},
},
}
@ -323,7 +298,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
s.Equal(provider.getInputType(InsertMode), "")
s.Equal(provider.getInputType(SearchMode), "")
functionSchema.Params[0].Value = embedEnglishLightV30
functionSchema.Params[0].Value = "model-v3.0"
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.NoError(err)
s.Equal(provider.getInputType(InsertMode), "search_document")

View File

@ -20,8 +20,11 @@ package function
import (
"fmt"
"os"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
@ -51,20 +54,12 @@ const (
// ali text embedding
const (
TextEmbeddingV1 string = "text-embedding-v1"
TextEmbeddingV2 string = "text-embedding-v2"
TextEmbeddingV3 string = "text-embedding-v3"
dashscopeAKEnvStr string = "MILVUSAI_DASHSCOPE_API_KEY"
)
// openai/azure text embedding
const (
TextEmbeddingAda002 string = "text-embedding-ada-002"
TextEmbedding3Small string = "text-embedding-3-small"
TextEmbedding3Large string = "text-embedding-3-large"
openaiAKEnvStr string = "MILVUSAI_OPENAI_API_KEY"
azureOpenaiAKEnvStr string = "MILVUSAI_AZURE_OPENAI_API_KEY"
@ -76,11 +71,10 @@ const (
// bedrock emebdding
const (
BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0"
awsAKIdParamKey string = "aws_access_key_id"
awsSAKParamKey string = "aws_secret_access_key"
regionParamKey string = "regin"
normalizeParamKey string = "normalize"
awsAKIdParamKey string = "aws_access_key_id"
awsSAKParamKey string = "aws_secret_access_key"
regionParamKey string = "regin"
normalizeParamKey string = "normalize"
bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID"
bedrockSAKEnvStr string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY"
@ -93,48 +87,24 @@ const (
projectIDParamKey string = "projectid"
taskTypeParamKey string = "task"
textEmbedding005 string = "text-embedding-005"
textMultilingualEmbedding002 string = "text-multilingual-embedding-002"
vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS"
)
// voyageAI
const (
voyage3Large string = "voyage-3-large"
voyage3 string = "voyage-3"
voyage3Lite string = "voyage-3-lite"
voyageCode3 string = "voyage-code-3"
voyageFinance2 string = "voyage-finance-2"
voyageLaw2 string = "voyage-law-2"
voyageCode2 string = "voyage-code-2"
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
truncationParamKey string = "truncation"
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
)
// cohere
const (
embedEnglishV30 string = "embed-english-v3.0"
embedMultilingualV30 string = "embed-multilingual-v3.0"
embedEnglishLightV30 string = "embed-english-light-v3.0"
embedMultilingualLightV30 string = "embed-multilingual-light-v3.0"
embedEnglishV20 string = "embed-english-v2.0"
embedEnglishLightV20 string = "embed-english-light-v2.0"
embedMultilingualV20 string = "embed-multilingual-v2.0"
cohereAIAKEnvStr string = "MILVUSAI_COHERE_API_KEY"
)
// siliconflow
const (
bAAIBgeLargeZhV15 string = "BAAI/bge-large-zh-v1.5"
bAAIBgeLargeEhV15 string = "BAAI/bge-large-eh-v1.5"
neteaseYoudaoBceEmbeddingBasev1 string = "netease-youdao/bce-embedding-base_v1"
bAAIBgeM3 string = "BAAI/bge-m3"
proBAAIBgeM3 string = "Pro/BAAI/bge-m3 "
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
)
@ -150,6 +120,23 @@ const (
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
)
const enableConfigAKAndURL string = "ENABLE_CONFIG_AK_AND_URL"
func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) {
var apiKey, url string
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
for _, param := range params {
switch strings.ToLower(param.Key) {
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
}
}
}
return apiKey, url
}
func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {

View File

@ -20,6 +20,7 @@ package function
import (
"fmt"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
@ -27,6 +28,22 @@ import (
type FunctionBase struct {
schema *schemapb.FunctionSchema
outputFields []*schemapb.FieldSchema
collectionName string
functionTypeName string
functionName string
provider string
}
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case Provider:
return strings.ToLower(param.Value), nil
default:
}
}
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
}
func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) {
@ -45,6 +62,15 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.Function
return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema",
coll.Name, fSchema.Name)
}
provider, err := getProvider(fSchema)
if err != nil {
return nil, err
}
base.collectionName = coll.Name
base.functionName = fSchema.Name
base.provider = provider
base.functionTypeName = fSchema.GetType().String()
return &base, nil
}

View File

@ -19,7 +19,9 @@
package function
import (
"context"
"fmt"
"strconv"
"sync"
"google.golang.org/protobuf/proto"
@ -27,18 +29,26 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
)
type Runner interface {
GetSchema() *schemapb.FunctionSchema
GetOutputFields() []*schemapb.FieldSchema
GetCollectionName() string
GetFunctionTypeName() string
GetFunctionName() string
GetFunctionProvider() string
Check() error
MaxBatch() int
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
}
@ -64,9 +74,18 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc
// Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here.
func ValidateFunctions(schema *schemapb.CollectionSchema) error {
for _, fSchema := range schema.Functions {
if _, err := createFunction(schema, fSchema); err != nil {
f, err := createFunction(schema, fSchema)
if err != nil {
return err
}
// ignore bm25 function
if f == nil {
continue
}
if err := f.Check(); err != nil {
return fmt.Errorf("Check function [%s:%s] failed, the err is: %v", fSchema.Name, fSchema.GetType().String(), err)
}
}
return nil
}
@ -87,7 +106,7 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor,
return executor, nil
}
func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
func (executor *FunctionExecutor) processSingleFunction(ctx context.Context, runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames()))
for _, name := range runner.GetSchema().GetInputFieldNames() {
for _, field := range msg.FieldsData {
@ -100,14 +119,18 @@ func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgs
return nil, fmt.Errorf("Input field not found")
}
outputs, err := runner.ProcessInsert(inputs)
tr := timerecord.NewTimeRecorder("function ProcessInsert")
outputs, err := runner.ProcessInsert(ctx, inputs)
if err != nil {
return nil, err
}
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
tr.CtxElapse(ctx, "function ProcessInsert done")
return outputs, nil
}
func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error {
func (executor *FunctionExecutor) ProcessInsert(ctx context.Context, msg *msgstream.InsertMsg) error {
numRows := msg.NumRows
for _, runner := range executor.runners {
if numRows > uint64(runner.MaxBatch()) {
@ -122,7 +145,7 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error
wg.Add(1)
go func(runner Runner) {
defer wg.Done()
data, err := executor.processSingleFunction(runner, msg)
data, err := executor.processSingleFunction(ctx, runner, msg)
if err != nil {
errChan <- err
return
@ -149,7 +172,7 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error
return nil
}
func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) {
func (executor *FunctionExecutor) processSingleSearch(ctx context.Context, runner Runner, placeholderGroup []byte) ([]byte, error) {
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroup, pb)
if len(pb.Placeholders) != 1 {
@ -158,14 +181,18 @@ func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholder
if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar {
return placeholderGroup, nil
}
res, err := runner.ProcessSearch(pb)
tr := timerecord.NewTimeRecorder("function ProcessSearch")
res, err := runner.ProcessSearch(ctx, pb)
if err != nil {
return nil, err
}
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
tr.CtxElapse(ctx, "function ProcessSearch done")
return proto.Marshal(res)
}
func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error {
func (executor *FunctionExecutor) prcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
runner, exist := executor.runners[req.FieldId]
if !exist {
return fmt.Errorf("Can not found function in field %d", req.FieldId)
@ -173,7 +200,7 @@ func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) er
if req.Nq > int64(runner.MaxBatch()) {
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch())
}
if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil {
if newHolder, err := executor.processSingleSearch(ctx, runner, req.GetPlaceholderGroup()); err != nil {
return err
} else {
req.PlaceholderGroup = newHolder
@ -181,7 +208,7 @@ func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) er
return nil
}
func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error {
func (executor *FunctionExecutor) prcessAdvanceSearch(ctx context.Context, req *internalpb.SearchRequest) error {
outputs := make(chan map[int64][]byte, len(req.GetSubReqs()))
errChan := make(chan error, len(req.GetSubReqs()))
var wg sync.WaitGroup
@ -193,7 +220,7 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ
wg.Add(1)
go func(runner Runner, idx int64, placeholderGroup []byte) {
defer wg.Done()
if newHolder, err := executor.processSingleSearch(runner, placeholderGroup); err != nil {
if newHolder, err := executor.processSingleSearch(ctx, runner, placeholderGroup); err != nil {
errChan <- err
} else {
outputs <- map[int64][]byte{idx: newHolder}
@ -216,11 +243,11 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ
return nil
}
func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error {
func (executor *FunctionExecutor) ProcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
if !req.IsAdvanced {
return executor.prcessSearch(req)
return executor.prcessSearch(ctx, req)
}
return executor.prcessAdvanceSearch(req)
return executor.prcessAdvanceSearch(ctx, req)
}
func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {

View File

@ -19,6 +19,7 @@
package function
import (
"context"
"encoding/json"
"io"
"net/http"
@ -148,7 +149,7 @@ func (s *FunctionExecutorSuite) TestExecutor() {
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
exec.ProcessInsert(msg)
exec.ProcessInsert(context.Background(), msg)
s.Equal(len(msg.FieldsData), 3)
}
@ -183,7 +184,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
err = exec.ProcessInsert(msg)
err = exec.ProcessInsert(context.Background(), msg)
s.Error(err)
}
@ -225,7 +226,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
IsAdvanced: false,
FieldId: 102,
}
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.NoError(err)
// No function found
@ -235,7 +236,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
IsAdvanced: false,
FieldId: 111,
}
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.Error(err)
// Large search nq
@ -245,7 +246,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
IsAdvanced: false,
FieldId: 102,
}
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.Error(err)
}
@ -277,12 +278,12 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
IsAdvanced: true,
SubReqs: []*internalpb.SubSearchRequest{subReq},
}
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.NoError(err)
// Large nq
subReq.Nq = 1000
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.Error(err)
}
}
@ -318,7 +319,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
IsAdvanced: false,
FieldId: 102,
}
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.Error(err)
}
// AdvanceSearch
@ -332,7 +333,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
IsAdvanced: true,
SubReqs: []*internalpb.SubSearchRequest{subReq},
}
err = exec.ProcessSearch(req)
err = exec.ProcessSearch(context.Background(), req)
s.Error(err)
}
}

View File

@ -36,18 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
)
// func mockEmbedding(texts []string, dim int) [][]float32 {
// embeddings := make([][]float32, 0)
// for i := 0; i < len(texts); i++ {
// f := float32(i)
// emb := make([]float32, 0)
// for j := 0; j < dim; j++ {
// emb = append(emb, f+float32(j)*0.1)
// }
// embeddings = append(embeddings, emb)
// }
// return embeddings
// }
const TestModel string = "TestModel"
func mockEmbedding[T int8 | float32](texts []string, dim int) [][]T {
embeddings := make([][]T, 0)

View File

@ -17,7 +17,6 @@
package ali
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -135,14 +134,11 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
if err != nil {
return nil, err
}

View File

@ -17,7 +17,6 @@
package cohere
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -100,15 +99,12 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
headers := map[string]string{
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("bearer %s", c.apiKey),
}
req.Header.Set("accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
if err != nil {
return nil, err
}

View File

@ -17,7 +17,6 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -148,18 +147,10 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1)
if err != nil {
return nil, err
}
for key, value := range headers {
req.Header.Set(key, value)
}
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {

View File

@ -218,7 +218,8 @@ func TestEmbeddingFailed(t *testing.T) {
func TestTimeout(t *testing.T) {
var st int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(3 * time.Second)
// (Timeout 1s + Wait 1s) * Retry 3
time.Sleep(6 * time.Second)
atomic.AddInt32(&st, 1)
w.WriteHeader(http.StatusUnauthorized)
}))

View File

@ -17,7 +17,6 @@
package siliconflow
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -115,14 +114,12 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
if err != nil {
return nil, err
}

View File

@ -17,7 +17,6 @@
package tei
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -46,10 +45,10 @@ func NewTEIEmbeddingClient(apiKey string, endpoint string) (*TEIEmbedding, error
return nil, err
}
if base.Scheme != "http" && base.Scheme != "https" {
return nil, fmt.Errorf("%s is not a valid http/https link", endpoint)
return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint)
}
if base.Host == "" {
return nil, fmt.Errorf("%s is not a valid http/https link", endpoint)
return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint)
}
base.Path = "/embed"
@ -88,17 +87,13 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
headers := map[string]string{
"Content-Type": "application/json",
}
req.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey)
}
body, err := utils.RetrySend(req, 3)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
if err != nil {
return nil, err
}

View File

@ -17,9 +17,12 @@
package utils
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"time"
)
const DefaultTimeout int64 = 30
@ -33,23 +36,31 @@ func send(req *http.Request) ([]byte, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, fmt.Errorf("Call service faild, read response failed, errs:[%v]", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Call %s faild, errs:[%s, %s]", req.URL, resp.Status, body)
return nil, fmt.Errorf("Call service faild, errs:[%s, %s]", resp.Status, body)
}
return body, nil
}
func RetrySend(req *http.Request, maxRetries int) ([]byte, error) {
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) {
var err error
var res []byte
var body []byte
for i := 0; i < maxRetries; i++ {
res, err = send(req)
if err == nil {
return res, nil
req, reqErr := http.NewRequestWithContext(ctx, httpMethod, url, bytes.NewBuffer(data))
if reqErr != nil {
return nil, reqErr
}
for k, v := range headers {
req.Header.Set(k, v)
}
body, err = send(req)
if err == nil {
return body, nil
}
time.Sleep(time.Duration(retryDelay) * time.Second)
}
return nil, err
}

View File

@ -17,7 +17,6 @@
package vertexai
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -134,10 +133,6 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
var token string
if c.token != "" {
token = c.token
@ -148,9 +143,11 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
}
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
body, err := utils.RetrySend(req, 3)
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
if err != nil {
return nil, err
}

View File

@ -17,7 +17,6 @@
package voyageai
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -108,7 +107,7 @@ func (c *VoyageAIEmbedding) Check() error {
return nil
}
func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (any, error) {
func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, truncation bool, timeoutSec int64) (any, error) {
if outputType != "float" && outputType != "int8" {
return nil, fmt.Errorf("Voyageai: unsupport output type: [%s], only support float and int8", outputType)
}
@ -117,9 +116,11 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
r.Input = texts
r.InputType = textType
r.OutputDtype = outputType
r.Truncation = truncation
if dim != 0 {
r.OutputDimension = int64(dim)
}
data, err := json.Marshal(r)
if err != nil {
return nil, err
@ -131,14 +132,11 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
if err != nil {
return nil, err
}

View File

@ -98,7 +98,7 @@ func TestEmbeddingOK(t *testing.T) {
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", true, 0)
ret := r.(*EmbeddingResponse[float32])
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
@ -158,7 +158,7 @@ func TestEmbeddingInt8Embed(t *testing.T) {
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "int8", 0)
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "int8", false, 0)
ret := r.(*EmbeddingResponse[int8])
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
@ -169,7 +169,7 @@ func TestEmbeddingInt8Embed(t *testing.T) {
assert.Equal(t, ret.Data[1].Embedding, []int8{3, 4})
assert.Equal(t, ret.Data[2].Embedding, []int8{5, 6})
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "unknow", 0)
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "unknow", true, 0)
assert.Error(t, err)
}
}
@ -186,7 +186,7 @@ func TestEmbeddingFailed(t *testing.T) {
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", false, 0)
assert.True(t, err != nil)
}
}

View File

@ -81,7 +81,8 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
if err != nil {
return nil, err
}
var apiKey, url, modelName, user string
apiKey, url := parseAKAndURL(functionSchema.Params)
var modelName, user string
var dim int64
for _, param := range functionSchema.Params {
@ -95,21 +96,12 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
}
case userParamKey:
user = param.Value
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
default:
}
}
var c openai.OpenAIEmbeddingInterface
if !isAzure {
if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large)
}
c, err = createOpenAIEmbeddingClient(apiKey, url)
if err != nil {
return nil, err

View File

@ -60,25 +60,17 @@ func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, function
if err != nil {
return nil, err
}
var apiKey, url, modelName string
apiKey, url := parseAKAndURL(functionSchema.Params)
var modelName string
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
default:
}
}
if modelName != bAAIBgeLargeZhV15 && modelName != bAAIBgeLargeEhV15 && modelName != neteaseYoudaoBceEmbeddingBasev1 && modelName != bAAIBgeM3 && modelName != proBAAIBgeM3 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s]",
modelName, bAAIBgeLargeZhV15, bAAIBgeLargeEhV15, neteaseYoudaoBceEmbeddingBasev1, bAAIBgeM3, proBAAIBgeM3)
}
c, err := createSiliconflowEmbeddingClient(apiKey, url)
if err != nil {
return nil, err

View File

@ -69,7 +69,7 @@ func createSiliconflowProvider(url string, schema *schemapb.FieldSchema, provide
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: bAAIBgeLargeEhV15},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
},
@ -187,7 +187,7 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: bAAIBgeLargeEhV15},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
@ -196,11 +196,4 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi
s.NoError(err)
s.Equal(provider.FieldDim(), int64(4))
s.True(provider.MaxBatch() > 0)
// Invalid model
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
_, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}
}

View File

@ -19,8 +19,9 @@
package function
import (
"context"
"fmt"
"strings"
"reflect"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -44,6 +45,13 @@ const (
teiProvider string = "tei"
)
func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error {
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field")
}
return nil
}
// Text embedding for retrieval task
type textEmbeddingProvider interface {
MaxBatch() int
@ -51,17 +59,6 @@ type textEmbeddingProvider interface {
FieldDim() int64
}
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case Provider:
return strings.ToLower(param.Value), nil
default:
}
}
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
}
type TextEmbeddingFunction struct {
FunctionBase
@ -82,20 +79,13 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
return nil, err
}
if base.outputFields[0].DataType != schemapb.DataType_FloatVector && base.outputFields[0].DataType != schemapb.DataType_Int8Vector {
return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s, %s], got [%s]",
schemapb.DataType_name[int32(schemapb.DataType_FloatVector)],
schemapb.DataType_name[int32(schemapb.DataType_Int8Vector)],
schemapb.DataType_name[int32(base.outputFields[0].DataType)])
}
provider, err := getProvider(functionSchema)
if err != nil {
if err := TextEmbeddingOutputsCheck(base.outputFields); err != nil {
return nil, err
}
var embP textEmbeddingProvider
var newProviderErr error
switch provider {
switch base.provider {
case openAIProvider:
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
case azureOpenAIProvider:
@ -115,7 +105,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
case teiProvider:
embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema)
default:
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
}
if newProviderErr != nil {
@ -127,10 +117,46 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
}, nil
}
func (runner *TextEmbeddingFunction) Check() error {
embds, err := runner.embProvider.CallEmbedding([]string{"check"}, InsertMode)
if err != nil {
return err
}
dim := 0
switch embds := embds.(type) {
case [][]float32:
dim = len(embds[0])
case [][]int8:
dim = len(embds[0])
default:
return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String())
}
if dim != int(runner.embProvider.FieldDim()) {
return fmt.Errorf("The dim set in the schema is inconsistent with the dim of the model, dim in schema is %d, dim of model is %d", runner.embProvider.FieldDim(), dim)
}
return nil
}
func (runner *TextEmbeddingFunction) MaxBatch() int {
return runner.embProvider.MaxBatch()
}
func (runner *TextEmbeddingFunction) GetCollectionName() string {
return runner.collectionName
}
func (runner *TextEmbeddingFunction) GetFunctionProvider() string {
return runner.provider
}
func (runner *TextEmbeddingFunction) GetFunctionTypeName() string {
return runner.functionTypeName
}
func (runner *TextEmbeddingFunction) GetFunctionName() string {
return runner.functionName
}
func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.FieldData, error) {
var outputField schemapb.FieldData
outputField.FieldId = runner.GetOutputFields()[0].FieldID
@ -174,7 +200,7 @@ func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.Fie
return []*schemapb.FieldData{&outputField}, nil
}
func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
func (runner *TextEmbeddingFunction) ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs))
}
@ -199,7 +225,7 @@ func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData)
return runner.packToFieldData(embds)
}
func (runner *TextEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
func (runner *TextEmbeddingFunction) ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally
numRows := len(texts)
if numRows > runner.MaxBatch() {

View File

@ -19,6 +19,7 @@
package function
import (
"context"
"strings"
"testing"
@ -126,7 +127,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
ret, err2 := runner.ProcessInsert(context.Background(), data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
@ -134,7 +135,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.ProcessInsert(data)
ret, _ := runner.ProcessInsert(context.Background(), data)
s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data)
}
}
@ -158,7 +159,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
ret, err2 := runner.ProcessInsert(context.Background(), data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
@ -166,7 +167,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.ProcessInsert(data)
ret, _ := runner.ProcessInsert(context.Background(), data)
s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data)
}
}
@ -185,7 +186,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: aliDashScopeProvider},
{Key: modelNameParamKey, Value: TextEmbeddingV3},
{Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
@ -195,7 +196,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
ret, err2 := runner.ProcessInsert(context.Background(), data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
@ -203,7 +204,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.ProcessInsert(data)
ret, _ := runner.ProcessInsert(context.Background(), data)
s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data)
}
@ -226,7 +227,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
}
data = append(data, &f)
data = append(data, &f)
_, err := runner.ProcessInsert(data)
_, err := runner.ProcessInsert(context.Background(), data)
s.Error(err)
}
@ -242,7 +243,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
},
}
data = append(data, &f)
_, err := runner.ProcessInsert(data)
_, err := runner.ProcessInsert(context.Background(), data)
s.Error(err)
}
// empty input
@ -257,7 +258,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
},
}
data = append(data, &f)
_, err := runner.ProcessInsert(data)
_, err := runner.ProcessInsert(context.Background(), data)
s.Error(err)
}
// large input data
@ -278,7 +279,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
},
}
data = append(data, &f)
_, err := runner.ProcessInsert(data)
_, err := runner.ProcessInsert(context.Background(), data)
s.Error(err)
}
}
@ -377,26 +378,6 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
s.Error(err)
}
// error model name
{
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-004"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
})
s.Error(err)
}
// no openai api key
{
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
@ -426,7 +407,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: bedrockProvider},
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: modelNameParamKey, Value: TestModel},
{Key: awsAKIdParamKey, Value: "mock"},
{Key: awsSAKParamKey, Value: "mock"},
{Key: regionParamKey, Value: "mock"},
@ -450,7 +431,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: aliDashScopeProvider},
{Key: modelNameParamKey, Value: TextEmbeddingV1},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
},
}
@ -472,7 +453,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: voyageAIProvider},
{Key: modelNameParamKey, Value: voyage3},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
},
}
@ -494,7 +475,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: siliconflowProvider},
{Key: modelNameParamKey, Value: bAAIBgeLargeZhV15},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
},
}
@ -516,7 +497,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
},
}
@ -640,7 +621,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
s.NoError(err)
placeholderGroup := commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
_, err = runner.ProcessSearch(&placeholderGroup)
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
s.Error(err)
}
@ -665,7 +646,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
s.NoError(err)
placeholderGroup := commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
_, err = runner.ProcessSearch(&placeholderGroup)
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
s.NoError(err)
}
}
@ -696,7 +677,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: embedEnglishV30},
{Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
@ -706,7 +687,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() {
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
ret, err2 := runner.ProcessInsert(context.Background(), data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
@ -743,7 +724,7 @@ func (s *TextEmbeddingFunctionSuite) TestUnsupportedVec() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: embedEnglishV30},
{Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
@ -778,7 +759,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: embedEnglishV30},
{Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
@ -807,7 +788,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
s.NoError(err)
placeholderGroup := commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
_, err = runner.ProcessSearch(&placeholderGroup)
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
s.NoError(err)
}
}
@ -880,7 +861,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: cohereProvider},
{Key: modelNameParamKey, Value: embedEnglishV30},
{Key: modelNameParamKey, Value: TestModel},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},

View File

@ -42,7 +42,7 @@ func getVertexAIJsonKey() ([]byte, error) {
jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv)
jsonKey, err := os.ReadFile(jsonKeyPath)
if err != nil {
vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err)
vtxKey.initErr = fmt.Errorf("Vertexai: read service account json file failed, %v", err)
return
}
vtxKey.jsonKey = jsonKey
@ -56,16 +56,6 @@ const (
vertexAISTS string = "STS"
)
func checkTask(modelName string, task string) error {
if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS {
return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS)
}
if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival {
return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival)
}
return nil
}
type VertexAIEmbeddingProvider struct {
fieldDim int64
@ -117,18 +107,11 @@ func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
if task == "" {
task = vertexAIDocRetrival
}
if err := checkTask(modelName, task); err != nil {
return nil, err
}
if location == "" {
location = "us-central1"
}
if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]",
modelName, textEmbedding005, textMultilingualEmbedding002)
}
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName)
var client *vertexai.VertexAIEmbedding
if c == nil {

View File

@ -67,7 +67,7 @@ func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbed
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: textEmbedding005},
{Key: modelNameParamKey, Value: TestModel},
{Key: locationParamKey, Value: "mock_local"},
{Key: projectIDParamKey, Value: "mock_id"},
{Key: taskTypeParamKey, Value: vertexAICodeRetrival},
@ -174,21 +174,6 @@ func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
s.Error(err2)
}
func (s *VertexAITextEmbeddingProviderSuite) TestCheckVertexAITask() {
err := checkTask(textMultilingualEmbedding002, "UnkownTask")
s.Error(err)
// textMultilingualEmbedding002 not support vertexAICodeRetrival task
err = checkTask(textMultilingualEmbedding002, vertexAICodeRetrival)
s.Error(err)
err = checkTask(textEmbedding005, vertexAICodeRetrival)
s.NoError(err)
err = checkTask(textMultilingualEmbedding002, vertexAISTS)
s.NoError(err)
}
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
defer os.Unsetenv(vertexServiceAccountJSONEnv)
@ -205,7 +190,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: textEmbedding005},
{Key: modelNameParamKey, Value: TestModel},
{Key: projectIDParamKey, Value: "mock_id"},
{Key: dimParamKey, Value: "4"},
},
@ -234,13 +219,6 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
}
// invalid task
{
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: "UnkownTask"}
_, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.Error(err)
}
}
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
@ -259,7 +237,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider()
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: textEmbedding005},
{Key: modelNameParamKey, Value: TestModel},
{Key: projectIDParamKey, Value: "mock_id"},
{Key: dimParamKey, Value: "4"},
},
@ -269,9 +247,4 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider()
s.NoError(err)
s.True(provider.MaxBatch() > 0)
s.Equal(provider.FieldDim(), int64(4))
// check model name
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
_, err = NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.Error(err)
}

View File

@ -21,6 +21,7 @@ package function
import (
"fmt"
"os"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -34,6 +35,7 @@ type VoyageAIEmbeddingProvider struct {
client *voyageai.VoyageAIEmbedding
modelName string
embedDimParam int64
truncate bool
embdType embeddingType
outputType string
@ -62,8 +64,10 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
if err != nil {
return nil, err
}
var apiKey, url, modelName string
apiKey, url := parseAKAndURL(functionSchema.Params)
var modelName string
dim := int64(0)
truncate := false
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
@ -75,27 +79,14 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
if err != nil {
return nil, err
}
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
case truncationParamKey:
if truncate, err = strconv.ParseBool(param.Value); err != nil {
return nil, fmt.Errorf("[%s param's value: %s] is invalid, only supports: [true/false]", truncationParamKey, param.Value)
}
default:
}
}
if modelName != voyage3Large && modelName != voyage3 && modelName != voyage3Lite && modelName != voyageCode3 && modelName != voyageFinance2 && modelName != voyageLaw2 && modelName != voyageCode2 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
modelName, voyage3Large, voyage3, voyage3Lite, voyageCode3, voyageFinance2, voyageLaw2, voyageCode2)
}
if dim != 0 {
if modelName != voyage3Large && modelName != voyageCode3 {
return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports dim parameter, only [%s, %s] support it.", modelName, voyage3, voyageCode3)
}
if dim != 1024 && dim != 256 && dim != 512 && dim != 2048 {
return nil, fmt.Errorf("VoyageAI text embedding model's dim only supports 2048, 1024 (default), 512, and 256.")
}
}
c, err := createVoyageAIEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
@ -113,16 +104,11 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
return "int8"
}()
if outputType == "int8" {
if modelName != voyage3Large && modelName != voyageCode3 {
return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports int8 output_dtype, only [%s, %s] support it.", modelName, voyage3, voyageCode3)
}
}
provider := VoyageAIEmbeddingProvider{
client: c,
fieldDim: fieldDim,
modelName: modelName,
truncate: truncate,
embedDimParam: dim,
embdType: embdType,
outputType: outputType,
@ -155,7 +141,7 @@ func (provider *VoyageAIEmbeddingProvider) CallEmbedding(texts []string, mode Te
if end > numRows {
end = numRows
}
r, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec)
r, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.truncate, provider.timeoutSec)
if err != nil {
return nil, err
}

View File

@ -69,7 +69,7 @@ func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerNa
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: voyage3Large},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "1024"},
@ -136,26 +136,6 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingIn8() {
s.NoError(err)
}
}
// Invalid model name
{
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: voyage3Lite},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
{Key: dimParamKey, Value: "1024"},
},
}
_, err := NewCohereEmbeddingProvider(int8VecField, functionSchema)
s.Error(err)
}
}
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
@ -318,10 +298,11 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: voyage3Large},
{Key: modelNameParamKey, Value: TestModel},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
{Key: dimParamKey, Value: "1024"},
{Key: truncationParamKey, Value: "true"},
},
}
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
@ -329,11 +310,12 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
s.Equal(provider.FieldDim(), int64(1024))
s.True(provider.MaxBatch() > 0)
// Invalid model
// Invalid truncation
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"}
}
// Invalid dim

View File

@ -127,6 +127,11 @@ const (
cgoTypeLabelName = `cgo_type`
queueTypeLabelName = `queue_type`
// model function/UDF labels
functionTypeName = "function_type_name"
functionProvider = "function_provider"
functionName = "function_name"
// entities label
LoadedLabel = "loaded"
NumEntitiesAllLabel = "all"

View File

@ -445,6 +445,15 @@ var (
Help: "the latency of parse expression",
Buckets: buckets,
}, []string{nodeIDLabelName, functionLabelName, statusLabelName})
// ProxyFunctionlatency records the latency of function
ProxyFunctionlatency = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: milvusNamespace,
Subsystem: typeutil.ProxyRole,
Name: "function_udf_call_latency",
Help: "latency of function call",
Buckets: buckets,
}, []string{nodeIDLabelName, collectionName, functionTypeName, functionProvider, functionName})
)
// RegisterProxy registers Proxy metrics
@ -512,6 +521,8 @@ func RegisterProxy(registry *prometheus.Registry) {
registry.MustRegister(ProxyParseExpressionLatency)
registry.MustRegister(ProxyFunctionlatency)
RegisterStreamingServiceClient(registry)
}