mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat: Add function running monitoring (#40358)
#35856 #40004 1. Optimize model verification logic 2. Add profiling code Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
0a7e692b6f
commit
359e7efd8e
@ -163,13 +163,17 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
|
||||
// Calculate embedding fields
|
||||
if function.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-call-function-udf")
|
||||
defer sp.End()
|
||||
exec, err := function.NewFunctionExecutor(schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessInsert(it.insertMsg); err != nil {
|
||||
sp.AddEvent("Create-function-udf")
|
||||
if err := exec.ProcessInsert(ctx, it.insertMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
sp.AddEvent("Call-function-udf")
|
||||
}
|
||||
rowNums := uint32(it.insertMsg.NRows())
|
||||
// set insertTask.rowIDs
|
||||
|
||||
@ -427,13 +427,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
|
||||
var err error
|
||||
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIds) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
|
||||
defer sp.End()
|
||||
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
|
||||
sp.AddEvent("Create-function-udf")
|
||||
if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
sp.AddEvent("Call-function-udf")
|
||||
}
|
||||
|
||||
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
|
||||
@ -516,13 +520,17 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
t.SearchRequest.GroupSize = queryInfo.GroupSize
|
||||
|
||||
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf")
|
||||
defer sp.End()
|
||||
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
|
||||
sp.AddEvent("Create-function-udf")
|
||||
if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
sp.AddEvent("Call-function-udf")
|
||||
}
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
|
||||
@ -20,7 +20,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
@ -40,6 +39,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
|
||||
@ -1025,7 +1025,8 @@ func TestCreateCollectionTask(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("collection with embedding function ", func(t *testing.T) {
|
||||
fmt.Println(schema)
|
||||
ts := function.CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
schema.Functions = []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "test",
|
||||
@ -1036,6 +1037,8 @@ func TestCreateCollectionTask(t *testing.T) {
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "128"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -148,6 +148,8 @@ func (it *upsertTask) OnEnqueue() error {
|
||||
}
|
||||
|
||||
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute")
|
||||
defer sp.End()
|
||||
collectionName := it.upsertMsg.InsertMsg.CollectionName
|
||||
if err := validateCollectionName(collectionName); err != nil {
|
||||
log.Ctx(ctx).Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
@ -156,13 +158,17 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
|
||||
// Calculate embedding fields
|
||||
if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf")
|
||||
defer sp.End()
|
||||
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil {
|
||||
sp.AddEvent("Create-function-udf")
|
||||
if err := exec.ProcessInsert(ctx, it.upsertMsg.InsertMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
sp.AddEvent("Call-function-udf")
|
||||
}
|
||||
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
|
||||
// set upsertTask.insertRequest.rowIDs
|
||||
|
||||
@ -720,8 +720,8 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
|
||||
switch function.GetType() {
|
||||
func checkFunctionOutputField(fSchema *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
|
||||
switch fSchema.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
if len(fields) != 1 {
|
||||
return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields))
|
||||
@ -731,8 +731,8 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem
|
||||
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
|
||||
}
|
||||
case schemapb.FunctionType_TextEmbedding:
|
||||
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
|
||||
return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field")
|
||||
if err := function.TextEmbeddingOutputsCheck(fields); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("check output field for unknown function type")
|
||||
|
||||
@ -37,6 +37,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
@ -2850,6 +2851,8 @@ func TestValidateFunction(t *testing.T) {
|
||||
|
||||
func TestValidateModelFunction(t *testing.T) {
|
||||
t.Run("Valid model function schema", func(t *testing.T) {
|
||||
ts := function.CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
|
||||
@ -2877,7 +2880,7 @@ func TestValidateModelFunction(t *testing.T) {
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: "mock_url"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
|
||||
@ -60,7 +60,8 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName string
|
||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
||||
var modelName string
|
||||
var dim int64
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
@ -72,26 +73,17 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
|
||||
modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3)
|
||||
}
|
||||
|
||||
c, err := createAliEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxBatch := 25
|
||||
if modelName == TextEmbeddingV3 {
|
||||
if modelName == "text-embedding-v3" {
|
||||
maxBatch = 6
|
||||
}
|
||||
|
||||
|
||||
@ -69,9 +69,9 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: TextEmbeddingV3},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
@ -85,8 +85,8 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) TestEmbedding() {
|
||||
ts := CreateAliEmbeddingServer()
|
||||
|
||||
defer ts.Close()
|
||||
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
s.NoError(err)
|
||||
@ -186,18 +186,13 @@ func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: "UnkownModels"},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
// invalid dim
|
||||
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
|
||||
// invalid dim
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TextEmbeddingV3}
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||
_, err = NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
@ -98,9 +98,13 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
|
||||
return nil, err
|
||||
}
|
||||
case awsAKIdParamKey:
|
||||
awsAccessKeyId = param.Value
|
||||
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
|
||||
awsAccessKeyId = param.Value
|
||||
}
|
||||
case awsSAKParamKey:
|
||||
awsSecretAccessKey = param.Value
|
||||
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
|
||||
awsSecretAccessKey = param.Value
|
||||
}
|
||||
case regionParamKey:
|
||||
region = param.Value
|
||||
case normalizeParamKey:
|
||||
@ -116,10 +120,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != BedRockTitanTextEmbeddingsV2 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s]",
|
||||
modelName, BedRockTitanTextEmbeddingsV2)
|
||||
}
|
||||
var client BedrockClient
|
||||
if c == nil {
|
||||
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)
|
||||
|
||||
@ -64,7 +64,7 @@ func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, di
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
@ -143,7 +143,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: awsAKIdParamKey, Value: "mock"},
|
||||
{Key: awsSAKParamKey, Value: "mock"},
|
||||
{Key: regionParamKey, Value: "mock"},
|
||||
@ -164,14 +164,8 @@ func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.Error(err)
|
||||
|
||||
// invalid model name
|
||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.Error(err)
|
||||
|
||||
// invalid dim
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TestModel}
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.Error(err)
|
||||
|
||||
@ -62,16 +62,13 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName string
|
||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
||||
var modelName string
|
||||
truncate := "END"
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
case truncateParamKey:
|
||||
if param.Value != "NONE" && param.Value != "START" && param.Value != "END" {
|
||||
return nil, fmt.Errorf("Illegal parameters, %s only supports [NONE, START, END]", truncateParamKey)
|
||||
@ -81,11 +78,6 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 && modelName != embedEnglishV20 && modelName != embedEnglishLightV20 && modelName != embedMultilingualV20 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
|
||||
modelName, embedEnglishV30, embedMultilingualV30, embedEnglishLightV30, embedMultilingualLightV30, embedEnglishV20, embedEnglishLightV20, embedMultilingualV20)
|
||||
}
|
||||
|
||||
c, err := createCohereEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -103,12 +95,6 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
||||
return "int8"
|
||||
}()
|
||||
|
||||
if outputType == "int8" {
|
||||
if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 {
|
||||
return nil, fmt.Errorf("Cohere text embedding model: [%s] doesn't supports int8. Valid for only v3 models.", modelName)
|
||||
}
|
||||
}
|
||||
|
||||
provider := CohereEmbeddingProvider{
|
||||
client: c,
|
||||
fieldDim: fieldDim,
|
||||
@ -132,7 +118,8 @@ func (provider *CohereEmbeddingProvider) FieldDim() int64 {
|
||||
|
||||
// Specifies the type of input passed to the model. Required for embedding models v3 and higher.
|
||||
func (provider *CohereEmbeddingProvider) getInputType(mode TextEmbeddingMode) string {
|
||||
if provider.modelName == embedEnglishV20 || provider.modelName == embedEnglishLightV20 || provider.modelName == embedMultilingualV20 {
|
||||
// v2 models not support instructor
|
||||
if strings.HasSuffix(provider.modelName, "v2.0") {
|
||||
return ""
|
||||
}
|
||||
if mode == InsertMode {
|
||||
|
||||
@ -69,7 +69,7 @@ func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: embedEnglishLightV30},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
},
|
||||
@ -143,25 +143,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingInt8() {
|
||||
s.Equal([][]int8{{0, 1, 2, 3}, {1, 2, 3, 4}, {2, 3, 4, 5}}, ret)
|
||||
}
|
||||
}
|
||||
|
||||
// Invalid model name
|
||||
{
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: embedEnglishLightV20},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
},
|
||||
}
|
||||
_, err := NewCohereEmbeddingProvider(int8VecField, functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
@ -278,7 +259,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: embedEnglishLightV20},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
@ -296,12 +277,6 @@ func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
|
||||
functionSchema.Params[2].Value = "Unknow"
|
||||
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
|
||||
// Invalid ModelName
|
||||
functionSchema.Params[2].Value = "END"
|
||||
functionSchema.Params[0].Value = "Unknow"
|
||||
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
|
||||
@ -313,7 +288,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: embedEnglishLightV20},
|
||||
{Key: modelNameParamKey, Value: "model-v2.0"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
@ -323,7 +298,7 @@ func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
|
||||
s.Equal(provider.getInputType(InsertMode), "")
|
||||
s.Equal(provider.getInputType(SearchMode), "")
|
||||
|
||||
functionSchema.Params[0].Value = embedEnglishLightV30
|
||||
functionSchema.Params[0].Value = "model-v3.0"
|
||||
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.NoError(err)
|
||||
s.Equal(provider.getInputType(InsertMode), "search_document")
|
||||
|
||||
@ -20,8 +20,11 @@ package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
@ -51,20 +54,12 @@ const (
|
||||
|
||||
// ali text embedding
|
||||
const (
|
||||
TextEmbeddingV1 string = "text-embedding-v1"
|
||||
TextEmbeddingV2 string = "text-embedding-v2"
|
||||
TextEmbeddingV3 string = "text-embedding-v3"
|
||||
|
||||
dashscopeAKEnvStr string = "MILVUSAI_DASHSCOPE_API_KEY"
|
||||
)
|
||||
|
||||
// openai/azure text embedding
|
||||
|
||||
const (
|
||||
TextEmbeddingAda002 string = "text-embedding-ada-002"
|
||||
TextEmbedding3Small string = "text-embedding-3-small"
|
||||
TextEmbedding3Large string = "text-embedding-3-large"
|
||||
|
||||
openaiAKEnvStr string = "MILVUSAI_OPENAI_API_KEY"
|
||||
|
||||
azureOpenaiAKEnvStr string = "MILVUSAI_AZURE_OPENAI_API_KEY"
|
||||
@ -76,11 +71,10 @@ const (
|
||||
// bedrock emebdding
|
||||
|
||||
const (
|
||||
BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0"
|
||||
awsAKIdParamKey string = "aws_access_key_id"
|
||||
awsSAKParamKey string = "aws_secret_access_key"
|
||||
regionParamKey string = "regin"
|
||||
normalizeParamKey string = "normalize"
|
||||
awsAKIdParamKey string = "aws_access_key_id"
|
||||
awsSAKParamKey string = "aws_secret_access_key"
|
||||
regionParamKey string = "regin"
|
||||
normalizeParamKey string = "normalize"
|
||||
|
||||
bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID"
|
||||
bedrockSAKEnvStr string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY"
|
||||
@ -93,48 +87,24 @@ const (
|
||||
projectIDParamKey string = "projectid"
|
||||
taskTypeParamKey string = "task"
|
||||
|
||||
textEmbedding005 string = "text-embedding-005"
|
||||
textMultilingualEmbedding002 string = "text-multilingual-embedding-002"
|
||||
|
||||
vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS"
|
||||
)
|
||||
|
||||
// voyageAI
|
||||
const (
|
||||
voyage3Large string = "voyage-3-large"
|
||||
voyage3 string = "voyage-3"
|
||||
voyage3Lite string = "voyage-3-lite"
|
||||
voyageCode3 string = "voyage-code-3"
|
||||
voyageFinance2 string = "voyage-finance-2"
|
||||
voyageLaw2 string = "voyage-law-2"
|
||||
voyageCode2 string = "voyage-code-2"
|
||||
|
||||
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
|
||||
truncationParamKey string = "truncation"
|
||||
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
|
||||
)
|
||||
|
||||
// cohere
|
||||
|
||||
const (
|
||||
embedEnglishV30 string = "embed-english-v3.0"
|
||||
embedMultilingualV30 string = "embed-multilingual-v3.0"
|
||||
embedEnglishLightV30 string = "embed-english-light-v3.0"
|
||||
embedMultilingualLightV30 string = "embed-multilingual-light-v3.0"
|
||||
embedEnglishV20 string = "embed-english-v2.0"
|
||||
embedEnglishLightV20 string = "embed-english-light-v2.0"
|
||||
embedMultilingualV20 string = "embed-multilingual-v2.0"
|
||||
|
||||
cohereAIAKEnvStr string = "MILVUSAI_COHERE_API_KEY"
|
||||
)
|
||||
|
||||
// siliconflow
|
||||
|
||||
const (
|
||||
bAAIBgeLargeZhV15 string = "BAAI/bge-large-zh-v1.5"
|
||||
bAAIBgeLargeEhV15 string = "BAAI/bge-large-eh-v1.5"
|
||||
neteaseYoudaoBceEmbeddingBasev1 string = "netease-youdao/bce-embedding-base_v1"
|
||||
bAAIBgeM3 string = "BAAI/bge-m3"
|
||||
proBAAIBgeM3 string = "Pro/BAAI/bge-m3 "
|
||||
|
||||
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
|
||||
)
|
||||
|
||||
@ -150,6 +120,23 @@ const (
|
||||
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
|
||||
)
|
||||
|
||||
const enableConfigAKAndURL string = "ENABLE_CONFIG_AK_AND_URL"
|
||||
|
||||
func parseAKAndURL(params []*commonpb.KeyValuePair) (string, string) {
|
||||
var apiKey, url string
|
||||
if strings.ToLower(os.Getenv(enableConfigAKAndURL)) != "false" {
|
||||
for _, param := range params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
return apiKey, url
|
||||
}
|
||||
|
||||
func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
|
||||
@ -20,6 +20,7 @@ package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
@ -27,6 +28,22 @@ import (
|
||||
type FunctionBase struct {
|
||||
schema *schemapb.FunctionSchema
|
||||
outputFields []*schemapb.FieldSchema
|
||||
|
||||
collectionName string
|
||||
functionTypeName string
|
||||
functionName string
|
||||
provider string
|
||||
}
|
||||
|
||||
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case Provider:
|
||||
return strings.ToLower(param.Value), nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
|
||||
}
|
||||
|
||||
func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) {
|
||||
@ -45,6 +62,15 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.Function
|
||||
return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema",
|
||||
coll.Name, fSchema.Name)
|
||||
}
|
||||
|
||||
provider, err := getProvider(fSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base.collectionName = coll.Name
|
||||
base.functionName = fSchema.Name
|
||||
base.provider = provider
|
||||
base.functionTypeName = fSchema.GetType().String()
|
||||
return &base, nil
|
||||
}
|
||||
|
||||
|
||||
@ -19,7 +19,9 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
@ -27,18 +29,26 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
GetSchema() *schemapb.FunctionSchema
|
||||
GetOutputFields() []*schemapb.FieldSchema
|
||||
GetCollectionName() string
|
||||
GetFunctionTypeName() string
|
||||
GetFunctionName() string
|
||||
GetFunctionProvider() string
|
||||
Check() error
|
||||
|
||||
MaxBatch() int
|
||||
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
|
||||
ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
|
||||
ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
|
||||
ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
|
||||
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
|
||||
}
|
||||
|
||||
@ -64,9 +74,18 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc
|
||||
// Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here.
|
||||
func ValidateFunctions(schema *schemapb.CollectionSchema) error {
|
||||
for _, fSchema := range schema.Functions {
|
||||
if _, err := createFunction(schema, fSchema); err != nil {
|
||||
f, err := createFunction(schema, fSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ignore bm25 function
|
||||
if f == nil {
|
||||
continue
|
||||
}
|
||||
if err := f.Check(); err != nil {
|
||||
return fmt.Errorf("Check function [%s:%s] failed, the err is: %v", fSchema.Name, fSchema.GetType().String(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -87,7 +106,7 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor,
|
||||
return executor, nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
|
||||
func (executor *FunctionExecutor) processSingleFunction(ctx context.Context, runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
|
||||
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames()))
|
||||
for _, name := range runner.GetSchema().GetInputFieldNames() {
|
||||
for _, field := range msg.FieldsData {
|
||||
@ -100,14 +119,18 @@ func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgs
|
||||
return nil, fmt.Errorf("Input field not found")
|
||||
}
|
||||
|
||||
outputs, err := runner.ProcessInsert(inputs)
|
||||
tr := timerecord.NewTimeRecorder("function ProcessInsert")
|
||||
outputs, err := runner.ProcessInsert(ctx, inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
tr.CtxElapse(ctx, "function ProcessInsert done")
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error {
|
||||
func (executor *FunctionExecutor) ProcessInsert(ctx context.Context, msg *msgstream.InsertMsg) error {
|
||||
numRows := msg.NumRows
|
||||
for _, runner := range executor.runners {
|
||||
if numRows > uint64(runner.MaxBatch()) {
|
||||
@ -122,7 +145,7 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error
|
||||
wg.Add(1)
|
||||
go func(runner Runner) {
|
||||
defer wg.Done()
|
||||
data, err := executor.processSingleFunction(runner, msg)
|
||||
data, err := executor.processSingleFunction(ctx, runner, msg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
@ -149,7 +172,7 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) {
|
||||
func (executor *FunctionExecutor) processSingleSearch(ctx context.Context, runner Runner, placeholderGroup []byte) ([]byte, error) {
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroup, pb)
|
||||
if len(pb.Placeholders) != 1 {
|
||||
@ -158,14 +181,18 @@ func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholder
|
||||
if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar {
|
||||
return placeholderGroup, nil
|
||||
}
|
||||
res, err := runner.ProcessSearch(pb)
|
||||
|
||||
tr := timerecord.NewTimeRecorder("function ProcessSearch")
|
||||
res, err := runner.ProcessSearch(ctx, pb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
tr.CtxElapse(ctx, "function ProcessSearch done")
|
||||
return proto.Marshal(res)
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error {
|
||||
func (executor *FunctionExecutor) prcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
|
||||
runner, exist := executor.runners[req.FieldId]
|
||||
if !exist {
|
||||
return fmt.Errorf("Can not found function in field %d", req.FieldId)
|
||||
@ -173,7 +200,7 @@ func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) er
|
||||
if req.Nq > int64(runner.MaxBatch()) {
|
||||
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch())
|
||||
}
|
||||
if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil {
|
||||
if newHolder, err := executor.processSingleSearch(ctx, runner, req.GetPlaceholderGroup()); err != nil {
|
||||
return err
|
||||
} else {
|
||||
req.PlaceholderGroup = newHolder
|
||||
@ -181,7 +208,7 @@ func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error {
|
||||
func (executor *FunctionExecutor) prcessAdvanceSearch(ctx context.Context, req *internalpb.SearchRequest) error {
|
||||
outputs := make(chan map[int64][]byte, len(req.GetSubReqs()))
|
||||
errChan := make(chan error, len(req.GetSubReqs()))
|
||||
var wg sync.WaitGroup
|
||||
@ -193,7 +220,7 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ
|
||||
wg.Add(1)
|
||||
go func(runner Runner, idx int64, placeholderGroup []byte) {
|
||||
defer wg.Done()
|
||||
if newHolder, err := executor.processSingleSearch(runner, placeholderGroup); err != nil {
|
||||
if newHolder, err := executor.processSingleSearch(ctx, runner, placeholderGroup); err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
outputs <- map[int64][]byte{idx: newHolder}
|
||||
@ -216,11 +243,11 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ
|
||||
return nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error {
|
||||
func (executor *FunctionExecutor) ProcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
|
||||
if !req.IsAdvanced {
|
||||
return executor.prcessSearch(req)
|
||||
return executor.prcessSearch(ctx, req)
|
||||
}
|
||||
return executor.prcessAdvanceSearch(req)
|
||||
return executor.prcessAdvanceSearch(ctx, req)
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -148,7 +149,7 @@ func (s *FunctionExecutorSuite) TestExecutor() {
|
||||
exec, err := NewFunctionExecutor(schema)
|
||||
s.NoError(err)
|
||||
msg := s.createMsg([]string{"sentence", "sentence"})
|
||||
exec.ProcessInsert(msg)
|
||||
exec.ProcessInsert(context.Background(), msg)
|
||||
s.Equal(len(msg.FieldsData), 3)
|
||||
}
|
||||
|
||||
@ -183,7 +184,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
|
||||
exec, err := NewFunctionExecutor(schema)
|
||||
s.NoError(err)
|
||||
msg := s.createMsg([]string{"sentence", "sentence"})
|
||||
err = exec.ProcessInsert(msg)
|
||||
err = exec.ProcessInsert(context.Background(), msg)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
@ -225,7 +226,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
|
||||
IsAdvanced: false,
|
||||
FieldId: 102,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.NoError(err)
|
||||
|
||||
// No function found
|
||||
@ -235,7 +236,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
|
||||
IsAdvanced: false,
|
||||
FieldId: 111,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.Error(err)
|
||||
|
||||
// Large search nq
|
||||
@ -245,7 +246,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
|
||||
IsAdvanced: false,
|
||||
FieldId: 102,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
@ -277,12 +278,12 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
|
||||
IsAdvanced: true,
|
||||
SubReqs: []*internalpb.SubSearchRequest{subReq},
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.NoError(err)
|
||||
|
||||
// Large nq
|
||||
subReq.Nq = 1000
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
@ -318,7 +319,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
|
||||
IsAdvanced: false,
|
||||
FieldId: 102,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.Error(err)
|
||||
}
|
||||
// AdvanceSearch
|
||||
@ -332,7 +333,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
|
||||
IsAdvanced: true,
|
||||
SubReqs: []*internalpb.SubSearchRequest{subReq},
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
err = exec.ProcessSearch(context.Background(), req)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,18 +36,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
|
||||
)
|
||||
|
||||
// func mockEmbedding(texts []string, dim int) [][]float32 {
|
||||
// embeddings := make([][]float32, 0)
|
||||
// for i := 0; i < len(texts); i++ {
|
||||
// f := float32(i)
|
||||
// emb := make([]float32, 0)
|
||||
// for j := 0; j < dim; j++ {
|
||||
// emb = append(emb, f+float32(j)*0.1)
|
||||
// }
|
||||
// embeddings = append(embeddings, emb)
|
||||
// }
|
||||
// return embeddings
|
||||
// }
|
||||
const TestModel string = "TestModel"
|
||||
|
||||
func mockEmbedding[T int8 | float32](texts []string, dim int) [][]T {
|
||||
embeddings := make([][]T, 0)
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -135,14 +134,11 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -100,15 +99,12 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
headers := map[string]string{
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("bearer %s", c.apiKey),
|
||||
}
|
||||
|
||||
req.Header.Set("accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("bearer %s", c.apiKey))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -148,18 +147,10 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var res EmbeddingResponse
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
|
||||
@ -218,7 +218,8 @@ func TestEmbeddingFailed(t *testing.T) {
|
||||
func TestTimeout(t *testing.T) {
|
||||
var st int32 = 0
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(3 * time.Second)
|
||||
// (Timeout 1s + Wait 1s) * Retry 3
|
||||
time.Sleep(6 * time.Second)
|
||||
atomic.AddInt32(&st, 1)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package siliconflow
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -115,14 +114,12 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package tei
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -46,10 +45,10 @@ func NewTEIEmbeddingClient(apiKey string, endpoint string) (*TEIEmbedding, error
|
||||
return nil, err
|
||||
}
|
||||
if base.Scheme != "http" && base.Scheme != "https" {
|
||||
return nil, fmt.Errorf("%s is not a valid http/https link", endpoint)
|
||||
return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint)
|
||||
}
|
||||
if base.Host == "" {
|
||||
return nil, fmt.Errorf("%s is not a valid http/https link", endpoint)
|
||||
return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint)
|
||||
}
|
||||
|
||||
base.Path = "/embed"
|
||||
@ -88,17 +87,13 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey)
|
||||
}
|
||||
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -17,9 +17,12 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultTimeout int64 = 30
|
||||
@ -33,23 +36,31 @@ func send(req *http.Request) ([]byte, error) {
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Call service faild, read response failed, errs:[%v]", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Call %s faild, errs:[%s, %s]", req.URL, resp.Status, body)
|
||||
return nil, fmt.Errorf("Call service faild, errs:[%s, %s]", resp.Status, body)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func RetrySend(req *http.Request, maxRetries int) ([]byte, error) {
|
||||
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) {
|
||||
var err error
|
||||
var res []byte
|
||||
var body []byte
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
res, err = send(req)
|
||||
if err == nil {
|
||||
return res, nil
|
||||
req, reqErr := http.NewRequestWithContext(ctx, httpMethod, url, bytes.NewBuffer(data))
|
||||
if reqErr != nil {
|
||||
return nil, reqErr
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
body, err = send(req)
|
||||
if err == nil {
|
||||
return body, nil
|
||||
}
|
||||
time.Sleep(time.Duration(retryDelay) * time.Second)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -134,10 +133,6 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var token string
|
||||
if c.token != "" {
|
||||
token = c.token
|
||||
@ -148,9 +143,11 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", token),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package voyageai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -108,7 +107,7 @@ func (c *VoyageAIEmbedding) Check() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (any, error) {
|
||||
func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, truncation bool, timeoutSec int64) (any, error) {
|
||||
if outputType != "float" && outputType != "int8" {
|
||||
return nil, fmt.Errorf("Voyageai: unsupport output type: [%s], only support float and int8", outputType)
|
||||
}
|
||||
@ -117,9 +116,11 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
|
||||
r.Input = texts
|
||||
r.InputType = textType
|
||||
r.OutputDtype = outputType
|
||||
r.Truncation = truncation
|
||||
if dim != 0 {
|
||||
r.OutputDimension = int64(dim)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -131,14 +132,11 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -98,7 +98,7 @@ func TestEmbeddingOK(t *testing.T) {
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
|
||||
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", true, 0)
|
||||
ret := r.(*EmbeddingResponse[float32])
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Data[0].Index, 0)
|
||||
@ -158,7 +158,7 @@ func TestEmbeddingInt8Embed(t *testing.T) {
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "int8", 0)
|
||||
r, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "int8", false, 0)
|
||||
ret := r.(*EmbeddingResponse[int8])
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Data[0].Index, 0)
|
||||
@ -169,7 +169,7 @@ func TestEmbeddingInt8Embed(t *testing.T) {
|
||||
assert.Equal(t, ret.Data[1].Embedding, []int8{3, 4})
|
||||
assert.Equal(t, ret.Data[2].Embedding, []int8{5, 6})
|
||||
|
||||
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "unknow", 0)
|
||||
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "unknow", true, 0)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -186,7 +186,7 @@ func TestEmbeddingFailed(t *testing.T) {
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
|
||||
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", false, 0)
|
||||
assert.True(t, err != nil)
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,7 +81,8 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName, user string
|
||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
||||
var modelName, user string
|
||||
var dim int64
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
@ -95,21 +96,12 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
|
||||
}
|
||||
case userParamKey:
|
||||
user = param.Value
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
var c openai.OpenAIEmbeddingInterface
|
||||
if !isAzure {
|
||||
if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
|
||||
modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large)
|
||||
}
|
||||
|
||||
c, err = createOpenAIEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -60,25 +60,17 @@ func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, function
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName string
|
||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
||||
var modelName string
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != bAAIBgeLargeZhV15 && modelName != bAAIBgeLargeEhV15 && modelName != neteaseYoudaoBceEmbeddingBasev1 && modelName != bAAIBgeM3 && modelName != proBAAIBgeM3 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s]",
|
||||
modelName, bAAIBgeLargeZhV15, bAAIBgeLargeEhV15, neteaseYoudaoBceEmbeddingBasev1, bAAIBgeM3, proBAAIBgeM3)
|
||||
}
|
||||
|
||||
c, err := createSiliconflowEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -69,7 +69,7 @@ func createSiliconflowProvider(url string, schema *schemapb.FieldSchema, provide
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: bAAIBgeLargeEhV15},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
},
|
||||
@ -187,7 +187,7 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: bAAIBgeLargeEhV15},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
@ -196,11 +196,4 @@ func (s *SiliconflowTextEmbeddingProviderSuite) TestNewSiliconflowEmbeddingProvi
|
||||
s.NoError(err)
|
||||
s.Equal(provider.FieldDim(), int64(4))
|
||||
s.True(provider.MaxBatch() > 0)
|
||||
|
||||
// Invalid model
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
_, err := NewSiliconflowEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,8 +19,9 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"reflect"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
@ -44,6 +45,13 @@ const (
|
||||
teiProvider string = "tei"
|
||||
)
|
||||
|
||||
func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error {
|
||||
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
|
||||
return fmt.Errorf("TextEmbedding function output field must be a FloatVector or Int8Vector field")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Text embedding for retrieval task
|
||||
type textEmbeddingProvider interface {
|
||||
MaxBatch() int
|
||||
@ -51,17 +59,6 @@ type textEmbeddingProvider interface {
|
||||
FieldDim() int64
|
||||
}
|
||||
|
||||
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case Provider:
|
||||
return strings.ToLower(param.Value), nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
|
||||
}
|
||||
|
||||
type TextEmbeddingFunction struct {
|
||||
FunctionBase
|
||||
|
||||
@ -82,20 +79,13 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if base.outputFields[0].DataType != schemapb.DataType_FloatVector && base.outputFields[0].DataType != schemapb.DataType_Int8Vector {
|
||||
return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s, %s], got [%s]",
|
||||
schemapb.DataType_name[int32(schemapb.DataType_FloatVector)],
|
||||
schemapb.DataType_name[int32(schemapb.DataType_Int8Vector)],
|
||||
schemapb.DataType_name[int32(base.outputFields[0].DataType)])
|
||||
}
|
||||
|
||||
provider, err := getProvider(functionSchema)
|
||||
if err != nil {
|
||||
if err := TextEmbeddingOutputsCheck(base.outputFields); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var embP textEmbeddingProvider
|
||||
var newProviderErr error
|
||||
switch provider {
|
||||
switch base.provider {
|
||||
case openAIProvider:
|
||||
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
||||
case azureOpenAIProvider:
|
||||
@ -115,7 +105,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
|
||||
case teiProvider:
|
||||
embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
|
||||
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
|
||||
}
|
||||
|
||||
if newProviderErr != nil {
|
||||
@ -127,10 +117,46 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) Check() error {
|
||||
embds, err := runner.embProvider.CallEmbedding([]string{"check"}, InsertMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dim := 0
|
||||
switch embds := embds.(type) {
|
||||
case [][]float32:
|
||||
dim = len(embds[0])
|
||||
case [][]int8:
|
||||
dim = len(embds[0])
|
||||
default:
|
||||
return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String())
|
||||
}
|
||||
if dim != int(runner.embProvider.FieldDim()) {
|
||||
return fmt.Errorf("The dim set in the schema is inconsistent with the dim of the model, dim in schema is %d, dim of model is %d", runner.embProvider.FieldDim(), dim)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) MaxBatch() int {
|
||||
return runner.embProvider.MaxBatch()
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) GetCollectionName() string {
|
||||
return runner.collectionName
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) GetFunctionProvider() string {
|
||||
return runner.provider
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) GetFunctionTypeName() string {
|
||||
return runner.functionTypeName
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) GetFunctionName() string {
|
||||
return runner.functionName
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.FieldData, error) {
|
||||
var outputField schemapb.FieldData
|
||||
outputField.FieldId = runner.GetOutputFields()[0].FieldID
|
||||
@ -174,7 +200,7 @@ func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.Fie
|
||||
return []*schemapb.FieldData{&outputField}, nil
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
|
||||
func (runner *TextEmbeddingFunction) ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
|
||||
if len(inputs) != 1 {
|
||||
return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs))
|
||||
}
|
||||
@ -199,7 +225,7 @@ func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData)
|
||||
return runner.packToFieldData(embds)
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
|
||||
func (runner *TextEmbeddingFunction) ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
|
||||
texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally
|
||||
numRows := len(texts)
|
||||
if numRows > runner.MaxBatch() {
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -126,7 +127,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
ret, err2 := runner.ProcessInsert(context.Background(), data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
@ -134,7 +135,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
|
||||
}
|
||||
{
|
||||
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
|
||||
ret, _ := runner.ProcessInsert(data)
|
||||
ret, _ := runner.ProcessInsert(context.Background(), data)
|
||||
s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
}
|
||||
@ -158,7 +159,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
ret, err2 := runner.ProcessInsert(context.Background(), data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
@ -166,7 +167,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
|
||||
}
|
||||
{
|
||||
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
|
||||
ret, _ := runner.ProcessInsert(data)
|
||||
ret, _ := runner.ProcessInsert(context.Background(), data)
|
||||
s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
}
|
||||
@ -185,7 +186,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: aliDashScopeProvider},
|
||||
{Key: modelNameParamKey, Value: TextEmbeddingV3},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
@ -195,7 +196,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
ret, err2 := runner.ProcessInsert(context.Background(), data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
@ -203,7 +204,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
}
|
||||
{
|
||||
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
|
||||
ret, _ := runner.ProcessInsert(data)
|
||||
ret, _ := runner.ProcessInsert(context.Background(), data)
|
||||
s.Equal([]float32{0.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
|
||||
@ -226,7 +227,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
}
|
||||
data = append(data, &f)
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
_, err := runner.ProcessInsert(context.Background(), data)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
@ -242,7 +243,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
_, err := runner.ProcessInsert(context.Background(), data)
|
||||
s.Error(err)
|
||||
}
|
||||
// empty input
|
||||
@ -257,7 +258,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
_, err := runner.ProcessInsert(context.Background(), data)
|
||||
s.Error(err)
|
||||
}
|
||||
// large input data
|
||||
@ -278,7 +279,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
_, err := runner.ProcessInsert(context.Background(), data)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
@ -377,26 +378,6 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// error model name
|
||||
{
|
||||
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-004"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// no openai api key
|
||||
{
|
||||
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
@ -426,7 +407,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: bedrockProvider},
|
||||
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: awsAKIdParamKey, Value: "mock"},
|
||||
{Key: awsSAKParamKey, Value: "mock"},
|
||||
{Key: regionParamKey, Value: "mock"},
|
||||
@ -450,7 +431,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: aliDashScopeProvider},
|
||||
{Key: modelNameParamKey, Value: TextEmbeddingV1},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
@ -472,7 +453,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: voyageAIProvider},
|
||||
{Key: modelNameParamKey, Value: voyage3},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
@ -494,7 +475,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: siliconflowProvider},
|
||||
{Key: modelNameParamKey, Value: bAAIBgeLargeZhV15},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
@ -516,7 +497,7 @@ func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: cohereProvider},
|
||||
{Key: modelNameParamKey, Value: embedEnglishLightV20},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
@ -640,7 +621,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
|
||||
s.NoError(err)
|
||||
placeholderGroup := commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
|
||||
_, err = runner.ProcessSearch(&placeholderGroup)
|
||||
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
@ -665,7 +646,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchFloat32() {
|
||||
s.NoError(err)
|
||||
placeholderGroup := commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
|
||||
_, err = runner.ProcessSearch(&placeholderGroup)
|
||||
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
@ -696,7 +677,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: cohereProvider},
|
||||
{Key: modelNameParamKey, Value: embedEnglishV30},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
@ -706,7 +687,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessInsertInt8() {
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
ret, err2 := runner.ProcessInsert(context.Background(), data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
@ -743,7 +724,7 @@ func (s *TextEmbeddingFunctionSuite) TestUnsupportedVec() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: cohereProvider},
|
||||
{Key: modelNameParamKey, Value: embedEnglishV30},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
@ -778,7 +759,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: cohereProvider},
|
||||
{Key: modelNameParamKey, Value: embedEnglishV30},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
@ -807,7 +788,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessSearchInt8() {
|
||||
s.NoError(err)
|
||||
placeholderGroup := commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
|
||||
_, err = runner.ProcessSearch(&placeholderGroup)
|
||||
_, err = runner.ProcessSearch(context.Background(), &placeholderGroup)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
@ -880,7 +861,7 @@ func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsertInt8() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: cohereProvider},
|
||||
{Key: modelNameParamKey, Value: embedEnglishV30},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
|
||||
@ -42,7 +42,7 @@ func getVertexAIJsonKey() ([]byte, error) {
|
||||
jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv)
|
||||
jsonKey, err := os.ReadFile(jsonKeyPath)
|
||||
if err != nil {
|
||||
vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err)
|
||||
vtxKey.initErr = fmt.Errorf("Vertexai: read service account json file failed, %v", err)
|
||||
return
|
||||
}
|
||||
vtxKey.jsonKey = jsonKey
|
||||
@ -56,16 +56,6 @@ const (
|
||||
vertexAISTS string = "STS"
|
||||
)
|
||||
|
||||
func checkTask(modelName string, task string) error {
|
||||
if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS {
|
||||
return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS)
|
||||
}
|
||||
if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival {
|
||||
return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type VertexAIEmbeddingProvider struct {
|
||||
fieldDim int64
|
||||
|
||||
@ -117,18 +107,11 @@ func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
|
||||
if task == "" {
|
||||
task = vertexAIDocRetrival
|
||||
}
|
||||
if err := checkTask(modelName, task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if location == "" {
|
||||
location = "us-central1"
|
||||
}
|
||||
|
||||
if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]",
|
||||
modelName, textEmbedding005, textMultilingualEmbedding002)
|
||||
}
|
||||
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName)
|
||||
var client *vertexai.VertexAIEmbedding
|
||||
if c == nil {
|
||||
|
||||
@ -67,7 +67,7 @@ func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbed
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: textEmbedding005},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: locationParamKey, Value: "mock_local"},
|
||||
{Key: projectIDParamKey, Value: "mock_id"},
|
||||
{Key: taskTypeParamKey, Value: vertexAICodeRetrival},
|
||||
@ -174,21 +174,6 @@ func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
|
||||
s.Error(err2)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestCheckVertexAITask() {
|
||||
err := checkTask(textMultilingualEmbedding002, "UnkownTask")
|
||||
s.Error(err)
|
||||
|
||||
// textMultilingualEmbedding002 not support vertexAICodeRetrival task
|
||||
err = checkTask(textMultilingualEmbedding002, vertexAICodeRetrival)
|
||||
s.Error(err)
|
||||
|
||||
err = checkTask(textEmbedding005, vertexAICodeRetrival)
|
||||
s.NoError(err)
|
||||
|
||||
err = checkTask(textMultilingualEmbedding002, vertexAISTS)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
|
||||
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
||||
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
||||
@ -205,7 +190,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: textEmbedding005},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: projectIDParamKey, Value: "mock_id"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
@ -234,13 +219,6 @@ func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
||||
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
|
||||
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
|
||||
}
|
||||
|
||||
// invalid task
|
||||
{
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: "UnkownTask"}
|
||||
_, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
|
||||
@ -259,7 +237,7 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider()
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: textEmbedding005},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: projectIDParamKey, Value: "mock_id"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
@ -269,9 +247,4 @@ func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider()
|
||||
s.NoError(err)
|
||||
s.True(provider.MaxBatch() > 0)
|
||||
s.Equal(provider.FieldDim(), int64(4))
|
||||
|
||||
// check model name
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
_, err = NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ package function
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
@ -34,6 +35,7 @@ type VoyageAIEmbeddingProvider struct {
|
||||
client *voyageai.VoyageAIEmbedding
|
||||
modelName string
|
||||
embedDimParam int64
|
||||
truncate bool
|
||||
embdType embeddingType
|
||||
outputType string
|
||||
|
||||
@ -62,8 +64,10 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName string
|
||||
apiKey, url := parseAKAndURL(functionSchema.Params)
|
||||
var modelName string
|
||||
dim := int64(0)
|
||||
truncate := false
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
@ -75,27 +79,14 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
case truncationParamKey:
|
||||
if truncate, err = strconv.ParseBool(param.Value); err != nil {
|
||||
return nil, fmt.Errorf("[%s param's value: %s] is invalid, only supports: [true/false]", truncationParamKey, param.Value)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != voyage3Large && modelName != voyage3 && modelName != voyage3Lite && modelName != voyageCode3 && modelName != voyageFinance2 && modelName != voyageLaw2 && modelName != voyageCode2 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
|
||||
modelName, voyage3Large, voyage3, voyage3Lite, voyageCode3, voyageFinance2, voyageLaw2, voyageCode2)
|
||||
}
|
||||
|
||||
if dim != 0 {
|
||||
if modelName != voyage3Large && modelName != voyageCode3 {
|
||||
return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports dim parameter, only [%s, %s] support it.", modelName, voyage3, voyageCode3)
|
||||
}
|
||||
if dim != 1024 && dim != 256 && dim != 512 && dim != 2048 {
|
||||
return nil, fmt.Errorf("VoyageAI text embedding model's dim only supports 2048, 1024 (default), 512, and 256.")
|
||||
}
|
||||
}
|
||||
c, err := createVoyageAIEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -113,16 +104,11 @@ func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSch
|
||||
return "int8"
|
||||
}()
|
||||
|
||||
if outputType == "int8" {
|
||||
if modelName != voyage3Large && modelName != voyageCode3 {
|
||||
return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports int8 output_dtype, only [%s, %s] support it.", modelName, voyage3, voyageCode3)
|
||||
}
|
||||
}
|
||||
|
||||
provider := VoyageAIEmbeddingProvider{
|
||||
client: c,
|
||||
fieldDim: fieldDim,
|
||||
modelName: modelName,
|
||||
truncate: truncate,
|
||||
embedDimParam: dim,
|
||||
embdType: embdType,
|
||||
outputType: outputType,
|
||||
@ -155,7 +141,7 @@ func (provider *VoyageAIEmbeddingProvider) CallEmbedding(texts []string, mode Te
|
||||
if end > numRows {
|
||||
end = numRows
|
||||
}
|
||||
r, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec)
|
||||
r, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.truncate, provider.timeoutSec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -69,7 +69,7 @@ func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerNa
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: voyage3Large},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: dimParamKey, Value: "1024"},
|
||||
@ -136,26 +136,6 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingIn8() {
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Invalid model name
|
||||
{
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: voyage3Lite},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
{Key: dimParamKey, Value: "1024"},
|
||||
},
|
||||
}
|
||||
_, err := NewCohereEmbeddingProvider(int8VecField, functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
@ -318,10 +298,11 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: voyage3Large},
|
||||
{Key: modelNameParamKey, Value: TestModel},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "1024"},
|
||||
{Key: truncationParamKey, Value: "true"},
|
||||
},
|
||||
}
|
||||
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
@ -329,11 +310,12 @@ func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider()
|
||||
s.Equal(provider.FieldDim(), int64(1024))
|
||||
s.True(provider.MaxBatch() > 0)
|
||||
|
||||
// Invalid model
|
||||
// Invalid truncation
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "Invalid"}
|
||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
functionSchema.Params[4] = &commonpb.KeyValuePair{Key: truncationParamKey, Value: "false"}
|
||||
}
|
||||
|
||||
// Invalid dim
|
||||
|
||||
@ -127,6 +127,11 @@ const (
|
||||
cgoTypeLabelName = `cgo_type`
|
||||
queueTypeLabelName = `queue_type`
|
||||
|
||||
// model function/UDF labels
|
||||
functionTypeName = "function_type_name"
|
||||
functionProvider = "function_provider"
|
||||
functionName = "function_name"
|
||||
|
||||
// entities label
|
||||
LoadedLabel = "loaded"
|
||||
NumEntitiesAllLabel = "all"
|
||||
|
||||
@ -445,6 +445,15 @@ var (
|
||||
Help: "the latency of parse expression",
|
||||
Buckets: buckets,
|
||||
}, []string{nodeIDLabelName, functionLabelName, statusLabelName})
|
||||
// ProxyFunctionlatency records the latency of function
|
||||
ProxyFunctionlatency = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: milvusNamespace,
|
||||
Subsystem: typeutil.ProxyRole,
|
||||
Name: "function_udf_call_latency",
|
||||
Help: "latency of function call",
|
||||
Buckets: buckets,
|
||||
}, []string{nodeIDLabelName, collectionName, functionTypeName, functionProvider, functionName})
|
||||
)
|
||||
|
||||
// RegisterProxy registers Proxy metrics
|
||||
@ -512,6 +521,8 @@ func RegisterProxy(registry *prometheus.Registry) {
|
||||
|
||||
registry.MustRegister(ProxyParseExpressionLatency)
|
||||
|
||||
registry.MustRegister(ProxyFunctionlatency)
|
||||
|
||||
RegisterStreamingServiceClient(registry)
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user