diff --git a/internal/datanode/importv2/task_import.go b/internal/datanode/importv2/task_import.go index 702a2d62d3..7491007f77 100644 --- a/internal/datanode/importv2/task_import.go +++ b/internal/datanode/importv2/task_import.go @@ -244,6 +244,7 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error { if !importutilv2.IsBackup(t.req.GetOptions()) { err = RunEmbeddingFunction(t, data) if err != nil { + log.Warn("run embedding function failed", WrapLogFields(t, zap.Error(err))...) return err } } diff --git a/internal/datanode/importv2/util.go b/internal/datanode/importv2/util.go index 551e015f1b..1e103f3bb2 100644 --- a/internal/datanode/importv2/util.go +++ b/internal/datanode/importv2/util.go @@ -424,6 +424,7 @@ func FillDynamicData(schema *schemapb.CollectionSchema, data *storage.InsertData } func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error { + log.Info("start to run embedding function") if err := RunDenseEmbedding(task, data); err != nil { return err } @@ -435,14 +436,18 @@ func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error { } func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error { + log.Info("start to run dense embedding") schema := task.GetSchema() allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties) + log.Info("allowNonBM25Outputs", zap.Any("allowNonBM25Outputs", allowNonBM25Outputs)) fieldIDs := lo.Keys(data.Data) needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, false) if err != nil { return err } + log.Info("needProcessFunctions", zap.Any("needProcessFunctions", needProcessFunctions)) if embedding.HasNonBM25Functions(schema.Functions, []int64{}) { + log.Info("has non bm25 functions") extraInfo := &models.ModelExtraInfo{ ClusterID: task.req.ClusterID, DBName: task.req.Schema.DbName, @@ -454,11 +459,13 @@ func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error { if err := exec.ProcessBulkInsert(context.Background(), data); err != nil { return err } + log.Info("end to run dense embedding") } return nil } func RunBm25Function(task *ImportTask, data *storage.InsertData) error { + log.Info("start to run bm25 function") fns := task.GetSchema().GetFunctions() for _, fn := range fns { runner, err := function.NewFunctionRunner(task.GetSchema(), fn) diff --git a/internal/util/function/embedding/function_executor.go b/internal/util/function/embedding/function_executor.go index 3e0be2e4b2..1e6a350e7c 100644 --- a/internal/util/function/embedding/function_executor.go +++ b/internal/util/function/embedding/function_executor.go @@ -277,7 +277,7 @@ func (executor *FunctionExecutor) ProcessBulkInsert(ctx context.Context, data *s for _, runner := range executor.runners { output, err := executor.processSingleBulkInsert(ctx, runner, data) if err != nil { - return nil + return err } for k, v := range output { data.Data[k] = v diff --git a/internal/util/function/embedding/text_embedding_function.go b/internal/util/function/embedding/text_embedding_function.go index 1435336440..4146ef750c 100644 --- a/internal/util/function/embedding/text_embedding_function.go +++ b/internal/util/function/embedding/text_embedding_function.go @@ -161,8 +161,14 @@ func (runner *TextEmbeddingFunction) Check(ctx context.Context) error { switch embds := embds.(type) { case [][]float32: dim = len(embds[0]) + if runner.GetOutputFields()[0].DataType != schemapb.DataType_FloatVector { + return fmt.Errorf("Embedding model output and field type mismatch, model output is %s, field type is %s", schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)]) + } case [][]int8: dim = len(embds[0]) + if runner.GetOutputFields()[0].DataType != schemapb.DataType_Int8Vector { + return fmt.Errorf("Embedding model output and field type mismatch, model output is %s, field type is %s", schemapb.DataType_name[int32(schemapb.DataType_Int8Vector)], schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)]) + } default: return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String()) } diff --git a/internal/util/function/embedding/text_embedding_function_test.go b/internal/util/function/embedding/text_embedding_function_test.go index 8537abe6a4..68ce158c66 100644 --- a/internal/util/function/embedding/text_embedding_function_test.go +++ b/internal/util/function/embedding/text_embedding_function_test.go @@ -1023,3 +1023,44 @@ func (s *TextEmbeddingFunctionSuite) TestDisable() { }, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"}) s.ErrorContains(err, "Text embedding model provider [openai] is disabled") } + +func (s *TextEmbeddingFunctionSuite) TestCheck() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_Int8Vector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + ts := CreateOpenAIEmbeddingServer() + defer ts.Close() + paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string { + key := openAIProvider + "." + models.URLParamKey + return map[string]string{ + key: ts.URL, + } + } + runner, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: openAIProvider}, + {Key: models.ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: models.DimParamKey, Value: "4"}, + {Key: models.CredentialParamKey, Value: "mock"}, + }, + }, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"}) + s.NoError(err) + err = runner.Check(context.Background()) + s.ErrorContains(err, "Embedding model output and field type mismatch, model output is FloatVector, field type is Int8Vector") +}