mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: add support for controlling function output field insertion (#44162)
#44053 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
2ed7d35783
commit
f07979f91d
@ -505,6 +505,7 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Properties: []*commonpb.KeyValuePair{{Key: common.CollectionAllowInsertNonBM25FunctionOutputs, Value: "true"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|||||||
@ -411,10 +411,11 @@ func FillDynamicData(schema *schemapb.CollectionSchema, data *storage.InsertData
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
|
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
|
||||||
if err := RunBm25Function(task, data); err != nil {
|
if err := RunDenseEmbedding(task, data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := RunDenseEmbedding(task, data); err != nil {
|
|
||||||
|
if err := RunBm25Function(task, data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -422,8 +423,14 @@ func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
|
|||||||
|
|
||||||
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
|
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
|
||||||
schema := task.GetSchema()
|
schema := task.GetSchema()
|
||||||
|
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
|
||||||
|
fieldIDs := lo.Keys(data.Data)
|
||||||
|
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if embedding.HasNonBM25Functions(schema.Functions, []int64{}) {
|
if embedding.HasNonBM25Functions(schema.Functions, []int64{}) {
|
||||||
exec, err := embedding.NewFunctionExecutor(schema)
|
exec, err := embedding.NewFunctionExecutor(schema, needProcessFunctions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/allocator"
|
"github.com/milvus-io/milvus/internal/allocator"
|
||||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||||
@ -165,20 +164,10 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
it.schema = schema.CollectionSchema
|
it.schema = schema.CollectionSchema
|
||||||
|
|
||||||
// Calculate embedding fields
|
if err := genFunctionFields(ctx, it.insertMsg, schema, false); err != nil {
|
||||||
if embedding.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
|
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-call-function-udf")
|
|
||||||
defer sp.End()
|
|
||||||
exec, err := embedding.NewFunctionExecutor(schema.CollectionSchema)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sp.AddEvent("Create-function-udf")
|
|
||||||
if err := exec.ProcessInsert(ctx, it.insertMsg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
sp.AddEvent("Call-function-udf")
|
|
||||||
}
|
|
||||||
rowNums := uint32(it.insertMsg.NRows())
|
rowNums := uint32(it.insertMsg.NRows())
|
||||||
// set insertTask.rowIDs
|
// set insertTask.rowIDs
|
||||||
var rowIDBegin UniqueID
|
var rowIDBegin UniqueID
|
||||||
|
|||||||
@ -472,7 +472,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) {
|
if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) {
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema)
|
exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -591,7 +591,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||||||
if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
|
if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf")
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema)
|
exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,7 +32,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/allocator"
|
"github.com/milvus-io/milvus/internal/allocator"
|
||||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||||
@ -354,6 +353,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
|||||||
// 2. merge field data on update semantic
|
// 2. merge field data on update semantic
|
||||||
it.deletePKs = &schemapb.IDs{}
|
it.deletePKs = &schemapb.IDs{}
|
||||||
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
|
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
|
||||||
|
|
||||||
if len(updateIdxInUpsert) > 0 {
|
if len(updateIdxInUpsert) > 0 {
|
||||||
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
||||||
// but query results fields do not set valid data when returning default value fields,
|
// but query results fields do not set valid data when returning default value fields,
|
||||||
@ -738,10 +738,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(it.schema.CollectionSchema)...)
|
bm25Fields := typeutil.NewSet[string](GetBM25FunctionOutputFields(it.schema.CollectionSchema)...)
|
||||||
// Calculate embedding fields
|
|
||||||
|
|
||||||
if embedding.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
|
|
||||||
if it.req.PartialUpdate {
|
if it.req.PartialUpdate {
|
||||||
// remove the old bm25 fields
|
// remove the old bm25 fields
|
||||||
ret := make([]*schemapb.FieldData, 0)
|
ret := make([]*schemapb.FieldData, 0)
|
||||||
@ -753,18 +750,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
it.upsertMsg.InsertMsg.FieldsData = ret
|
it.upsertMsg.InsertMsg.FieldsData = ret
|
||||||
}
|
}
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf")
|
|
||||||
defer sp.End()
|
|
||||||
exec, err := embedding.NewFunctionExecutor(it.schema.CollectionSchema)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
sp.AddEvent("Create-function-udf")
|
|
||||||
if err := exec.ProcessInsert(ctx, it.upsertMsg.InsertMsg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
sp.AddEvent("Call-function-udf")
|
|
||||||
}
|
|
||||||
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
|
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
|
||||||
// set upsertTask.insertRequest.rowIDs
|
// set upsertTask.insertRequest.rowIDs
|
||||||
tr := timerecord.NewTimeRecorder("applyPK")
|
tr := timerecord.NewTimeRecorder("applyPK")
|
||||||
@ -808,8 +794,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := checkAndFlattenStructFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
|
if err := checkAndFlattenStructFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -817,6 +802,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|||||||
|
|
||||||
// use the passed pk as new pk when autoID == false
|
// use the passed pk as new pk when autoID == false
|
||||||
// automatic generate pk as new pk wehen autoID == true
|
// automatic generate pk as new pk wehen autoID == true
|
||||||
|
var err error
|
||||||
it.result.IDs, it.oldIDs, err = checkUpsertPrimaryFieldData(allFields, it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
|
it.result.IDs, it.oldIDs, err = checkUpsertPrimaryFieldData(allFields, it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
|
||||||
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
|
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1044,6 +1030,10 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
|||||||
return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0")
|
return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := genFunctionFields(ctx, it.upsertMsg.InsertMsg, it.schema, it.req.GetPartialUpdate()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if it.req.GetPartialUpdate() {
|
if it.req.GetPartialUpdate() {
|
||||||
err = it.queryPreExecute(ctx)
|
err = it.queryPreExecute(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -501,6 +501,8 @@ func TestUpsertTask_Function(t *testing.T) {
|
|||||||
schema: info,
|
schema: info,
|
||||||
result: &milvuspb.MutationResult{},
|
result: &milvuspb.MutationResult{},
|
||||||
}
|
}
|
||||||
|
err = genFunctionFields(task.ctx, task.upsertMsg.InsertMsg, task.schema, task.req.GetPartialUpdate())
|
||||||
|
assert.NoError(t, err)
|
||||||
err = task.insertPreExecute(ctx)
|
err = task.insertPreExecute(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import (
|
|||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
@ -2624,6 +2625,16 @@ func GetFunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
|
|||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetBM25FunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
|
||||||
|
fields := make([]string, 0)
|
||||||
|
for _, fSchema := range collSchema.Functions {
|
||||||
|
if fSchema.Type == schemapb.FunctionType_BM25 {
|
||||||
|
fields = append(fields, fSchema.OutputFieldNames...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
|
func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
|
||||||
properties := make(map[string]string)
|
properties := make(map[string]string)
|
||||||
for _, pair := range pairs {
|
for _, pair := range pairs {
|
||||||
@ -2967,6 +2978,35 @@ func extractFieldsFromResults(results []*schemapb.FieldData, precedenceTimezone
|
|||||||
}
|
}
|
||||||
fieldData.Type = schemapb.DataType_Array
|
fieldData.Type = schemapb.DataType_Array
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func genFunctionFields(ctx context.Context, insertMsg *msgstream.InsertMsg, schema *schemaInfo, partialUpdate bool) error {
|
||||||
|
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
|
||||||
|
fieldIDs := lo.Map(insertMsg.FieldsData, func(fieldData *schemapb.FieldData, _ int) int64 {
|
||||||
|
id, _ := schema.MapFieldID(fieldData.FieldName)
|
||||||
|
return id
|
||||||
|
})
|
||||||
|
|
||||||
|
// Since PartialUpdate is supported, the field_data here may not be complete
|
||||||
|
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, partialUpdate)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Warn("Check upsert field error,", zap.String("collectionName", schema.Name), zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedding.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
|
||||||
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-genFunctionFields-call-function-udf")
|
||||||
|
defer sp.End()
|
||||||
|
exec, err := embedding.NewFunctionExecutor(schema.CollectionSchema, needProcessFunctions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sp.AddEvent("Create-function-udf")
|
||||||
|
if err := exec.ProcessInsert(ctx, insertMsg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sp.AddEvent("Call-function-udf")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -91,11 +91,14 @@ func ValidateFunctions(schema *schemapb.CollectionSchema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
|
func NewFunctionExecutor(schema *schemapb.CollectionSchema, functions []*schemapb.FunctionSchema) (*FunctionExecutor, error) {
|
||||||
executor := &FunctionExecutor{
|
executor := &FunctionExecutor{
|
||||||
runners: make(map[int64]Runner),
|
runners: make(map[int64]Runner),
|
||||||
}
|
}
|
||||||
for _, fSchema := range schema.Functions {
|
if functions == nil {
|
||||||
|
functions = schema.Functions
|
||||||
|
}
|
||||||
|
for _, fSchema := range functions {
|
||||||
runner, err := createFunction(schema, fSchema)
|
runner, err := createFunction(schema, fSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -162,7 +162,7 @@ func (s *FunctionExecutorSuite) TestExecutor() {
|
|||||||
ts := CreateOpenAIEmbeddingServer()
|
ts := CreateOpenAIEmbeddingServer()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
schema := s.creataSchema(ts.URL)
|
schema := s.creataSchema(ts.URL)
|
||||||
exec, err := NewFunctionExecutor(schema)
|
exec, err := NewFunctionExecutor(schema, nil)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
msg := s.createMsg([]string{"sentence", "sentence"})
|
msg := s.createMsg([]string{"sentence", "sentence"})
|
||||||
exec.ProcessInsert(context.Background(), msg)
|
exec.ProcessInsert(context.Background(), msg)
|
||||||
@ -197,7 +197,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
schema := s.creataSchema(ts.URL)
|
schema := s.creataSchema(ts.URL)
|
||||||
exec, err := NewFunctionExecutor(schema)
|
exec, err := NewFunctionExecutor(schema, nil)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
msg := s.createMsg([]string{"sentence", "sentence"})
|
msg := s.createMsg([]string{"sentence", "sentence"})
|
||||||
err = exec.ProcessInsert(context.Background(), msg)
|
err = exec.ProcessInsert(context.Background(), msg)
|
||||||
@ -207,7 +207,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
|
|||||||
func (s *FunctionExecutorSuite) TestErrorSchema() {
|
func (s *FunctionExecutorSuite) TestErrorSchema() {
|
||||||
schema := s.creataSchema("http://localhost")
|
schema := s.creataSchema("http://localhost")
|
||||||
schema.Functions[0].Type = schemapb.FunctionType_Unknown
|
schema.Functions[0].Type = schemapb.FunctionType_Unknown
|
||||||
_, err := NewFunctionExecutor(schema)
|
_, err := NewFunctionExecutor(schema, nil)
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
|
|||||||
ts := CreateOpenAIEmbeddingServer()
|
ts := CreateOpenAIEmbeddingServer()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
schema := s.creataSchema(ts.URL)
|
schema := s.creataSchema(ts.URL)
|
||||||
exec, err := NewFunctionExecutor(schema)
|
exec, err := NewFunctionExecutor(schema, nil)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -309,7 +309,7 @@ func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
|
|||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
schema := s.creataSchema(ts.URL)
|
schema := s.creataSchema(ts.URL)
|
||||||
exec, err := NewFunctionExecutor(schema)
|
exec, err := NewFunctionExecutor(schema, nil)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
f := &schemapb.FieldData{
|
f := &schemapb.FieldData{
|
||||||
Type: schemapb.DataType_VarChar,
|
Type: schemapb.DataType_VarChar,
|
||||||
|
|||||||
@ -185,6 +185,12 @@ const (
|
|||||||
CollectionAutoCompactionKey = "collection.autocompaction.enabled"
|
CollectionAutoCompactionKey = "collection.autocompaction.enabled"
|
||||||
CollectionDescription = "collection.description"
|
CollectionDescription = "collection.description"
|
||||||
|
|
||||||
|
// Note:
|
||||||
|
// Function output fields cannot be included in inserted data.
|
||||||
|
// In particular, the `bm25` function output field is always disallowed
|
||||||
|
// and is not controlled by this option.
|
||||||
|
CollectionAllowInsertNonBM25FunctionOutputs = "collection.function.allowInsertNonBM25FunctionOutputs"
|
||||||
|
|
||||||
// rate limit
|
// rate limit
|
||||||
CollectionInsertRateMaxKey = "collection.insertRate.max.mb"
|
CollectionInsertRateMaxKey = "collection.insertRate.max.mb"
|
||||||
CollectionInsertRateMinKey = "collection.insertRate.min.mb"
|
CollectionInsertRateMinKey = "collection.insertRate.min.mb"
|
||||||
@ -545,6 +551,16 @@ func AllocAutoID(allocFunc func(uint32) (int64, int64, error), rowNum uint32, cl
|
|||||||
return idStart | int64(reversed), idEnd | int64(reversed), nil
|
return idStart | int64(reversed), idEnd | int64(reversed), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetCollectionAllowInsertNonBM25FunctionOutputs(kvs []*commonpb.KeyValuePair) bool {
|
||||||
|
for _, kv := range kvs {
|
||||||
|
if kv.Key == CollectionAllowInsertNonBM25FunctionOutputs {
|
||||||
|
enable, _ := strconv.ParseBool(kv.Value)
|
||||||
|
return enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func IsAllowInsertAutoID(kvs ...*commonpb.KeyValuePair) (bool, bool) {
|
func IsAllowInsertAutoID(kvs ...*commonpb.KeyValuePair) (bool, bool) {
|
||||||
for _, kv := range kvs {
|
for _, kv := range kvs {
|
||||||
if kv.Key == AllowInsertAutoIDKey {
|
if kv.Key == AllowInsertAutoIDKey {
|
||||||
|
|||||||
@ -297,3 +297,19 @@ func TestAllocAutoID(t *testing.T) {
|
|||||||
assert.EqualValues(t, 0b0100, start>>60)
|
assert.EqualValues(t, 0b0100, start>>60)
|
||||||
assert.EqualValues(t, 0b0100, end>>60)
|
assert.EqualValues(t, 0b0100, end>>60)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFunctionProperty(t *testing.T) {
|
||||||
|
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs([]*commonpb.KeyValuePair{}))
|
||||||
|
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
|
||||||
|
[]*commonpb.KeyValuePair{{Key: "other", Value: "test"}}),
|
||||||
|
)
|
||||||
|
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
|
||||||
|
[]*commonpb.KeyValuePair{{Key: CollectionAllowInsertNonBM25FunctionOutputs, Value: "false"}}),
|
||||||
|
)
|
||||||
|
assert.False(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
|
||||||
|
[]*commonpb.KeyValuePair{{Key: CollectionAllowInsertNonBM25FunctionOutputs, Value: "test"}}),
|
||||||
|
)
|
||||||
|
assert.True(t, GetCollectionAllowInsertNonBM25FunctionOutputs(
|
||||||
|
[]*commonpb.KeyValuePair{{Key: CollectionAllowInsertNonBM25FunctionOutputs, Value: "true"}}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@ -2305,3 +2306,51 @@ func EstimateSparseVectorNNZFromPlaceholderGroup(placeholderGroup []byte, nq int
|
|||||||
overheadBytes := math.Max(10, float64(nq*3))
|
overheadBytes := math.Max(10, float64(nq*3))
|
||||||
return (len(placeholderGroup) - int(overheadBytes)) / 8
|
return (len(placeholderGroup) - int(overheadBytes)) / 8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetNeedProcessFunctions(fieldIDs []int64, functions []*schemapb.FunctionSchema, allowNonBM25Outputs bool, partialUpdate bool) ([]*schemapb.FunctionSchema, error) {
|
||||||
|
if len(functions) == 0 {
|
||||||
|
return functions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldIDFuncMapping := map[int64]*schemapb.FunctionSchema{}
|
||||||
|
funCandidate := map[string]*schemapb.FunctionSchema{}
|
||||||
|
|
||||||
|
for _, functionSchema := range functions {
|
||||||
|
funCandidate[functionSchema.Name] = functionSchema
|
||||||
|
for _, fieldID := range functionSchema.OutputFieldIds {
|
||||||
|
fieldIDFuncMapping[fieldID] = functionSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fieldID := range fieldIDs {
|
||||||
|
if f, exists := fieldIDFuncMapping[fieldID]; exists {
|
||||||
|
if f.Type == schemapb.FunctionType_BM25 {
|
||||||
|
return nil, fmt.Errorf("Attempt to insert bm25 function output field")
|
||||||
|
}
|
||||||
|
if !allowNonBM25Outputs {
|
||||||
|
return nil, fmt.Errorf("Insert data has function output field, but collection's property `collection.function.allowInsertNonBM25FunctionOutputs` is not enable")
|
||||||
|
}
|
||||||
|
delete(funCandidate, f.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
needProcessFunctions := []*schemapb.FunctionSchema{}
|
||||||
|
for _, functionSchema := range funCandidate {
|
||||||
|
if partialUpdate {
|
||||||
|
// If some input exists, push it down to the function for processing
|
||||||
|
allInputNotExist := true
|
||||||
|
for _, inputID := range functionSchema.InputFieldIds {
|
||||||
|
if slices.Contains(fieldIDs, inputID) {
|
||||||
|
allInputNotExist = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !allInputNotExist {
|
||||||
|
needProcessFunctions = append(needProcessFunctions, functionSchema)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
needProcessFunctions = append(needProcessFunctions, functionSchema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return needProcessFunctions, nil
|
||||||
|
}
|
||||||
|
|||||||
@ -4468,3 +4468,53 @@ func TestUpdateFieldData_IndexFix(t *testing.T) {
|
|||||||
assert.Equal(t, updateSparseData.Contents[1], updatedContents[2])
|
assert.Equal(t, updateSparseData.Contents[1], updatedContents[2])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetNeedProcessFunctions(t *testing.T) {
|
||||||
|
{
|
||||||
|
f, err := GetNeedProcessFunctions([]int64{}, []*schemapb.FunctionSchema{}, false, false)
|
||||||
|
assert.Len(t, f, 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
fs := []*schemapb.FunctionSchema{{Name: "test_func", OutputFieldIds: []int64{1}}}
|
||||||
|
_, err := GetNeedProcessFunctions([]int64{1, 2}, fs, false, false)
|
||||||
|
assert.ErrorContains(t, err, "Insert data has function output field")
|
||||||
|
f, err := GetNeedProcessFunctions([]int64{1, 2}, fs, true, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, f, 0)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
fs := []*schemapb.FunctionSchema{{Name: "test_func", OutputFieldIds: []int64{1}}}
|
||||||
|
_, err := GetNeedProcessFunctions([]int64{1}, fs, false, false)
|
||||||
|
assert.ErrorContains(t, err, "Insert data has function output field")
|
||||||
|
f, err := GetNeedProcessFunctions([]int64{1}, fs, true, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, f, 0)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
fs := []*schemapb.FunctionSchema{{Name: "test_func", OutputFieldIds: []int64{1}}, {Name: "test_func2", OutputFieldIds: []int64{2}}}
|
||||||
|
_, err := GetNeedProcessFunctions([]int64{1}, fs, false, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
f, err := GetNeedProcessFunctions([]int64{1}, fs, true, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, f, 1)
|
||||||
|
assert.Equal(t, f[0].Name, "test_func2")
|
||||||
|
}
|
||||||
|
{
|
||||||
|
fs := []*schemapb.FunctionSchema{{Name: "test_func", Type: schemapb.FunctionType_BM25, OutputFieldIds: []int64{1}}}
|
||||||
|
_, err := GetNeedProcessFunctions([]int64{1}, fs, true, false)
|
||||||
|
assert.ErrorContains(t, err, "Attempt to insert bm25 function output field")
|
||||||
|
}
|
||||||
|
{
|
||||||
|
fs := []*schemapb.FunctionSchema{
|
||||||
|
{Name: "test_func", InputFieldIds: []int64{1, 2}, OutputFieldIds: []int64{3}},
|
||||||
|
{Name: "test_func2", InputFieldIds: []int64{1}, OutputFieldIds: []int64{2}},
|
||||||
|
}
|
||||||
|
_, err := GetNeedProcessFunctions([]int64{1, 2}, fs, false, true)
|
||||||
|
assert.Error(t, err)
|
||||||
|
f, err := GetNeedProcessFunctions([]int64{1, 2}, fs, true, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, f, 1)
|
||||||
|
assert.Equal(t, f[0].Name, "test_func")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user