milvus/internal/rootcoord/create_collection_task.go
XuanYang-cn 471b3e4e09
fix: Move Init and Remove EZ logic out of metatable (#45827)
This will avoid endless retry CreateDatabase/DropDatabase when
cipherPlugin fails in the new DDL framework.

See also: #45826

---------

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
2025-11-25 20:01:06 +08:00

651 lines
21 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"strconv"
"github.com/cockroachdb/errors"
"github.com/twpayne/go-geom/encoding/wkb"
"github.com/twpayne/go-geom/encoding/wkt"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timestamptz"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type createCollectionTask struct {
*Core
Req *milvuspb.CreateCollectionRequest
header *message.CreateCollectionMessageHeader
body *message.CreateCollectionRequest
}
func (t *createCollectionTask) validate(ctx context.Context) error {
if t.Req == nil {
return errors.New("empty requests")
}
Params := paramtable.Get()
// 1. check shard number
shardsNum := t.Req.GetShardsNum()
var cfgMaxShardNum int32
if Params.CommonCfg.PreCreatedTopicEnabled.GetAsBool() {
cfgMaxShardNum = int32(len(Params.CommonCfg.TopicNames.GetAsStrings()))
} else {
cfgMaxShardNum = Params.RootCoordCfg.DmlChannelNum.GetAsInt32()
}
if shardsNum > cfgMaxShardNum {
return fmt.Errorf("shard num (%d) exceeds max configuration (%d)", shardsNum, cfgMaxShardNum)
}
cfgShardLimit := Params.ProxyCfg.MaxShardNum.GetAsInt32()
if shardsNum > cfgShardLimit {
return fmt.Errorf("shard num (%d) exceeds system limit (%d)", shardsNum, cfgShardLimit)
}
// 2. check db-collection capacity
db2CollIDs := t.meta.ListAllAvailCollections(ctx)
if err := t.checkMaxCollectionsPerDB(ctx, db2CollIDs); err != nil {
return err
}
// 3. check total collection number
totalCollections := 0
for _, collIDs := range db2CollIDs {
totalCollections += len(collIDs)
}
maxCollectionNum := Params.QuotaConfig.MaxCollectionNum.GetAsInt()
if totalCollections >= maxCollectionNum {
log.Ctx(ctx).Warn("unable to create collection because the number of collection has reached the limit", zap.Int("max_collection_num", maxCollectionNum))
return merr.WrapErrCollectionNumLimitExceeded(t.Req.GetDbName(), maxCollectionNum)
}
// 4. check collection * shard * partition
var newPartNum int64 = 1
if t.Req.GetNumPartitions() > 0 {
newPartNum = t.Req.GetNumPartitions()
}
return checkGeneralCapacity(ctx, 1, newPartNum, t.Req.GetShardsNum(), t.Core)
}
// checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections.
func (t *createCollectionTask) checkMaxCollectionsPerDB(ctx context.Context, db2CollIDs map[int64][]int64) error {
Params := paramtable.Get()
collIDs, ok := db2CollIDs[t.header.DbId]
if !ok {
log.Ctx(ctx).Warn("can not found DB ID", zap.String("collection", t.Req.GetCollectionName()), zap.String("dbName", t.Req.GetDbName()))
return merr.WrapErrDatabaseNotFound(t.Req.GetDbName(), "failed to create collection")
}
db, err := t.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
log.Ctx(ctx).Warn("can not found DB ID", zap.String("collection", t.Req.GetCollectionName()), zap.String("dbName", t.Req.GetDbName()))
return merr.WrapErrDatabaseNotFound(t.Req.GetDbName(), "failed to create collection")
}
check := func(maxColNumPerDB int) error {
if len(collIDs) >= maxColNumPerDB {
log.Ctx(ctx).Warn("unable to create collection because the number of collection has reached the limit in DB", zap.Int("maxCollectionNumPerDB", maxColNumPerDB))
return merr.WrapErrCollectionNumLimitExceeded(t.Req.GetDbName(), maxColNumPerDB)
}
return nil
}
maxColNumPerDBStr := db.GetProperty(common.DatabaseMaxCollectionsKey)
if maxColNumPerDBStr != "" {
maxColNumPerDB, err := strconv.Atoi(maxColNumPerDBStr)
if err != nil {
log.Ctx(ctx).Warn("parse value of property fail", zap.String("key", common.DatabaseMaxCollectionsKey),
zap.String("value", maxColNumPerDBStr), zap.Error(err))
return fmt.Errorf("parse value of property fail, key:%s, value:%s", common.DatabaseMaxCollectionsKey, maxColNumPerDBStr)
}
return check(maxColNumPerDB)
}
maxColNumPerDB := Params.QuotaConfig.MaxCollectionNumPerDB.GetAsInt()
return check(maxColNumPerDB)
}
func checkGeometryDefaultValue(value string) error {
geomT, err := wkt.Unmarshal(value)
if err != nil {
log.Warn("invalid default value for geometry field", zap.Error(err))
return merr.WrapErrParameterInvalidMsg("invalid default value for geometry field")
}
_, err = wkb.Marshal(geomT, wkb.NDR)
if err != nil {
log.Warn("invalid default value for geometry field", zap.Error(err))
return merr.WrapErrParameterInvalidMsg("invalid default value for geometry field")
}
return nil
}
func hasSystemFields(schema *schemapb.CollectionSchema, systemFields []string) bool {
for _, f := range schema.GetFields() {
if funcutil.SliceContain(systemFields, f.GetName()) {
return true
}
}
return false
}
func (t *createCollectionTask) validateSchema(ctx context.Context, schema *schemapb.CollectionSchema) error {
log.Ctx(ctx).With(zap.String("CollectionName", t.Req.CollectionName))
if t.Req.GetCollectionName() != schema.GetName() {
log.Ctx(ctx).Error("collection name not matches schema name", zap.String("SchemaName", schema.Name))
msg := fmt.Sprintf("collection name = %s, schema.Name=%s", t.Req.GetCollectionName(), schema.Name)
return merr.WrapErrParameterInvalid("collection name matches schema name", "don't match", msg)
}
if err := checkFieldSchema(schema.GetFields()); err != nil {
return err
}
// Validate default
if err := timestamptz.CheckAndRewriteTimestampTzDefaultValue(schema); err != nil {
return err
}
if err := checkStructArrayFieldSchema(schema.GetStructArrayFields()); err != nil {
return err
}
if hasSystemFields(schema, []string{RowIDFieldName, TimeStampFieldName, MetaFieldName, NamespaceFieldName}) {
log.Ctx(ctx).Error("schema contains system field",
zap.String("RowIDFieldName", RowIDFieldName),
zap.String("TimeStampFieldName", TimeStampFieldName),
zap.String("MetaFieldName", MetaFieldName),
zap.String("NamespaceFieldName", NamespaceFieldName))
msg := fmt.Sprintf("schema contains system field: %s, %s, %s, %s", RowIDFieldName, TimeStampFieldName, MetaFieldName, NamespaceFieldName)
return merr.WrapErrParameterInvalid("schema don't contains system field", "contains", msg)
}
if err := validateStructArrayFieldDataType(schema.GetStructArrayFields()); err != nil {
return err
}
// check analyzer was vaild
analyzerInfos := make([]*querypb.AnalyzerInfo, 0)
for _, field := range schema.GetFields() {
err := validateAnalyzer(schema, field, &analyzerInfos)
if err != nil {
return err
}
}
// validate analyzer params at any streaming node
if len(analyzerInfos) > 0 {
resp, err := t.mixCoord.ValidateAnalyzer(t.ctx, &querypb.ValidateAnalyzerRequest{
AnalyzerInfos: analyzerInfos,
})
if err != nil {
return err
}
if err := merr.Error(resp); err != nil {
return err
}
}
return validateFieldDataType(schema.GetFields())
}
func (t *createCollectionTask) assignFieldAndFunctionID(schema *schemapb.CollectionSchema) error {
name2id := map[string]int64{}
idx := 0
for _, field := range schema.GetFields() {
field.FieldID = int64(idx + StartOfUserFieldID)
idx++
name2id[field.GetName()] = field.GetFieldID()
}
for _, structArrayField := range schema.GetStructArrayFields() {
structArrayField.FieldID = int64(idx + StartOfUserFieldID)
idx++
for _, field := range structArrayField.GetFields() {
field.FieldID = int64(idx + StartOfUserFieldID)
idx++
// Also register sub-field names in name2id map
name2id[field.GetName()] = field.GetFieldID()
}
}
for fidx, function := range schema.GetFunctions() {
function.InputFieldIds = make([]int64, len(function.InputFieldNames))
function.Id = int64(fidx) + StartOfUserFunctionID
for idx, name := range function.InputFieldNames {
fieldId, ok := name2id[name]
if !ok {
return fmt.Errorf("input field %s of function %s not found", name, function.GetName())
}
function.InputFieldIds[idx] = fieldId
}
function.OutputFieldIds = make([]int64, len(function.OutputFieldNames))
for idx, name := range function.OutputFieldNames {
fieldId, ok := name2id[name]
if !ok {
return fmt.Errorf("output field %s of function %s not found", name, function.GetName())
}
function.OutputFieldIds[idx] = fieldId
}
}
return nil
}
func (t *createCollectionTask) appendDynamicField(ctx context.Context, schema *schemapb.CollectionSchema) {
if schema.EnableDynamicField {
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
Name: MetaFieldName,
Description: "dynamic schema",
DataType: schemapb.DataType_JSON,
IsDynamic: true,
})
log.Ctx(ctx).Info("append dynamic field", zap.String("collection", schema.Name))
}
}
func (t *createCollectionTask) appendConsistecyLevel() {
if ok, _ := getConsistencyLevel(t.Req.Properties...); ok {
return
}
for _, p := range t.Req.Properties {
if p.GetKey() == common.ConsistencyLevel {
// if there's already a consistency level, overwrite it.
p.Value = strconv.Itoa(int(t.Req.ConsistencyLevel))
return
}
}
// append consistency level into schema properties
t.Req.Properties = append(t.Req.Properties, &commonpb.KeyValuePair{
Key: common.ConsistencyLevel,
Value: strconv.Itoa(int(t.Req.ConsistencyLevel)),
})
}
func (t *createCollectionTask) handleNamespaceField(ctx context.Context, schema *schemapb.CollectionSchema) error {
if !Params.CommonCfg.EnableNamespace.GetAsBool() {
return nil
}
hasIsolation := hasIsolationProperty(t.Req.Properties...)
_, err := typeutil.GetPartitionKeyFieldSchema(schema)
hasPartitionKey := err == nil
enabled, has, err := common.ParseNamespaceProp(t.Req.Properties...)
if err != nil {
return err
}
if !has || !enabled {
return nil
}
if hasIsolation {
iso, err := common.IsPartitionKeyIsolationKvEnabled(t.Req.Properties...)
if err != nil {
return err
}
if !iso {
return merr.WrapErrCollectionIllegalSchema(t.Req.CollectionName,
"isolation property is false when namespace enabled")
}
}
if hasPartitionKey {
return merr.WrapErrParameterInvalidMsg("namespace is not supported with partition key mode")
}
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
Name: common.NamespaceFieldName,
IsPartitionKey: true,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: fmt.Sprintf("%d", paramtable.Get().ProxyCfg.MaxVarCharLength.GetAsInt())},
},
})
t.Req.Properties = append(t.Req.Properties, &commonpb.KeyValuePair{
Key: common.PartitionKeyIsolationKey,
Value: "true",
})
log.Ctx(ctx).Info("added namespace field",
zap.String("collectionName", t.Req.CollectionName),
zap.String("fieldName", common.NamespaceFieldName))
return nil
}
func hasIsolationProperty(props ...*commonpb.KeyValuePair) bool {
for _, p := range props {
if p.GetKey() == common.PartitionKeyIsolationKey {
return true
}
}
return false
}
func (t *createCollectionTask) appendSysFields(schema *schemapb.CollectionSchema) {
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: int64(RowIDField),
Name: RowIDFieldName,
IsPrimaryKey: false,
Description: "row id",
DataType: schemapb.DataType_Int64,
})
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: int64(TimeStampField),
Name: TimeStampFieldName,
IsPrimaryKey: false,
Description: "time stamp",
DataType: schemapb.DataType_Int64,
})
}
func (t *createCollectionTask) prepareSchema(ctx context.Context) error {
if err := t.validateSchema(ctx, t.body.CollectionSchema); err != nil {
return err
}
t.appendConsistecyLevel()
t.appendDynamicField(ctx, t.body.CollectionSchema)
if err := t.handleNamespaceField(ctx, t.body.CollectionSchema); err != nil {
return err
}
if err := t.assignFieldAndFunctionID(t.body.CollectionSchema); err != nil {
return err
}
// Validate timezone
tz, exist := funcutil.TryGetAttrByKeyFromRepeatedKV(common.TimezoneKey, t.Req.GetProperties())
if exist && !timestamptz.IsTimezoneValid(tz) {
return merr.WrapErrParameterInvalidMsg("unknown or invalid IANA Time Zone ID: %s", tz)
}
// Set properties for persistent
t.body.CollectionSchema.Properties = t.Req.GetProperties()
t.body.CollectionSchema.Version = 0
t.appendSysFields(t.body.CollectionSchema)
return nil
}
func (t *createCollectionTask) assignCollectionID() error {
var err error
t.header.CollectionId, err = t.idAllocator.AllocOne()
t.body.CollectionID = t.header.CollectionId
return err
}
func (t *createCollectionTask) assignPartitionIDs(ctx context.Context) error {
Params := paramtable.Get()
partitionNames := make([]string, 0, t.Req.GetNumPartitions())
defaultPartitionName := Params.CommonCfg.DefaultPartitionName.GetValue()
if _, err := typeutil.GetPartitionKeyFieldSchema(t.body.CollectionSchema); err == nil {
// only when enabling partition key mode, we allow to create multiple partitions.
partitionNums := t.Req.GetNumPartitions()
// double check, default num of physical partitions should be greater than 0
if partitionNums <= 0 {
return errors.New("the specified partitions should be greater than 0 if partition key is used")
}
cfgMaxPartitionNum := Params.RootCoordCfg.MaxPartitionNum.GetAsInt64()
if partitionNums > cfgMaxPartitionNum {
return fmt.Errorf("partition number (%d) exceeds max configuration (%d), collection: %s",
partitionNums, cfgMaxPartitionNum, t.Req.CollectionName)
}
for i := int64(0); i < partitionNums; i++ {
partitionNames = append(partitionNames, fmt.Sprintf("%s_%d", defaultPartitionName, i))
}
} else {
// compatible with old versions <= 2.2.8
partitionNames = append(partitionNames, defaultPartitionName)
}
// allocate partition ids
start, end, err := t.idAllocator.Alloc(uint32(len(partitionNames)))
if err != nil {
return err
}
t.header.PartitionIds = make([]int64, len(partitionNames))
t.body.PartitionIDs = make([]int64, len(partitionNames))
for i := start; i < end; i++ {
t.header.PartitionIds[i-start] = i
t.body.PartitionIDs[i-start] = i
}
t.body.PartitionNames = partitionNames
log.Ctx(ctx).Info("assign partitions when create collection",
zap.String("collectionName", t.Req.GetCollectionName()),
zap.Int64s("partitionIds", t.header.PartitionIds),
zap.Strings("partitionNames", t.body.PartitionNames))
return nil
}
func (t *createCollectionTask) assignChannels(ctx context.Context) error {
vchannels, err := snmanager.StaticStreamingNodeManager.AllocVirtualChannels(ctx, balancer.AllocVChannelParam{
CollectionID: t.header.GetCollectionId(),
Num: int(t.Req.GetShardsNum()),
})
if err != nil {
return err
}
for _, vchannel := range vchannels {
t.body.PhysicalChannelNames = append(t.body.PhysicalChannelNames, funcutil.ToPhysicalChannel(vchannel))
t.body.VirtualChannelNames = append(t.body.VirtualChannelNames, vchannel)
}
return nil
}
func (t *createCollectionTask) Prepare(ctx context.Context) error {
t.body.Base = &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateCollection,
}
db, err := t.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
// set collection timezone
properties := t.Req.GetProperties()
_, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(common.TimezoneKey, properties)
if !ok {
dbTz, ok2 := funcutil.TryGetAttrByKeyFromRepeatedKV(common.TimezoneKey, db.Properties)
if !ok2 {
dbTz = common.DefaultTimezone
}
timezoneKV := &commonpb.KeyValuePair{Key: common.TimezoneKey, Value: dbTz}
t.Req.Properties = append(properties, timezoneKV)
}
if ezProps := hookutil.GetEzPropByDBProperties(db.Properties); ezProps != nil {
t.Req.Properties = append(t.Req.Properties, ezProps)
}
t.header.DbId = db.ID
t.body.DbID = t.header.DbId
if err := t.validate(ctx); err != nil {
return err
}
if err := t.prepareSchema(ctx); err != nil {
return err
}
if err := t.assignCollectionID(); err != nil {
return err
}
if err := t.assignPartitionIDs(ctx); err != nil {
return err
}
if err := t.assignChannels(ctx); err != nil {
return err
}
return t.validateIfCollectionExists(ctx)
}
func (t *createCollectionTask) validateIfCollectionExists(ctx context.Context) error {
// Check if the collection name duplicates an alias.
if _, err := t.meta.DescribeAlias(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp); err == nil {
err2 := fmt.Errorf("collection name [%s] conflicts with an existing alias, please choose a unique name", t.Req.GetCollectionName())
log.Ctx(ctx).Warn("create collection failed", zap.String("database", t.Req.GetDbName()), zap.Error(err2))
return err2
}
// Check if the collection already exists.
existedCollInfo, err := t.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if err == nil {
newCollInfo := newCollectionModel(t.header, t.body, 0)
if equal := existedCollInfo.Equal(*newCollInfo); !equal {
return fmt.Errorf("create duplicate collection with different parameters, collection: %s", t.Req.GetCollectionName())
}
return errIgnoredCreateCollection
}
return nil
}
func validateMultiAnalyzerParams(params string, coll *schemapb.CollectionSchema, fieldSchema *schemapb.FieldSchema, infos *[]*querypb.AnalyzerInfo) error {
var m map[string]json.RawMessage
var analyzerMap map[string]json.RawMessage
var mFileName string
err := json.Unmarshal([]byte(params), &m)
if err != nil {
return err
}
mfield, ok := m["by_field"]
if !ok {
return fmt.Errorf("multi analyzer params now must set by_field to specify with field decide analyzer")
}
err = json.Unmarshal(mfield, &mFileName)
if err != nil {
return fmt.Errorf("multi analyzer params by_field must be string but now: %s", mfield)
}
// check field exist
fieldExist := false
for _, field := range coll.GetFields() {
if field.GetName() == mFileName {
// only support string field now
if field.GetDataType() != schemapb.DataType_VarChar {
return fmt.Errorf("multi analyzer params now only support by string field, but field %s is not string", field.GetName())
}
fieldExist = true
break
}
}
if !fieldExist {
return fmt.Errorf("multi analyzer dependent field %s not exist in collection %s", string(mfield), coll.GetName())
}
if value, ok := m["alias"]; ok {
mapping := map[string]string{}
err = json.Unmarshal(value, &mapping)
if err != nil {
return fmt.Errorf("multi analyzer alias must be string map but now: %s", value)
}
}
analyzers, ok := m["analyzers"]
if !ok {
return fmt.Errorf("multi analyzer params must set analyzers ")
}
err = json.Unmarshal(analyzers, &analyzerMap)
if err != nil {
return fmt.Errorf("unmarshal analyzers failed: %s", err)
}
hasDefault := false
for name, bytes := range analyzerMap {
*infos = append(*infos, &querypb.AnalyzerInfo{
Name: name,
Field: fieldSchema.GetName(),
Params: string(bytes),
})
if name == "default" {
hasDefault = true
}
}
if !hasDefault {
return fmt.Errorf("multi analyzer must set default analyzer for all unknown value")
}
return nil
}
func validateAnalyzer(collSchema *schemapb.CollectionSchema, fieldSchema *schemapb.FieldSchema, analyzerInfos *[]*querypb.AnalyzerInfo) error {
h := typeutil.CreateFieldSchemaHelper(fieldSchema)
if !h.EnableMatch() && !typeutil.IsBm25FunctionInputField(collSchema, fieldSchema) {
return nil
}
if !h.EnableAnalyzer() {
return fmt.Errorf("field %s which has enable_match or is input of BM25 function must also enable_analyzer", fieldSchema.Name)
}
if params, ok := h.GetMultiAnalyzerParams(); ok {
if h.EnableMatch() {
return fmt.Errorf("multi analyzer now only support for bm25, but now field %s enable match", fieldSchema.Name)
}
if h.HasAnalyzerParams() {
return fmt.Errorf("field %s analyzer params should be none if has multi analyzer params", fieldSchema.Name)
}
return validateMultiAnalyzerParams(params, collSchema, fieldSchema, analyzerInfos)
}
for _, kv := range fieldSchema.GetTypeParams() {
if kv.GetKey() == "analyzer_params" {
*analyzerInfos = append(*analyzerInfos, &querypb.AnalyzerInfo{
Field: fieldSchema.GetName(),
Params: kv.GetValue(),
})
}
}
// return nil when use default analyzer
return nil
}