enhance: add support for controlling function output field insertion (#44162)

#44053

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-09-24 17:26:04 +08:00 committed by GitHub
parent 2ed7d35783
commit f07979f91d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 217 additions and 54 deletions

View File

@ -505,6 +505,7 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
}, },
}, },
}, },
Properties: []*commonpb.KeyValuePair{{Key: common.CollectionAllowInsertNonBM25FunctionOutputs, Value: "true"}},
} }
var once sync.Once var once sync.Once

View File

@ -411,10 +411,11 @@ func FillDynamicData(schema *schemapb.CollectionSchema, data *storage.InsertData
} }
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error { func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
if err := RunBm25Function(task, data); err != nil { if err := RunDenseEmbedding(task, data); err != nil {
return err return err
} }
if err := RunDenseEmbedding(task, data); err != nil {
if err := RunBm25Function(task, data); err != nil {
return err return err
} }
return nil return nil
@ -422,8 +423,14 @@ func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error { func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
schema := task.GetSchema() schema := task.GetSchema()
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
fieldIDs := lo.Keys(data.Data)
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, false)
if err != nil {
return err
}
if embedding.HasNonBM25Functions(schema.Functions, []int64{}) { if embedding.HasNonBM25Functions(schema.Functions, []int64{}) {
exec, err := embedding.NewFunctionExecutor(schema) exec, err := embedding.NewFunctionExecutor(schema, needProcessFunctions)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
@ -165,20 +164,10 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
} }
it.schema = schema.CollectionSchema it.schema = schema.CollectionSchema
// Calculate embedding fields if err := genFunctionFields(ctx, it.insertMsg, schema, false); err != nil {
if embedding.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-call-function-udf")
defer sp.End()
exec, err := embedding.NewFunctionExecutor(schema.CollectionSchema)
if err != nil {
return err return err
} }
sp.AddEvent("Create-function-udf")
if err := exec.ProcessInsert(ctx, it.insertMsg); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
rowNums := uint32(it.insertMsg.NRows()) rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs // set insertTask.rowIDs
var rowIDBegin UniqueID var rowIDBegin UniqueID

View File

@ -472,7 +472,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) { if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
defer sp.End() defer sp.End()
exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema) exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema, nil)
if err != nil { if err != nil {
return err return err
} }
@ -591,7 +591,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) { if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf")
defer sp.End() defer sp.End()
exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema) exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
@ -354,6 +353,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
// 2. merge field data on update semantic // 2. merge field data on update semantic
it.deletePKs = &schemapb.IDs{} it.deletePKs = &schemapb.IDs{}
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize)) it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
if len(updateIdxInUpsert) > 0 { if len(updateIdxInUpsert) > 0 {
// Note: For fields containing default values, default values need to be set according to valid data during insertion, // Note: For fields containing default values, default values need to be set according to valid data during insertion,
// but query results fields do not set valid data when returning default value fields, // but query results fields do not set valid data when returning default value fields,
@ -738,10 +738,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err return err
} }
bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(it.schema.CollectionSchema)...) bm25Fields := typeutil.NewSet[string](GetBM25FunctionOutputFields(it.schema.CollectionSchema)...)
// Calculate embedding fields
if embedding.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
if it.req.PartialUpdate { if it.req.PartialUpdate {
// remove the old bm25 fields // remove the old bm25 fields
ret := make([]*schemapb.FieldData, 0) ret := make([]*schemapb.FieldData, 0)
@ -753,18 +750,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
} }
it.upsertMsg.InsertMsg.FieldsData = ret it.upsertMsg.InsertMsg.FieldsData = ret
} }
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf")
defer sp.End()
exec, err := embedding.NewFunctionExecutor(it.schema.CollectionSchema)
if err != nil {
return err
}
sp.AddEvent("Create-function-udf")
if err := exec.ProcessInsert(ctx, it.upsertMsg.InsertMsg); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs // set upsertTask.insertRequest.rowIDs
tr := timerecord.NewTimeRecorder("applyPK") tr := timerecord.NewTimeRecorder("applyPK")
@ -808,8 +794,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
} }
} }
err := checkAndFlattenStructFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg) if err := checkAndFlattenStructFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg); err != nil {
if err != nil {
return err return err
} }
@ -817,6 +802,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
// use the passed pk as new pk when autoID == false // use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true // automatic generate pk as new pk wehen autoID == true
var err error
it.result.IDs, it.oldIDs, err = checkUpsertPrimaryFieldData(allFields, it.schema.CollectionSchema, it.upsertMsg.InsertMsg) it.result.IDs, it.oldIDs, err = checkUpsertPrimaryFieldData(allFields, it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName)) log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
if err != nil { if err != nil {
@ -1044,6 +1030,10 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0") return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0")
} }
if err := genFunctionFields(ctx, it.upsertMsg.InsertMsg, it.schema, it.req.GetPartialUpdate()); err != nil {
return err
}
if it.req.GetPartialUpdate() { if it.req.GetPartialUpdate() {
err = it.queryPreExecute(ctx) err = it.queryPreExecute(ctx)
if err != nil { if err != nil {

View File

@ -501,6 +501,8 @@ func TestUpsertTask_Function(t *testing.T) {
schema: info, schema: info,
result: &milvuspb.MutationResult{}, result: &milvuspb.MutationResult{},
} }
err = genFunctionFields(task.ctx, task.upsertMsg.InsertMsg, task.schema, task.req.GetPartialUpdate())
assert.NoError(t, err)
err = task.insertPreExecute(ctx) err = task.insertPreExecute(ctx)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -27,6 +27,7 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo" "github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -2624,6 +2625,16 @@ func GetFunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
return fields return fields
} }
func GetBM25FunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
fields := make([]string, 0)
for _, fSchema := range collSchema.Functions {
if fSchema.Type == schemapb.FunctionType_BM25 {
fields = append(fields, fSchema.OutputFieldNames...)
}
}
return fields
}
func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 { func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
properties := make(map[string]string) properties := make(map[string]string)
for _, pair := range pairs { for _, pair := range pairs {
@ -2967,6 +2978,35 @@ func extractFieldsFromResults(results []*schemapb.FieldData, precedenceTimezone
} }
fieldData.Type = schemapb.DataType_Array fieldData.Type = schemapb.DataType_Array
} }
return nil
}
func genFunctionFields(ctx context.Context, insertMsg *msgstream.InsertMsg, schema *schemaInfo, partialUpdate bool) error {
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
fieldIDs := lo.Map(insertMsg.FieldsData, func(fieldData *schemapb.FieldData, _ int) int64 {
id, _ := schema.MapFieldID(fieldData.FieldName)
return id
})
// Since PartialUpdate is supported, the field_data here may not be complete
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, partialUpdate)
if err != nil {
log.Ctx(ctx).Warn("Check upsert field error,", zap.String("collectionName", schema.Name), zap.Error(err))
return err
}
if embedding.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-genFunctionFields-call-function-udf")
defer sp.End()
exec, err := embedding.NewFunctionExecutor(schema.CollectionSchema, needProcessFunctions)
if err != nil {
return err
}
sp.AddEvent("Create-function-udf")
if err := exec.ProcessInsert(ctx, insertMsg); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
return nil return nil
} }

View File

@ -91,11 +91,14 @@ func ValidateFunctions(schema *schemapb.CollectionSchema) error {
return nil return nil
} }
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { func NewFunctionExecutor(schema *schemapb.CollectionSchema, functions []*schemapb.FunctionSchema) (*FunctionExecutor, error) {
executor := &FunctionExecutor{ executor := &FunctionExecutor{
runners: make(map[int64]Runner), runners: make(map[int64]Runner),
} }
for _, fSchema := range schema.Functions { if functions == nil {
functions = schema.Functions
}
for _, fSchema := range functions {
runner, err := createFunction(schema, fSchema) runner, err := createFunction(schema, fSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -162,7 +162,7 @@ func (s *FunctionExecutorSuite) TestExecutor() {
ts := CreateOpenAIEmbeddingServer() ts := CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
schema := s.creataSchema(ts.URL) schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema) exec, err := NewFunctionExecutor(schema, nil)
s.NoError(err) s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"}) msg := s.createMsg([]string{"sentence", "sentence"})
exec.ProcessInsert(context.Background(), msg) exec.ProcessInsert(context.Background(), msg)
@ -197,7 +197,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
})) }))
defer ts.Close() defer ts.Close()
schema := s.creataSchema(ts.URL) schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema) exec, err := NewFunctionExecutor(schema, nil)
s.NoError(err) s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"}) msg := s.createMsg([]string{"sentence", "sentence"})
err = exec.ProcessInsert(context.Background(), msg) err = exec.ProcessInsert(context.Background(), msg)
@ -207,7 +207,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
func (s *FunctionExecutorSuite) TestErrorSchema() { func (s *FunctionExecutorSuite) TestErrorSchema() {
schema := s.creataSchema("http://localhost") schema := s.creataSchema("http://localhost")
schema.Functions[0].Type = schemapb.FunctionType_Unknown schema.Functions[0].Type = schemapb.FunctionType_Unknown
_, err := NewFunctionExecutor(schema) _, err := NewFunctionExecutor(schema, nil)
s.Error(err) s.Error(err)
} }
@ -215,7 +215,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
ts := CreateOpenAIEmbeddingServer() ts := CreateOpenAIEmbeddingServer()
defer ts.Close() defer ts.Close()
schema := s.creataSchema(ts.URL) schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema) exec, err := NewFunctionExecutor(schema, nil)
s.NoError(err) s.NoError(err)
{ {
@ -309,7 +309,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
defer ts.Close() defer ts.Close()
schema := s.creataSchema(ts.URL) schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema) exec, err := NewFunctionExecutor(schema, nil)
s.NoError(err) s.NoError(err)
f := &schemapb.FieldData{ f := &schemapb.FieldData{
Type: schemapb.DataType_VarChar, Type: schemapb.DataType_VarChar,

View File

@ -185,6 +185,12 @@ const (
CollectionAutoCompactionKey = "collection.autocompaction.enabled" CollectionAutoCompactionKey = "collection.autocompaction.enabled"
CollectionDescription = "collection.description" CollectionDescription = "collection.description"
// Note:
// Function output fields cannot be included in inserted data.
// In particular, the `bm25` function output field is always disallowed
// and is not controlled by this option.
CollectionAllowInsertNonBM25FunctionOutputs = "collection.function.allowInsertNonBM25FunctionOutputs"
// rate limit // rate limit
CollectionInsertRateMaxKey = "collection.insertRate.max.mb" CollectionInsertRateMaxKey = "collection.insertRate.max.mb"
CollectionInsertRateMinKey = "collection.insertRate.min.mb" CollectionInsertRateMinKey = "collection.insertRate.min.mb"
@ -545,6 +551,16 @@ func AllocAutoID(allocFunc func(uint32) (int64, int64, error), rowNum uint32, cl
return idStart | int64(reversed), idEnd | int64(reversed), nil return idStart | int64(reversed), idEnd | int64(reversed), nil
} }
func GetCollectionAllowInsertNonBM25FunctionOutputs(kvs []*commonpb.KeyValuePair) bool {
for _, kv := range kvs {
if kv.Key == CollectionAllowInsertNonBM25FunctionOutputs {
enable, _ := strconv.ParseBool(kv.Value)
return enable
}
}
return false
}
func IsAllowInsertAutoID(kvs ...*commonpb.KeyValuePair) (bool, bool) { func IsAllowInsertAutoID(kvs ...*commonpb.KeyValuePair) (bool, bool) {
for _, kv := range kvs { for _, kv := range kvs {
if kv.Key == AllowInsertAutoIDKey { if kv.Key == AllowInsertAutoIDKey {

View File

@ -297,3 +297,19 @@ func TestAllocAutoID(t *testing.T) {
assert.EqualValues(t, 0b0100, start>>60) assert.EqualValues(t, 0b0100, start>>60)
assert.EqualValues(t, 0b0100, end>>60) assert.EqualValues(t, 0b0100, end>>60)
} }
func TestFunctionProperty(t *testing.T) {
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs([]*commonpb.KeyValuePair{}))
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
[]*commonpb.KeyValuePair{{Key: "other", Value: "test"}}),
)
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
[]*commonpb.KeyValuePair{{Key: CollectionAllowInsertNonBM25FunctionOutputs, Value: "false"}}),
)
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
[]*commonpb.KeyValuePair{{Key: CollectionAllowInsertNonBM25FunctionOutputs, Value: "test"}}),
)
assert.True(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
[]*commonpb.KeyValuePair{{Key: CollectionAllowInsertNonBM25FunctionOutputs, Value: "true"}}),
)
}

View File

@ -23,6 +23,7 @@ import (
"fmt" "fmt"
"math" "math"
"reflect" "reflect"
"slices"
"sort" "sort"
"strconv" "strconv"
"unsafe" "unsafe"
@ -2305,3 +2306,51 @@ func EstimateSparseVectorNNZFromPlaceholderGroup(placeholderGroup []byte, nq int
overheadBytes := math.Max(10, float64(nq*3)) overheadBytes := math.Max(10, float64(nq*3))
return (len(placeholderGroup) - int(overheadBytes)) / 8 return (len(placeholderGroup) - int(overheadBytes)) / 8
} }
func GetNeedProcessFunctions(fieldIDs []int64, functions []*schemapb.FunctionSchema, allowNonBM25Outputs bool, partialUpdate bool) ([]*schemapb.FunctionSchema, error) {
if len(functions) == 0 {
return functions, nil
}
fieldIDFuncMapping := map[int64]*schemapb.FunctionSchema{}
funCandidate := map[string]*schemapb.FunctionSchema{}
for _, functionSchema := range functions {
funCandidate[functionSchema.Name] = functionSchema
for _, fieldID := range functionSchema.OutputFieldIds {
fieldIDFuncMapping[fieldID] = functionSchema
}
}
for _, fieldID := range fieldIDs {
if f, exists := fieldIDFuncMapping[fieldID]; exists {
if f.Type == schemapb.FunctionType_BM25 {
return nil, fmt.Errorf("Attempt to insert bm25 function output field")
}
if !allowNonBM25Outputs {
return nil, fmt.Errorf("Insert data has function output field, but collection's property `collection.function.allowInsertNonBM25FunctionOutputs` is not enable")
}
delete(funCandidate, f.Name)
}
}
needProcessFunctions := []*schemapb.FunctionSchema{}
for _, functionSchema := range funCandidate {
if partialUpdate {
// If some input exists, push it down to the function for processing
allInputNotExist := true
for _, inputID := range functionSchema.InputFieldIds {
if slices.Contains(fieldIDs, inputID) {
allInputNotExist = false
break
}
}
if !allInputNotExist {
needProcessFunctions = append(needProcessFunctions, functionSchema)
}
} else {
needProcessFunctions = append(needProcessFunctions, functionSchema)
}
}
return needProcessFunctions, nil
}

View File

@ -4468,3 +4468,53 @@ func TestUpdateFieldData_IndexFix(t *testing.T) {
assert.Equal(t, updateSparseData.Contents[1], updatedContents[2]) assert.Equal(t, updateSparseData.Contents[1], updatedContents[2])
}) })
} }
func TestGetNeedProcessFunctions(t *testing.T) {
{
f, err := GetNeedProcessFunctions([]int64{}, []*schemapb.FunctionSchema{}, false, false)
assert.Len(t, f, 0)
assert.NoError(t, err)
}
{
fs := []*schemapb.FunctionSchema{{Name: "test_func", OutputFieldIds: []int64{1}}}
_, err := GetNeedProcessFunctions([]int64{1, 2}, fs, false, false)
assert.ErrorContains(t, err, "Insert data has function output field")
f, err := GetNeedProcessFunctions([]int64{1, 2}, fs, true, false)
assert.NoError(t, err)
assert.Len(t, f, 0)
}
{
fs := []*schemapb.FunctionSchema{{Name: "test_func", OutputFieldIds: []int64{1}}}
_, err := GetNeedProcessFunctions([]int64{1}, fs, false, false)
assert.ErrorContains(t, err, "Insert data has function output field")
f, err := GetNeedProcessFunctions([]int64{1}, fs, true, false)
assert.NoError(t, err)
assert.Len(t, f, 0)
}
{
fs := []*schemapb.FunctionSchema{{Name: "test_func", OutputFieldIds: []int64{1}}, {Name: "test_func2", OutputFieldIds: []int64{2}}}
_, err := GetNeedProcessFunctions([]int64{1}, fs, false, false)
assert.Error(t, err)
f, err := GetNeedProcessFunctions([]int64{1}, fs, true, false)
assert.NoError(t, err)
assert.Len(t, f, 1)
assert.Equal(t, f[0].Name, "test_func2")
}
{
fs := []*schemapb.FunctionSchema{{Name: "test_func", Type: schemapb.FunctionType_BM25, OutputFieldIds: []int64{1}}}
_, err := GetNeedProcessFunctions([]int64{1}, fs, true, false)
assert.ErrorContains(t, err, "Attempt to insert bm25 function output field")
}
{
fs := []*schemapb.FunctionSchema{
{Name: "test_func", InputFieldIds: []int64{1, 2}, OutputFieldIds: []int64{3}},
{Name: "test_func2", InputFieldIds: []int64{1}, OutputFieldIds: []int64{2}},
}
_, err := GetNeedProcessFunctions([]int64{1, 2}, fs, false, true)
assert.Error(t, err)
f, err := GetNeedProcessFunctions([]int64{1, 2}, fs, true, true)
assert.NoError(t, err)
assert.Len(t, f, 1)
assert.Equal(t, f[0].Name, "test_func")
}
}