mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-06 02:42:53 +08:00
enhance: Add embedding model and schema field type checks (#46421)
https://github.com/milvus-io/milvus/issues/46415 - Add output type validation when creating functions - Fix improper error handling in bulk insert tasks Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
7e4f87e351
commit
617a77b0bd
@ -244,6 +244,7 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error {
|
||||
if !importutilv2.IsBackup(t.req.GetOptions()) {
|
||||
err = RunEmbeddingFunction(t, data)
|
||||
if err != nil {
|
||||
log.Warn("run embedding function failed", WrapLogFields(t, zap.Error(err))...)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -424,6 +424,7 @@ func FillDynamicData(schema *schemapb.CollectionSchema, data *storage.InsertData
|
||||
}
|
||||
|
||||
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
|
||||
log.Info("start to run embedding function")
|
||||
if err := RunDenseEmbedding(task, data); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -435,14 +436,18 @@ func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
|
||||
}
|
||||
|
||||
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
|
||||
log.Info("start to run dense embedding")
|
||||
schema := task.GetSchema()
|
||||
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
|
||||
log.Info("allowNonBM25Outputs", zap.Any("allowNonBM25Outputs", allowNonBM25Outputs))
|
||||
fieldIDs := lo.Keys(data.Data)
|
||||
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("needProcessFunctions", zap.Any("needProcessFunctions", needProcessFunctions))
|
||||
if embedding.HasNonBM25Functions(schema.Functions, []int64{}) {
|
||||
log.Info("has non bm25 functions")
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: task.req.ClusterID,
|
||||
DBName: task.req.Schema.DbName,
|
||||
@ -454,11 +459,13 @@ func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
|
||||
if err := exec.ProcessBulkInsert(context.Background(), data); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("end to run dense embedding")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunBm25Function(task *ImportTask, data *storage.InsertData) error {
|
||||
log.Info("start to run bm25 function")
|
||||
fns := task.GetSchema().GetFunctions()
|
||||
for _, fn := range fns {
|
||||
runner, err := function.NewFunctionRunner(task.GetSchema(), fn)
|
||||
|
||||
@ -277,7 +277,7 @@ func (executor *FunctionExecutor) ProcessBulkInsert(ctx context.Context, data *s
|
||||
for _, runner := range executor.runners {
|
||||
output, err := executor.processSingleBulkInsert(ctx, runner, data)
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
for k, v := range output {
|
||||
data.Data[k] = v
|
||||
|
||||
@ -161,8 +161,14 @@ func (runner *TextEmbeddingFunction) Check(ctx context.Context) error {
|
||||
switch embds := embds.(type) {
|
||||
case [][]float32:
|
||||
dim = len(embds[0])
|
||||
if runner.GetOutputFields()[0].DataType != schemapb.DataType_FloatVector {
|
||||
return fmt.Errorf("Embedding model output and field type mismatch, model output is %s, field type is %s", schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)])
|
||||
}
|
||||
case [][]int8:
|
||||
dim = len(embds[0])
|
||||
if runner.GetOutputFields()[0].DataType != schemapb.DataType_Int8Vector {
|
||||
return fmt.Errorf("Embedding model output and field type mismatch, model output is %s, field type is %s", schemapb.DataType_name[int32(schemapb.DataType_Int8Vector)], schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)])
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String())
|
||||
}
|
||||
|
||||
@ -1023,3 +1023,44 @@ func (s *TextEmbeddingFunctionSuite) TestDisable() {
|
||||
}, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"})
|
||||
s.ErrorContains(err, "Text embedding model provider [openai] is disabled")
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestCheck() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_Int8Vector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
|
||||
key := openAIProvider + "." + models.URLParamKey
|
||||
return map[string]string{
|
||||
key: ts.URL,
|
||||
}
|
||||
}
|
||||
runner, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: models.ModelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: models.DimParamKey, Value: "4"},
|
||||
{Key: models.CredentialParamKey, Value: "mock"},
|
||||
},
|
||||
}, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"})
|
||||
s.NoError(err)
|
||||
err = runner.Check(context.Background())
|
||||
s.ErrorContains(err, "Embedding model output and field type mismatch, model output is FloatVector, field type is Int8Vector")
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user