diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 7aa1457fd1..ffd79b67af 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1549,11 +1549,6 @@ func (t *loadCollectionTask) PreExecute(ctx context.Context) error { return err } - // To compat with LoadCollcetion before Milvus@2.1 - if t.ReplicaNumber == 0 { - t.ReplicaNumber = 1 - } - return nil } diff --git a/internal/querycoordv2/meta/coordinator_broker.go b/internal/querycoordv2/meta/coordinator_broker.go index cbcb9fced7..2df54688af 100644 --- a/internal/querycoordv2/meta/coordinator_broker.go +++ b/internal/querycoordv2/meta/coordinator_broker.go @@ -30,7 +30,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -47,6 +49,8 @@ type Broker interface { GetSegmentInfo(ctx context.Context, segmentID ...UniqueID) (*datapb.GetSegmentInfoResponse, error) GetIndexInfo(ctx context.Context, collectionID UniqueID, segmentID UniqueID) ([]*querypb.FieldIndexInfo, error) GetRecoveryInfoV2(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentInfo, error) + DescribeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) + GetCollectionLoadInfo(ctx context.Context, collectionID UniqueID) ([]string, int64, error) } type CoordinatorBroker struct { @@ -83,6 +87,48 @@ func (broker *CoordinatorBroker) DescribeCollection(ctx context.Context, collect return resp, nil } +func (broker *CoordinatorBroker) DescribeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { + ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) + defer cancel() + + req := &rootcoordpb.DescribeDatabaseRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), + ), + DbName: dbName, + } + resp, err := broker.rootCoord.DescribeDatabase(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Ctx(ctx).Warn("failed to describe database", zap.Error(err)) + return nil, err + } + return resp, nil +} + +// try to get database level replica_num and resource groups, return (resource_groups, replica_num, error) +func (broker *CoordinatorBroker) GetCollectionLoadInfo(ctx context.Context, collectionID UniqueID) ([]string, int64, error) { + // to do by weiliu1031: querycoord should cache mappings: collectionID->dbName + collectionInfo, err := broker.DescribeCollection(ctx, collectionID) + if err != nil { + return nil, 0, err + } + + dbInfo, err := broker.DescribeDatabase(ctx, collectionInfo.GetDbName()) + if err != nil { + return nil, 0, err + } + replicaNum, err := common.DatabaseLevelReplicaNumber(dbInfo.GetProperties()) + if err != nil { + return nil, 0, err + } + rgs, err := common.DatabaseLevelResourceGroups(dbInfo.GetProperties()) + if err != nil { + return nil, 0, err + } + + return rgs, replicaNum, nil +} + func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() diff --git a/internal/querycoordv2/meta/coordinator_broker_test.go b/internal/querycoordv2/meta/coordinator_broker_test.go index 476a997dd2..778268f7ce 100644 --- a/internal/querycoordv2/meta/coordinator_broker_test.go +++ b/internal/querycoordv2/meta/coordinator_broker_test.go @@ -18,6 +18,7 @@ package meta import ( "context" + "strings" "testing" "github.com/cockroachdb/errors" @@ -32,6 +33,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -490,6 +493,90 @@ func (s *CoordinatorBrokerDataCoordSuite) TestGetIndexInfo() { }) } +func (s *CoordinatorBrokerRootCoordSuite) TestDescribeDatabase() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_case", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + }, nil) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.NoError(err) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Status(errors.New("fake error")), + }, nil) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_unimplemented", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnimplemented) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.Error(err) + s.resetMock() + }) +} + +func (s *CoordinatorBrokerRootCoordSuite) TestGetCollectionLoadInfo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_case", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + DbName: "fake_db1", + }, nil) + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + Properties: []*commonpb.KeyValuePair{ + { + Key: common.DatabaseReplicaNumber, + Value: "3", + }, + { + Key: common.DatabaseResourceGroups, + Value: strings.Join([]string{"rg1", "rg2"}, ","), + }, + }, + }, nil) + rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, 1) + s.NoError(err) + s.Equal(int64(3), replicas) + s.Contains(rgs, "rg1") + s.Contains(rgs, "rg2") + s.resetMock() + }) + + s.Run("props not set", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + DbName: "fake_db1", + }, nil) + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + Properties: []*commonpb.KeyValuePair{}, + }, nil) + _, _, err := s.broker.GetCollectionLoadInfo(ctx, 1) + s.Error(err) + s.resetMock() + }) +} + func TestCoordinatorBroker(t *testing.T) { suite.Run(t, new(CoordinatorBrokerRootCoordSuite)) suite.Run(t, new(CoordinatorBrokerDataCoordSuite)) diff --git a/internal/querycoordv2/meta/mock_broker.go b/internal/querycoordv2/meta/mock_broker.go index ff35489855..a940aff58b 100644 --- a/internal/querycoordv2/meta/mock_broker.go +++ b/internal/querycoordv2/meta/mock_broker.go @@ -13,6 +13,8 @@ import ( mock "github.com/stretchr/testify/mock" querypb "github.com/milvus-io/milvus/internal/proto/querypb" + + rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) // MockBroker is an autogenerated mock type for the Broker type @@ -83,6 +85,123 @@ func (_c *MockBroker_DescribeCollection_Call) RunAndReturn(run func(context.Cont return _c } +// DescribeDatabase provides a mock function with given fields: ctx, dbName +func (_m *MockBroker) DescribeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { + ret := _m.Called(ctx, dbName) + + var r0 *rootcoordpb.DescribeDatabaseResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*rootcoordpb.DescribeDatabaseResponse, error)); ok { + return rf(ctx, dbName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *rootcoordpb.DescribeDatabaseResponse); ok { + r0 = rf(ctx, dbName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, dbName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase' +type MockBroker_DescribeDatabase_Call struct { + *mock.Call +} + +// DescribeDatabase is a helper method to define mock.On call +// - ctx context.Context +// - dbName string +func (_e *MockBroker_Expecter) DescribeDatabase(ctx interface{}, dbName interface{}) *MockBroker_DescribeDatabase_Call { + return &MockBroker_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase", ctx, dbName)} +} + +func (_c *MockBroker_DescribeDatabase_Call) Run(run func(ctx context.Context, dbName string)) *MockBroker_DescribeDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockBroker_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *MockBroker_DescribeDatabase_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_DescribeDatabase_Call) RunAndReturn(run func(context.Context, string) (*rootcoordpb.DescribeDatabaseResponse, error)) *MockBroker_DescribeDatabase_Call { + _c.Call.Return(run) + return _c +} + +// GetCollectionLoadInfo provides a mock function with given fields: ctx, collectionID +func (_m *MockBroker) GetCollectionLoadInfo(ctx context.Context, collectionID int64) ([]string, int64, error) { + ret := _m.Called(ctx, collectionID) + + var r0 []string + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, int64) ([]string, int64, error)); ok { + return rf(ctx, collectionID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) []string); ok { + r0 = rf(ctx, collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) int64); ok { + r1 = rf(ctx, collectionID) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok { + r2 = rf(ctx, collectionID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockBroker_GetCollectionLoadInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionLoadInfo' +type MockBroker_GetCollectionLoadInfo_Call struct { + *mock.Call +} + +// GetCollectionLoadInfo is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +func (_e *MockBroker_Expecter) GetCollectionLoadInfo(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionLoadInfo_Call { + return &MockBroker_GetCollectionLoadInfo_Call{Call: _e.mock.On("GetCollectionLoadInfo", ctx, collectionID)} +} + +func (_c *MockBroker_GetCollectionLoadInfo_Call) Run(run func(ctx context.Context, collectionID int64)) *MockBroker_GetCollectionLoadInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockBroker_GetCollectionLoadInfo_Call) Return(_a0 []string, _a1 int64, _a2 error) *MockBroker_GetCollectionLoadInfo_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockBroker_GetCollectionLoadInfo_Call) RunAndReturn(run func(context.Context, int64) ([]string, int64, error)) *MockBroker_GetCollectionLoadInfo_Call { + _c.Call.Return(run) + return _c +} + // GetIndexInfo provides a mock function with given fields: ctx, collectionID, segmentID func (_m *MockBroker) GetIndexInfo(ctx context.Context, collectionID int64, segmentID int64) ([]*querypb.FieldIndexInfo, error) { ret := _m.Called(ctx, collectionID, segmentID) diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index f71172fd89..78c2fdb89b 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -436,17 +436,19 @@ func (suite *ServerSuite) loadAll() { for _, collection := range suite.collections { if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { req := &querypb.LoadCollectionRequest{ - CollectionID: collection, - ReplicaNumber: suite.replicaNumber[collection], + CollectionID: collection, + ReplicaNumber: suite.replicaNumber[collection], + ResourceGroups: []string{meta.DefaultResourceGroupName}, } resp, err := suite.server.LoadCollection(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) } else { req := &querypb.LoadPartitionsRequest{ - CollectionID: collection, - PartitionIDs: suite.partitions[collection], - ReplicaNumber: suite.replicaNumber[collection], + CollectionID: collection, + PartitionIDs: suite.partitions[collection], + ReplicaNumber: suite.replicaNumber[collection], + ResourceGroups: []string{meta.DefaultResourceGroupName}, } resp, err := suite.server.LoadPartitions(ctx, req) suite.NoError(err) diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index dea9817a27..b64f09921f 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -215,6 +215,24 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection return merr.Status(err), nil } + if req.GetReplicaNumber() <= 0 || len(req.GetResourceGroups()) == 0 { + // when replica number or resource groups is not set, use database level config + rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, req.GetCollectionID()) + if err != nil { + log.Warn("failed to get data base level load info", zap.Error(err)) + } + + if req.GetReplicaNumber() <= 0 { + log.Info("load collection use database level replica number", zap.Int64("databaseLevelReplicaNum", replicas)) + req.ReplicaNumber = int32(replicas) + } + + if len(req.GetResourceGroups()) == 0 { + log.Info("load collection use database level resource groups", zap.Strings("databaseLevelResourceGroups", rgs)) + req.ResourceGroups = rgs + } + } + if err := s.checkResourceGroup(req.GetCollectionID(), req.GetResourceGroups()); err != nil { msg := "failed to load collection" log.Warn(msg, zap.Error(err)) @@ -316,6 +334,24 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions return merr.Status(err), nil } + if req.GetReplicaNumber() <= 0 || len(req.GetResourceGroups()) == 0 { + // when replica number or resource groups is not set, use database level config + rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, req.GetCollectionID()) + if err != nil { + log.Warn("failed to get data base level load info", zap.Error(err)) + } + + if req.GetReplicaNumber() <= 0 { + log.Info("load collection use database level replica number", zap.Int64("databaseLevelReplicaNum", replicas)) + req.ReplicaNumber = int32(replicas) + } + + if len(req.GetResourceGroups()) == 0 { + log.Info("load collection use database level resource groups", zap.Strings("databaseLevelResourceGroups", rgs)) + req.ResourceGroups = rgs + } + } + if err := s.checkResourceGroup(req.GetCollectionID(), req.GetResourceGroups()); err != nil { msg := "failed to load partitions" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 744004fd8f..e4fb877d01 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -207,6 +207,8 @@ func (suite *ServiceSuite) SetupTest() { } suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + + suite.broker.EXPECT().GetCollectionLoadInfo(mock.Anything, mock.Anything).Return([]string{meta.DefaultResourceGroupName}, 1, nil).Maybe() } func (suite *ServiceSuite) TestShowCollections() { diff --git a/pkg/common/common.go b/pkg/common/common.go index 2b9ebc4d82..ea148b03b7 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -18,6 +18,8 @@ package common import ( "encoding/binary" + "fmt" + "strconv" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -134,6 +136,10 @@ const ( CollectionDiskQuotaKey = "collection.diskProtection.diskQuota.mb" PartitionDiskQuotaKey = "partition.diskProtection.diskQuota.mb" + + // database level properties + DatabaseReplicaNumber = "database.replica.number" + DatabaseResourceGroups = "database.resource_groups" ) // common properties @@ -205,3 +211,38 @@ const ( // LatestVerision is the magic number for watch latest revision LatestRevision = int64(-1) ) + +func DatabaseLevelReplicaNumber(kvs []*commonpb.KeyValuePair) (int64, error) { + for _, kv := range kvs { + if kv.Key == DatabaseReplicaNumber { + replicaNum, err := strconv.ParseInt(kv.Value, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid database property: [key=%s] [value=%s]", kv.Key, kv.Value) + } + + return replicaNum, nil + } + } + + return 0, fmt.Errorf("database property not found: %s", DatabaseReplicaNumber) +} + +func DatabaseLevelResourceGroups(kvs []*commonpb.KeyValuePair) ([]string, error) { + for _, kv := range kvs { + if kv.Key == DatabaseResourceGroups { + invalidPropValue := fmt.Errorf("invalid database property: [key=%s] [value=%s]", kv.Key, kv.Value) + if len(kv.Value) == 0 { + return nil, invalidPropValue + } + + rgs := strings.Split(kv.Value, ",") + if len(rgs) == 0 { + return nil, invalidPropValue + } + + return rgs, nil + } + } + + return nil, fmt.Errorf("database property not found: %s", DatabaseResourceGroups) +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 7228b1b6ab..2dc31e33fb 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -1,9 +1,12 @@ package common import ( + "strings" "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func TestIsSystemField(t *testing.T) { @@ -38,3 +41,50 @@ func TestIsSystemField(t *testing.T) { }) } } + +func TestDatabaseProperties(t *testing.T) { + props := []*commonpb.KeyValuePair{ + { + Key: DatabaseReplicaNumber, + Value: "3", + }, + { + Key: DatabaseResourceGroups, + Value: strings.Join([]string{"rg1", "rg2"}, ","), + }, + } + + replicaNum, err := DatabaseLevelReplicaNumber(props) + assert.NoError(t, err) + assert.Equal(t, int64(3), replicaNum) + + rgs, err := DatabaseLevelResourceGroups(props) + assert.NoError(t, err) + assert.Contains(t, rgs, "rg1") + assert.Contains(t, rgs, "rg2") + + // test prop not found + _, err = DatabaseLevelReplicaNumber(nil) + assert.Error(t, err) + + _, err = DatabaseLevelResourceGroups(nil) + assert.Error(t, err) + + // test invalid prop value + + props = []*commonpb.KeyValuePair{ + { + Key: DatabaseReplicaNumber, + Value: "xxxx", + }, + { + Key: DatabaseResourceGroups, + Value: "", + }, + } + _, err = DatabaseLevelReplicaNumber(props) + assert.Error(t, err) + + _, err = DatabaseLevelResourceGroups(props) + assert.Error(t, err) +} diff --git a/tests/integration/replicas/load/load_test.go b/tests/integration/replicas/load/load_test.go new file mode 100644 index 0000000000..837a634c53 --- /dev/null +++ b/tests/integration/replicas/load/load_test.go @@ -0,0 +1,187 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +const ( + dim = 128 + dbName = "" + collectionName = "test_load_collection" +) + +type LoadTestSuite struct { + integration.MiniClusterSuite +} + +func (s *LoadTestSuite) SetupSuite() { + paramtable.Init() + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "1000") + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1") + + s.Require().NoError(s.SetupEmbedEtcd()) +} + +func (s *LoadTestSuite) loadCollection(collectionName string, replica int, rgs []string) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // load + loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + ReplicaNumber: int32(replica), + ResourceGroups: rgs, + }) + s.NoError(err) + s.True(merr.Ok(loadStatus)) + s.WaitForLoad(ctx, collectionName) +} + +func (s *LoadTestSuite) releaseCollection(collectionName string) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // load + status, err := s.Cluster.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.True(merr.Ok(status)) +} + +func (s *LoadTestSuite) TestLoadWithDatabaseLevelConfig() { + ctx := context.Background() + s.CreateCollectionWithConfiguration(ctx, &integration.CreateCollectionConfig{ + DBName: dbName, + Dim: dim, + CollectionName: collectionName, + ChannelNum: 1, + SegmentNum: 3, + RowNumPerSegment: 2000, + }) + + // prepare resource groups + rgNum := 3 + rgs := make([]string, 0) + for i := 0; i < rgNum; i++ { + rgs = append(rgs, fmt.Sprintf("rg_%d", i)) + s.Cluster.QueryCoord.CreateResourceGroup(ctx, &milvuspb.CreateResourceGroupRequest{ + ResourceGroup: rgs[i], + Config: &rgpb.ResourceGroupConfig{ + Requests: &rgpb.ResourceGroupLimit{ + NodeNum: 1, + }, + Limits: &rgpb.ResourceGroupLimit{ + NodeNum: 1, + }, + + TransferFrom: []*rgpb.ResourceGroupTransfer{ + { + ResourceGroup: meta.DefaultResourceGroupName, + }, + }, + TransferTo: []*rgpb.ResourceGroupTransfer{ + { + ResourceGroup: meta.DefaultResourceGroupName, + }, + }, + }, + }) + } + + resp, err := s.Cluster.QueryCoord.ListResourceGroups(ctx, &milvuspb.ListResourceGroupsRequest{}) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + s.Len(resp.GetResourceGroups(), rgNum+1) + + for i := 1; i < rgNum; i++ { + s.Cluster.AddQueryNode() + } + + s.Eventually(func() bool { + matchCounter := 0 + for _, rg := range rgs { + resp1, err := s.Cluster.QueryCoord.DescribeResourceGroup(ctx, &querypb.DescribeResourceGroupRequest{ + ResourceGroup: rg, + }) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + if len(resp1.ResourceGroup.Nodes) == 1 { + matchCounter += 1 + } + } + return matchCounter == rgNum + }, 30*time.Second, time.Second) + + status, err := s.Cluster.Proxy.AlterDatabase(ctx, &milvuspb.AlterDatabaseRequest{ + DbName: "default", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.DatabaseReplicaNumber, + Value: "3", + }, + { + Key: common.DatabaseResourceGroups, + Value: strings.Join(rgs, ","), + }, + }, + }) + s.NoError(err) + s.True(merr.Ok(status)) + + resp1, err := s.Cluster.Proxy.DescribeDatabase(ctx, &milvuspb.DescribeDatabaseRequest{ + DbName: "default", + }) + s.NoError(err) + s.True(merr.Ok(resp1.Status)) + s.Len(resp1.GetProperties(), 2) + + // load collection without specified replica and rgs + s.loadCollection(collectionName, 0, nil) + resp2, err := s.Cluster.Proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.True(merr.Ok(resp2.Status)) + s.Len(resp2.GetReplicas(), 3) + s.releaseCollection(collectionName) +} + +func TestReplicas(t *testing.T) { + suite.Run(t, new(LoadTestSuite)) +}