diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 3ea2120c31..e2ed510722 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -159,7 +159,8 @@ proxy: maxNameLength: 255 # Maximum length of name for a collection or alias maxFieldNum: 256 # Maximum number of fields in a collection maxDimension: 32768 # Maximum dimension of a vector - maxShardNum: 256 # Maximum number of shards in a collection + # It's strongly DISCOURAGED to set `maxShardNum` > 64. + maxShardNum: 64 # Maximum number of shards in a collection maxTaskNum: 1024 # max task number of proxy task queue # please adjust in embedded Milvus: false ginLogging: true # Whether to produce gin logs. diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index e771866d8e..e9986385c4 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -51,9 +51,18 @@ func (t *createCollectionTask) validate() error { return err } - if t.Req.GetShardsNum() >= maxShardNum { - return fmt.Errorf("shard num (%d) exceeds limit (%d)", t.Req.GetShardsNum(), maxShardNum) + shardsNum := int64(t.Req.GetShardsNum()) + + cfgMaxShardNum := Params.RootCoordCfg.DmlChannelNum + if shardsNum >= cfgMaxShardNum { + return fmt.Errorf("shard num (%d) exceeds max configuration (%d)", shardsNum, cfgMaxShardNum) } + + cfgShardLimit := int64(Params.ProxyCfg.MaxShardNum) + if shardsNum > cfgShardLimit { + return fmt.Errorf("shard num (%d) exceeds system limit (%d)", shardsNum, cfgShardLimit) + } + return nil } diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index fa0ffd8e31..b75540c904 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -40,11 +40,39 @@ func Test_createCollectionTask_validate(t *testing.T) { assert.Error(t, err) }) - t.Run("shard num exceeds limit", func(t *testing.T) { + t.Run("shard num exceeds configuration", func(t *testing.T) { + cfgMaxShardNum := Params.RootCoordCfg.DmlChannelNum + restoreCfg := func() { Params.RootCoordCfg.DmlChannelNum = cfgMaxShardNum } + defer restoreCfg() + + Params.RootCoordCfg.DmlChannelNum = 1 + task := createCollectionTask{ Req: &milvuspb.CreateCollectionRequest{ Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, - ShardsNum: maxShardNum + 1, + ShardsNum: 2, + }, + } + err := task.validate() + assert.Error(t, err) + }) + + t.Run("shard num exceeds limit", func(t *testing.T) { + cfgMaxShardNum := Params.RootCoordCfg.DmlChannelNum + cfgShardLimit := Params.ProxyCfg.MaxShardNum + restoreCfg := func() { + Params.RootCoordCfg.DmlChannelNum = cfgMaxShardNum + Params.ProxyCfg.MaxShardNum = cfgShardLimit + } + defer restoreCfg() + + Params.RootCoordCfg.DmlChannelNum = 100 + Params.ProxyCfg.MaxShardNum = 4 + + task := createCollectionTask{ + Req: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, + ShardsNum: 8, }, } err := task.validate()