enhance: Add embedding model and schema field type checks (#46421)

https://github.com/milvus-io/milvus/issues/46415

- Add output type validation when creating functions
- Fix improper error handling in bulk insert tasks

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-12-19 11:05:19 +08:00 committed by GitHub
parent 7e4f87e351
commit 617a77b0bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 1 deletions

View File

@ -244,6 +244,7 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error {
if !importutilv2.IsBackup(t.req.GetOptions()) {
err = RunEmbeddingFunction(t, data)
if err != nil {
log.Warn("run embedding function failed", WrapLogFields(t, zap.Error(err))...)
return err
}
}

View File

@ -424,6 +424,7 @@ func FillDynamicData(schema *schemapb.CollectionSchema, data *storage.InsertData
}
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
log.Info("start to run embedding function")
if err := RunDenseEmbedding(task, data); err != nil {
return err
}
@ -435,14 +436,18 @@ func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
}
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
log.Info("start to run dense embedding")
schema := task.GetSchema()
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
log.Info("allowNonBM25Outputs", zap.Any("allowNonBM25Outputs", allowNonBM25Outputs))
fieldIDs := lo.Keys(data.Data)
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, false)
if err != nil {
return err
}
log.Info("needProcessFunctions", zap.Any("needProcessFunctions", needProcessFunctions))
if embedding.HasNonBM25Functions(schema.Functions, []int64{}) {
log.Info("has non bm25 functions")
extraInfo := &models.ModelExtraInfo{
ClusterID: task.req.ClusterID,
DBName: task.req.Schema.DbName,
@ -454,11 +459,13 @@ func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
if err := exec.ProcessBulkInsert(context.Background(), data); err != nil {
return err
}
log.Info("end to run dense embedding")
}
return nil
}
func RunBm25Function(task *ImportTask, data *storage.InsertData) error {
log.Info("start to run bm25 function")
fns := task.GetSchema().GetFunctions()
for _, fn := range fns {
runner, err := function.NewFunctionRunner(task.GetSchema(), fn)

View File

@ -277,7 +277,7 @@ func (executor *FunctionExecutor) ProcessBulkInsert(ctx context.Context, data *s
for _, runner := range executor.runners {
output, err := executor.processSingleBulkInsert(ctx, runner, data)
if err != nil {
return nil
return err
}
for k, v := range output {
data.Data[k] = v

View File

@ -161,8 +161,14 @@ func (runner *TextEmbeddingFunction) Check(ctx context.Context) error {
switch embds := embds.(type) {
case [][]float32:
dim = len(embds[0])
if runner.GetOutputFields()[0].DataType != schemapb.DataType_FloatVector {
return fmt.Errorf("Embedding model output and field type mismatch, model output is %s, field type is %s", schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)])
}
case [][]int8:
dim = len(embds[0])
if runner.GetOutputFields()[0].DataType != schemapb.DataType_Int8Vector {
return fmt.Errorf("Embedding model output and field type mismatch, model output is %s, field type is %s", schemapb.DataType_name[int32(schemapb.DataType_Int8Vector)], schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)])
}
default:
return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String())
}

View File

@ -1023,3 +1023,44 @@ func (s *TextEmbeddingFunctionSuite) TestDisable() {
}, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"})
s.ErrorContains(err, "Text embedding model provider [openai] is disabled")
}
func (s *TextEmbeddingFunctionSuite) TestCheck() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_Int8Vector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
key := openAIProvider + "." + models.URLParamKey
return map[string]string{
key: ts.URL,
}
}
runner, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: models.ModelNameParamKey, Value: "text-embedding-ada-002"},
{Key: models.DimParamKey, Value: "4"},
{Key: models.CredentialParamKey, Value: "mock"},
},
}, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"})
s.NoError(err)
err = runner.Check(context.Background())
s.ErrorContains(err, "Embedding model output and field type mismatch, model output is FloatVector, field type is Int8Vector")
}