diff --git a/Makefile b/Makefile index 6f3a0c1fb2..8eb4d1fde8 100644 --- a/Makefile +++ b/Makefile @@ -457,6 +457,7 @@ generate-mockery-datacoord: getdeps $(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Scheduler --dir=internal/datacoord --filename=mock_scheduler.go --output=internal/datacoord --structname=MockScheduler --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage + $(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage generate-mockery-datanode: getdeps $(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage diff --git a/internal/datacoord/broker/mock_coordinator_broker.go b/internal/datacoord/broker/mock_coordinator_broker.go index 48854205f2..c0e817b6ad 100644 --- a/internal/datacoord/broker/mock_coordinator_broker.go +++ b/internal/datacoord/broker/mock_coordinator_broker.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.30.1. DO NOT EDIT. +// Code generated by mockery v2.32.4. DO NOT EDIT. package broker @@ -77,6 +77,59 @@ func (_c *MockBroker_DescribeCollectionInternal_Call) RunAndReturn(run func(cont return _c } +// GetDatabaseID provides a mock function with given fields: ctx, dbName +func (_m *MockBroker) GetDatabaseID(ctx context.Context, dbName string) (int64, error) { + ret := _m.Called(ctx, dbName) + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { + return rf(ctx, dbName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, dbName) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, dbName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_GetDatabaseID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseID' +type MockBroker_GetDatabaseID_Call struct { + *mock.Call +} + +// GetDatabaseID is a helper method to define mock.On call +// - ctx context.Context +// - dbName string +func (_e *MockBroker_Expecter) GetDatabaseID(ctx interface{}, dbName interface{}) *MockBroker_GetDatabaseID_Call { + return &MockBroker_GetDatabaseID_Call{Call: _e.mock.On("GetDatabaseID", ctx, dbName)} +} + +func (_c *MockBroker_GetDatabaseID_Call) Run(run func(ctx context.Context, dbName string)) *MockBroker_GetDatabaseID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockBroker_GetDatabaseID_Call) Return(_a0 int64, _a1 error) *MockBroker_GetDatabaseID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_GetDatabaseID_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *MockBroker_GetDatabaseID_Call { + _c.Call.Return(run) + return _c +} + // HasCollection provides a mock function with given fields: ctx, collectionID func (_m *MockBroker) HasCollection(ctx context.Context, collectionID int64) (bool, error) { ret := _m.Called(ctx, collectionID) diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index d94350768d..0dabfd524d 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -263,7 +263,7 @@ func CheckDiskQuota(job ImportJob, meta *meta, imeta ImportMeta) (int64, error) } err := merr.WrapErrServiceQuotaExceeded("disk quota exceeded, please allocate more resources") - totalUsage, collectionsUsage := meta.GetCollectionBinlogSize() + totalUsage, collectionsUsage, _ := meta.GetCollectionBinlogSize() tasks := imeta.GetTaskBy(WithJob(job.GetJobID()), WithType(PreImportTaskType)) files := make([]*datapb.ImportFileStats, 0) diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index fa885abade..4f7eb0c176 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -85,6 +85,7 @@ type collectionInfo struct { Properties map[string]string CreatedAt Timestamp DatabaseName string + DatabaseID int64 } // NewMeta creates meta from provided `kv.TxnKV` @@ -200,6 +201,7 @@ func (m *meta) GetClonedCollectionInfo(collectionID UniqueID) *collectionInfo { StartPositions: common.CloneKeyDataPairs(coll.StartPositions), Properties: clonedProperties, DatabaseName: coll.DatabaseName, + DatabaseID: coll.DatabaseID, } return cloneColl @@ -257,10 +259,11 @@ func (m *meta) GetNumRowsOfCollection(collectionID UniqueID) int64 { } // GetCollectionBinlogSize returns the total binlog size and binlog size of collections. -func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) { +func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64, map[UniqueID]map[UniqueID]int64) { m.RLock() defer m.RUnlock() collectionBinlogSize := make(map[UniqueID]int64) + partitionBinlogSize := make(map[UniqueID]map[UniqueID]int64) collectionRowsNum := make(map[UniqueID]map[commonpb.SegmentState]int64) segments := m.segments.GetSegments() var total int64 @@ -270,6 +273,13 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) { total += segmentSize collectionBinlogSize[segment.GetCollectionID()] += segmentSize + partBinlogSize, ok := partitionBinlogSize[segment.GetCollectionID()] + if !ok { + partBinlogSize = make(map[int64]int64) + partitionBinlogSize[segment.GetCollectionID()] = partBinlogSize + } + partBinlogSize[segment.GetPartitionID()] += segmentSize + coll, ok := m.collections[segment.GetCollectionID()] if ok { metrics.DataCoordStoredBinlogSize.WithLabelValues(coll.DatabaseName, @@ -294,7 +304,7 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) { } } } - return total, collectionBinlogSize + return total, collectionBinlogSize, partitionBinlogSize } func (m *meta) GetAllCollectionNumRows() map[int64]int64 { diff --git a/internal/datacoord/meta_test.go b/internal/datacoord/meta_test.go index 62e277262d..4f6724fd60 100644 --- a/internal/datacoord/meta_test.go +++ b/internal/datacoord/meta_test.go @@ -603,13 +603,13 @@ func TestMeta_Basic(t *testing.T) { assert.NoError(t, err) // check TotalBinlogSize - total, collectionBinlogSize := meta.GetCollectionBinlogSize() + total, collectionBinlogSize, _ := meta.GetCollectionBinlogSize() assert.Len(t, collectionBinlogSize, 1) assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID]) assert.Equal(t, int64(size0+size1), total) meta.collections[collID] = collInfo - total, collectionBinlogSize = meta.GetCollectionBinlogSize() + total, collectionBinlogSize, _ = meta.GetCollectionBinlogSize() assert.Len(t, collectionBinlogSize, 1) assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID]) assert.Equal(t, int64(size0+size1), total) diff --git a/internal/datacoord/metrics_info.go b/internal/datacoord/metrics_info.go index 61bc46bae2..1085a08f4f 100644 --- a/internal/datacoord/metrics_info.go +++ b/internal/datacoord/metrics_info.go @@ -37,10 +37,11 @@ import ( // getQuotaMetrics returns DataCoordQuotaMetrics. func (s *Server) getQuotaMetrics() *metricsinfo.DataCoordQuotaMetrics { - total, colSizes := s.meta.GetCollectionBinlogSize() + total, colSizes, partSizes := s.meta.GetCollectionBinlogSize() return &metricsinfo.DataCoordQuotaMetrics{ TotalBinlogSize: total, CollectionBinlogSize: colSizes, + PartitionsBinlogSize: partSizes, } } diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index c6a16b180f..a5e77c6e30 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -329,6 +329,15 @@ type mockRootCoordClient struct { cnt int64 } +func (m *mockRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) { + return &rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + DbID: 1, + DbName: "default", + CreatedTimestamp: 1, + }, nil +} + func (m *mockRootCoordClient) Close() error { // TODO implement me panic("implement me") diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 80f922dd27..d6042f1927 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -1159,6 +1159,7 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i Properties: properties, CreatedAt: resp.GetCreatedTimestamp(), DatabaseName: resp.GetDbName(), + DatabaseID: resp.GetDbId(), } s.meta.AddCollection(collInfo) return nil diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 44c5bcafc9..74305fc50c 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1551,6 +1551,7 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt Partitions: req.GetPartitionIDs(), StartPositions: req.GetStartPositions(), Properties: properties, + DatabaseID: req.GetDbID(), } s.meta.AddCollection(collInfo) return merr.Success(), nil diff --git a/internal/datanode/rate_collector.go b/internal/datanode/rate_collector.go index b7052c3bb1..10e76ea548 100644 --- a/internal/datanode/rate_collector.go +++ b/internal/datanode/rate_collector.go @@ -47,7 +47,7 @@ func initGlobalRateCollector() error { // newRateCollector returns a new rateCollector. func newRateCollector() (*rateCollector, error) { - rc, err := ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity) + rc, err := ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, false) if err != nil { return nil, err } diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 21a0a22da2..13c70b5a00 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -648,3 +648,14 @@ func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRe } return ret.(*milvuspb.ListDatabasesResponse), err } + +func (c *Client) DescribeDatabase(ctx context.Context, req *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*rootcoordpb.DescribeDatabaseResponse, error) { + return client.DescribeDatabase(ctx, req) + }) +} diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 07ff7602bd..2d6e3f6155 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -77,6 +77,10 @@ type Server struct { newQueryCoordClient func() types.QueryCoordClient } +func (s *Server) DescribeDatabase(ctx context.Context, request *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) { + return s.rootCoord.DescribeDatabase(ctx, request) +} + func (s *Server) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { return s.rootCoord.CreateDatabase(ctx, request) } diff --git a/internal/mocks/mock_rootcoord.go b/internal/mocks/mock_rootcoord.go index 1e6b629af1..34dc55c0ed 100644 --- a/internal/mocks/mock_rootcoord.go +++ b/internal/mocks/mock_rootcoord.go @@ -861,6 +861,61 @@ func (_c *RootCoord_DescribeCollectionInternal_Call) RunAndReturn(run func(conte return _c } +// DescribeDatabase provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DescribeDatabase(_a0 context.Context, _a1 *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *rootcoordpb.DescribeDatabaseResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) *rootcoordpb.DescribeDatabaseResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RootCoord_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase' +type RootCoord_DescribeDatabase_Call struct { + *mock.Call +} + +// DescribeDatabase is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *rootcoordpb.DescribeDatabaseRequest +func (_e *RootCoord_Expecter) DescribeDatabase(_a0 interface{}, _a1 interface{}) *RootCoord_DescribeDatabase_Call { + return &RootCoord_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase", _a0, _a1)} +} + +func (_c *RootCoord_DescribeDatabase_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.DescribeDatabaseRequest)) *RootCoord_DescribeDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*rootcoordpb.DescribeDatabaseRequest)) + }) + return _c +} + +func (_c *RootCoord_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *RootCoord_DescribeDatabase_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RootCoord_DescribeDatabase_Call) RunAndReturn(run func(context.Context, *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error)) *RootCoord_DescribeDatabase_Call { + _c.Call.Return(run) + return _c +} + // DropAlias provides a mock function with given fields: _a0, _a1 func (_m *RootCoord) DropAlias(_a0 context.Context, _a1 *milvuspb.DropAliasRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/mocks/mock_rootcoord_client.go b/internal/mocks/mock_rootcoord_client.go index e9ee4246e4..83852dc4fc 100644 --- a/internal/mocks/mock_rootcoord_client.go +++ b/internal/mocks/mock_rootcoord_client.go @@ -1124,6 +1124,76 @@ func (_c *MockRootCoordClient_DescribeCollectionInternal_Call) RunAndReturn(run return _c } +// DescribeDatabase provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *rootcoordpb.DescribeDatabaseResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) *rootcoordpb.DescribeDatabaseResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase' +type MockRootCoordClient_DescribeDatabase_Call struct { + *mock.Call +} + +// DescribeDatabase is a helper method to define mock.On call +// - ctx context.Context +// - in *rootcoordpb.DescribeDatabaseRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DescribeDatabase(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DescribeDatabase_Call { + return &MockRootCoordClient_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DescribeDatabase_Call) Run(run func(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DescribeDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*rootcoordpb.DescribeDatabaseRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *MockRootCoordClient_DescribeDatabase_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DescribeDatabase_Call) RunAndReturn(run func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error)) *MockRootCoordClient_DescribeDatabase_Call { + _c.Call.Return(run) + return _c +} + // DropAlias provides a mock function with given fields: ctx, in, opts func (_m *MockRootCoordClient) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { _va := make([]interface{}, len(opts)) diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index fd6007a0dd..508f440233 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -634,6 +634,7 @@ message AlterCollectionRequest { repeated int64 partitionIDs = 3; repeated common.KeyDataPair start_positions = 4; repeated common.KeyValuePair properties = 5; + int64 dbID = 6; } message GcConfirmRequest { diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 93544a5288..6715af58d9 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -266,6 +266,13 @@ message ShowConfigurationsResponse { repeated common.KeyValuePair configuations = 2; } +enum RateScope { + Cluster = 0; + Database = 1; + Collection = 2; + Partition = 3; +} + enum RateType { DDLCollection = 0; DDLPartition = 1; diff --git a/internal/proto/proxy.proto b/internal/proto/proxy.proto index 5d02d1b4c2..a2ced624be 100644 --- a/internal/proto/proxy.proto +++ b/internal/proto/proxy.proto @@ -58,6 +58,7 @@ message RefreshPolicyInfoCacheRequest { string opKey = 3; } +// Deprecated: use ClusterLimiter instead it message CollectionRate { int64 collection = 1; repeated internal.Rate rates = 2; @@ -65,9 +66,27 @@ message CollectionRate { repeated common.ErrorCode codes = 4; } +message LimiterNode { + // self limiter information + Limiter limiter = 1; + // db id -> db limiter + // collection id -> collection limiter + // partition id -> partition limiter + map children = 2; +} + +message Limiter { + repeated internal.Rate rates = 1; + // we can use map to store quota states and error code, because key in map fields cannot be enum types + repeated milvus.QuotaState states = 2; + repeated common.ErrorCode codes = 3; +} + message SetRatesRequest { common.MsgBase base = 1; + // deprecated repeated CollectionRate rates = 2; + LimiterNode rootLimiter = 3; } message ListClientInfosRequest { diff --git a/internal/proto/root_coord.proto b/internal/proto/root_coord.proto index e576346e69..f197d23b06 100644 --- a/internal/proto/root_coord.proto +++ b/internal/proto/root_coord.proto @@ -140,6 +140,7 @@ service RootCoord { rpc CreateDatabase(milvus.CreateDatabaseRequest) returns (common.Status) {} rpc DropDatabase(milvus.DropDatabaseRequest) returns (common.Status) {} rpc ListDatabases(milvus.ListDatabasesRequest) returns (milvus.ListDatabasesResponse) {} + rpc DescribeDatabase(DescribeDatabaseRequest) returns(DescribeDatabaseResponse){} } message AllocTimestampRequest { @@ -206,3 +207,14 @@ message GetCredentialResponse { string password = 3; } +message DescribeDatabaseRequest { + common.MsgBase base = 1; + string db_name = 2; +} + +message DescribeDatabaseResponse { + common.Status status = 1; + string db_name = 2; + int64 dbID = 3; + uint64 created_timestamp = 4; +} diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 0c16197a65..59a01c3f5d 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -56,6 +56,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/util/requestutil" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -160,8 +162,10 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p for _, alias := range aliasName { metrics.CleanupProxyCollectionMetrics(paramtable.GetNodeID(), alias) } + DeregisterSubLabel(ratelimitutil.GetCollectionSubLabel(request.GetDbName(), request.GetCollectionName())) } else if msgType == commonpb.MsgType_DropDatabase { metrics.CleanupProxyDBMetrics(paramtable.GetNodeID(), request.GetDbName()) + DeregisterSubLabel(ratelimitutil.GetDBSubLabel(request.GetDbName())) } log.Info("complete to invalidate collection meta cache") @@ -289,6 +293,7 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab } log.Info(rpcDone(method)) + DeregisterSubLabel(ratelimitutil.GetDBSubLabel(request.GetDbName())) metrics.ProxyFunctionCall.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), method, @@ -527,6 +532,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol zap.Uint64("BeginTs", dct.BeginTs()), zap.Uint64("EndTs", dct.EndTs()), ) + DeregisterSubLabel(ratelimitutil.GetCollectionSubLabel(request.GetDbName(), request.GetCollectionName())) metrics.ProxyFunctionCall.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), @@ -2680,11 +2686,11 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) hookutil.FailCntKey: len(it.result.ErrIndex), }) SetReportValue(it.result.GetStatus(), v) - - rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size())) if merr.Ok(it.result.GetStatus()) { metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeUpsert, dbName, username).Add(float64(v)) } + + rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size())) metrics.ProxyFunctionCall.WithLabelValues(nodeID, method, metrics.SuccessLabel, dbName, collectionName).Inc() successCnt := it.result.UpsertCnt - int64(len(it.result.ErrIndex)) @@ -2700,6 +2706,19 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) return it.result, nil } +func GetDBAndCollectionRateSubLabels(req any) []string { + subLabels := make([]string, 2) + dbName, _ := requestutil.GetDbNameFromRequest(req) + if dbName != "" { + subLabels[0] = ratelimitutil.GetDBSubLabel(dbName.(string)) + } + collectionName, _ := requestutil.GetCollectionNameFromRequest(req) + if collectionName != "" { + subLabels[1] = ratelimitutil.GetCollectionSubLabel(dbName.(string), collectionName.(string)) + } + return subLabels +} + // Search searches the most similar records of requests. func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { var err error @@ -2734,7 +2753,8 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest) request.GetCollectionName(), ).Add(float64(request.GetNq())) - rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq())) + subLabels := GetDBAndCollectionRateSubLabels(request) + rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()), subLabels...) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.SearchResults{ @@ -2909,8 +2929,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest) if merr.Ok(qt.result.GetStatus()) { metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v)) } + metrics.ProxyReadReqSendBytes.WithLabelValues(nodeID).Add(float64(sentSize)) - rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize)) + rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...) } return qt.result, nil } @@ -2941,6 +2962,13 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea request.GetCollectionName(), ).Add(float64(receiveSize)) + subLabels := GetDBAndCollectionRateSubLabels(request) + allNQ := int64(0) + for _, searchRequest := range request.Requests { + allNQ += searchRequest.GetNq() + } + rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(allNQ), subLabels...) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.SearchResults{ Status: merr.Status(err), @@ -3098,8 +3126,9 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea if merr.Ok(qt.result.GetStatus()) { metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeHybridSearch, dbName, username).Add(float64(v)) } + metrics.ProxyReadReqSendBytes.WithLabelValues(nodeID).Add(float64(sentSize)) - rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize)) + rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...) } return qt.result, nil } @@ -3246,7 +3275,8 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryRes request.GetCollectionName(), ).Add(float64(1)) - rateCol.Add(internalpb.RateType_DQLQuery.String(), 1) + subLabels := GetDBAndCollectionRateSubLabels(request) + rateCol.Add(internalpb.RateType_DQLQuery.String(), 1, subLabels...) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.QueryResults{ @@ -3364,7 +3394,7 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryRes ).Observe(float64(tr.ElapseSpan().Milliseconds())) sentSize := proto.Size(qt.result) - rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize)) + rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...) metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize)) return qt.result, nil @@ -5092,7 +5122,7 @@ func (node *Proxy) SetRates(ctx context.Context, request *proxypb.SetRatesReques return resp, nil } - err := node.multiRateLimiter.SetRates(request.GetRates()) + err := node.simpleLimiter.SetRates(request.GetRootLimiter()) // TODO: set multiple rate limiter rates if err != nil { resp = merr.Status(err) @@ -5162,12 +5192,9 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt }, nil } - states, reasons := node.multiRateLimiter.GetQuotaStates() return &milvuspb.CheckHealthResponse{ - Status: merr.Success(), - QuotaStates: states, - Reasons: reasons, - IsHealthy: true, + Status: merr.Success(), + IsHealthy: true, }, nil } @@ -5978,3 +6005,10 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR metrics.ProxyReqLatency.WithLabelValues(nodeID, method).Observe(float64(tr.ElapseSpan().Milliseconds())) return resp, nil } + +// DeregisterSubLabel must add the sub-labels here if using other labels for the sub-labels +func DeregisterSubLabel(subLabel string) { + rateCol.DeregisterSubLabel(internalpb.RateType_DQLQuery.String(), subLabel) + rateCol.DeregisterSubLabel(internalpb.RateType_DQLSearch.String(), subLabel) + rateCol.DeregisterSubLabel(metricsinfo.ReadResultThroughput, subLabel) +} diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 18eb04d9d9..b6bcb47aa8 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -63,6 +63,7 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { chMgr.EXPECT().removeDMLStream(mock.Anything).Return() node := &Proxy{chMgr: chMgr} + _ = node.initRateCollector() node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() @@ -78,7 +79,7 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { func TestProxy_CheckHealth(t *testing.T) { t.Run("not healthy", func(t *testing.T) { node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Abnormal) ctx := context.Background() resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) @@ -96,7 +97,7 @@ func TestProxy_CheckHealth(t *testing.T) { dataCoord: NewDataCoordMock(), session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}, } - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) @@ -129,7 +130,7 @@ func TestProxy_CheckHealth(t *testing.T) { queryCoord: qc, dataCoord: dataCoordMock, } - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) @@ -146,7 +147,7 @@ func TestProxy_CheckHealth(t *testing.T) { dataCoord: NewDataCoordMock(), queryCoord: qc, } - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -156,18 +157,30 @@ func TestProxy_CheckHealth(t *testing.T) { states := []milvuspb.QuotaState{milvuspb.QuotaState_DenyToWrite, milvuspb.QuotaState_DenyToRead} codes := []commonpb.ErrorCode{commonpb.ErrorCode_MemoryQuotaExhausted, commonpb.ErrorCode_ForceDeny} - node.multiRateLimiter.SetRates([]*proxypb.CollectionRate{ - { - Collection: 1, - States: states, - Codes: codes, + err = node.simpleLimiter.SetRates(&proxypb.LimiterNode{ + Limiter: &proxypb.Limiter{}, + // db level + Children: map[int64]*proxypb.LimiterNode{ + 1: { + Limiter: &proxypb.Limiter{}, + // collection level + Children: map[int64]*proxypb.LimiterNode{ + 100: { + Limiter: &proxypb.Limiter{ + States: states, + Codes: codes, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + }, + }, }, }) + assert.NoError(t, err) + resp, err = node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) assert.Equal(t, true, resp.IsHealthy) - assert.Equal(t, 2, len(resp.GetQuotaStates())) - assert.Equal(t, 2, len(resp.GetReasons())) }) } @@ -229,7 +242,7 @@ func TestProxy_ResourceGroup(t *testing.T) { node, err := NewProxy(ctx, factory) assert.NoError(t, err) - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) qc := mocks.NewMockQueryCoordClient(t) @@ -321,7 +334,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) { node, err := NewProxy(ctx, factory) assert.NoError(t, err) - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) qc := mocks.NewMockQueryCoordClient(t) @@ -922,7 +935,7 @@ func TestProxyCreateDatabase(t *testing.T) { node.tsoAllocator = ×tampAllocator{ tso: newMockTimestampAllocatorInterface(), } - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched.ddQueue.setMaxTaskNum(10) @@ -977,11 +990,12 @@ func TestProxyDropDatabase(t *testing.T) { ctx := context.Background() node, err := NewProxy(ctx, factory) + node.initRateCollector() assert.NoError(t, err) node.tsoAllocator = ×tampAllocator{ tso: newMockTimestampAllocatorInterface(), } - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched.ddQueue.setMaxTaskNum(10) @@ -1040,7 +1054,7 @@ func TestProxyListDatabase(t *testing.T) { node.tsoAllocator = ×tampAllocator{ tso: newMockTimestampAllocatorInterface(), } - node.multiRateLimiter = NewMultiRateLimiter() + node.simpleLimiter = NewSimpleLimiter() node.UpdateStateCode(commonpb.StateCode_Healthy) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched.ddQueue.setMaxTaskNum(10) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 1a5326746e..92c0f8a6b2 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -89,10 +89,10 @@ type Cache interface { RemoveDatabase(ctx context.Context, database string) HasDatabase(ctx context.Context, database string) bool + GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) // AllocID is only using on requests that need to skip timestamp allocation, don't overuse it. AllocID(ctx context.Context) (int64, error) } - type collectionBasicInfo struct { collID typeutil.UniqueID createdTimestamp uint64 @@ -109,6 +109,11 @@ type collectionInfo struct { consistencyLevel commonpb.ConsistencyLevel } +type databaseInfo struct { + dbID typeutil.UniqueID + createdTimestamp uint64 +} + // schemaInfo is a helper function wraps *schemapb.CollectionSchema // with extra fields mapping and methods type schemaInfo struct { @@ -244,17 +249,19 @@ type MetaCache struct { rootCoord types.RootCoordClient queryCoord types.QueryCoordClient - collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info - collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders - dbInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName - credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load - privilegeInfos map[string]struct{} // privileges cache - userToRoles map[string]map[string]struct{} // user to role cache - mu sync.RWMutex - credMut sync.RWMutex - leaderMut sync.RWMutex - shardMgr shardClientMgr - sfGlobal conc.Singleflight[*collectionInfo] + dbInfo map[string]*databaseInfo // database -> db_info + collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info + collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders + dbCollectionInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName + credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load + privilegeInfos map[string]struct{} // privileges cache + userToRoles map[string]map[string]struct{} // user to role cache + mu sync.RWMutex + credMut sync.RWMutex + leaderMut sync.RWMutex + shardMgr shardClientMgr + sfGlobal conc.Singleflight[*collectionInfo] + sfDB conc.Singleflight[*databaseInfo] IDStart int64 IDCount int64 @@ -287,15 +294,16 @@ func InitMetaCache(ctx context.Context, rootCoord types.RootCoordClient, queryCo // NewMetaCache creates a MetaCache with provided RootCoord and QueryNode func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordClient, shardMgr shardClientMgr) (*MetaCache, error) { return &MetaCache{ - rootCoord: rootCoord, - queryCoord: queryCoord, - collInfo: map[string]map[string]*collectionInfo{}, - collLeader: map[string]map[string]*shardLeaders{}, - dbInfo: map[string]map[typeutil.UniqueID]string{}, - credMap: map[string]*internalpb.CredentialInfo{}, - shardMgr: shardMgr, - privilegeInfos: map[string]struct{}{}, - userToRoles: map[string]map[string]struct{}{}, + rootCoord: rootCoord, + queryCoord: queryCoord, + dbInfo: map[string]*databaseInfo{}, + collInfo: map[string]map[string]*collectionInfo{}, + collLeader: map[string]map[string]*shardLeaders{}, + dbCollectionInfo: map[string]map[typeutil.UniqueID]string{}, + credMap: map[string]*internalpb.CredentialInfo{}, + shardMgr: shardMgr, + privilegeInfos: map[string]struct{}{}, + userToRoles: map[string]map[string]struct{}{}, }, nil } @@ -510,7 +518,7 @@ func (m *MetaCache) innerGetCollectionByID(collectionID int64) (string, string) m.mu.RLock() defer m.mu.RUnlock() - for database, db := range m.dbInfo { + for database, db := range m.dbCollectionInfo { name, ok := db[collectionID] if ok { return database, name @@ -554,7 +562,7 @@ func (m *MetaCache) updateDBInfo(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - m.dbInfo = dbInfo + m.dbCollectionInfo = dbInfo return nil } @@ -739,6 +747,19 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio return partitions, nil } +func (m *MetaCache) describeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { + req := &rootcoordpb.DescribeDatabaseRequest{ + DbName: dbName, + } + + resp, err := m.rootCoord.DescribeDatabase(ctx, req) + if err = merr.CheckRPCCall(resp, err); err != nil { + return nil, err + } + + return resp, nil +} + // parsePartitionsInfo parse partitionInfo list to partitionInfos struct. // prepare all name to id & info map // try parse partition names to partitionKey index. @@ -1084,6 +1105,7 @@ func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) { m.mu.Lock() defer m.mu.Unlock() delete(m.collInfo, database) + delete(m.dbInfo, database) } func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool { @@ -1093,6 +1115,41 @@ func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool { return ok } +func (m *MetaCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) { + dbInfo := m.safeGetDBInfo(database) + if dbInfo != nil { + return dbInfo, nil + } + + dbInfo, err, _ := m.sfDB.Do(database, func() (*databaseInfo, error) { + resp, err := m.describeDatabase(ctx, database) + if err != nil { + return nil, err + } + + m.mu.Lock() + defer m.mu.Unlock() + dbInfo := &databaseInfo{ + dbID: resp.GetDbID(), + createdTimestamp: resp.GetCreatedTimestamp(), + } + m.dbInfo[database] = dbInfo + return dbInfo, nil + }) + + return dbInfo, err +} + +func (m *MetaCache) safeGetDBInfo(database string) *databaseInfo { + m.mu.RLock() + defer m.mu.RUnlock() + db, ok := m.dbInfo[database] + if !ok { + return nil + } + return db +} + func (m *MetaCache) AllocID(ctx context.Context) (int64, error) { m.IDLock.Lock() defer m.IDLock.Unlock() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index df2a65fbf6..e93bf7c57f 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -817,6 +817,49 @@ func TestMetaCache_Database(t *testing.T) { assert.Equal(t, CheckDatabase(ctx, dbName), true) } +func TestGetDatabaseInfo(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctx := context.Background() + rootCoord := mocks.NewMockRootCoordClient(t) + queryCoord := &mocks.MockQueryCoordClient{} + shardMgr := newShardClientMgr() + cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr) + assert.NoError(t, err) + + rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + DbID: 1, + DbName: "default", + }, nil).Once() + { + dbInfo, err := cache.GetDatabaseInfo(ctx, "default") + assert.NoError(t, err) + assert.Equal(t, UniqueID(1), dbInfo.dbID) + } + + { + dbInfo, err := cache.GetDatabaseInfo(ctx, "default") + assert.NoError(t, err) + assert.Equal(t, UniqueID(1), dbInfo.dbID) + } + }) + + t.Run("error", func(t *testing.T) { + ctx := context.Background() + rootCoord := mocks.NewMockRootCoordClient(t) + queryCoord := &mocks.MockQueryCoordClient{} + shardMgr := newShardClientMgr() + cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr) + assert.NoError(t, err) + + rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Status(errors.New("mock error: describe database")), + }, nil).Once() + _, err = cache.GetDatabaseInfo(ctx, "default") + assert.Error(t, err) + }) +} + func TestMetaCache_AllocID(t *testing.T) { ctx := context.Background() queryCoord := &mocks.MockQueryCoordClient{} @@ -935,9 +978,9 @@ func TestGlobalMetaCache_UpdateDBInfo(t *testing.T) { }, nil).Once() err := cache.updateDBInfo(ctx) assert.NoError(t, err) - assert.Len(t, cache.dbInfo, 1) - assert.Len(t, cache.dbInfo["db1"], 1) - assert.Equal(t, "collection1", cache.dbInfo["db1"][1]) + assert.Len(t, cache.dbCollectionInfo, 1) + assert.Len(t, cache.dbCollectionInfo["db1"], 1) + assert.Equal(t, "collection1", cache.dbCollectionInfo["db1"][1]) }) } diff --git a/internal/proxy/metrics_info.go b/internal/proxy/metrics_info.go index c02fae5aa2..109ca02211 100644 --- a/internal/proxy/metrics_info.go +++ b/internal/proxy/metrics_info.go @@ -50,12 +50,29 @@ func getQuotaMetrics() (*metricsinfo.ProxyQuotaMetrics, error) { Rate: rate, }) } + + getSubLabelRateMetric := func(label string) { + rates, err2 := rateCol.RateSubLabel(label, ratelimitutil.DefaultAvgDuration) + if err2 != nil { + err = err2 + return + } + for s, f := range rates { + rms = append(rms, metricsinfo.RateMetric{ + Label: s, + Rate: f, + }) + } + } getRateMetric(internalpb.RateType_DMLInsert.String()) getRateMetric(internalpb.RateType_DMLUpsert.String()) getRateMetric(internalpb.RateType_DMLDelete.String()) getRateMetric(internalpb.RateType_DQLSearch.String()) + getSubLabelRateMetric(internalpb.RateType_DQLSearch.String()) getRateMetric(internalpb.RateType_DQLQuery.String()) + getSubLabelRateMetric(internalpb.RateType_DQLQuery.String()) getRateMetric(metricsinfo.ReadResultThroughput) + getSubLabelRateMetric(metricsinfo.ReadResultThroughput) if err != nil { return nil, err } diff --git a/internal/proxy/mock_cache.go b/internal/proxy/mock_cache.go index 6d61e17df0..06c496d2eb 100644 --- a/internal/proxy/mock_cache.go +++ b/internal/proxy/mock_cache.go @@ -450,6 +450,61 @@ func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Contex return _c } +// GetDatabaseInfo provides a mock function with given fields: ctx, database +func (_m *MockCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) { + ret := _m.Called(ctx, database) + + var r0 *databaseInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*databaseInfo, error)); ok { + return rf(ctx, database) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *databaseInfo); ok { + r0 = rf(ctx, database) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*databaseInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, database) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCache_GetDatabaseInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseInfo' +type MockCache_GetDatabaseInfo_Call struct { + *mock.Call +} + +// GetDatabaseInfo is a helper method to define mock.On call +// - ctx context.Context +// - database string +func (_e *MockCache_Expecter) GetDatabaseInfo(ctx interface{}, database interface{}) *MockCache_GetDatabaseInfo_Call { + return &MockCache_GetDatabaseInfo_Call{Call: _e.mock.On("GetDatabaseInfo", ctx, database)} +} + +func (_c *MockCache_GetDatabaseInfo_Call) Run(run func(ctx context.Context, database string)) *MockCache_GetDatabaseInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockCache_GetDatabaseInfo_Call) Return(_a0 *databaseInfo, _a1 error) *MockCache_GetDatabaseInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCache_GetDatabaseInfo_Call) RunAndReturn(run func(context.Context, string) (*databaseInfo, error)) *MockCache_GetDatabaseInfo_Call { + _c.Call.Return(run) + return _c +} + // GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) { ret := _m.Called(ctx, database, collectionName, partitionName) diff --git a/internal/proxy/multi_rate_limiter.go b/internal/proxy/multi_rate_limiter.go deleted file mode 100644 index 7744c4d38b..0000000000 --- a/internal/proxy/multi_rate_limiter.go +++ /dev/null @@ -1,375 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "context" - "fmt" - "strconv" - "sync" - "time" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/pkg/config" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/ratelimitutil" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -var QuotaErrorString = map[commonpb.ErrorCode]string{ - commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator", - commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources", - commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources", - commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay", -} - -func GetQuotaErrorString(errCode commonpb.ErrorCode) string { - return QuotaErrorString[errCode] -} - -// MultiRateLimiter includes multilevel rate limiters, such as global rateLimiter, -// collection level rateLimiter and so on. It also implements Limiter interface. -type MultiRateLimiter struct { - quotaStatesMu sync.RWMutex - // for DML and DQL - collectionLimiters map[int64]*rateLimiter - // for DDL - globalDDLLimiter *rateLimiter -} - -// NewMultiRateLimiter returns a new MultiRateLimiter. -func NewMultiRateLimiter() *MultiRateLimiter { - m := &MultiRateLimiter{ - collectionLimiters: make(map[int64]*rateLimiter, 0), - globalDDLLimiter: newRateLimiter(true), - } - return m -} - -// Check checks if request would be limited or denied. -func (m *MultiRateLimiter) Check(collectionIDs []int64, rt internalpb.RateType, n int) error { - if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() { - return nil - } - - m.quotaStatesMu.RLock() - defer m.quotaStatesMu.RUnlock() - - checkFunc := func(limiter *rateLimiter) error { - if limiter == nil { - return nil - } - - limit, rate := limiter.limit(rt, n) - if rate == 0 { - return limiter.getQuotaExceededError(rt) - } - if limit { - return limiter.getRateLimitError(rate) - } - return nil - } - - // first, check global level rate limits - ret := checkFunc(m.globalDDLLimiter) - - // second check collection level rate limits - // only dml, dql and flush have collection level rate limits - if ret == nil && len(collectionIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) { - // store done limiters to cancel them when error occurs. - doneLimiters := make([]*rateLimiter, 0, len(collectionIDs)+1) - doneLimiters = append(doneLimiters, m.globalDDLLimiter) - - for _, collectionID := range collectionIDs { - ret = checkFunc(m.collectionLimiters[collectionID]) - if ret != nil { - for _, limiter := range doneLimiters { - limiter.cancel(rt, n) - } - break - } - doneLimiters = append(doneLimiters, m.collectionLimiters[collectionID]) - } - } - return ret -} - -func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool { - // Most ddl is global level, only DDLFlush will be applied at collection - switch rt { - case internalpb.RateType_DDLCollection, internalpb.RateType_DDLPartition, internalpb.RateType_DDLIndex, - internalpb.RateType_DDLCompaction: - return true - default: - return false - } -} - -// GetQuotaStates returns quota states. -func (m *MultiRateLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) { - m.quotaStatesMu.RLock() - defer m.quotaStatesMu.RUnlock() - serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode]) - - // deduplicate same (state, code) pair from different collection - for _, limiter := range m.collectionLimiters { - limiter.quotaStates.Range(func(state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool { - if serviceStates[state] == nil { - serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]() - } - serviceStates[state].Insert(errCode) - return true - }) - } - - states := make([]milvuspb.QuotaState, 0) - reasons := make([]string, 0) - for state, errCodes := range serviceStates { - for errCode := range errCodes { - states = append(states, state) - reasons = append(reasons, GetQuotaErrorString(errCode)) - } - } - - return states, reasons -} - -// SetQuotaStates sets quota states for MultiRateLimiter. -func (m *MultiRateLimiter) SetRates(rates []*proxypb.CollectionRate) error { - m.quotaStatesMu.Lock() - defer m.quotaStatesMu.Unlock() - collectionSet := typeutil.NewUniqueSet() - for _, collectionRates := range rates { - collectionSet.Insert(collectionRates.Collection) - rateLimiter, ok := m.collectionLimiters[collectionRates.GetCollection()] - if !ok { - rateLimiter = newRateLimiter(false) - } - err := rateLimiter.setRates(collectionRates) - if err != nil { - return err - } - m.collectionLimiters[collectionRates.GetCollection()] = rateLimiter - } - - // remove dropped collection's rate limiter - for collectionID := range m.collectionLimiters { - if !collectionSet.Contain(collectionID) { - delete(m.collectionLimiters, collectionID) - } - } - return nil -} - -// rateLimiter implements Limiter. -type rateLimiter struct { - limiters *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter] - quotaStates *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode] -} - -// newRateLimiter returns a new RateLimiter. -func newRateLimiter(globalLevel bool) *rateLimiter { - rl := &rateLimiter{ - limiters: typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter](), - quotaStates: typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode](), - } - rl.registerLimiters(globalLevel) - return rl -} - -// limit returns true, the request will be rejected. -// Otherwise, the request will pass. -func (rl *rateLimiter) limit(rt internalpb.RateType, n int) (bool, float64) { - limit, ok := rl.limiters.Get(rt) - if !ok { - return false, -1 - } - return !limit.AllowN(time.Now(), n), float64(limit.Limit()) -} - -func (rl *rateLimiter) cancel(rt internalpb.RateType, n int) { - limit, ok := rl.limiters.Get(rt) - if !ok { - return - } - limit.Cancel(n) -} - -func (rl *rateLimiter) setRates(collectionRate *proxypb.CollectionRate) error { - log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0).With( - zap.Int64("proxyNodeID", paramtable.GetNodeID()), - zap.Int64("CollectionID", collectionRate.Collection), - ) - for _, r := range collectionRate.GetRates() { - if limit, ok := rl.limiters.Get(r.GetRt()); ok { - limit.SetLimit(ratelimitutil.Limit(r.GetR())) - setRateGaugeByRateType(r.GetRt(), paramtable.GetNodeID(), collectionRate.Collection, r.GetR()) - } else { - return fmt.Errorf("unregister rateLimiter for rateType %s", r.GetRt().String()) - } - log.RatedDebug(30, "current collection rates in proxy", - zap.String("rateType", r.Rt.String()), - zap.String("rateLimit", ratelimitutil.Limit(r.GetR()).String()), - ) - } - - // clear old quota states - rl.quotaStates = typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]() - for i := 0; i < len(collectionRate.GetStates()); i++ { - rl.quotaStates.Insert(collectionRate.States[i], collectionRate.Codes[i]) - log.RatedWarn(30, "Proxy set collection quota states", - zap.String("state", collectionRate.GetStates()[i].String()), - zap.String("reason", collectionRate.GetCodes()[i].String()), - ) - } - - return nil -} - -func (rl *rateLimiter) getQuotaExceededError(rt internalpb.RateType) error { - switch rt { - case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad: - if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok { - return merr.WrapErrServiceQuotaExceeded(GetQuotaErrorString(errCode)) - } - case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery: - if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok { - return merr.WrapErrServiceQuotaExceeded(GetQuotaErrorString(errCode)) - } - } - return nil -} - -func (rl *rateLimiter) getRateLimitError(rate float64) error { - return merr.WrapErrServiceRateLimit(rate, "request is rejected by grpc RateLimiter middleware, please retry later") -} - -// setRateGaugeByRateType sets ProxyLimiterRate metrics. -func setRateGaugeByRateType(rateType internalpb.RateType, nodeID int64, collectionID int64, rate float64) { - if ratelimitutil.Limit(rate) == ratelimitutil.Inf { - return - } - nodeIDStr := strconv.FormatInt(nodeID, 10) - collectionIDStr := strconv.FormatInt(collectionID, 10) - switch rateType { - case internalpb.RateType_DMLInsert: - metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.InsertLabel).Set(rate) - case internalpb.RateType_DMLUpsert: - metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.UpsertLabel).Set(rate) - case internalpb.RateType_DMLDelete: - metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.DeleteLabel).Set(rate) - case internalpb.RateType_DQLSearch: - metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.SearchLabel).Set(rate) - case internalpb.RateType_DQLQuery: - metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.QueryLabel).Set(rate) - } -} - -// registerLimiters register limiter for all rate types. -func (rl *rateLimiter) registerLimiters(globalLevel bool) { - log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0) - quotaConfig := &Params.QuotaConfig - for rt := range internalpb.RateType_name { - var r *paramtable.ParamItem - switch internalpb.RateType(rt) { - case internalpb.RateType_DDLCollection: - r = "aConfig.DDLCollectionRate - case internalpb.RateType_DDLPartition: - r = "aConfig.DDLPartitionRate - case internalpb.RateType_DDLIndex: - r = "aConfig.MaxIndexRate - case internalpb.RateType_DDLFlush: - if globalLevel { - r = "aConfig.MaxFlushRate - } else { - r = "aConfig.MaxFlushRatePerCollection - } - case internalpb.RateType_DDLCompaction: - r = "aConfig.MaxCompactionRate - case internalpb.RateType_DMLInsert: - if globalLevel { - r = "aConfig.DMLMaxInsertRate - } else { - r = "aConfig.DMLMaxInsertRatePerCollection - } - case internalpb.RateType_DMLUpsert: - if globalLevel { - r = "aConfig.DMLMaxUpsertRate - } else { - r = "aConfig.DMLMaxUpsertRatePerCollection - } - case internalpb.RateType_DMLDelete: - if globalLevel { - r = "aConfig.DMLMaxDeleteRate - } else { - r = "aConfig.DMLMaxDeleteRatePerCollection - } - case internalpb.RateType_DMLBulkLoad: - if globalLevel { - r = "aConfig.DMLMaxBulkLoadRate - } else { - r = "aConfig.DMLMaxBulkLoadRatePerCollection - } - case internalpb.RateType_DQLSearch: - if globalLevel { - r = "aConfig.DQLMaxSearchRate - } else { - r = "aConfig.DQLMaxSearchRatePerCollection - } - case internalpb.RateType_DQLQuery: - if globalLevel { - r = "aConfig.DQLMaxQueryRate - } else { - r = "aConfig.DQLMaxQueryRatePerCollection - } - } - limit := ratelimitutil.Limit(r.GetAsFloat()) - burst := r.GetAsFloat() // use rate as burst, because Limiter is with punishment mechanism, burst is insignificant. - rl.limiters.GetOrInsert(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst)) - onEvent := func(rateType internalpb.RateType) func(*config.Event) { - return func(event *config.Event) { - f, err := strconv.ParseFloat(r.Formatter(event.Value), 64) - if err != nil { - log.Info("Error format for rateLimit", - zap.String("rateType", rateType.String()), - zap.String("key", event.Key), - zap.String("value", event.Value), - zap.Error(err)) - return - } - limit, ok := rl.limiters.Get(rateType) - if !ok { - return - } - limit.SetLimit(ratelimitutil.Limit(f)) - } - }(internalpb.RateType(rt)) - paramtable.Get().Watch(r.Key, config.NewHandler(fmt.Sprintf("rateLimiter-%d", rt), onEvent)) - log.RatedDebug(30, "RateLimiter register for rateType", - zap.String("rateType", internalpb.RateType_name[rt]), - zap.String("rateLimit", ratelimitutil.Limit(r.GetAsFloat()).String()), - zap.String("burst", fmt.Sprintf("%v", burst))) - } -} diff --git a/internal/proxy/multi_rate_limiter_test.go b/internal/proxy/multi_rate_limiter_test.go deleted file mode 100644 index 489bf2a9a3..0000000000 --- a/internal/proxy/multi_rate_limiter_test.go +++ /dev/null @@ -1,338 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "context" - "fmt" - "math" - "math/rand" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/ratelimitutil" -) - -func TestMultiRateLimiter(t *testing.T) { - collectionID := int64(1) - t.Run("test multiRateLimiter", func(t *testing.T) { - bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() - paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") - multiLimiter := NewMultiRateLimiter() - multiLimiter.collectionLimiters[collectionID] = newRateLimiter(false) - for _, rt := range internalpb.RateType_value { - if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) { - multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1)) - } else { - multiLimiter.collectionLimiters[collectionID].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) - } - } - for _, rt := range internalpb.RateType_value { - if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) { - err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 5) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 5) - assert.ErrorIs(t, err, merr.ErrServiceRateLimit) - } else { - err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), math.MaxInt) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), math.MaxInt) - assert.ErrorIs(t, err, merr.ErrServiceRateLimit) - } - } - Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) - }) - - t.Run("test global static limit", func(t *testing.T) { - bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() - paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") - multiLimiter := NewMultiRateLimiter() - multiLimiter.collectionLimiters[1] = newRateLimiter(false) - multiLimiter.collectionLimiters[2] = newRateLimiter(false) - multiLimiter.collectionLimiters[3] = newRateLimiter(false) - for _, rt := range internalpb.RateType_value { - if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) { - multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1)) - } else { - multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1)) - multiLimiter.collectionLimiters[1].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1)) - multiLimiter.collectionLimiters[2].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1)) - multiLimiter.collectionLimiters[3].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1)) - } - } - for _, rt := range internalpb.RateType_value { - if internalpb.RateType(rt) == internalpb.RateType_DDLFlush { - err := multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 5) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 5) - assert.ErrorIs(t, err, merr.ErrServiceRateLimit) - } else if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) { - err := multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 5) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 5) - assert.ErrorIs(t, err, merr.ErrServiceRateLimit) - } else { - err := multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{2}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - err = multiLimiter.Check([]int64{3}, internalpb.RateType(rt), 1) - assert.ErrorIs(t, err, merr.ErrServiceRateLimit) - } - } - Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) - }) - - t.Run("not enable quotaAndLimit", func(t *testing.T) { - multiLimiter := NewMultiRateLimiter() - multiLimiter.collectionLimiters[collectionID] = newRateLimiter(false) - bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() - paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false") - for _, rt := range internalpb.RateType_value { - err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1) - assert.NoError(t, err) - } - Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) - }) - - t.Run("test limit", func(t *testing.T) { - run := func(insertRate float64) { - bakInsertRate := Params.QuotaConfig.DMLMaxInsertRate.GetValue() - paramtable.Get().Save(Params.QuotaConfig.DMLMaxInsertRate.Key, fmt.Sprintf("%f", insertRate)) - multiLimiter := NewMultiRateLimiter() - bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() - paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") - err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType_DMLInsert, 1*1024*1024) - assert.NoError(t, err) - Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) - Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate) - } - run(math.MaxFloat64) - run(math.MaxFloat64 / 1.2) - run(math.MaxFloat64 / 2) - run(math.MaxFloat64 / 3) - run(math.MaxFloat64 / 10000) - }) - - t.Run("test set rates", func(t *testing.T) { - multiLimiter := NewMultiRateLimiter() - zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value)) - for _, rt := range internalpb.RateType_value { - zeroRates = append(zeroRates, &internalpb.Rate{ - Rt: internalpb.RateType(rt), R: 0, - }) - } - - err := multiLimiter.SetRates([]*proxypb.CollectionRate{ - { - Collection: 1, - Rates: zeroRates, - }, - { - Collection: 2, - Rates: zeroRates, - }, - }) - assert.NoError(t, err) - }) - - t.Run("test quota states", func(t *testing.T) { - multiLimiter := NewMultiRateLimiter() - zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value)) - for _, rt := range internalpb.RateType_value { - zeroRates = append(zeroRates, &internalpb.Rate{ - Rt: internalpb.RateType(rt), R: 0, - }) - } - - err := multiLimiter.SetRates([]*proxypb.CollectionRate{ - { - Collection: 1, - Rates: zeroRates, - States: []milvuspb.QuotaState{ - milvuspb.QuotaState_DenyToWrite, - }, - Codes: []commonpb.ErrorCode{ - commonpb.ErrorCode_DiskQuotaExhausted, - }, - }, - { - Collection: 2, - Rates: zeroRates, - - States: []milvuspb.QuotaState{ - milvuspb.QuotaState_DenyToRead, - }, - Codes: []commonpb.ErrorCode{ - commonpb.ErrorCode_ForceDeny, - }, - }, - }) - assert.NoError(t, err) - - states, codes := multiLimiter.GetQuotaStates() - assert.Len(t, states, 2) - assert.Len(t, codes, 2) - assert.Contains(t, codes, GetQuotaErrorString(commonpb.ErrorCode_DiskQuotaExhausted)) - assert.Contains(t, codes, GetQuotaErrorString(commonpb.ErrorCode_ForceDeny)) - }) -} - -func TestRateLimiter(t *testing.T) { - t.Run("test limit", func(t *testing.T) { - paramtable.Get().CleanEvent() - limiter := newRateLimiter(false) - for _, rt := range internalpb.RateType_value { - limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) - } - for _, rt := range internalpb.RateType_value { - ok, _ := limiter.limit(internalpb.RateType(rt), 1) - assert.False(t, ok) - ok, _ = limiter.limit(internalpb.RateType(rt), math.MaxInt) - assert.False(t, ok) - ok, _ = limiter.limit(internalpb.RateType(rt), math.MaxInt) - assert.True(t, ok) - } - }) - - t.Run("test setRates", func(t *testing.T) { - paramtable.Get().CleanEvent() - limiter := newRateLimiter(false) - for _, rt := range internalpb.RateType_value { - limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) - } - - zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value)) - for _, rt := range internalpb.RateType_value { - zeroRates = append(zeroRates, &internalpb.Rate{ - Rt: internalpb.RateType(rt), R: 0, - }) - } - err := limiter.setRates(&proxypb.CollectionRate{ - Collection: 1, - Rates: zeroRates, - }) - assert.NoError(t, err) - for _, rt := range internalpb.RateType_value { - for i := 0; i < 100; i++ { - ok, _ := limiter.limit(internalpb.RateType(rt), 1) - assert.True(t, ok) - } - } - - err = limiter.setRates(&proxypb.CollectionRate{ - Collection: 1, - States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToRead, milvuspb.QuotaState_DenyToWrite}, - Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_DiskQuotaExhausted}, - }) - assert.NoError(t, err) - assert.Equal(t, limiter.quotaStates.Len(), 2) - - err = limiter.setRates(&proxypb.CollectionRate{ - Collection: 1, - States: []milvuspb.QuotaState{}, - }) - assert.NoError(t, err) - assert.Equal(t, limiter.quotaStates.Len(), 0) - }) - - t.Run("test get error code", func(t *testing.T) { - paramtable.Get().CleanEvent() - limiter := newRateLimiter(false) - for _, rt := range internalpb.RateType_value { - limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) - } - - zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value)) - for _, rt := range internalpb.RateType_value { - zeroRates = append(zeroRates, &internalpb.Rate{ - Rt: internalpb.RateType(rt), R: 0, - }) - } - err := limiter.setRates(&proxypb.CollectionRate{ - Collection: 1, - Rates: zeroRates, - States: []milvuspb.QuotaState{ - milvuspb.QuotaState_DenyToWrite, - milvuspb.QuotaState_DenyToRead, - }, - Codes: []commonpb.ErrorCode{ - commonpb.ErrorCode_DiskQuotaExhausted, - commonpb.ErrorCode_ForceDeny, - }, - }) - assert.NoError(t, err) - assert.Error(t, limiter.getQuotaExceededError(internalpb.RateType_DQLQuery)) - assert.Error(t, limiter.getQuotaExceededError(internalpb.RateType_DMLInsert)) - }) - - t.Run("tests refresh rate by config", func(t *testing.T) { - paramtable.Get().CleanEvent() - limiter := newRateLimiter(false) - - etcdCli, _ := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - - Params.Save(Params.QuotaConfig.DDLLimitEnabled.Key, "true") - defer Params.Reset(Params.QuotaConfig.DDLLimitEnabled.Key) - Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true") - defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key) - ctx := context.Background() - // avoid production precision issues when comparing 0-terminated numbers - newRate := fmt.Sprintf("%.2f1", rand.Float64()) - etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate) - defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate") - etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid") - defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate") - etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/collection/max", "8") - defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/collection/max") - - assert.Eventually(t, func() bool { - limit, _ := limiter.limiters.Get(internalpb.RateType_DDLCollection) - return newRate == limit.Limit().String() - }, 20*time.Second, time.Second) - - limit, _ := limiter.limiters.Get(internalpb.RateType_DDLPartition) - assert.Equal(t, "+inf", limit.Limit().String()) - - limit, _ = limiter.limiters.Get(internalpb.RateType_DMLInsert) - assert.Equal(t, "8.388608e+06", limit.Limit().String()) - }) -} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index b216a8542e..8bd24331fe 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -93,7 +93,7 @@ type Proxy struct { dataCoord types.DataCoordClient queryCoord types.QueryCoordClient - multiRateLimiter *MultiRateLimiter + simpleLimiter *SimpleLimiter chMgr channelsMgr @@ -147,7 +147,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { factory: factory, searchResultCh: make(chan *internalpb.SearchResults, n), shardMgr: mgr, - multiRateLimiter: NewMultiRateLimiter(), + simpleLimiter: NewSimpleLimiter(), lbPolicy: lbPolicy, resourceManager: resourceManager, replicateStreamManager: replicateStreamManager, @@ -197,7 +197,7 @@ func (node *Proxy) initSession() error { // initRateCollector creates and starts rateCollector in Proxy. func (node *Proxy) initRateCollector() error { var err error - rateCol, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity) + rateCol, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, true) if err != nil { return err } @@ -542,8 +542,8 @@ func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string, // GetRateLimiter returns the rateLimiter in Proxy. func (node *Proxy) GetRateLimiter() (types.Limiter, error) { - if node.multiRateLimiter == nil { + if node.simpleLimiter == nil { return nil, fmt.Errorf("nil rate limiter in Proxy") } - return node.multiRateLimiter, nil + return node.simpleLimiter, nil } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 4f3b6bfe20..566ba86f26 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -298,8 +298,7 @@ func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p * ctx, cancel := context.WithCancel(ctx) defer cancel() - multiLimiter := NewMultiRateLimiter() - s.multiRateLimiter = multiLimiter + s.simpleLimiter = NewSimpleLimiter() opts := tracer.GetInterceptorOpts() s.grpcServer = grpc.NewServer( @@ -309,7 +308,7 @@ func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p * grpc.MaxSendMsgSize(p.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - RateLimitInterceptor(multiLimiter), + RateLimitInterceptor(s.simpleLimiter), )), grpc.StreamInterceptor(otelgrpc.StreamServerInterceptor(opts...))) proxypb.RegisterProxyServer(s.grpcServer, s) diff --git a/internal/proxy/rate_limit_interceptor.go b/internal/proxy/rate_limit_interceptor.go index 15a019286e..541fcebdc7 100644 --- a/internal/proxy/rate_limit_interceptor.go +++ b/internal/proxy/rate_limit_interceptor.go @@ -23,25 +23,29 @@ import ( "strconv" "github.com/golang/protobuf/proto" + "go.uber.org/zap" "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/requestutil" ) // RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting. func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - collectionIDs, rt, n, err := getRequestInfo(req) + dbID, collectionIDToPartIDs, rt, n, err := getRequestInfo(ctx, req) if err != nil { + log.RatedWarn(10, "failed to get request info", zap.Error(err)) return handler(ctx, req) } - err = limiter.Check(collectionIDs, rt, n) + err = limiter.Check(dbID, collectionIDToPartIDs, rt, n) nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10) metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc() if err != nil { @@ -56,72 +60,146 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor { } } +type reqPartName interface { + requestutil.DBNameGetter + requestutil.CollectionNameGetter + requestutil.PartitionNameGetter +} + +type reqPartNames interface { + requestutil.DBNameGetter + requestutil.CollectionNameGetter + requestutil.PartitionNamesGetter +} + +type reqCollName interface { + requestutil.DBNameGetter + requestutil.CollectionNameGetter +} + +func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map[int64][]int64, error) { + db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName()) + if err != nil { + return 0, nil, err + } + collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), r.GetCollectionName()) + if err != nil { + return 0, nil, err + } + part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName()) + if err != nil { + return 0, nil, err + } + return db.dbID, map[int64][]int64{collectionID: {part.partitionID}}, nil +} + +func getCollectionAndPartitionIDs(ctx context.Context, r reqPartNames) (int64, map[int64][]int64, error) { + db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName()) + if err != nil { + return 0, nil, err + } + collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), r.GetCollectionName()) + if err != nil { + return 0, nil, err + } + parts := make([]int64, len(r.GetPartitionNames())) + for i, s := range r.GetPartitionNames() { + part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), s) + if err != nil { + return 0, nil, err + } + parts[i] = part.partitionID + } + + return db.dbID, map[int64][]int64{collectionID: parts}, nil +} + +func getCollectionID(r reqCollName) (int64, map[int64][]int64) { + db, _ := globalMetaCache.GetDatabaseInfo(context.TODO(), r.GetDbName()) + collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) + return db.dbID, map[int64][]int64{collectionID: {}} +} + // getRequestInfo returns collection name and rateType of request and return tokens needed. -func getRequestInfo(req interface{}) ([]int64, internalpb.RateType, int, error) { +func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) { switch r := req.(type) { case *milvuspb.InsertRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DMLInsert, proto.Size(r), nil + dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName)) + return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err case *milvuspb.UpsertRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DMLUpsert, proto.Size(r), nil + dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName)) + return dbID, collToPartIDs, internalpb.RateType_DMLUpsert, proto.Size(r), err case *milvuspb.DeleteRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DMLDelete, proto.Size(r), nil + dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName)) + return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err case *milvuspb.ImportRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DMLBulkLoad, proto.Size(r), nil + dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName)) + return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err case *milvuspb.SearchRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DQLSearch, int(r.GetNq()), nil + dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames)) + return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err case *milvuspb.QueryRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DQLQuery, 1, nil // think of the query request's nq as 1 + dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames)) + return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1 case *milvuspb.CreateCollectionRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil case *milvuspb.DropCollectionRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil case *milvuspb.LoadCollectionRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil case *milvuspb.ReleaseCollectionRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil case *milvuspb.CreatePartitionRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil case *milvuspb.DropPartitionRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil case *milvuspb.LoadPartitionsRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil case *milvuspb.ReleasePartitionsRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil case *milvuspb.CreateIndexRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLIndex, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil case *milvuspb.DropIndexRequest: - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName()) - return []int64{collectionID}, internalpb.RateType_DDLIndex, 1, nil + dbID, collToPartIDs := getCollectionID(req.(reqCollName)) + return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil case *milvuspb.FlushRequest: - collectionIDs := make([]int64, 0, len(r.GetCollectionNames())) + db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName()) + if err != nil { + return 0, map[int64][]int64{}, 0, 0, err + } + + collToPartIDs := make(map[int64][]int64, 0) for _, collectionName := range r.GetCollectionNames() { - collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), collectionName) - collectionIDs = append(collectionIDs, collectionID) + collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName) + if err != nil { + return 0, map[int64][]int64{}, 0, 0, err + } + collToPartIDs[collectionID] = []int64{} } - return collectionIDs, internalpb.RateType_DDLFlush, 1, nil + return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil case *milvuspb.ManualCompactionRequest: - return nil, internalpb.RateType_DDLCompaction, 1, nil - // TODO: support more request - default: - if req == nil { - return nil, 0, 0, fmt.Errorf("null request") + dbName := GetCurDBNameFromContextOrDefault(ctx) + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName) + if err != nil { + return 0, map[int64][]int64{}, 0, 0, err } - return nil, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name()) + return dbInfo.dbID, map[int64][]int64{ + r.GetCollectionID(): {}, + }, internalpb.RateType_DDLCompaction, 1, nil + default: // TODO: support more request + if req == nil { + return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("null request") + } + return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name()) } } diff --git a/internal/proxy/rate_limit_interceptor_test.go b/internal/proxy/rate_limit_interceptor_test.go index 1271587620..a018d5f307 100644 --- a/internal/proxy/rate_limit_interceptor_test.go +++ b/internal/proxy/rate_limit_interceptor_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -38,7 +39,7 @@ type limiterMock struct { quotaStateReasons []commonpb.ErrorCode } -func (l *limiterMock) Check(collection []int64, rt internalpb.RateType, n int) error { +func (l *limiterMock) Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error { if l.rate == 0 { return merr.ErrServiceQuotaExceeded } @@ -51,119 +52,173 @@ func (l *limiterMock) Check(collection []int64, rt internalpb.RateType, n int) e func TestRateLimitInterceptor(t *testing.T) { t.Run("test getRequestInfo", func(t *testing.T) { mockCache := NewMockCache(t) - mockCache.On("GetCollectionID", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - mock.AnythingOfType("string"), - ).Return(int64(0), nil) + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil) + mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{ + name: "p1", + partitionID: 10, + createdTimestamp: 10001, + createdUtcTimestamp: 10002, + }, nil) + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil) globalMetaCache = mockCache - collection, rt, size, err := getRequestInfo(&milvuspb.InsertRequest{}) + database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{}) assert.NoError(t, err) assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size) assert.Equal(t, internalpb.RateType_DMLInsert, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.True(t, len(col2part) == 1) + assert.Equal(t, int64(10), col2part[1][0]) - collection, rt, size, err = getRequestInfo(&milvuspb.UpsertRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{}) assert.NoError(t, err) assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size) assert.Equal(t, internalpb.RateType_DMLUpsert, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.True(t, len(col2part) == 1) + assert.Equal(t, int64(10), col2part[1][0]) - collection, rt, size, err = getRequestInfo(&milvuspb.DeleteRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{}) assert.NoError(t, err) assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{}), size) assert.Equal(t, internalpb.RateType_DMLDelete, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.True(t, len(col2part) == 1) + assert.Equal(t, int64(10), col2part[1][0]) - collection, rt, size, err = getRequestInfo(&milvuspb.ImportRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{}) assert.NoError(t, err) assert.Equal(t, proto.Size(&milvuspb.ImportRequest{}), size) assert.Equal(t, internalpb.RateType_DMLBulkLoad, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.True(t, len(col2part) == 1) + assert.Equal(t, int64(10), col2part[1][0]) - collection, rt, size, err = getRequestInfo(&milvuspb.SearchRequest{Nq: 5}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.SearchRequest{ + Nq: 5, + PartitionNames: []string{ + "p1", + }, + }) assert.NoError(t, err) assert.Equal(t, 5, size) assert.Equal(t, internalpb.RateType_DQLSearch, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 1, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.QueryRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DQLQuery, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.CreateCollectionRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLCollection, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.LoadCollectionRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLCollection, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.ReleaseCollectionRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLCollection, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.DropCollectionRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLCollection, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.CreatePartitionRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLPartition, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.LoadPartitionsRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLPartition, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.ReleasePartitionsRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLPartition, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.DropPartitionRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLPartition, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.CreateIndexRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLIndex, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.DropIndexRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropIndexRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLIndex, rt) - assert.ElementsMatch(t, collection, []int64{int64(0)}) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) + assert.Equal(t, 0, len(col2part[1])) - collection, rt, size, err = getRequestInfo(&milvuspb.FlushRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.FlushRequest{ + CollectionNames: []string{ + "col1", + }, + }) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLFlush, rt) - assert.Len(t, collection, 0) + assert.Equal(t, database, int64(100)) + assert.Equal(t, 1, len(col2part)) - collection, rt, size, err = getRequestInfo(&milvuspb.ManualCompactionRequest{}) + database, _, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{}) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DDLCompaction, rt) - assert.Len(t, collection, 0) + assert.Equal(t, database, int64(100)) + + _, _, _, _, err = getRequestInfo(context.Background(), nil) + assert.Error(t, err) + + _, _, _, _, err = getRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{}) + assert.Error(t, err) }) t.Run("test getFailedResponse", func(t *testing.T) { @@ -190,11 +245,17 @@ func TestRateLimitInterceptor(t *testing.T) { t.Run("test RateLimitInterceptor", func(t *testing.T) { mockCache := NewMockCache(t) - mockCache.On("GetCollectionID", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - mock.AnythingOfType("string"), - ).Return(int64(0), nil) + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil) + mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{ + name: "p1", + partitionID: 10, + createdTimestamp: 10001, + createdUtcTimestamp: 10002, + }, nil) + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil) globalMetaCache = mockCache limiter := limiterMock{rate: 100} @@ -224,4 +285,158 @@ func TestRateLimitInterceptor(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_ForceDeny, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode()) assert.NoError(t, err) }) + + t.Run("request info fail", func(t *testing.T) { + mockCache := NewMockCache(t) + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")) + originCache := globalMetaCache + globalMetaCache = mockCache + defer func() { + globalMetaCache = originCache + }() + + limiter := limiterMock{rate: 100} + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return &milvuspb.MutationResult{ + Status: merr.Success(), + }, nil + } + serverInfo := &grpc.UnaryServerInfo{FullMethod: "MockFullMethod"} + + limiter.limit = true + interceptorFun := RateLimitInterceptor(&limiter) + rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler) + assert.Equal(t, commonpb.ErrorCode_Success, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode()) + assert.NoError(t, err) + }) +} + +func TestGetInfo(t *testing.T) { + mockCache := NewMockCache(t) + ctx := context.Background() + originCache := globalMetaCache + globalMetaCache = mockCache + defer func() { + globalMetaCache = originCache + }() + + t.Run("fail to get database", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(4) + { + _, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionName: "p1", + }) + assert.Error(t, err) + } + { + _, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionNames: []string{"p1"}, + }) + assert.Error(t, err) + } + { + _, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{ + DbName: "foo", + }) + assert.Error(t, err) + } + { + _, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{}) + assert.Error(t, err) + } + }) + + t.Run("fail to get collection", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil).Times(3) + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(0), errors.New("mock error: get collection id")).Times(3) + { + _, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionName: "p1", + }) + assert.Error(t, err) + } + { + _, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionNames: []string{"p1"}, + }) + assert.Error(t, err) + } + { + _, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{ + DbName: "foo", + CollectionNames: []string{"coo"}, + }) + assert.Error(t, err) + } + }) + + t.Run("fail to get partition", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil).Twice() + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Twice() + mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get partition info")).Twice() + { + _, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionName: "p1", + }) + assert.Error(t, err) + } + { + _, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionNames: []string{"p1"}, + }) + assert.Error(t, err) + } + }) + + t.Run("success", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil).Twice() + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Twice() + mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{ + name: "p1", + partitionID: 100, + }, nil) + { + db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionName: "p1", + }) + assert.NoError(t, err) + assert.Equal(t, int64(100), db) + assert.NotNil(t, col2par[10]) + assert.Equal(t, int64(100), col2par[10][0]) + } + { + db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{ + DbName: "foo", + CollectionName: "coo", + PartitionNames: []string{"p1"}, + }) + assert.NoError(t, err) + assert.Equal(t, int64(100), db) + assert.NotNil(t, col2par[10]) + assert.Equal(t, int64(100), col2par[10][0]) + } + }) } diff --git a/internal/proxy/rootcoord_mock_test.go b/internal/proxy/rootcoord_mock_test.go index a9f258a8a1..53d639e91c 100644 --- a/internal/proxy/rootcoord_mock_test.go +++ b/internal/proxy/rootcoord_mock_test.go @@ -1111,6 +1111,10 @@ func (coord *RootCoordMock) RenameCollection(ctx context.Context, req *milvuspb. return &commonpb.Status{}, nil } +func (coord *RootCoordMock) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) { + return &rootcoordpb.DescribeDatabaseResponse{}, nil +} + type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) diff --git a/internal/proxy/simple_rate_limiter.go b/internal/proxy/simple_rate_limiter.go new file mode 100644 index 0000000000..97fd8c9c2a --- /dev/null +++ b/internal/proxy/simple_rate_limiter.go @@ -0,0 +1,344 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + "strconv" + "sync" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/util/quota" + rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// SimpleLimiter is implemented based on Limiter interface +type SimpleLimiter struct { + quotaStatesMu sync.RWMutex + rateLimiter *rlinternal.RateLimiterTree +} + +// NewSimpleLimiter returns a new SimpleLimiter. +func NewSimpleLimiter() *SimpleLimiter { + rootRateLimiter := newClusterLimiter() + m := &SimpleLimiter{rateLimiter: rlinternal.NewRateLimiterTree(rootRateLimiter)} + return m +} + +// Check checks if request would be limited or denied. +func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error { + if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() { + return nil + } + + m.quotaStatesMu.RLock() + defer m.quotaStatesMu.RUnlock() + + // 1. check global(cluster) level rate limits + clusterRateLimiters := m.rateLimiter.GetRootLimiters() + ret := clusterRateLimiters.Check(rt, n) + + if ret != nil { + clusterRateLimiters.Cancel(rt, n) + return ret + } + + // store done limiters to cancel them when error occurs. + doneLimiters := make([]*rlinternal.RateLimiterNode, 0) + doneLimiters = append(doneLimiters, clusterRateLimiters) + + cancelAllLimiters := func() { + for _, limiter := range doneLimiters { + limiter.Cancel(rt, n) + } + } + + // 2. check database level rate limits + if ret == nil { + dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter) + ret = dbRateLimiters.Check(rt, n) + if ret != nil { + cancelAllLimiters() + return ret + } + doneLimiters = append(doneLimiters, dbRateLimiters) + } + + // 3. check collection level rate limits + if ret == nil && len(collectionIDToPartIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) { + for collectionID := range collectionIDToPartIDs { + // only dml and dql have collection level rate limits + collectionRateLimiters := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID, + newDatabaseLimiter, newCollectionLimiters) + ret = collectionRateLimiters.Check(rt, n) + if ret != nil { + cancelAllLimiters() + return ret + } + doneLimiters = append(doneLimiters, collectionRateLimiters) + } + } + + // 4. check partition level rate limits + if ret == nil && len(collectionIDToPartIDs) > 0 { + for collectionID, partitionIDs := range collectionIDToPartIDs { + for _, partID := range partitionIDs { + partitionRateLimiters := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partID, + newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters) + ret = partitionRateLimiters.Check(rt, n) + if ret != nil { + cancelAllLimiters() + return ret + } + doneLimiters = append(doneLimiters, partitionRateLimiters) + } + } + } + + return ret +} + +func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool { + // Most ddl is global level, only DDLFlush will be applied at collection + switch rt { + case internalpb.RateType_DDLCollection, + internalpb.RateType_DDLPartition, + internalpb.RateType_DDLIndex, + internalpb.RateType_DDLCompaction: + return true + default: + return false + } +} + +// GetQuotaStates returns quota states. +func (m *SimpleLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) { + m.quotaStatesMu.RLock() + defer m.quotaStatesMu.RUnlock() + serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode]) + + rlinternal.TraverseRateLimiterTree(m.rateLimiter.GetRootLimiters(), nil, + func(node *rlinternal.RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool { + if serviceStates[state] == nil { + serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]() + } + serviceStates[state].Insert(errCode) + return true + }) + + states := make([]milvuspb.QuotaState, 0) + reasons := make([]string, 0) + for state, errCodes := range serviceStates { + for errCode := range errCodes { + states = append(states, state) + reasons = append(reasons, ratelimitutil.GetQuotaErrorString(errCode)) + } + } + + return states, reasons +} + +// SetRates sets quota states for SimpleLimiter. +func (m *SimpleLimiter) SetRates(rootLimiter *proxypb.LimiterNode) error { + m.quotaStatesMu.Lock() + defer m.quotaStatesMu.Unlock() + if err := m.updateRateLimiter(rootLimiter); err != nil { + return err + } + + m.rateLimiter.ClearInvalidLimiterNode(rootLimiter) + return nil +} + +func initLimiter(rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[internalpb.RateType]*paramtable.ParamItem) { + log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0) + for rt, p := range rateLimiterConfigs { + limit := ratelimitutil.Limit(p.GetAsFloat()) + burst := p.GetAsFloat() // use rate as burst, because SimpleLimiter is with punishment mechanism, burst is insignificant. + rln.GetLimiters().GetOrInsert(rt, ratelimitutil.NewLimiter(limit, burst)) + onEvent := func(rateType internalpb.RateType, formatFunc func(originValue string) string) func(*config.Event) { + return func(event *config.Event) { + f, err := strconv.ParseFloat(formatFunc(event.Value), 64) + if err != nil { + log.Info("Error format for rateLimit", + zap.String("rateType", rateType.String()), + zap.String("key", event.Key), + zap.String("value", event.Value), + zap.Error(err)) + return + } + l, ok := rln.GetLimiters().Get(rateType) + if !ok { + log.Info("rateLimiter not found for rateType", zap.String("rateType", rateType.String())) + return + } + l.SetLimit(ratelimitutil.Limit(f)) + } + }(rt, p.Formatter) + paramtable.Get().Watch(p.Key, config.NewHandler(fmt.Sprintf("rateLimiter-%d", rt), onEvent)) + log.RatedDebug(30, "RateLimiter register for rateType", + zap.String("rateType", internalpb.RateType_name[(int32(rt))]), + zap.String("rateLimit", ratelimitutil.Limit(p.GetAsFloat()).String()), + zap.String("burst", fmt.Sprintf("%v", burst))) + } +} + +// newClusterLimiter init limiter of cluster level for all rate types and rate scopes. +// Cluster rate limiter doesn't support to accumulate metrics dynamically, it only uses +// configurations as limit values. +func newClusterLimiter() *rlinternal.RateLimiterNode { + clusterRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster) + clusterLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Cluster) + initLimiter(clusterRateLimiters, clusterLimiterConfigs) + return clusterRateLimiters +} + +func newDatabaseLimiter() *rlinternal.RateLimiterNode { + dbRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Database) + databaseLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Database) + initLimiter(dbRateLimiters, databaseLimiterConfigs) + return dbRateLimiters +} + +func newCollectionLimiters() *rlinternal.RateLimiterNode { + collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Collection) + collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Collection) + initLimiter(collectionRateLimiters, collectionLimiterConfigs) + return collectionRateLimiters +} + +func newPartitionLimiters() *rlinternal.RateLimiterNode { + partRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Partition) + collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Partition) + initLimiter(partRateLimiters, collectionLimiterConfigs) + return partRateLimiters +} + +func (m *SimpleLimiter) updateLimiterNode(req *proxypb.Limiter, node *rlinternal.RateLimiterNode, sourceID string) error { + curLimiters := node.GetLimiters() + for _, rate := range req.GetRates() { + limit, ok := curLimiters.Get(rate.GetRt()) + if !ok { + return fmt.Errorf("unregister rateLimiter for rateType %s", rate.GetRt().String()) + } + limit.SetLimit(ratelimitutil.Limit(rate.GetR())) + setRateGaugeByRateType(rate.GetRt(), paramtable.GetNodeID(), sourceID, rate.GetR()) + } + quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]() + states := req.GetStates() + codes := req.GetCodes() + for i, state := range states { + quotaStates.Insert(state, codes[i]) + } + node.SetQuotaStates(quotaStates) + return nil +} + +func (m *SimpleLimiter) updateRateLimiter(reqRootLimiterNode *proxypb.LimiterNode) error { + reqClusterLimiter := reqRootLimiterNode.GetLimiter() + clusterLimiter := m.rateLimiter.GetRootLimiters() + err := m.updateLimiterNode(reqClusterLimiter, clusterLimiter, "cluster") + if err != nil { + log.Warn("update cluster rate limiters failed", zap.Error(err)) + return err + } + + getDBSourceID := func(dbID int64) string { + return fmt.Sprintf("db.%d", dbID) + } + getCollectionSourceID := func(collectionID int64) string { + return fmt.Sprintf("collection.%d", collectionID) + } + getPartitionSourceID := func(partitionID int64) string { + return fmt.Sprintf("partition.%d", partitionID) + } + + for dbID, reqDBRateLimiters := range reqRootLimiterNode.GetChildren() { + // update database rate limiters + dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter) + err := m.updateLimiterNode(reqDBRateLimiters.GetLimiter(), dbRateLimiters, getDBSourceID(dbID)) + if err != nil { + log.Warn("update database rate limiters failed", zap.Error(err)) + return err + } + + // update collection rate limiters + for collectionID, reqCollectionRateLimiter := range reqDBRateLimiters.GetChildren() { + collectionRateLimiter := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID, + newDatabaseLimiter, newCollectionLimiters) + err := m.updateLimiterNode(reqCollectionRateLimiter.GetLimiter(), collectionRateLimiter, + getCollectionSourceID(collectionID)) + if err != nil { + log.Warn("update collection rate limiters failed", zap.Error(err)) + return err + } + + // update partition rate limiters + for partitionID, reqPartitionRateLimiters := range reqCollectionRateLimiter.GetChildren() { + partitionRateLimiter := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partitionID, + newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters) + + err := m.updateLimiterNode(reqPartitionRateLimiters.GetLimiter(), partitionRateLimiter, + getPartitionSourceID(partitionID)) + if err != nil { + log.Warn("update partition rate limiters failed", zap.Error(err)) + return err + } + } + } + } + + return nil +} + +// setRateGaugeByRateType sets ProxyLimiterRate metrics. +func setRateGaugeByRateType(rateType internalpb.RateType, nodeID int64, sourceID string, rate float64) { + if ratelimitutil.Limit(rate) == ratelimitutil.Inf { + return + } + nodeIDStr := strconv.FormatInt(nodeID, 10) + metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, sourceID, rateType.String()).Set(rate) +} + +func getDefaultLimiterConfig(scope internalpb.RateScope) map[internalpb.RateType]*paramtable.ParamItem { + return quota.GetQuotaConfigMap(scope) +} + +func IsDDLRequest(rt internalpb.RateType) bool { + switch rt { + case internalpb.RateType_DDLCollection, + internalpb.RateType_DDLPartition, + internalpb.RateType_DDLIndex, + internalpb.RateType_DDLFlush, + internalpb.RateType_DDLCompaction: + return true + default: + return false + } +} diff --git a/internal/proxy/simple_rate_limiter_test.go b/internal/proxy/simple_rate_limiter_test.go new file mode 100644 index 0000000000..f5427797b6 --- /dev/null +++ b/internal/proxy/simple_rate_limiter_test.go @@ -0,0 +1,415 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + "math" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" +) + +func TestSimpleRateLimiter(t *testing.T) { + collectionID := int64(1) + collectionIDToPartIDs := map[int64][]int64{collectionID: {}} + t.Run("test simpleRateLimiter", func(t *testing.T) { + bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() + paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") + + simpleLimiter := NewSimpleLimiter() + clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters() + + simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, collectionID, newDatabaseLimiter, + func() *rlinternal.RateLimiterNode { + collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster) + + for _, rt := range internalpb.RateType_value { + if IsDDLRequest(internalpb.RateType(rt)) { + clusterRateLimiters.GetLimiters(). + Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1)) + } else { + collectionRateLimiters.GetLimiters(). + Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) + } + } + + return collectionRateLimiters + }) + + for _, rt := range internalpb.RateType_value { + if IsDDLRequest(internalpb.RateType(rt)) { + err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) + } else { + err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), math.MaxInt) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), math.MaxInt) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) + } + } + Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) + }) + + t.Run("test global static limit", func(t *testing.T) { + bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() + paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") + simpleLimiter := NewSimpleLimiter() + clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters() + + collectionIDToPartIDs := map[int64][]int64{ + 1: {}, + 2: {}, + 3: {}, + } + + for i := 1; i <= 3; i++ { + simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(i), newDatabaseLimiter, + func() *rlinternal.RateLimiterNode { + collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster) + + for _, rt := range internalpb.RateType_value { + if IsDDLRequest(internalpb.RateType(rt)) { + clusterRateLimiters.GetLimiters(). + Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1)) + } else { + clusterRateLimiters.GetLimiters(). + Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1)) + collectionRateLimiters.GetLimiters(). + Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1)) + } + } + + return collectionRateLimiters + }) + } + + for _, rt := range internalpb.RateType_value { + if IsDDLRequest(internalpb.RateType(rt)) { + err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) + } else { + err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) + } + } + Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) + }) + + t.Run("not enable quotaAndLimit", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() + paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + for _, rt := range internalpb.RateType_value { + err := simpleLimiter.Check(0, nil, internalpb.RateType(rt), 1) + assert.NoError(t, err) + } + Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) + }) + + t.Run("test limit", func(t *testing.T) { + run := func(insertRate float64) { + bakInsertRate := Params.QuotaConfig.DMLMaxInsertRate.GetValue() + paramtable.Get().Save(Params.QuotaConfig.DMLMaxInsertRate.Key, fmt.Sprintf("%f", insertRate)) + simpleLimiter := NewSimpleLimiter() + bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() + paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") + err := simpleLimiter.Check(0, nil, internalpb.RateType_DMLInsert, 1*1024*1024) + assert.NoError(t, err) + Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) + Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate) + } + run(math.MaxFloat64) + run(math.MaxFloat64 / 1.2) + run(math.MaxFloat64 / 2) + run(math.MaxFloat64 / 3) + run(math.MaxFloat64 / 10000) + }) + + t.Run("test set rates", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + zeroRates := getZeroCollectionRates() + + err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{ + 1: { + Limiter: &proxypb.Limiter{ + Rates: zeroRates, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + 2: { + Limiter: &proxypb.Limiter{ + Rates: zeroRates, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + })) + + assert.NoError(t, err) + }) + + t.Run("test quota states", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{ + 1: { + // collection limiter + Limiter: &proxypb.Limiter{ + Rates: getZeroCollectionRates(), + States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToWrite, milvuspb.QuotaState_DenyToRead}, + Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_ForceDeny}, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + })) + + assert.NoError(t, err) + + states, codes := simpleLimiter.GetQuotaStates() + assert.Len(t, states, 2) + assert.Len(t, codes, 2) + assert.Contains(t, codes, ratelimitutil.GetQuotaErrorString(commonpb.ErrorCode_DiskQuotaExhausted)) + assert.Contains(t, codes, ratelimitutil.GetQuotaErrorString(commonpb.ErrorCode_ForceDeny)) + }) +} + +func getZeroRates() []*internalpb.Rate { + zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value)) + for _, rt := range internalpb.RateType_value { + zeroRates = append(zeroRates, &internalpb.Rate{ + Rt: internalpb.RateType(rt), R: 0, + }) + } + return zeroRates +} + +func getZeroCollectionRates() []*internalpb.Rate { + collectionRate := []internalpb.RateType{ + internalpb.RateType_DMLInsert, + internalpb.RateType_DMLDelete, + internalpb.RateType_DMLBulkLoad, + internalpb.RateType_DQLSearch, + internalpb.RateType_DQLQuery, + internalpb.RateType_DDLFlush, + } + zeroRates := make([]*internalpb.Rate, 0, len(collectionRate)) + for _, rt := range collectionRate { + zeroRates = append(zeroRates, &internalpb.Rate{ + Rt: rt, R: 0, + }) + } + return zeroRates +} + +func newCollectionLimiterNode(collectionLimiterNodes map[int64]*proxypb.LimiterNode) *proxypb.LimiterNode { + return &proxypb.LimiterNode{ + // cluster limiter + Limiter: &proxypb.Limiter{}, + // db level + Children: map[int64]*proxypb.LimiterNode{ + 0: { + // db limiter + Limiter: &proxypb.Limiter{}, + // collection level + Children: collectionLimiterNodes, + }, + }, + } +} + +func TestRateLimiter(t *testing.T) { + t.Run("test limit", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + rootLimiters := simpleLimiter.rateLimiter.GetRootLimiters() + for _, rt := range internalpb.RateType_value { + rootLimiters.GetLimiters().Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) + } + for _, rt := range internalpb.RateType_value { + ok, _ := rootLimiters.Limit(internalpb.RateType(rt), 1) + assert.False(t, ok) + ok, _ = rootLimiters.Limit(internalpb.RateType(rt), math.MaxInt) + assert.False(t, ok) + ok, _ = rootLimiters.Limit(internalpb.RateType(rt), math.MaxInt) + assert.True(t, ok) + } + }) + + t.Run("test setRates", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + + collectionRateLimiters := simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(1), newDatabaseLimiter, + func() *rlinternal.RateLimiterNode { + collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster) + for _, rt := range internalpb.RateType_value { + collectionRateLimiters.GetLimiters().Insert(internalpb.RateType(rt), + ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) + } + + return collectionRateLimiters + }) + + err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{ + 1: { + // collection limiter + Limiter: &proxypb.Limiter{ + Rates: getZeroRates(), + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + })) + + assert.NoError(t, err) + + for _, rt := range internalpb.RateType_value { + for i := 0; i < 100; i++ { + ok, _ := collectionRateLimiters.Limit(internalpb.RateType(rt), 1) + assert.True(t, ok) + } + } + + err = simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{ + 1: { + // collection limiter + Limiter: &proxypb.Limiter{ + States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToRead, milvuspb.QuotaState_DenyToWrite}, + Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_DiskQuotaExhausted}, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + })) + + collectionRateLimiter := simpleLimiter.rateLimiter.GetCollectionLimiters(0, 1) + assert.NotNil(t, collectionRateLimiter) + assert.NoError(t, err) + assert.Equal(t, collectionRateLimiter.GetQuotaStates().Len(), 2) + + err = simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{ + 1: { + // collection limiter + Limiter: &proxypb.Limiter{ + States: []milvuspb.QuotaState{}, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + })) + + assert.NoError(t, err) + assert.Equal(t, collectionRateLimiter.GetQuotaStates().Len(), 0) + }) + + t.Run("test get error code", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + + collectionRateLimiters := simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(1), newDatabaseLimiter, + func() *rlinternal.RateLimiterNode { + collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster) + for _, rt := range internalpb.RateType_value { + collectionRateLimiters.GetLimiters().Insert(internalpb.RateType(rt), + ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1)) + } + + return collectionRateLimiters + }) + + err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{ + 1: { + // collection limiter + Limiter: &proxypb.Limiter{ + Rates: getZeroRates(), + States: []milvuspb.QuotaState{ + milvuspb.QuotaState_DenyToWrite, + milvuspb.QuotaState_DenyToRead, + }, + Codes: []commonpb.ErrorCode{ + commonpb.ErrorCode_DiskQuotaExhausted, + commonpb.ErrorCode_ForceDeny, + }, + }, + Children: make(map[int64]*proxypb.LimiterNode), + }, + })) + + assert.NoError(t, err) + assert.Error(t, collectionRateLimiters.GetQuotaExceededError(internalpb.RateType_DQLQuery)) + assert.Error(t, collectionRateLimiters.GetQuotaExceededError(internalpb.RateType_DMLInsert)) + }) + + t.Run("tests refresh rate by config", func(t *testing.T) { + simpleLimiter := NewSimpleLimiter() + clusterRateLimiter := simpleLimiter.rateLimiter.GetRootLimiters() + etcdCli, _ := etcd.GetEtcdClient( + Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), + Params.EtcdCfg.EtcdUseSSL.GetAsBool(), + Params.EtcdCfg.Endpoints.GetAsStrings(), + Params.EtcdCfg.EtcdTLSCert.GetValue(), + Params.EtcdCfg.EtcdTLSKey.GetValue(), + Params.EtcdCfg.EtcdTLSCACert.GetValue(), + Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) + + Params.Save(Params.QuotaConfig.DDLLimitEnabled.Key, "true") + defer Params.Reset(Params.QuotaConfig.DDLLimitEnabled.Key) + Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true") + defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key) + ctx := context.Background() + // avoid production precision issues when comparing 0-terminated numbers + r := rand.Float64() + newRate := fmt.Sprintf("%.2f", r) + etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate) + defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate") + etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid") + defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate") + etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/max", "8") + defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/max") + + assert.Eventually(t, func() bool { + limit, _ := clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DDLCollection) + return math.Abs(r-float64(limit.Limit())) < 0.01 + }, 10*time.Second, 1*time.Second) + + limit, _ := clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DDLPartition) + assert.Equal(t, "+inf", limit.Limit().String()) + + limit, _ = clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DMLInsert) + assert.True(t, math.Abs(8*1024*1024-float64(limit.Limit())) < 0.01) + }) +} diff --git a/internal/querynodev2/collector/collector.go b/internal/querynodev2/collector/collector.go index 797a29d319..66e48ba20e 100644 --- a/internal/querynodev2/collector/collector.go +++ b/internal/querynodev2/collector/collector.go @@ -59,7 +59,7 @@ func ConstructLabel(subs ...string) string { func init() { var err error - Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity) + Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, false) if err != nil { log.Fatal("failed to initialize querynode rate collector", zap.Error(err)) } diff --git a/internal/rootcoord/broker.go b/internal/rootcoord/broker.go index 51e80f9d0b..264679ee77 100644 --- a/internal/rootcoord/broker.go +++ b/internal/rootcoord/broker.go @@ -234,6 +234,11 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv return err } + db, err := b.s.meta.GetDatabaseByName(ctx, req.GetDbName(), typeutil.MaxTimestamp) + if err != nil { + return err + } + partitionIDs := make([]int64, len(colMeta.Partitions)) for _, p := range colMeta.Partitions { partitionIDs = append(partitionIDs, p.PartitionID) @@ -249,6 +254,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv PartitionIDs: partitionIDs, StartPositions: colMeta.StartPositions, Properties: req.GetProperties(), + DbID: db.ID, } resp, err := b.s.dataCoord.BroadcastAlteredCollection(ctx, dcReq) diff --git a/internal/rootcoord/broker_test.go b/internal/rootcoord/broker_test.go index eeae4e83e0..5accc2b1b1 100644 --- a/internal/rootcoord/broker_test.go +++ b/internal/rootcoord/broker_test.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" + pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/indexpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/util/merr" @@ -239,6 +240,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) { mock.Anything, mock.Anything, ).Return(collMeta, nil) + mockGetDatabase(meta) c.meta = meta b := newServerBroker(c) ctx := context.Background() @@ -256,6 +258,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) { mock.Anything, mock.Anything, ).Return(collMeta, nil) + mockGetDatabase(meta) c.meta = meta b := newServerBroker(c) ctx := context.Background() @@ -273,6 +276,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) { mock.Anything, mock.Anything, ).Return(collMeta, nil) + mockGetDatabase(meta) c.meta = meta b := newServerBroker(c) ctx := context.Background() @@ -327,3 +331,11 @@ func TestServerBroker_GcConfirm(t *testing.T) { assert.True(t, broker.GcConfirm(context.Background(), 100, 10000)) }) } + +func mockGetDatabase(meta *mockrootcoord.IMetaTable) { + db := model.NewDatabase(1, "default", pb.DatabaseState_DatabaseCreated) + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(db, nil).Maybe() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything). + Return(db, nil).Maybe() +} diff --git a/internal/rootcoord/describe_collection_task.go b/internal/rootcoord/describe_collection_task.go index 1413015c4f..8a9da97a9f 100644 --- a/internal/rootcoord/describe_collection_task.go +++ b/internal/rootcoord/describe_collection_task.go @@ -44,6 +44,7 @@ func (t *describeCollectionTask) Execute(ctx context.Context) (err error) { if err != nil { return err } + aliases := t.core.meta.ListAliasesByID(coll.CollectionID) db, err := t.core.meta.GetDatabaseByID(ctx, coll.DBID, t.GetTs()) if err != nil { diff --git a/internal/rootcoord/describe_db_task.go b/internal/rootcoord/describe_db_task.go new file mode 100644 index 0000000000..ec63cbf448 --- /dev/null +++ b/internal/rootcoord/describe_db_task.go @@ -0,0 +1,56 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rootcoord + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// describeDBTask describe database request task +type describeDBTask struct { + baseTask + Req *rootcoordpb.DescribeDatabaseRequest + Rsp *rootcoordpb.DescribeDatabaseResponse + allowUnavailable bool +} + +func (t *describeDBTask) Prepare(ctx context.Context) error { + return nil +} + +// Execute task execution +func (t *describeDBTask) Execute(ctx context.Context) (err error) { + db, err := t.core.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp) + if err != nil { + t.Rsp = &rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Status(err), + } + return err + } + + t.Rsp = &rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + DbID: db.ID, + DbName: db.Name, + CreatedTimestamp: db.CreatedTime, + } + return nil +} diff --git a/internal/rootcoord/describe_db_task_test.go b/internal/rootcoord/describe_db_task_test.go new file mode 100644 index 0000000000..9d86708d92 --- /dev/null +++ b/internal/rootcoord/describe_db_task_test.go @@ -0,0 +1,88 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rootcoord + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/metastore/model" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" + "github.com/milvus-io/milvus/pkg/util" +) + +func Test_describeDatabaseTask_Execute(t *testing.T) { + t.Run("failed to get database by name", func(t *testing.T) { + core := newTestCore(withInvalidMeta()) + task := &describeDBTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.DescribeDatabaseRequest{ + DbName: "testDB", + }, + } + err := task.Execute(context.Background()) + assert.Error(t, err) + assert.NotNil(t, task.Rsp) + assert.NotNil(t, task.Rsp.Status) + }) + + t.Run("describe with empty database name", func(t *testing.T) { + meta := mockrootcoord.NewIMetaTable(t) + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(model.NewDefaultDatabase(), nil) + core := newTestCore(withMeta(meta)) + + task := &describeDBTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.DescribeDatabaseRequest{}, + } + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + assert.Equal(t, util.DefaultDBName, task.Rsp.GetDbName()) + assert.Equal(t, util.DefaultDBID, task.Rsp.GetDbID()) + }) + + t.Run("describe with specified database name", func(t *testing.T) { + meta := mockrootcoord.NewIMetaTable(t) + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(&model.Database{ + Name: "db1", + ID: 100, + CreatedTime: 1, + }, nil) + core := newTestCore(withMeta(meta)) + + task := &describeDBTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.DescribeDatabaseRequest{DbName: "db1"}, + } + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success)) + assert.Equal(t, "db1", task.Rsp.GetDbName()) + assert.Equal(t, int64(100), task.Rsp.GetDbID()) + assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) + }) +} diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index b416c2c93b..0ec6489ba6 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -55,6 +55,7 @@ type IMetaTable interface { RemoveCollection(ctx context.Context, collectionID UniqueID, ts Timestamp) error GetCollectionByName(ctx context.Context, dbName string, collectionName string, ts Timestamp) (*model.Collection, error) GetCollectionByID(ctx context.Context, dbName string, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error) + GetCollectionByIDWithMaxTs(ctx context.Context, collectionID UniqueID) (*model.Collection, error) ListCollections(ctx context.Context, dbName string, ts Timestamp, onlyAvail bool) ([]*model.Collection, error) ListAllAvailCollections(ctx context.Context) map[int64][]int64 ListCollectionPhysicalChannels() map[typeutil.UniqueID][]string @@ -362,7 +363,7 @@ func (mt *MetaTable) getDatabaseByNameInternal(_ context.Context, dbName string, db, ok := mt.dbName2Meta[dbName] if !ok { - return nil, fmt.Errorf("database:%s not found", dbName) + return nil, merr.WrapErrDatabaseNotFound(dbName) } return db, nil @@ -519,12 +520,12 @@ func filterUnavailable(coll *model.Collection) *model.Collection { } // getLatestCollectionByIDInternal should be called with ts = typeutil.MaxTimestamp -func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowAvailable bool) (*model.Collection, error) { +func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowUnavailable bool) (*model.Collection, error) { coll, ok := mt.collID2Meta[collectionID] if !ok || coll == nil { return nil, merr.WrapErrCollectionNotFound(collectionID) } - if allowAvailable { + if allowUnavailable { return coll.Clone(), nil } if !coll.Available() { @@ -623,6 +624,11 @@ func (mt *MetaTable) GetCollectionByID(ctx context.Context, dbName string, colle return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, allowUnavailable) } +// GetCollectionByIDWithMaxTs get collection, dbName can be ignored if ts is max timestamps +func (mt *MetaTable) GetCollectionByIDWithMaxTs(ctx context.Context, collectionID UniqueID) (*model.Collection, error) { + return mt.GetCollectionByID(ctx, "", collectionID, typeutil.MaxTimestamp, false) +} + func (mt *MetaTable) ListAllAvailCollections(ctx context.Context) map[int64][]int64 { mt.ddLock.RLock() defer mt.ddLock.RUnlock() diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 4015874285..f17ff9d1e2 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -95,6 +95,11 @@ type mockMetaTable struct { DropGrantFunc func(tenant string, role *milvuspb.RoleEntity) error ListPolicyFunc func(tenant string) ([]string, error) ListUserRoleFunc func(tenant string) ([]string, error) + DescribeDatabaseFunc func(ctx context.Context, dbName string) (*model.Database, error) +} + +func (m mockMetaTable) GetDatabaseByName(ctx context.Context, dbName string, ts Timestamp) (*model.Database, error) { + return m.DescribeDatabaseFunc(ctx, dbName) } func (m mockMetaTable) ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Database, error) { @@ -516,6 +521,9 @@ func withInvalidMeta() Opt { meta.ListAliasesFunc = func(ctx context.Context, dbName, collectionName string, ts Timestamp) ([]string, error) { return nil, errors.New("error mock ListAliases") } + meta.DescribeDatabaseFunc = func(ctx context.Context, dbName string) (*model.Database, error) { + return nil, errors.New("error mock DescribeDatabase") + } return withMeta(meta) } diff --git a/internal/rootcoord/mocks/meta_table.go b/internal/rootcoord/mocks/meta_table.go index 5e87bcfb0b..26ee021238 100644 --- a/internal/rootcoord/mocks/meta_table.go +++ b/internal/rootcoord/mocks/meta_table.go @@ -843,6 +843,61 @@ func (_c *IMetaTable_GetCollectionByID_Call) RunAndReturn(run func(context.Conte return _c } +// GetCollectionByIDWithMaxTs provides a mock function with given fields: ctx, collectionID +func (_m *IMetaTable) GetCollectionByIDWithMaxTs(ctx context.Context, collectionID int64) (*model.Collection, error) { + ret := _m.Called(ctx, collectionID) + + var r0 *model.Collection + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (*model.Collection, error)); ok { + return rf(ctx, collectionID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) *model.Collection); ok { + r0 = rf(ctx, collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IMetaTable_GetCollectionByIDWithMaxTs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionByIDWithMaxTs' +type IMetaTable_GetCollectionByIDWithMaxTs_Call struct { + *mock.Call +} + +// GetCollectionByIDWithMaxTs is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +func (_e *IMetaTable_Expecter) GetCollectionByIDWithMaxTs(ctx interface{}, collectionID interface{}) *IMetaTable_GetCollectionByIDWithMaxTs_Call { + return &IMetaTable_GetCollectionByIDWithMaxTs_Call{Call: _e.mock.On("GetCollectionByIDWithMaxTs", ctx, collectionID)} +} + +func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) Run(run func(ctx context.Context, collectionID int64)) *IMetaTable_GetCollectionByIDWithMaxTs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) Return(_a0 *model.Collection, _a1 error) *IMetaTable_GetCollectionByIDWithMaxTs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) RunAndReturn(run func(context.Context, int64) (*model.Collection, error)) *IMetaTable_GetCollectionByIDWithMaxTs_Call { + _c.Call.Return(run) + return _c +} + // GetCollectionByName provides a mock function with given fields: ctx, dbName, collectionName, ts func (_m *IMetaTable) GetCollectionByName(ctx context.Context, dbName string, collectionName string, ts uint64) (*model.Collection, error) { ret := _m.Called(ctx, dbName, collectionName, ts) diff --git a/internal/rootcoord/quota_center.go b/internal/rootcoord/quota_center.go index 15b48cbc92..a0a1bbc24d 100644 --- a/internal/rootcoord/quota_center.go +++ b/internal/rootcoord/quota_center.go @@ -21,6 +21,7 @@ import ( "fmt" "math" "strconv" + "strings" "sync" "time" @@ -36,6 +37,8 @@ import ( "github.com/milvus-io/milvus/internal/tso" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/proxyutil" + "github.com/milvus-io/milvus/internal/util/quota" + rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -65,9 +68,44 @@ const Inf = ratelimitutil.Inf type Limit = ratelimitutil.Limit -type collectionRates = map[internalpb.RateType]Limit +func GetInfLimiter(_ internalpb.RateType) *ratelimitutil.Limiter { + // It indicates an infinite limiter with burst is 0 + return ratelimitutil.NewLimiter(Inf, 0) +} -type collectionStates = map[milvuspb.QuotaState]commonpb.ErrorCode +func GetEarliestLimiter() *ratelimitutil.Limiter { + // It indicates an earliest limiter with burst is 0 + return ratelimitutil.NewLimiter(0, 0) +} + +type opType int + +const ( + ddl opType = iota + dml + dql + allOps +) + +var ddlRateTypes = typeutil.NewSet( + internalpb.RateType_DDLCollection, + internalpb.RateType_DDLPartition, + internalpb.RateType_DDLIndex, + internalpb.RateType_DDLFlush, + internalpb.RateType_DDLCompaction, +) + +var dmlRateTypes = typeutil.NewSet( + internalpb.RateType_DMLInsert, + internalpb.RateType_DMLUpsert, + internalpb.RateType_DMLDelete, + internalpb.RateType_DMLBulkLoad, +) + +var dqlRateTypes = typeutil.NewSet( + internalpb.RateType_DQLSearch, + internalpb.RateType_DQLQuery, +) // QuotaCenter manages the quota and limitations of the whole cluster, // it receives metrics info from DataNodes, QueryNodes and Proxies, and @@ -105,11 +143,17 @@ type QuotaCenter struct { dataCoordMetrics *metricsinfo.DataCoordQuotaMetrics totalBinlogSize int64 - readableCollections []int64 - writableCollections []int64 + readableCollections map[int64]map[int64][]int64 // db id -> collection id -> partition id + writableCollections map[int64]map[int64][]int64 // db id -> collection id -> partition id + dbs *typeutil.ConcurrentMap[string, int64] // db name -> db id + collections *typeutil.ConcurrentMap[string, int64] // db id + collection name -> collection id + + // this is a transitional data structure to cache db id for each collection. + // TODO many metrics information only have collection id currently, it can be removed after db id add into all metrics. + collectionIDToDBID *typeutil.ConcurrentMap[int64, int64] // collection id -> db id + + rateLimiter *rlinternal.RateLimiterTree - currentRates map[int64]collectionRates - quotaStates map[int64]collectionStates tsoAllocator tso.Allocator rateAllocateStrategy RateAllocateStrategy @@ -120,25 +164,108 @@ type QuotaCenter struct { } // NewQuotaCenter returns a new QuotaCenter. -func NewQuotaCenter(proxies proxyutil.ProxyClientManagerInterface, queryCoord types.QueryCoordClient, dataCoord types.DataCoordClient, tsoAllocator tso.Allocator, meta IMetaTable) *QuotaCenter { +func NewQuotaCenter(proxies proxyutil.ProxyClientManagerInterface, queryCoord types.QueryCoordClient, + dataCoord types.DataCoordClient, tsoAllocator tso.Allocator, meta IMetaTable, +) *QuotaCenter { ctx, cancel := context.WithCancel(context.TODO()) - return &QuotaCenter{ - ctx: ctx, - cancel: cancel, - proxies: proxies, - queryCoord: queryCoord, - dataCoord: dataCoord, - currentRates: make(map[int64]map[internalpb.RateType]Limit), - quotaStates: make(map[int64]map[milvuspb.QuotaState]commonpb.ErrorCode), - tsoAllocator: tsoAllocator, - meta: meta, - readableCollections: make([]int64, 0), - writableCollections: make([]int64, 0), - + q := &QuotaCenter{ + ctx: ctx, + cancel: cancel, + proxies: proxies, + queryCoord: queryCoord, + dataCoord: dataCoord, + tsoAllocator: tsoAllocator, + meta: meta, + readableCollections: make(map[int64]map[int64][]int64, 0), + writableCollections: make(map[int64]map[int64][]int64, 0), + rateLimiter: rlinternal.NewRateLimiterTree(initInfLimiter(internalpb.RateScope_Cluster, allOps)), rateAllocateStrategy: DefaultRateAllocateStrategy, stopChan: make(chan struct{}), } + q.clearMetrics() + return q +} + +func initInfLimiter(rateScope internalpb.RateScope, opType opType) *rlinternal.RateLimiterNode { + return initLimiter(GetInfLimiter, rateScope, opType) +} + +func newParamLimiterFunc(rateScope internalpb.RateScope, opType opType) func() *rlinternal.RateLimiterNode { + return func() *rlinternal.RateLimiterNode { + return initLimiter(func(rt internalpb.RateType) *ratelimitutil.Limiter { + limitVal := quota.GetQuotaValue(rateScope, rt, Params) + return ratelimitutil.NewLimiter(Limit(limitVal), 0) + }, rateScope, opType) + } +} + +func newParamLimiterFuncWithLimitFunc(rateScope internalpb.RateScope, + opType opType, + limitFunc func(internalpb.RateType) Limit, +) func() *rlinternal.RateLimiterNode { + return func() *rlinternal.RateLimiterNode { + return initLimiter(func(rt internalpb.RateType) *ratelimitutil.Limiter { + limitVal := limitFunc(rt) + return ratelimitutil.NewLimiter(limitVal, 0) + }, rateScope, opType) + } +} + +func initLimiter(limiterFunc func(internalpb.RateType) *ratelimitutil.Limiter, rateScope internalpb.RateScope, opType opType) *rlinternal.RateLimiterNode { + rateLimiters := rlinternal.NewRateLimiterNode(rateScope) + getRateTypes(rateScope, opType).Range(func(rt internalpb.RateType) bool { + rateLimiters.GetLimiters().GetOrInsert(rt, limiterFunc(rt)) + return true + }) + return rateLimiters +} + +func updateLimiter(node *rlinternal.RateLimiterNode, limiter *ratelimitutil.Limiter, rateScope internalpb.RateScope, opType opType) { + if node == nil { + log.Warn("update limiter failed, node is nil", zap.Any("rateScope", rateScope), zap.Any("opType", opType)) + return + } + limiters := node.GetLimiters() + getRateTypes(rateScope, opType).Range(func(rt internalpb.RateType) bool { + originLimiter, ok := limiters.Get(rt) + if !ok { + log.Warn("update limiter failed, limiter not found", + zap.Any("rateScope", rateScope), + zap.Any("opType", opType), + zap.Any("rateType", rt)) + return true + } + originLimiter.SetLimit(limiter.Limit()) + return true + }) +} + +func getRateTypes(scope internalpb.RateScope, opType opType) typeutil.Set[internalpb.RateType] { + var allRateTypes typeutil.Set[internalpb.RateType] + switch scope { + case internalpb.RateScope_Cluster: + fallthrough + case internalpb.RateScope_Database: + allRateTypes = ddlRateTypes.Union(dmlRateTypes).Union(dqlRateTypes) + case internalpb.RateScope_Collection: + allRateTypes = typeutil.NewSet(internalpb.RateType_DDLFlush).Union(dmlRateTypes).Union(dqlRateTypes) + case internalpb.RateScope_Partition: + allRateTypes = dmlRateTypes.Union(dqlRateTypes) + default: + panic("Unknown rate scope:" + scope.String()) + } + + switch opType { + case ddl: + return ddlRateTypes.Intersection(allRateTypes) + case dml: + return dmlRateTypes.Intersection(allRateTypes) + case dql: + return dqlRateTypes.Intersection(allRateTypes) + default: + return allRateTypes + } } func (q *QuotaCenter) Start() { @@ -160,9 +287,9 @@ func (q *QuotaCenter) run() { log.Info("QuotaCenter exit") return case <-ticker.C: - err := q.syncMetrics() + err := q.collectMetrics() if err != nil { - log.Warn("quotaCenter sync metrics failed", zap.Error(err)) + log.Warn("quotaCenter collect metrics failed", zap.Error(err)) break } err = q.calculateRates() @@ -170,9 +297,9 @@ func (q *QuotaCenter) run() { log.Warn("quotaCenter calculate rates failed", zap.Error(err)) break } - err = q.setRates() + err = q.sendRatesToProxy() if err != nil { - log.Warn("quotaCenter setRates failed", zap.Error(err)) + log.Warn("quotaCenter send rates to proxy failed", zap.Error(err)) } q.recordMetrics() } @@ -195,6 +322,9 @@ func (q *QuotaCenter) clearMetrics() { q.dataNodeMetrics = make(map[UniqueID]*metricsinfo.DataNodeQuotaMetrics, 0) q.queryNodeMetrics = make(map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics, 0) q.proxyMetrics = make(map[UniqueID]*metricsinfo.ProxyQuotaMetrics, 0) + q.collectionIDToDBID = typeutil.NewConcurrentMap[int64, int64]() + q.collections = typeutil.NewConcurrentMap[string, int64]() + q.dbs = typeutil.NewConcurrentMap[string, int64]() } func updateNumEntitiesLoaded(current map[int64]int64, qn *metricsinfo.QueryNodeCollectionMetrics) map[int64]int64 { @@ -204,57 +334,25 @@ func updateNumEntitiesLoaded(current map[int64]int64, qn *metricsinfo.QueryNodeC return current } -func (q *QuotaCenter) reportNumEntitiesLoaded(numEntitiesLoaded map[int64]int64) { - for collectionID, num := range numEntitiesLoaded { - info, err := q.meta.GetCollectionByID(context.Background(), "", collectionID, typeutil.MaxTimestamp, false) - if err != nil { - log.Warn("failed to get collection info by its id, ignore to report loaded num entities", - zap.Int64("collection", collectionID), - zap.Int64("num_entities_loaded", num), - zap.Error(err), - ) - continue - } - metrics.RootCoordNumEntities.WithLabelValues(info.Name, metrics.LoadedLabel).Set(float64(num)) - } +func FormatCollectionKey(dbID int64, collectionName string) string { + return fmt.Sprintf("%d.%s", dbID, collectionName) } -func (q *QuotaCenter) reportDataCoordCollectionMetrics(dc *metricsinfo.DataCoordCollectionMetrics) { - for collectionID, collection := range dc.Collections { - info, err := q.meta.GetCollectionByID(context.Background(), "", collectionID, typeutil.MaxTimestamp, false) - if err != nil { - log.Warn("failed to get collection info by its id, ignore to report total_num_entities/indexed_entities", - zap.Int64("collection", collectionID), - zap.Int64("num_entities_total", collection.NumEntitiesTotal), - zap.Int("lenOfIndexedInfo", len(collection.IndexInfo)), - zap.Error(err), - ) - continue - } - metrics.RootCoordNumEntities.WithLabelValues(info.Name, metrics.TotalLabel).Set(float64(collection.NumEntitiesTotal)) - fields := lo.KeyBy(info.Fields, func(v *model.Field) int64 { return v.FieldID }) - for _, indexInfo := range collection.IndexInfo { - if _, ok := fields[indexInfo.FieldID]; !ok { - log.Warn("field id not found, ignore to report indexed num entities", - zap.Int64("collection", collectionID), - zap.Int64("field", indexInfo.FieldID), - ) - continue - } - field := fields[indexInfo.FieldID] - metrics.RootCoordIndexedNumEntities.WithLabelValues( - info.Name, - indexInfo.IndexName, - strconv.FormatBool(typeutil.IsVectorType(field.DataType))).Set(float64(indexInfo.NumEntitiesIndexed)) - } +func SplitCollectionKey(key string) (dbID int64, collectionName string) { + splits := strings.Split(key, ".") + if len(splits) == 2 { + dbID, _ = strconv.ParseInt(splits[0], 10, 64) + collectionName = splits[1] } + return } -// syncMetrics sends GetMetrics requests to DataCoord and QueryCoord to sync the metrics in DataNodes and QueryNodes. -func (q *QuotaCenter) syncMetrics() error { +// collectMetrics sends GetMetrics requests to DataCoord and QueryCoord to sync the metrics in DataNodes and QueryNodes. +func (q *QuotaCenter) collectMetrics() error { oldDataNodes := typeutil.NewSet(lo.Keys(q.dataNodeMetrics)...) oldQueryNodes := typeutil.NewSet(lo.Keys(q.queryNodeMetrics)...) q.clearMetrics() + ctx, cancel := context.WithTimeout(q.ctx, GetMetricsTimeout) defer cancel() @@ -264,12 +362,10 @@ func (q *QuotaCenter) syncMetrics() error { return err } - numEntitiesLoaded := make(map[int64]int64) - // get Query cluster metrics group.Go(func() error { rsp, err := q.queryCoord.GetMetrics(ctx, req) - if err := merr.CheckRPCCall(rsp, err); err != nil { + if err = merr.CheckRPCCall(rsp, err); err != nil { return err } queryCoordTopology := &metricsinfo.QueryCoordTopology{} @@ -279,6 +375,7 @@ func (q *QuotaCenter) syncMetrics() error { } collections := typeutil.NewUniqueSet() + numEntitiesLoaded := make(map[int64]int64) for _, queryNodeMetric := range queryCoordTopology.Cluster.ConnectedNodes { if queryNodeMetric.QuotaMetrics != nil { oldQueryNodes.Remove(queryNodeMetric.ID) @@ -289,14 +386,36 @@ func (q *QuotaCenter) syncMetrics() error { numEntitiesLoaded = updateNumEntitiesLoaded(numEntitiesLoaded, queryNodeMetric.CollectionMetrics) } } - q.readableCollections = collections.Collect() - q.reportNumEntitiesLoaded(numEntitiesLoaded) - return nil + + q.readableCollections = make(map[int64]map[int64][]int64, 0) + var rangeErr error + collections.Range(func(collectionID int64) bool { + coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID) + if getErr != nil { + rangeErr = getErr + return false + } + collIDToPartIDs, ok := q.readableCollections[coll.DBID] + if !ok { + collIDToPartIDs = make(map[int64][]int64) + q.readableCollections[coll.DBID] = collIDToPartIDs + } + collIDToPartIDs[collectionID] = append(collIDToPartIDs[collectionID], + lo.Map(coll.Partitions, func(part *model.Partition, _ int) int64 { return part.PartitionID })...) + q.collectionIDToDBID.Insert(collectionID, coll.DBID) + q.collections.Insert(FormatCollectionKey(coll.DBID, coll.Name), collectionID) + if numEntity, ok := numEntitiesLoaded[collectionID]; ok { + metrics.RootCoordNumEntities.WithLabelValues(coll.Name, metrics.LoadedLabel).Set(float64(numEntity)) + } + return true + }) + + return rangeErr }) // get Data cluster metrics group.Go(func() error { rsp, err := q.dataCoord.GetMetrics(ctx, req) - if err := merr.CheckRPCCall(rsp, err); err != nil { + if err = merr.CheckRPCCall(rsp, err); err != nil { return err } dataCoordTopology := &metricsinfo.DataCoordTopology{} @@ -305,10 +424,6 @@ func (q *QuotaCenter) syncMetrics() error { return err } - if dataCoordTopology.Cluster.Self.CollectionMetrics != nil { - q.reportDataCoordCollectionMetrics(dataCoordTopology.Cluster.Self.CollectionMetrics) - } - collections := typeutil.NewUniqueSet() for _, dataNodeMetric := range dataCoordTopology.Cluster.ConnectedDataNodes { if dataNodeMetric.QuotaMetrics != nil { @@ -317,13 +432,61 @@ func (q *QuotaCenter) syncMetrics() error { collections.Insert(dataNodeMetric.QuotaMetrics.Effect.CollectionIDs...) } } - q.writableCollections = collections.Collect() + q.diskMu.Lock() if dataCoordTopology.Cluster.Self.QuotaMetrics != nil { q.dataCoordMetrics = dataCoordTopology.Cluster.Self.QuotaMetrics } q.diskMu.Unlock() - return nil + + q.writableCollections = make(map[int64]map[int64][]int64, 0) + var collectionMetrics map[int64]*metricsinfo.DataCoordCollectionInfo + cm := dataCoordTopology.Cluster.Self.CollectionMetrics + if cm != nil { + collectionMetrics = cm.Collections + } + var rangeErr error + collections.Range(func(collectionID int64) bool { + var coll *model.Collection + coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID) + if getErr != nil { + rangeErr = getErr + return false + } + + collIDToPartIDs, ok := q.writableCollections[coll.DBID] + if !ok { + collIDToPartIDs = make(map[int64][]int64) + q.writableCollections[coll.DBID] = collIDToPartIDs + } + collIDToPartIDs[collectionID] = append(collIDToPartIDs[collectionID], + lo.Map(coll.Partitions, func(part *model.Partition, _ int) int64 { return part.PartitionID })...) + q.collectionIDToDBID.Insert(collectionID, coll.DBID) + q.collections.Insert(FormatCollectionKey(coll.DBID, coll.Name), collectionID) + if collectionMetrics == nil { + return true + } + if datacoordCollectionMetric, ok := collectionMetrics[collectionID]; ok { + metrics.RootCoordNumEntities.WithLabelValues(coll.Name, metrics.TotalLabel).Set(float64(datacoordCollectionMetric.NumEntitiesTotal)) + fields := lo.KeyBy(coll.Fields, func(v *model.Field) int64 { return v.FieldID }) + for _, indexInfo := range datacoordCollectionMetric.IndexInfo { + if _, ok := fields[indexInfo.FieldID]; !ok { + log.Warn("field id not found, ignore to report indexed num entities", + zap.Int64("collection", collectionID), + zap.Int64("field", indexInfo.FieldID), + ) + continue + } + field := fields[indexInfo.FieldID] + metrics.RootCoordIndexedNumEntities.WithLabelValues( + coll.Name, + indexInfo.IndexName, + strconv.FormatBool(typeutil.IsVectorType(field.DataType))).Set(float64(indexInfo.NumEntitiesIndexed)) + } + } + return true + }) + return rangeErr }) // get Proxies metrics group.Go(func() error { @@ -344,6 +507,16 @@ func (q *QuotaCenter) syncMetrics() error { } return nil }) + group.Go(func() error { + dbs, err := q.meta.ListDatabases(ctx, typeutil.MaxTimestamp) + if err != nil { + return err + } + for _, db := range dbs { + q.dbs.Insert(db.Name, db.ID) + } + return nil + }) err = group.Wait() if err != nil { return err @@ -359,54 +532,95 @@ func (q *QuotaCenter) syncMetrics() error { } // forceDenyWriting sets dml rates to 0 to reject all dml requests. -func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, collections ...int64) { - if len(collections) == 0 && len(q.writableCollections) != 0 { - // default to all writable collections - collections = q.writableCollections +func (q *QuotaCenter) forceDenyWriting(errorCode commonpb.ErrorCode, cluster bool, dbIDs, collectionIDs []int64, col2partitionIDs map[int64][]int64) error { + if cluster { + clusterLimiters := q.rateLimiter.GetRootLimiters() + updateLimiter(clusterLimiters, GetEarliestLimiter(), internalpb.RateScope_Cluster, dml) + clusterLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) } - for _, collection := range collections { - if _, ok := q.currentRates[collection]; !ok { - q.currentRates[collection] = make(map[internalpb.RateType]Limit) - q.quotaStates[collection] = make(map[milvuspb.QuotaState]commonpb.ErrorCode) + + for _, dbID := range dbIDs { + dbLimiters := q.rateLimiter.GetDatabaseLimiters(dbID) + if dbLimiters == nil { + log.Warn("db limiter not found of db ID", zap.Int64("dbID", dbID)) + return fmt.Errorf("db limiter not found of db ID: %d", dbID) } - q.currentRates[collection][internalpb.RateType_DMLInsert] = 0 - q.currentRates[collection][internalpb.RateType_DMLUpsert] = 0 - q.currentRates[collection][internalpb.RateType_DMLDelete] = 0 - q.currentRates[collection][internalpb.RateType_DMLBulkLoad] = 0 - q.quotaStates[collection][milvuspb.QuotaState_DenyToWrite] = errorCode + updateLimiter(dbLimiters, GetEarliestLimiter(), internalpb.RateScope_Database, dml) + dbLimiters.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) } - log.RatedWarn(10, "QuotaCenter force to deny writing", - zap.Int64s("collectionIDs", collections), - zap.String("reason", errorCode.String())) + + for _, collectionID := range collectionIDs { + dbID, ok := q.collectionIDToDBID.Get(collectionID) + if !ok { + return fmt.Errorf("db ID not found of collection ID: %d", collectionID) + } + collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collectionID) + if collectionLimiter == nil { + log.Warn("collection limiter not found of collection ID", + zap.Int64("dbID", dbID), + zap.Int64("collectionID", collectionID)) + return fmt.Errorf("collection limiter not found of collection ID: %d", collectionID) + } + updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dml) + collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) + } + + for collectionID, partitionIDs := range col2partitionIDs { + for _, partitionID := range partitionIDs { + dbID, ok := q.collectionIDToDBID.Get(collectionID) + if !ok { + return fmt.Errorf("db ID not found of collection ID: %d", collectionID) + } + partitionLimiter := q.rateLimiter.GetPartitionLimiters(dbID, collectionID, partitionID) + if partitionLimiter == nil { + log.Warn("partition limiter not found of partition ID", + zap.Int64("dbID", dbID), + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID)) + return fmt.Errorf("partition limiter not found of partition ID: %d", partitionID) + } + updateLimiter(partitionLimiter, GetEarliestLimiter(), internalpb.RateScope_Partition, dml) + partitionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, errorCode) + } + } + + if cluster || len(dbIDs) > 0 || len(collectionIDs) > 0 || len(col2partitionIDs) > 0 { + log.RatedWarn(10, "QuotaCenter force to deny writing", + zap.Bool("cluster", cluster), + zap.Int64s("dbIDs", dbIDs), + zap.Int64s("collectionIDs", collectionIDs), + zap.Any("partitionIDs", col2partitionIDs), + zap.String("reason", errorCode.String())) + } + + return nil } // forceDenyReading sets dql rates to 0 to reject all dql requests. -func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode, collections ...int64) { - if len(collections) == 0 { - // default to all readable collections - collections = q.readableCollections - } - for _, collection := range collections { - if _, ok := q.currentRates[collection]; !ok { - q.currentRates[collection] = make(map[internalpb.RateType]Limit) - q.quotaStates[collection] = make(map[milvuspb.QuotaState]commonpb.ErrorCode) +func (q *QuotaCenter) forceDenyReading(errorCode commonpb.ErrorCode) { + var collectionIDs []int64 + for dbID, collectionIDToPartIDs := range q.readableCollections { + for collectionID := range collectionIDToPartIDs { + collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collectionID) + updateLimiter(collectionLimiter, GetEarliestLimiter(), internalpb.RateScope_Collection, dql) + collectionLimiter.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, errorCode) + collectionIDs = append(collectionIDs, collectionID) } - q.currentRates[collection][internalpb.RateType_DQLSearch] = 0 - q.currentRates[collection][internalpb.RateType_DQLQuery] = 0 - q.quotaStates[collection][milvuspb.QuotaState_DenyToRead] = errorCode } + log.Warn("QuotaCenter force to deny reading", - zap.Int64s("collectionIDs", collections), + zap.Int64s("collectionIDs", collectionIDs), zap.String("reason", errorCode.String())) } // getRealTimeRate return real time rate in Proxy. -func (q *QuotaCenter) getRealTimeRate(rateType internalpb.RateType) float64 { +func (q *QuotaCenter) getRealTimeRate(label string) float64 { var rate float64 for _, metric := range q.proxyMetrics { for _, r := range metric.Rms { - if r.Label == rateType.String() { + if r.Label == label { rate += r.Rate + break } } } @@ -414,24 +628,33 @@ func (q *QuotaCenter) getRealTimeRate(rateType internalpb.RateType) float64 { } // guaranteeMinRate make sure the rate will not be less than the min rate. -func (q *QuotaCenter) guaranteeMinRate(minRate float64, rateType internalpb.RateType, collections ...int64) { - for _, collection := range collections { - if minRate > 0 && q.currentRates[collection][rateType] < Limit(minRate) { - q.currentRates[collection][rateType] = Limit(minRate) - } +func (q *QuotaCenter) guaranteeMinRate(minRate float64, rt internalpb.RateType, rln *rlinternal.RateLimiterNode) { + v, ok := rln.GetLimiters().Get(rt) + if ok && minRate > 0 && v.Limit() < Limit(minRate) { + v.SetLimit(Limit(minRate)) } } // calculateReadRates calculates and sets dql rates. -func (q *QuotaCenter) calculateReadRates() { +func (q *QuotaCenter) calculateReadRates() error { log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0) if Params.QuotaConfig.ForceDenyReading.GetAsBool() { q.forceDenyReading(commonpb.ErrorCode_ForceDeny) - return + return nil } limitCollectionSet := typeutil.NewUniqueSet() - enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool() + limitDBNameSet := typeutil.NewSet[string]() + limitCollectionNameSet := typeutil.NewSet[string]() + clusterLimit := false + + formatCollctionRateKey := func(dbName, collectionName string) string { + return fmt.Sprintf("%s.%s", dbName, collectionName) + } + splitCollctionRateKey := func(key string) (string, string) { + parts := strings.Split(key, ".") + return parts[0], parts[1] + } // query latency queueLatencyThreshold := Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second) // enableQueueProtection && queueLatencyThreshold >= 0 means enable queue latency protection @@ -446,6 +669,7 @@ func (q *QuotaCenter) calculateReadRates() { } // queue length + enableQueueProtection := Params.QuotaConfig.QueueProtectionEnabled.GetAsBool() nqInQueueThreshold := Params.QuotaConfig.NQInQueueThreshold.GetAsInt64() if enableQueueProtection && nqInQueueThreshold >= 0 { // >= 0 means enable queue length protection @@ -465,56 +689,190 @@ func (q *QuotaCenter) calculateReadRates() { enableResultProtection := Params.QuotaConfig.ResultProtectionEnabled.GetAsBool() if enableResultProtection { maxRate := Params.QuotaConfig.MaxReadResultRate.GetAsFloat() + maxDBRate := Params.QuotaConfig.MaxReadResultRatePerDB.GetAsFloat() + maxCollectionRate := Params.QuotaConfig.MaxReadResultRatePerCollection.GetAsFloat() + rateCount := float64(0) + dbRateCount := make(map[string]float64) + collectionRateCount := make(map[string]float64) for _, metric := range q.proxyMetrics { for _, rm := range metric.Rms { if rm.Label == metricsinfo.ReadResultThroughput { rateCount += rm.Rate + continue + } + dbName, ok := ratelimitutil.GetDBFromSubLabel(metricsinfo.ReadResultThroughput, rm.Label) + if ok { + dbRateCount[dbName] += rm.Rate + continue + } + dbName, collectionName, ok := ratelimitutil.GetCollectionFromSubLabel(metricsinfo.ReadResultThroughput, rm.Label) + if ok { + collectionRateCount[formatCollctionRateKey(dbName, collectionName)] += rm.Rate + continue } } } if rateCount >= maxRate { - limitCollectionSet.Insert(q.readableCollections...) + clusterLimit = true + } + for s, f := range dbRateCount { + if f >= maxDBRate { + limitDBNameSet.Insert(s) + } + } + for s, f := range collectionRateCount { + if f >= maxCollectionRate { + limitCollectionNameSet.Insert(s) + } } } + dbIDs := make(map[int64]string, q.dbs.Len()) + collectionIDs := make(map[int64]string, q.collections.Len()) + q.dbs.Range(func(name string, id int64) bool { + dbIDs[id] = name + return true + }) + q.collections.Range(func(name string, id int64) bool { + _, collectionName := SplitCollectionKey(name) + collectionIDs[id] = collectionName + return true + }) + coolOffSpeed := Params.QuotaConfig.CoolOffSpeed.GetAsFloat() - coolOff := func(realTimeSearchRate float64, realTimeQueryRate float64, collections ...int64) { + coolOffCollectionID := func(collections ...int64) error { for _, collection := range collections { - if q.currentRates[collection][internalpb.RateType_DQLSearch] != Inf && realTimeSearchRate > 0 { - q.currentRates[collection][internalpb.RateType_DQLSearch] = Limit(realTimeSearchRate * coolOffSpeed) - log.RatedWarn(10, "QuotaCenter cool read rates off done", - zap.Int64("collectionID", collection), - zap.Any("searchRate", q.currentRates[collection][internalpb.RateType_DQLSearch])) + dbID, ok := q.collectionIDToDBID.Get(collection) + if !ok { + return fmt.Errorf("db ID not found of collection ID: %d", collection) } - if q.currentRates[collection][internalpb.RateType_DQLQuery] != Inf && realTimeQueryRate > 0 { - q.currentRates[collection][internalpb.RateType_DQLQuery] = Limit(realTimeQueryRate * coolOffSpeed) - log.RatedWarn(10, "QuotaCenter cool read rates off done", - zap.Int64("collectionID", collection), - zap.Any("queryRate", q.currentRates[collection][internalpb.RateType_DQLQuery])) + collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collection) + if collectionLimiter == nil { + return fmt.Errorf("collection limiter not found: %d", collection) } + dbName, ok := dbIDs[dbID] + if !ok { + return fmt.Errorf("db name not found of db ID: %d", dbID) + } + collectionName, ok := collectionIDs[collection] + if !ok { + return fmt.Errorf("collection name not found of collection ID: %d", collection) + } + + realTimeSearchRate := q.getRealTimeRate( + ratelimitutil.FormatSubLabel(internalpb.RateType_DQLSearch.String(), + ratelimitutil.GetCollectionSubLabel(dbName, collectionName))) + realTimeQueryRate := q.getRealTimeRate( + ratelimitutil.FormatSubLabel(internalpb.RateType_DQLQuery.String(), + ratelimitutil.GetCollectionSubLabel(dbName, collectionName))) + q.coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed, collectionLimiter, log) collectionProps := q.getCollectionLimitProperties(collection) - q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionSearchRateMinKey), internalpb.RateType_DQLSearch, collection) - q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionQueryRateMinKey), internalpb.RateType_DQLQuery, collection) + q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionSearchRateMinKey), + internalpb.RateType_DQLSearch, collectionLimiter) + q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionQueryRateMinKey), + internalpb.RateType_DQLQuery, collectionLimiter) } + return nil } - // TODO: unify search and query? - realTimeSearchRate := q.getRealTimeRate(internalpb.RateType_DQLSearch) - realTimeQueryRate := q.getRealTimeRate(internalpb.RateType_DQLQuery) - coolOff(realTimeSearchRate, realTimeQueryRate, limitCollectionSet.Collect()...) + if clusterLimit { + realTimeClusterSearchRate := q.getRealTimeRate(internalpb.RateType_DQLSearch.String()) + realTimeClusterQueryRate := q.getRealTimeRate(internalpb.RateType_DQLQuery.String()) + q.coolOffReading(realTimeClusterSearchRate, realTimeClusterQueryRate, coolOffSpeed, q.rateLimiter.GetRootLimiters(), log) + } + + var updateLimitErr error + limitDBNameSet.Range(func(name string) bool { + dbID, ok := q.dbs.Get(name) + if !ok { + log.Warn("db not found", zap.String("dbName", name)) + updateLimitErr = fmt.Errorf("db not found: %s", name) + return false + } + dbLimiter := q.rateLimiter.GetDatabaseLimiters(dbID) + if dbLimiter == nil { + log.Warn("database limiter not found", zap.Int64("dbID", dbID)) + updateLimitErr = fmt.Errorf("database limiter not found") + return false + } + + realTimeSearchRate := q.getRealTimeRate( + ratelimitutil.FormatSubLabel(internalpb.RateType_DQLSearch.String(), + ratelimitutil.GetDBSubLabel(name))) + realTimeQueryRate := q.getRealTimeRate( + ratelimitutil.FormatSubLabel(internalpb.RateType_DQLQuery.String(), + ratelimitutil.GetDBSubLabel(name))) + q.coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed, dbLimiter, log) + return true + }) + if updateLimitErr != nil { + return updateLimitErr + } + + limitCollectionNameSet.Range(func(name string) bool { + dbName, collectionName := splitCollctionRateKey(name) + dbID, ok := q.dbs.Get(dbName) + if !ok { + log.Warn("db not found", zap.String("dbName", dbName)) + updateLimitErr = fmt.Errorf("db not found: %s", dbName) + return false + } + collectionID, ok := q.collections.Get(FormatCollectionKey(dbID, collectionName)) + if !ok { + log.Warn("collection not found", zap.String("collectionName", name)) + updateLimitErr = fmt.Errorf("collection not found: %s", name) + return false + } + limitCollectionSet.Insert(collectionID) + return true + }) + if updateLimitErr != nil { + return updateLimitErr + } + + if updateLimitErr = coolOffCollectionID(limitCollectionSet.Collect()...); updateLimitErr != nil { + return updateLimitErr + } + + return nil +} + +func (q *QuotaCenter) coolOffReading(realTimeSearchRate, realTimeQueryRate, coolOffSpeed float64, + node *rlinternal.RateLimiterNode, mlog *log.MLogger, +) { + limiter := node.GetLimiters() + + v, ok := limiter.Get(internalpb.RateType_DQLSearch) + if ok && v.Limit() != Inf && realTimeSearchRate > 0 { + v.SetLimit(Limit(realTimeSearchRate * coolOffSpeed)) + mlog.RatedWarn(10, "QuotaCenter cool read rates off done", + zap.Any("level", node.Level()), + zap.Any("id", node.GetID()), + zap.Any("searchRate", v.Limit())) + } + + v, ok = limiter.Get(internalpb.RateType_DQLQuery) + if ok && v.Limit() != Inf && realTimeQueryRate > 0 { + v.SetLimit(Limit(realTimeQueryRate * coolOffSpeed)) + mlog.RatedWarn(10, "QuotaCenter cool read rates off done", + zap.Any("level", node.Level()), + zap.Any("id", node.GetID()), + zap.Any("queryRate", v.Limit())) + } } // calculateWriteRates calculates and sets dml rates. func (q *QuotaCenter) calculateWriteRates() error { log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0) if Params.QuotaConfig.ForceDenyWriting.GetAsBool() { - q.forceDenyWriting(commonpb.ErrorCode_ForceDeny) - return nil + return q.forceDenyWriting(commonpb.ErrorCode_ForceDeny, true, nil, nil, nil) } - q.checkDiskQuota() + if err := q.checkDiskQuota(); err != nil { + return err + } ts, err := q.tsoAllocator.GenerateTSO(1) if err != nil { @@ -538,37 +896,68 @@ func (q *QuotaCenter) calculateWriteRates() error { growingSegFactors := q.getGrowingSegmentsSizeFactor() updateCollectionFactor(growingSegFactors) + ttCollections := make([]int64, 0) + memoryCollections := make([]int64, 0) + for collection, factor := range collectionFactors { metrics.RootCoordRateLimitRatio.WithLabelValues(fmt.Sprint(collection)).Set(1 - factor) if factor <= 0 { if _, ok := ttFactors[collection]; ok && factor == ttFactors[collection] { // factor comes from ttFactor - q.forceDenyWriting(commonpb.ErrorCode_TimeTickLongDelay, collection) + ttCollections = append(ttCollections, collection) } else { - // factor comes from memFactor or growingSegFactor, all about mem exhausted - q.forceDenyWriting(commonpb.ErrorCode_MemoryQuotaExhausted, collection) + memoryCollections = append(memoryCollections, collection) } } - if q.currentRates[collection][internalpb.RateType_DMLInsert] != Inf { - q.currentRates[collection][internalpb.RateType_DMLInsert] *= Limit(factor) + dbID, ok := q.collectionIDToDBID.Get(collection) + if !ok { + return fmt.Errorf("db ID not found of collection ID: %d", collection) } - if q.currentRates[collection][internalpb.RateType_DMLUpsert] != Inf { - q.currentRates[collection][internalpb.RateType_DMLUpsert] *= Limit(factor) + collectionLimiter := q.rateLimiter.GetCollectionLimiters(dbID, collection) + if collectionLimiter == nil { + return fmt.Errorf("collection limiter not found: %d", collection) } - if q.currentRates[collection][internalpb.RateType_DMLDelete] != Inf { - q.currentRates[collection][internalpb.RateType_DMLDelete] *= Limit(factor) + + limiter := collectionLimiter.GetLimiters() + for _, rt := range []internalpb.RateType{ + internalpb.RateType_DMLInsert, + internalpb.RateType_DMLUpsert, + internalpb.RateType_DMLDelete, + } { + v, ok := limiter.Get(rt) + if ok { + if v.Limit() != Inf { + v.SetLimit(v.Limit() * Limit(factor)) + } + } } collectionProps := q.getCollectionLimitProperties(collection) - q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionInsertRateMinKey), internalpb.RateType_DMLInsert, collection) - q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionUpsertRateMinKey), internalpb.RateType_DMLUpsert, collection) - q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionDeleteRateMinKey), internalpb.RateType_DMLDelete, collection) + q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionInsertRateMinKey), + internalpb.RateType_DMLInsert, collectionLimiter) + q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionUpsertRateMinKey), + internalpb.RateType_DMLUpsert, collectionLimiter) + q.guaranteeMinRate(getCollectionRateLimitConfig(collectionProps, common.CollectionDeleteRateMinKey), + internalpb.RateType_DMLDelete, collectionLimiter) log.RatedDebug(10, "QuotaCenter cool write rates off done", zap.Int64("collectionID", collection), zap.Float64("factor", factor)) } + if len(ttCollections) > 0 { + if err = q.forceDenyWriting(commonpb.ErrorCode_TimeTickLongDelay, false, nil, ttCollections, nil); err != nil { + log.Warn("fail to force deny writing for time tick delay", zap.Error(err)) + return err + } + } + if len(memoryCollections) > 0 { + if err = q.forceDenyWriting(commonpb.ErrorCode_MemoryQuotaExhausted, false, nil, memoryCollections, nil); err != nil { + log.Warn("fail to force deny writing for memory quota", zap.Error(err)) + return err + } + } + return nil } @@ -765,69 +1154,86 @@ func (q *QuotaCenter) getGrowingSegmentsSizeFactor() map[int64]float64 { // calculateRates calculates target rates by different strategies. func (q *QuotaCenter) calculateRates() error { - q.resetAllCurrentRates() - - err := q.calculateWriteRates() + err := q.resetAllCurrentRates() if err != nil { + log.Warn("QuotaCenter resetAllCurrentRates failed", zap.Error(err)) + return err + } + + err = q.calculateWriteRates() + if err != nil { + log.Warn("QuotaCenter calculateWriteRates failed", zap.Error(err)) + return err + } + err = q.calculateReadRates() + if err != nil { + log.Warn("QuotaCenter calculateReadRates failed", zap.Error(err)) return err } - q.calculateReadRates() // log.Debug("QuotaCenter calculates rate done", zap.Any("rates", q.currentRates)) return nil } -func (q *QuotaCenter) resetAllCurrentRates() { - q.quotaStates = make(map[int64]map[milvuspb.QuotaState]commonpb.ErrorCode) - q.currentRates = map[int64]map[internalpb.RateType]ratelimitutil.Limit{} - for _, collection := range q.writableCollections { - q.resetCurrentRate(internalpb.RateType_DMLInsert, collection) - q.resetCurrentRate(internalpb.RateType_DMLUpsert, collection) - q.resetCurrentRate(internalpb.RateType_DMLDelete, collection) - q.resetCurrentRate(internalpb.RateType_DMLBulkLoad, collection) - } +func (q *QuotaCenter) resetAllCurrentRates() error { + q.rateLimiter = rlinternal.NewRateLimiterTree(initInfLimiter(internalpb.RateScope_Cluster, allOps)) + initLimiters := func(sourceCollections map[int64]map[int64][]int64) { + for dbID, collections := range sourceCollections { + for collectionID, partitionIDs := range collections { + getCollectionLimitVal := func(rateType internalpb.RateType) Limit { + limitVal, err := q.getCollectionMaxLimit(rateType, collectionID) + if err != nil { + return Limit(quota.GetQuotaValue(internalpb.RateScope_Collection, rateType, Params)) + } + return limitVal + } - for _, collection := range q.readableCollections { - q.resetCurrentRate(internalpb.RateType_DQLSearch, collection) - q.resetCurrentRate(internalpb.RateType_DQLQuery, collection) + for _, partitionID := range partitionIDs { + q.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partitionID, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFuncWithLimitFunc(internalpb.RateScope_Collection, allOps, getCollectionLimitVal), + newParamLimiterFunc(internalpb.RateScope_Partition, allOps)) + } + if len(partitionIDs) == 0 { + q.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFuncWithLimitFunc(internalpb.RateScope_Collection, allOps, getCollectionLimitVal)) + } + } + if len(collections) == 0 { + q.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newParamLimiterFunc(internalpb.RateScope_Database, allOps)) + } + } } + initLimiters(q.readableCollections) + initLimiters(q.writableCollections) + return nil } -// resetCurrentRates resets all current rates to configured rates. -func (q *QuotaCenter) resetCurrentRate(rt internalpb.RateType, collection int64) { - if q.currentRates[collection] == nil { - q.currentRates[collection] = make(map[internalpb.RateType]ratelimitutil.Limit) - } - - if q.quotaStates[collection] == nil { - q.quotaStates[collection] = make(map[milvuspb.QuotaState]commonpb.ErrorCode) - } - - collectionProps := q.getCollectionLimitProperties(collection) +// getCollectionMaxLimit get limit value from collection's properties. +func (q *QuotaCenter) getCollectionMaxLimit(rt internalpb.RateType, collectionID int64) (ratelimitutil.Limit, error) { + collectionProps := q.getCollectionLimitProperties(collectionID) switch rt { case internalpb.RateType_DMLInsert: - q.currentRates[collection][rt] = Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionInsertRateMaxKey)) + return Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionInsertRateMaxKey)), nil case internalpb.RateType_DMLUpsert: - q.currentRates[collection][rt] = Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionUpsertRateMaxKey)) + return Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionUpsertRateMaxKey)), nil case internalpb.RateType_DMLDelete: - q.currentRates[collection][rt] = Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionDeleteRateMaxKey)) + return Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionDeleteRateMaxKey)), nil case internalpb.RateType_DMLBulkLoad: - q.currentRates[collection][rt] = Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionBulkLoadRateMaxKey)) + return Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionBulkLoadRateMaxKey)), nil case internalpb.RateType_DQLSearch: - q.currentRates[collection][rt] = Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionSearchRateMaxKey)) + return Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionSearchRateMaxKey)), nil case internalpb.RateType_DQLQuery: - q.currentRates[collection][rt] = Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionQueryRateMaxKey)) - } - if q.currentRates[collection][rt] < 0 { - q.currentRates[collection][rt] = Inf // no limit + return Limit(getCollectionRateLimitConfig(collectionProps, common.CollectionQueryRateMaxKey)), nil + default: + return 0, fmt.Errorf("unsupportd rate type:%s", rt.String()) } } func (q *QuotaCenter) getCollectionLimitProperties(collection int64) map[string]string { log := log.Ctx(context.Background()).WithRateGroup("rootcoord.QuotaCenter", 1.0, 60.0) - - // dbName can be ignored if ts is max timestamps - collectionInfo, err := q.meta.GetCollectionByID(context.TODO(), "", collection, typeutil.MaxTimestamp, false) + collectionInfo, err := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collection) if err != nil { log.RatedWarn(10, "failed to get rate limit properties from collection meta", zap.Int64("collectionID", collection), @@ -844,118 +1250,217 @@ func (q *QuotaCenter) getCollectionLimitProperties(collection int64) map[string] } // checkDiskQuota checks if disk quota exceeded. -func (q *QuotaCenter) checkDiskQuota() { +func (q *QuotaCenter) checkDiskQuota() error { q.diskMu.Lock() defer q.diskMu.Unlock() if !Params.QuotaConfig.DiskProtectionEnabled.GetAsBool() { - return + return nil } if q.dataCoordMetrics == nil { - return + return nil } - collections := typeutil.NewUniqueSet() + totalDiskQuota := Params.QuotaConfig.DiskQuota.GetAsFloat() + total := q.dataCoordMetrics.TotalBinlogSize + if float64(total) >= totalDiskQuota { + err := q.forceDenyWriting(commonpb.ErrorCode_DiskQuotaExhausted, true, nil, nil, nil) + if err != nil { + log.Warn("fail to force deny writing", zap.Error(err)) + } + return err + } + + collectionDiskQuota := Params.QuotaConfig.DiskQuotaPerCollection.GetAsFloat() + dbSizeInfo := make(map[int64]int64) + collections := make([]int64, 0) for collection, binlogSize := range q.dataCoordMetrics.CollectionBinlogSize { collectionProps := q.getCollectionLimitProperties(collection) - colDiskQuota := getCollectionRateLimitConfig(collectionProps, common.CollectionDiskQuotaKey) + colDiskQuota := getRateLimitConfig(collectionProps, common.CollectionDiskQuotaKey, collectionDiskQuota) if float64(binlogSize) >= colDiskQuota { log.RatedWarn(10, "collection disk quota exceeded", zap.Int64("collection", collection), zap.Int64("coll disk usage", binlogSize), zap.Float64("coll disk quota", colDiskQuota)) - collections.Insert(collection) + collections = append(collections, collection) + } + dbID, ok := q.collectionIDToDBID.Get(collection) + if !ok { + log.Warn("cannot find db id for collection", zap.Int64("collection", collection)) + continue + } + dbSizeInfo[dbID] += binlogSize + } + + dbs := make([]int64, 0) + dbDiskQuota := Params.QuotaConfig.DiskQuotaPerDB.GetAsFloat() + for dbID, binlogSize := range dbSizeInfo { + if float64(binlogSize) >= dbDiskQuota { + log.RatedWarn(10, "db disk quota exceeded", + zap.Int64("db", dbID), + zap.Int64("db disk usage", binlogSize), + zap.Float64("db disk quota", dbDiskQuota)) + dbs = append(dbs, dbID) } } - if collections.Len() > 0 { - q.forceDenyWriting(commonpb.ErrorCode_DiskQuotaExhausted, collections.Collect()...) + + col2partitions := make(map[int64][]int64) + partitionDiskQuota := Params.QuotaConfig.DiskQuotaPerPartition.GetAsFloat() + for collection, partitions := range q.dataCoordMetrics.PartitionsBinlogSize { + for partition, binlogSize := range partitions { + if float64(binlogSize) >= partitionDiskQuota { + log.RatedWarn(10, "partition disk quota exceeded", + zap.Int64("collection", collection), + zap.Int64("partition", partition), + zap.Int64("part disk usage", binlogSize), + zap.Float64("part disk quota", partitionDiskQuota)) + col2partitions[collection] = append(col2partitions[collection], partition) + } + } } - total := q.dataCoordMetrics.TotalBinlogSize - if float64(total) >= totalDiskQuota { - log.RatedWarn(10, "total disk quota exceeded", - zap.Int64("total disk usage", total), - zap.Float64("total disk quota", totalDiskQuota)) - q.forceDenyWriting(commonpb.ErrorCode_DiskQuotaExhausted) + + err := q.forceDenyWriting(commonpb.ErrorCode_DiskQuotaExhausted, false, dbs, collections, col2partitions) + if err != nil { + log.Warn("fail to force deny writing", zap.Error(err)) + return err } q.totalBinlogSize = total + return nil } -// setRates notifies Proxies to set rates for different rate types. -func (q *QuotaCenter) setRates() error { - ctx, cancel := context.WithTimeout(q.ctx, SetRatesTimeout) - defer cancel() - - toCollectionRate := func(collection int64, currentRates map[internalpb.RateType]ratelimitutil.Limit) *proxypb.CollectionRate { - rates := make([]*internalpb.Rate, 0, len(q.currentRates)) - switch q.rateAllocateStrategy { - case Average: - proxyNum := q.proxies.GetProxyCount() - if proxyNum == 0 { - return nil +func (q *QuotaCenter) toRequestLimiter(limiter *rlinternal.RateLimiterNode) *proxypb.Limiter { + var rates []*internalpb.Rate + switch q.rateAllocateStrategy { + case Average: + proxyNum := q.proxies.GetProxyCount() + if proxyNum == 0 { + return nil + } + limiter.GetLimiters().Range(func(rt internalpb.RateType, limiter *ratelimitutil.Limiter) bool { + if !limiter.HasUpdated() { + return true } - for rt, r := range currentRates { - if r == Inf { - rates = append(rates, &internalpb.Rate{Rt: rt, R: float64(r)}) - } else { - rates = append(rates, &internalpb.Rate{Rt: rt, R: float64(r) / float64(proxyNum)}) + r := limiter.Limit() + if r != Inf { + rates = append(rates, &internalpb.Rate{Rt: rt, R: float64(r) / float64(proxyNum)}) + } + return true + }) + case ByRateWeight: + // TODO: support ByRateWeight + } + + size := limiter.GetQuotaStates().Len() + states := make([]milvuspb.QuotaState, 0, size) + codes := make([]commonpb.ErrorCode, 0, size) + + limiter.GetQuotaStates().Range(func(state milvuspb.QuotaState, code commonpb.ErrorCode) bool { + states = append(states, state) + codes = append(codes, code) + return true + }) + + return &proxypb.Limiter{ + Rates: rates, + States: states, + Codes: codes, + } +} + +func (q *QuotaCenter) toRatesRequest() *proxypb.SetRatesRequest { + clusterRateLimiter := q.rateLimiter.GetRootLimiters() + + // collect db rate limit if clusterRateLimiter has database limiter children + dbLimiters := make(map[int64]*proxypb.LimiterNode, clusterRateLimiter.GetChildren().Len()) + clusterRateLimiter.GetChildren().Range(func(dbID int64, dbRateLimiters *rlinternal.RateLimiterNode) bool { + dbLimiter := q.toRequestLimiter(dbRateLimiters) + + // collect collection rate limit if dbRateLimiters has collection limiter children + collectionLimiters := make(map[int64]*proxypb.LimiterNode, dbRateLimiters.GetChildren().Len()) + dbRateLimiters.GetChildren().Range(func(collectionID int64, collectionRateLimiters *rlinternal.RateLimiterNode) bool { + collectionLimiter := q.toRequestLimiter(collectionRateLimiters) + + // collect partitions rate limit if collectionRateLimiters has partition limiter children + partitionLimiters := make(map[int64]*proxypb.LimiterNode, collectionRateLimiters.GetChildren().Len()) + collectionRateLimiters.GetChildren().Range(func(partitionID int64, partitionRateLimiters *rlinternal.RateLimiterNode) bool { + partitionLimiters[partitionID] = &proxypb.LimiterNode{ + Limiter: q.toRequestLimiter(partitionRateLimiters), + Children: make(map[int64]*proxypb.LimiterNode, 0), } + return true + }) + + collectionLimiters[collectionID] = &proxypb.LimiterNode{ + Limiter: collectionLimiter, + Children: partitionLimiters, } + return true + }) - case ByRateWeight: - // TODO: support ByRateWeight + dbLimiters[dbID] = &proxypb.LimiterNode{ + Limiter: dbLimiter, + Children: collectionLimiters, } - return &proxypb.CollectionRate{ - Collection: collection, - Rates: rates, - States: lo.Keys(q.quotaStates[collection]), - Codes: lo.Values(q.quotaStates[collection]), - } + return true + }) + + clusterLimiter := &proxypb.LimiterNode{ + Limiter: q.toRequestLimiter(clusterRateLimiter), + Children: dbLimiters, } - collectionRates := make([]*proxypb.CollectionRate, 0) - for collection, rates := range q.currentRates { - collectionRates = append(collectionRates, toCollectionRate(collection, rates)) - } timestamp := tsoutil.ComposeTSByTime(time.Now(), 0) - req := &proxypb.SetRatesRequest{ + return &proxypb.SetRatesRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgID(int64(timestamp)), commonpbutil.WithTimeStamp(timestamp), ), - Rates: collectionRates, + Rates: []*proxypb.CollectionRate{}, + RootLimiter: clusterLimiter, } - return q.proxies.SetRates(ctx, req) +} + +// sendRatesToProxy notifies Proxies to set rates for different rate types. +func (q *QuotaCenter) sendRatesToProxy() error { + ctx, cancel := context.WithTimeout(context.Background(), SetRatesTimeout) + defer cancel() + return q.proxies.SetRates(ctx, q.toRatesRequest()) } // recordMetrics records metrics of quota states. func (q *QuotaCenter) recordMetrics() { - record := func(errorCode commonpb.ErrorCode) { - var hasException float64 = 0 - for collectionID, states := range q.quotaStates { - info, err := q.meta.GetCollectionByID(context.Background(), "", collectionID, typeutil.MaxTimestamp, false) - if err != nil { - log.Warn("failed to get collection info by its id, ignore to report quota states", - zap.Int64("collection", collectionID), - zap.Error(err), - ) - continue - } - dbm, err := q.meta.GetDatabaseByID(context.Background(), info.DBID, typeutil.MaxTimestamp) - if err != nil { - log.Warn("failed to get database name info by its id, ignore to report quota states", - zap.Int64("collection", collectionID), - zap.Error(err), - ) - continue - } + dbIDs := make(map[int64]string, q.dbs.Len()) + collectionIDs := make(map[int64]string, q.collections.Len()) + q.dbs.Range(func(name string, id int64) bool { + dbIDs[id] = name + return true + }) + q.collections.Range(func(name string, id int64) bool { + _, collectionName := SplitCollectionKey(name) + collectionIDs[id] = collectionName + return true + }) - for _, state := range states { - if state == errorCode { - hasException = 1 + record := func(errorCode commonpb.ErrorCode) { + rlinternal.TraverseRateLimiterTree(q.rateLimiter.GetRootLimiters(), nil, + func(node *rlinternal.RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool { + if errCode == errorCode { + var name string + switch node.Level() { + case internalpb.RateScope_Cluster: + name = "cluster" + case internalpb.RateScope_Database: + name = "db_" + dbIDs[node.GetID()] + case internalpb.RateScope_Collection: + name = "collection_" + collectionIDs[node.GetID()] + default: + return false + } + metrics.RootCoordQuotaStates.WithLabelValues(errorCode.String(), name).Set(1.0) + return false } - } - metrics.RootCoordQuotaStates.WithLabelValues(errorCode.String(), dbm.Name).Set(hasException) - } + return true + }) } record(commonpb.ErrorCode_MemoryQuotaExhausted) record(commonpb.ErrorCode_DiskQuotaExhausted) diff --git a/internal/rootcoord/quota_center_test.go b/internal/rootcoord/quota_center_test.go index 2f0cb4a45c..2c19189b68 100644 --- a/internal/rootcoord/quota_center_test.go +++ b/internal/rootcoord/quota_center_test.go @@ -18,6 +18,7 @@ package rootcoord import ( "context" + "encoding/json" "fmt" "math" "testing" @@ -37,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/internal/util/proxyutil" + interalratelimitutil "github.com/milvus-io/milvus/internal/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" @@ -62,11 +64,22 @@ func TestQuotaCenter(t *testing.T) { dc := mocks.NewMockDataCoordClient(t) dc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + collectionIDToPartitionIDs := map[int64][]int64{ + 1: {}, + 2: {}, + 3: {}, + } + + collectionIDToDBID := typeutil.NewConcurrentMap[int64, int64]() + collectionIDToDBID.Insert(1, 0) + collectionIDToDBID.Insert(2, 0) + collectionIDToDBID.Insert(3, 0) + t.Run("test QuotaCenter", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) quotaCenter.Start() time.Sleep(10 * time.Millisecond) @@ -98,6 +111,12 @@ func TestQuotaCenter(t *testing.T) { }) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + Name: "default", + ID: 1, + }, + }, nil).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) quotaCenter.Start() time.Sleep(3 * time.Second) @@ -108,17 +127,24 @@ func TestQuotaCenter(t *testing.T) { assert.True(t, time.Since(start).Seconds() <= 5) }) - t.Run("test syncMetrics", func(t *testing.T) { + t.Run("test collectMetrics", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + Name: "default", + ID: 1, + }, + }, nil).Maybe() + qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{Status: merr.Success()}, nil) quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() assert.Error(t, err) // for empty response quotaCenter = NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() assert.Error(t, err) dc.ExpectedCalls = nil @@ -127,66 +153,264 @@ func TestQuotaCenter(t *testing.T) { }, nil) quotaCenter = NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() assert.Error(t, err) dc.ExpectedCalls = nil dc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock err")) quotaCenter = NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() assert.Error(t, err) qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ Status: merr.Status(err), }, nil) quotaCenter = NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() assert.Error(t, err) }) - t.Run("test forceDeny", func(t *testing.T) { + t.Run("list database fail", func(t *testing.T) { + qc := mocks.NewMockQueryCoordClient(t) + dc2 := mocks.NewMockDataCoordClient(t) + pcm2 := proxyutil.NewMockProxyClientManager(t) + meta := mockrootcoord.NewIMetaTable(t) + + emptyQueryCoordTopology := &metricsinfo.QueryCoordTopology{} + queryBytes, _ := json.Marshal(emptyQueryCoordTopology) + qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + Response: string(queryBytes), + }, nil).Once() + emptyDataCoordTopology := &metricsinfo.DataCoordTopology{} + dataBytes, _ := json.Marshal(emptyDataCoordTopology) + dc2.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + Response: string(dataBytes), + }, nil).Once() + pcm2.EXPECT().GetProxyMetrics(mock.Anything).Return([]*milvuspb.GetMetricsResponse{}, nil).Once() + + meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() + quotaCenter := NewQuotaCenter(pcm2, qc, dc2, core.tsoAllocator, meta) + err = quotaCenter.collectMetrics() + assert.Error(t, err) + }) + + t.Run("get collection by id fail, querynode", func(t *testing.T) { + qc := mocks.NewMockQueryCoordClient(t) + dc2 := mocks.NewMockDataCoordClient(t) + pcm2 := proxyutil.NewMockProxyClientManager(t) + meta := mockrootcoord.NewIMetaTable(t) + + emptyQueryCoordTopology := &metricsinfo.QueryCoordTopology{ + Cluster: metricsinfo.QueryClusterTopology{ + ConnectedNodes: []metricsinfo.QueryNodeInfos{ + { + QuotaMetrics: &metricsinfo.QueryNodeQuotaMetrics{ + Effect: metricsinfo.NodeEffect{ + CollectionIDs: []int64{1000}, + }, + }, + CollectionMetrics: &metricsinfo.QueryNodeCollectionMetrics{ + CollectionRows: map[int64]int64{ + 1000: 100, + }, + }, + }, + }, + }, + } + queryBytes, _ := json.Marshal(emptyQueryCoordTopology) + qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + Response: string(queryBytes), + }, nil).Once() + emptyDataCoordTopology := &metricsinfo.DataCoordTopology{} + dataBytes, _ := json.Marshal(emptyDataCoordTopology) + dc2.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + Response: string(dataBytes), + }, nil).Once() + pcm2.EXPECT().GetProxyMetrics(mock.Anything).Return([]*milvuspb.GetMetricsResponse{}, nil).Once() + meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + ID: 1, + Name: "default", + }, + }, nil).Once() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, errors.New("mock err: get collection by id")).Once() + + quotaCenter := NewQuotaCenter(pcm2, qc, dc2, core.tsoAllocator, meta) + err = quotaCenter.collectMetrics() + assert.Error(t, err) + }) + + t.Run("get collection by id fail, datanode", func(t *testing.T) { + qc := mocks.NewMockQueryCoordClient(t) + dc2 := mocks.NewMockDataCoordClient(t) + pcm2 := proxyutil.NewMockProxyClientManager(t) + meta := mockrootcoord.NewIMetaTable(t) + + emptyQueryCoordTopology := &metricsinfo.QueryCoordTopology{} + queryBytes, _ := json.Marshal(emptyQueryCoordTopology) + qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + Response: string(queryBytes), + }, nil).Once() + emptyDataCoordTopology := &metricsinfo.DataCoordTopology{ + Cluster: metricsinfo.DataClusterTopology{ + ConnectedDataNodes: []metricsinfo.DataNodeInfos{ + { + QuotaMetrics: &metricsinfo.DataNodeQuotaMetrics{ + Effect: metricsinfo.NodeEffect{ + CollectionIDs: []int64{1000}, + }, + }, + }, + }, + }, + } + dataBytes, _ := json.Marshal(emptyDataCoordTopology) + dc2.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + Response: string(dataBytes), + }, nil).Once() + pcm2.EXPECT().GetProxyMetrics(mock.Anything).Return([]*milvuspb.GetMetricsResponse{}, nil).Once() + meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + ID: 1, + Name: "default", + }, + }, nil).Once() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, errors.New("mock err: get collection by id")).Once() + + quotaCenter := NewQuotaCenter(pcm2, qc, dc2, core.tsoAllocator, meta) + err = quotaCenter.collectMetrics() + assert.Error(t, err) + }) + + t.Run("test force deny reading collection", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - quotaCenter.readableCollections = []int64{1, 2, 3} - quotaCenter.resetAllCurrentRates() - quotaCenter.forceDenyReading(commonpb.ErrorCode_ForceDeny, 1, 2, 3, 4) - for _, collection := range quotaCenter.readableCollections { - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DQLSearch]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DQLQuery]) - } - assert.Equal(t, Limit(0), quotaCenter.currentRates[4][internalpb.RateType_DQLSearch]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[4][internalpb.RateType_DQLQuery]) - quotaCenter.writableCollections = []int64{1, 2, 3} - quotaCenter.resetAllCurrentRates() - quotaCenter.forceDenyWriting(commonpb.ErrorCode_ForceDeny, 1, 2, 3, 4) - for _, collection := range quotaCenter.writableCollections { - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLInsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLUpsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLDelete]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLBulkLoad]) + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + err := quotaCenter.resetAllCurrentRates() + assert.NoError(t, err) + + Params.Save(Params.QuotaConfig.ForceDenyReading.Key, "true") + defer Params.Reset(Params.QuotaConfig.ForceDenyReading.Key) + quotaCenter.calculateReadRates() + + for collectionID := range collectionIDToPartitionIDs { + collectionLimiters := quotaCenter.rateLimiter.GetCollectionLimiters(0, collectionID) + assert.NotNil(t, collectionLimiters) + + limiters := collectionLimiters.GetLimiters() + assert.NotNil(t, limiters) + + for _, rt := range []internalpb.RateType{ + internalpb.RateType_DQLSearch, + internalpb.RateType_DQLQuery, + } { + ret, ok := limiters.Get(rt) + assert.True(t, ok) + assert.Equal(t, ret.Limit(), Limit(0)) + } + } + }) + + t.Run("test force deny writing", func(t *testing.T) { + qc := mocks.NewMockQueryCoordClient(t) + meta := mockrootcoord.NewIMetaTable(t) + meta.EXPECT(). + GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything). + Return(nil, merr.ErrCollectionNotFound). + Maybe() + + quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.collectionIDToDBID = typeutil.NewConcurrentMap[int64, int64]() + quotaCenter.collectionIDToDBID.Insert(1, 0) + quotaCenter.collectionIDToDBID.Insert(2, 0) + quotaCenter.collectionIDToDBID.Insert(3, 0) + + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.writableCollections[0][1] = append(quotaCenter.writableCollections[0][1], 1000) + + err := quotaCenter.resetAllCurrentRates() + assert.NoError(t, err) + + err = quotaCenter.forceDenyWriting(commonpb.ErrorCode_ForceDeny, false, nil, []int64{4}, nil) + assert.Error(t, err) + + err = quotaCenter.forceDenyWriting(commonpb.ErrorCode_ForceDeny, false, nil, []int64{1, 2, 3}, map[int64][]int64{ + 1: {1000}, + }) + assert.NoError(t, err) + + for collectionID := range collectionIDToPartitionIDs { + collectionLimiters := quotaCenter.rateLimiter.GetCollectionLimiters(0, collectionID) + assert.NotNil(t, collectionLimiters) + + limiters := collectionLimiters.GetLimiters() + assert.NotNil(t, limiters) + + for _, rt := range []internalpb.RateType{ + internalpb.RateType_DMLInsert, + internalpb.RateType_DMLUpsert, + internalpb.RateType_DMLDelete, + internalpb.RateType_DMLBulkLoad, + } { + ret, ok := limiters.Get(rt) + assert.True(t, ok) + assert.Equal(t, ret.Limit(), Limit(0)) + } + } + + err = quotaCenter.forceDenyWriting(commonpb.ErrorCode_ForceDeny, false, []int64{0}, nil, nil) + assert.NoError(t, err) + dbLimiters := quotaCenter.rateLimiter.GetDatabaseLimiters(0) + assert.NotNil(t, dbLimiters) + limiters := dbLimiters.GetLimiters() + assert.NotNil(t, limiters) + for _, rt := range []internalpb.RateType{ + internalpb.RateType_DMLInsert, + internalpb.RateType_DMLUpsert, + internalpb.RateType_DMLDelete, + internalpb.RateType_DMLBulkLoad, + } { + ret, ok := limiters.Get(rt) + assert.True(t, ok) + assert.Equal(t, ret.Limit(), Limit(0)) } - assert.Equal(t, Limit(0), quotaCenter.currentRates[4][internalpb.RateType_DMLInsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[4][internalpb.RateType_DMLUpsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[4][internalpb.RateType_DMLDelete]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[4][internalpb.RateType_DMLBulkLoad]) }) t.Run("test calculateRates", func(t *testing.T) { + forceBak := Params.QuotaConfig.ForceDenyWriting.GetValue() + paramtable.Get().Save(Params.QuotaConfig.ForceDenyWriting.Key, "false") + defer func() { + paramtable.Get().Save(Params.QuotaConfig.ForceDenyWriting.Key, forceBak) + }() + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.clearMetrics() err = quotaCenter.calculateRates() assert.NoError(t, err) alloc := newMockTsoAllocator() alloc.GenerateTSOF = func(count uint32) (typeutil.Timestamp, error) { - return 0, fmt.Errorf("mock err") + return 0, fmt.Errorf("mock tso err") } quotaCenter.tsoAllocator = alloc + quotaCenter.clearMetrics() err = quotaCenter.calculateRates() assert.Error(t, err) }) @@ -194,7 +418,7 @@ func TestQuotaCenter(t *testing.T) { t.Run("test getTimeTickDelayFactor factors", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) type ttCase struct { maxTtDelay time.Duration @@ -242,7 +466,7 @@ func TestQuotaCenter(t *testing.T) { t.Run("test TimeTickDelayFactor factors", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) type ttCase struct { delay time.Duration @@ -269,8 +493,11 @@ func TestQuotaCenter(t *testing.T) { paramtable.Get().Save(Params.QuotaConfig.DMLMinUpsertRatePerCollection.Key, "0.0") paramtable.Get().Save(Params.QuotaConfig.DMLMaxDeleteRatePerCollection.Key, "100.0") paramtable.Get().Save(Params.QuotaConfig.DMLMinDeleteRatePerCollection.Key, "0.0") - - quotaCenter.writableCollections = []int64{1, 2, 3} + forceBak := Params.QuotaConfig.ForceDenyWriting.GetValue() + paramtable.Get().Save(Params.QuotaConfig.ForceDenyWriting.Key, "false") + defer func() { + paramtable.Get().Save(Params.QuotaConfig.ForceDenyWriting.Key, forceBak) + }() alloc := newMockTsoAllocator() quotaCenter.tsoAllocator = alloc @@ -304,9 +531,21 @@ func TestQuotaCenter(t *testing.T) { }, }, } - quotaCenter.resetAllCurrentRates() - quotaCenter.calculateWriteRates() - deleteFactor := float64(quotaCenter.currentRates[1][internalpb.RateType_DMLDelete]) / Params.QuotaConfig.DMLMaxInsertRatePerCollection.GetAsFloat() + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.collectionIDToDBID = collectionIDToDBID + err = quotaCenter.resetAllCurrentRates() + assert.NoError(t, err) + + err = quotaCenter.calculateWriteRates() + assert.NoError(t, err) + + limit, ok := quotaCenter.rateLimiter.GetCollectionLimiters(0, 1).GetLimiters().Get(internalpb.RateType_DMLDelete) + assert.True(t, ok) + assert.NotNil(t, limit) + + deleteFactor := float64(limit.Limit()) / Params.QuotaConfig.DMLMaxInsertRatePerCollection.GetAsFloat() assert.True(t, math.Abs(deleteFactor-c.expectedFactor) < 0.01) } Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, backup) @@ -315,13 +554,30 @@ func TestQuotaCenter(t *testing.T) { t.Run("test calculateReadRates", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + ID: 0, + Name: "default", + }, + }, nil).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - quotaCenter.readableCollections = []int64{1, 2, 3} + quotaCenter.clearMetrics() + quotaCenter.collectionIDToDBID = collectionIDToDBID + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.dbs.Insert("default", 0) + quotaCenter.collections.Insert("0.col1", 1) + quotaCenter.collections.Insert("0.col2", 2) + quotaCenter.collections.Insert("0.col3", 3) + colSubLabel := ratelimitutil.GetCollectionSubLabel("default", "col1") quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{ 1: {Rms: []metricsinfo.RateMetric{ {Label: internalpb.RateType_DQLSearch.String(), Rate: 100}, {Label: internalpb.RateType_DQLQuery.String(), Rate: 100}, + {Label: ratelimitutil.FormatSubLabel(internalpb.RateType_DQLSearch.String(), colSubLabel), Rate: 100}, + {Label: ratelimitutil.FormatSubLabel(internalpb.RateType_DQLQuery.String(), colSubLabel), Rate: 100}, }}, } @@ -331,7 +587,26 @@ func TestQuotaCenter(t *testing.T) { paramtable.Get().Save(Params.QuotaConfig.DQLLimitEnabled.Key, "true") paramtable.Get().Save(Params.QuotaConfig.DQLMaxQueryRatePerCollection.Key, "500") paramtable.Get().Save(Params.QuotaConfig.DQLMaxSearchRatePerCollection.Key, "500") - quotaCenter.resetAllCurrentRates() + + checkLimiter := func() { + for db, collections := range quotaCenter.readableCollections { + for collection := range collections { + if collection != 1 { + continue + } + limiters := quotaCenter.rateLimiter.GetCollectionLimiters(db, collection).GetLimiters() + searchLimit, _ := limiters.Get(internalpb.RateType_DQLSearch) + assert.Equal(t, Limit(100.0*0.9), searchLimit.Limit()) + + queryLimit, _ := limiters.Get(internalpb.RateType_DQLQuery) + assert.Equal(t, Limit(100.0*0.9), queryLimit.Limit()) + } + } + } + + err := quotaCenter.resetAllCurrentRates() + assert.NoError(t, err) + quotaCenter.queryNodeMetrics = map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics{ 1: {SearchQueue: metricsinfo.ReadInfoInQueue{ AvgQueueDuration: Params.QuotaConfig.QueueLatencyThreshold.GetAsDuration(time.Second), @@ -340,62 +615,69 @@ func TestQuotaCenter(t *testing.T) { CollectionIDs: []int64{1, 2, 3}, }}, } - quotaCenter.calculateReadRates() - for _, collection := range quotaCenter.readableCollections { - assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLSearch]) - assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLQuery]) - } + + err = quotaCenter.calculateReadRates() + assert.NoError(t, err) + checkLimiter() paramtable.Get().Save(Params.QuotaConfig.NQInQueueThreshold.Key, "100") quotaCenter.queryNodeMetrics = map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics{ - 1: {SearchQueue: metricsinfo.ReadInfoInQueue{ - UnsolvedQueue: Params.QuotaConfig.NQInQueueThreshold.GetAsInt64(), - }}, - } - quotaCenter.calculateReadRates() - for _, collection := range quotaCenter.readableCollections { - assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLSearch]) - assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLQuery]) + 1: { + SearchQueue: metricsinfo.ReadInfoInQueue{ + UnsolvedQueue: Params.QuotaConfig.NQInQueueThreshold.GetAsInt64(), + }, + }, } + err = quotaCenter.calculateReadRates() + assert.NoError(t, err) + checkLimiter() paramtable.Get().Save(Params.QuotaConfig.ResultProtectionEnabled.Key, "true") paramtable.Get().Save(Params.QuotaConfig.MaxReadResultRate.Key, "1") quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{ - 1: {Rms: []metricsinfo.RateMetric{ - {Label: internalpb.RateType_DQLSearch.String(), Rate: 100}, - {Label: internalpb.RateType_DQLQuery.String(), Rate: 100}, - {Label: metricsinfo.ReadResultThroughput, Rate: 1.2}, - }}, + 1: { + Rms: []metricsinfo.RateMetric{ + {Label: internalpb.RateType_DQLSearch.String(), Rate: 100}, + {Label: internalpb.RateType_DQLQuery.String(), Rate: 100}, + {Label: ratelimitutil.FormatSubLabel(internalpb.RateType_DQLSearch.String(), colSubLabel), Rate: 100}, + {Label: ratelimitutil.FormatSubLabel(internalpb.RateType_DQLQuery.String(), colSubLabel), Rate: 100}, + {Label: metricsinfo.ReadResultThroughput, Rate: 1.2}, + }, + }, } quotaCenter.queryNodeMetrics = map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics{1: {SearchQueue: metricsinfo.ReadInfoInQueue{}}} - quotaCenter.calculateReadRates() - for _, collection := range quotaCenter.readableCollections { - assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLSearch]) - assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLQuery]) - } + err = quotaCenter.calculateReadRates() + assert.NoError(t, err) + checkLimiter() }) t.Run("test calculateWriteRates", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) err = quotaCenter.calculateWriteRates() assert.NoError(t, err) // force deny - forceBak := Params.QuotaConfig.ForceDenyWriting.GetValue() paramtable.Get().Save(Params.QuotaConfig.ForceDenyWriting.Key, "true") - quotaCenter.writableCollections = []int64{1, 2, 3} + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.collectionIDToDBID = collectionIDToDBID + quotaCenter.collectionIDToDBID = collectionIDToDBID quotaCenter.resetAllCurrentRates() err = quotaCenter.calculateWriteRates() assert.NoError(t, err) - for _, collection := range quotaCenter.writableCollections { - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLInsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLUpsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLDelete]) - } - paramtable.Get().Save(Params.QuotaConfig.ForceDenyWriting.Key, forceBak) + limiters := quotaCenter.rateLimiter.GetRootLimiters().GetLimiters() + a, _ := limiters.Get(internalpb.RateType_DMLInsert) + assert.Equal(t, Limit(0), a.Limit()) + b, _ := limiters.Get(internalpb.RateType_DMLUpsert) + assert.Equal(t, Limit(0), b.Limit()) + c, _ := limiters.Get(internalpb.RateType_DMLDelete) + assert.Equal(t, Limit(0), c.Limit()) + + paramtable.Get().Reset(Params.QuotaConfig.ForceDenyWriting.Key) // disable tt delay protection disableTtBak := Params.QuotaConfig.TtProtectionEnabled.GetValue() @@ -411,14 +693,20 @@ func TestQuotaCenter(t *testing.T) { } err = quotaCenter.calculateWriteRates() assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_MemoryQuotaExhausted, quotaCenter.quotaStates[1][milvuspb.QuotaState_DenyToWrite]) + for db, collections := range quotaCenter.writableCollections { + for collection := range collections { + states := quotaCenter.rateLimiter.GetCollectionLimiters(db, collection).GetQuotaStates() + code, _ := states.Get(milvuspb.QuotaState_DenyToWrite) + assert.Equal(t, commonpb.ErrorCode_MemoryQuotaExhausted, code) + } + } paramtable.Get().Save(Params.QuotaConfig.TtProtectionEnabled.Key, disableTtBak) }) t.Run("test MemoryFactor factors", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) type memCase struct { lowWater float64 @@ -443,7 +731,9 @@ func TestQuotaCenter(t *testing.T) { {0.85, 0.95, 95, 100, 0}, } - quotaCenter.writableCollections = append(quotaCenter.writableCollections, 1, 2, 3) + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } for _, c := range memCases { paramtable.Get().Save(Params.QuotaConfig.QueryNodeMemoryLowWaterLevel.Key, fmt.Sprintf("%f", c.lowWater)) paramtable.Get().Save(Params.QuotaConfig.QueryNodeMemoryHighWaterLevel.Key, fmt.Sprintf("%f", c.highWater)) @@ -473,7 +763,7 @@ func TestQuotaCenter(t *testing.T) { t.Run("test GrowingSegmentsSize factors", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) defaultRatio := Params.QuotaConfig.GrowingSegmentsSizeMinRateRatio.GetAsFloat() tests := []struct { @@ -498,7 +788,9 @@ func TestQuotaCenter(t *testing.T) { {0.85, 0.95, 95, 100, defaultRatio}, } - quotaCenter.writableCollections = append(quotaCenter.writableCollections, 1, 2, 3) + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } paramtable.Get().Save(Params.QuotaConfig.GrowingSegmentsSizeProtectionEnabled.Key, "true") for _, test := range tests { paramtable.Get().Save(Params.QuotaConfig.GrowingSegmentsSizeLowWaterLevel.Key, fmt.Sprintf("%f", test.low)) @@ -528,26 +820,53 @@ func TestQuotaCenter(t *testing.T) { t.Run("test checkDiskQuota", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) quotaCenter.checkDiskQuota() - // total DiskQuota exceeded - quotaBackup := Params.QuotaConfig.DiskQuota.GetValue() - paramtable.Get().Save(Params.QuotaConfig.DiskQuota.Key, "99") - quotaCenter.dataCoordMetrics = &metricsinfo.DataCoordQuotaMetrics{ - TotalBinlogSize: 200 * 1024 * 1024, - CollectionBinlogSize: map[int64]int64{1: 100 * 1024 * 1024}, + checkLimiter := func(notEquals ...int64) { + for db, collections := range quotaCenter.writableCollections { + for collection := range collections { + limiters := quotaCenter.rateLimiter.GetCollectionLimiters(db, collection).GetLimiters() + if lo.Contains(notEquals, collection) { + a, _ := limiters.Get(internalpb.RateType_DMLInsert) + assert.NotEqual(t, Limit(0), a.Limit()) + b, _ := limiters.Get(internalpb.RateType_DMLUpsert) + assert.NotEqual(t, Limit(0), b.Limit()) + c, _ := limiters.Get(internalpb.RateType_DMLDelete) + assert.NotEqual(t, Limit(0), c.Limit()) + } else { + a, _ := limiters.Get(internalpb.RateType_DMLInsert) + assert.Equal(t, Limit(0), a.Limit()) + b, _ := limiters.Get(internalpb.RateType_DMLUpsert) + assert.Equal(t, Limit(0), b.Limit()) + c, _ := limiters.Get(internalpb.RateType_DMLDelete) + assert.Equal(t, Limit(0), c.Limit()) + } + } + } } - quotaCenter.writableCollections = []int64{1, 2, 3} + + // total DiskQuota exceeded + paramtable.Get().Save(Params.QuotaConfig.DiskQuota.Key, "99") + paramtable.Get().Save(Params.QuotaConfig.DiskQuotaPerCollection.Key, "90") + quotaCenter.dataCoordMetrics = &metricsinfo.DataCoordQuotaMetrics{ + TotalBinlogSize: 10 * 1024 * 1024, + CollectionBinlogSize: map[int64]int64{ + 1: 100 * 1024 * 1024, + 2: 100 * 1024 * 1024, + 3: 100 * 1024 * 1024, + }, + } + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.collectionIDToDBID = collectionIDToDBID quotaCenter.resetAllCurrentRates() quotaCenter.checkDiskQuota() - for _, collection := range quotaCenter.writableCollections { - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLInsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLUpsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[collection][internalpb.RateType_DMLDelete]) - } - paramtable.Get().Save(Params.QuotaConfig.DiskQuota.Key, quotaBackup) + checkLimiter() + paramtable.Get().Reset(Params.QuotaConfig.DiskQuota.Key) + paramtable.Get().Reset(Params.QuotaConfig.DiskQuotaPerCollection.Key) // collection DiskQuota exceeded colQuotaBackup := Params.QuotaConfig.DiskQuotaPerCollection.GetValue() @@ -555,18 +874,12 @@ func TestQuotaCenter(t *testing.T) { quotaCenter.dataCoordMetrics = &metricsinfo.DataCoordQuotaMetrics{CollectionBinlogSize: map[int64]int64{ 1: 20 * 1024 * 1024, 2: 30 * 1024 * 1024, 3: 60 * 1024 * 1024, }} - quotaCenter.writableCollections = []int64{1, 2, 3} + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } quotaCenter.resetAllCurrentRates() quotaCenter.checkDiskQuota() - assert.NotEqual(t, Limit(0), quotaCenter.currentRates[1][internalpb.RateType_DMLInsert]) - assert.NotEqual(t, Limit(0), quotaCenter.currentRates[1][internalpb.RateType_DMLUpsert]) - assert.NotEqual(t, Limit(0), quotaCenter.currentRates[1][internalpb.RateType_DMLDelete]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[2][internalpb.RateType_DMLInsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[2][internalpb.RateType_DMLUpsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[2][internalpb.RateType_DMLDelete]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[3][internalpb.RateType_DMLInsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[3][internalpb.RateType_DMLUpsert]) - assert.Equal(t, Limit(0), quotaCenter.currentRates[3][internalpb.RateType_DMLDelete]) + checkLimiter(1) paramtable.Get().Save(Params.QuotaConfig.DiskQuotaPerCollection.Key, colQuotaBackup) }) @@ -575,43 +888,59 @@ func TestQuotaCenter(t *testing.T) { pcm.EXPECT().GetProxyCount().Return(1) pcm.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } quotaCenter.resetAllCurrentRates() collectionID := int64(1) - quotaCenter.currentRates[collectionID] = make(map[internalpb.RateType]ratelimitutil.Limit) - quotaCenter.quotaStates[collectionID] = make(map[milvuspb.QuotaState]commonpb.ErrorCode) - quotaCenter.currentRates[collectionID][internalpb.RateType_DMLInsert] = 100 - quotaCenter.quotaStates[collectionID][milvuspb.QuotaState_DenyToWrite] = commonpb.ErrorCode_MemoryQuotaExhausted - quotaCenter.quotaStates[collectionID][milvuspb.QuotaState_DenyToRead] = commonpb.ErrorCode_ForceDeny - err = quotaCenter.setRates() + limitNode := quotaCenter.rateLimiter.GetCollectionLimiters(0, collectionID) + limitNode.GetLimiters().Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(100, 100)) + limitNode.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_MemoryQuotaExhausted) + limitNode.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, commonpb.ErrorCode_ForceDeny) + err = quotaCenter.sendRatesToProxy() assert.NoError(t, err) }) t.Run("test recordMetrics", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } + quotaCenter.resetAllCurrentRates() collectionID := int64(1) - quotaCenter.quotaStates[collectionID] = make(map[milvuspb.QuotaState]commonpb.ErrorCode) - quotaCenter.quotaStates[collectionID][milvuspb.QuotaState_DenyToWrite] = commonpb.ErrorCode_MemoryQuotaExhausted - quotaCenter.quotaStates[collectionID][milvuspb.QuotaState_DenyToRead] = commonpb.ErrorCode_ForceDeny + limitNode := quotaCenter.rateLimiter.GetCollectionLimiters(0, collectionID) + limitNode.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_MemoryQuotaExhausted) + limitNode.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, commonpb.ErrorCode_ForceDeny) quotaCenter.recordMetrics() }) t.Run("test guaranteeMinRate", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 0: collectionIDToPartitionIDs, + } quotaCenter.resetAllCurrentRates() minRate := Limit(100) collectionID := int64(1) - quotaCenter.currentRates[collectionID] = make(map[internalpb.RateType]ratelimitutil.Limit) - quotaCenter.currentRates[collectionID][internalpb.RateType_DQLSearch] = Limit(50) - quotaCenter.guaranteeMinRate(float64(minRate), internalpb.RateType_DQLSearch, 1) - assert.Equal(t, minRate, quotaCenter.currentRates[collectionID][internalpb.RateType_DQLSearch]) + limitNode := quotaCenter.rateLimiter.GetCollectionLimiters(0, collectionID) + limitNode.GetLimiters().Insert(internalpb.RateType_DQLSearch, ratelimitutil.NewLimiter(50, 50)) + quotaCenter.guaranteeMinRate(float64(minRate), internalpb.RateType_DQLSearch, limitNode) + limiter, _ := limitNode.GetLimiters().Get(internalpb.RateType_DQLSearch) + assert.EqualValues(t, minRate, limiter.Limit()) }) t.Run("test diskAllowance", func(t *testing.T) { @@ -632,7 +961,7 @@ func TestQuotaCenter(t *testing.T) { t.Run(test.name, func(t *testing.T) { collection := UniqueID(0) meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, nil, dc, core.tsoAllocator, meta) quotaCenter.resetAllCurrentRates() quotaBackup := Params.QuotaConfig.DiskQuota.GetValue() @@ -654,21 +983,33 @@ func TestQuotaCenter(t *testing.T) { t.Run("test reset current rates", func(t *testing.T) { meta := mockrootcoord.NewIMetaTable(t) - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, nil, dc, core.tsoAllocator, meta) - quotaCenter.readableCollections = []int64{1} - quotaCenter.writableCollections = []int64{1} + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 0: {1: {}}, + } + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 0: {1: {}}, + } + quotaCenter.collectionIDToDBID = collectionIDToDBID quotaCenter.resetAllCurrentRates() - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLInsert]), Params.QuotaConfig.DMLMaxInsertRatePerCollection.GetAsFloat()) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLUpsert]), Params.QuotaConfig.DMLMaxUpsertRatePerCollection.GetAsFloat()) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLDelete]), Params.QuotaConfig.DMLMaxDeleteRatePerCollection.GetAsFloat()) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLBulkLoad]), Params.QuotaConfig.DMLMaxBulkLoadRatePerCollection.GetAsFloat()) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DQLSearch]), Params.QuotaConfig.DQLMaxSearchRatePerCollection.GetAsFloat()) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DQLQuery]), Params.QuotaConfig.DQLMaxQueryRatePerCollection.GetAsFloat()) + limiters := quotaCenter.rateLimiter.GetCollectionLimiters(0, 1).GetLimiters() + + getRate := func(m *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter], key internalpb.RateType) float64 { + v, _ := m.Get(key) + return float64(v.Limit()) + } + + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLInsert), Params.QuotaConfig.DMLMaxInsertRatePerCollection.GetAsFloat()) + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLUpsert), Params.QuotaConfig.DMLMaxUpsertRatePerCollection.GetAsFloat()) + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLDelete), Params.QuotaConfig.DMLMaxDeleteRatePerCollection.GetAsFloat()) + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLBulkLoad), Params.QuotaConfig.DMLMaxBulkLoadRatePerCollection.GetAsFloat()) + assert.Equal(t, getRate(limiters, internalpb.RateType_DQLSearch), Params.QuotaConfig.DQLMaxSearchRatePerCollection.GetAsFloat()) + assert.Equal(t, getRate(limiters, internalpb.RateType_DQLQuery), Params.QuotaConfig.DQLMaxQueryRatePerCollection.GetAsFloat()) meta.ExpectedCalls = nil - meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(&model.Collection{ Properties: []*commonpb.KeyValuePair{ { Key: common.CollectionInsertRateMaxKey, @@ -701,12 +1042,13 @@ func TestQuotaCenter(t *testing.T) { }, }, nil) quotaCenter.resetAllCurrentRates() - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLInsert]), float64(1*1024*1024)) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLDelete]), float64(2*1024*1024)) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLBulkLoad]), float64(3*1024*1024)) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DQLQuery]), float64(4)) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DQLSearch]), float64(5)) - assert.Equal(t, float64(quotaCenter.currentRates[1][internalpb.RateType_DMLUpsert]), float64(6*1024*1024)) + limiters = quotaCenter.rateLimiter.GetCollectionLimiters(0, 1).GetLimiters() + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLInsert), float64(1*1024*1024)) + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLDelete), float64(2*1024*1024)) + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLBulkLoad), float64(3*1024*1024)) + assert.Equal(t, getRate(limiters, internalpb.RateType_DQLQuery), float64(4)) + assert.Equal(t, getRate(limiters, internalpb.RateType_DQLSearch), float64(5)) + assert.Equal(t, getRate(limiters, internalpb.RateType_DMLUpsert), float64(6*1024*1024)) }) } @@ -764,6 +1106,14 @@ func (s *QuotaCenterSuite) TestSyncMetricsSuccess() { meta := s.meta core := s.core + call := meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + ID: 1, + Name: "default", + }, + }, nil) + defer call.Unset() + s.Run("querycoord_cluster", func() { pcm.EXPECT().GetProxyMetrics(mock.Anything).Return(nil, nil).Once() dc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ @@ -787,13 +1137,16 @@ func (s *QuotaCenterSuite) TestSyncMetricsSuccess() { Status: merr.Status(nil), Response: resp, }, nil).Once() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i int64) (*model.Collection, error) { + return &model.Collection{CollectionID: i, DBID: 1}, nil + }).Times(3) quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() s.Require().NoError(err) - s.ElementsMatch([]int64{100, 200, 300}, quotaCenter.readableCollections) + s.ElementsMatch([]int64{100, 200, 300}, lo.Keys(quotaCenter.readableCollections[1])) nodes := lo.Keys(quotaCenter.queryNodeMetrics) s.ElementsMatch([]int64{1, 2}, nodes) }) @@ -821,13 +1174,16 @@ func (s *QuotaCenterSuite) TestSyncMetricsSuccess() { Status: merr.Status(nil), Response: resp, }, nil).Once() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i int64) (*model.Collection, error) { + return &model.Collection{CollectionID: i, DBID: 1}, nil + }).Times(3) quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() s.Require().NoError(err) - s.ElementsMatch([]int64{100, 200, 300}, quotaCenter.writableCollections) + s.ElementsMatch([]int64{100, 200, 300}, lo.Keys(quotaCenter.writableCollections[1])) nodes := lo.Keys(quotaCenter.dataNodeMetrics) s.ElementsMatch([]int64{1, 2}, nodes) }) @@ -839,6 +1195,13 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { qc := s.qc meta := s.meta core := s.core + call := meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + ID: 1, + Name: "default", + }, + }, nil) + defer call.Unset() s.Run("querycoord_failure", func() { pcm.EXPECT().GetProxyMetrics(mock.Anything).Return(nil, nil).Once() @@ -850,7 +1213,7 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err := quotaCenter.syncMetrics() + err := quotaCenter.collectMetrics() s.Error(err) }) @@ -868,7 +1231,7 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err := quotaCenter.syncMetrics() + err := quotaCenter.collectMetrics() s.Error(err) }) @@ -882,7 +1245,7 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { dc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err := quotaCenter.syncMetrics() + err := quotaCenter.collectMetrics() s.Error(err) }) @@ -899,7 +1262,7 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { }, nil).Once() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err := quotaCenter.syncMetrics() + err := quotaCenter.collectMetrics() s.Error(err) }) @@ -916,7 +1279,7 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { pcm.EXPECT().GetProxyMetrics(mock.Anything).Return(nil, errors.New("mocked")).Once() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err := quotaCenter.syncMetrics() + err := quotaCenter.collectMetrics() s.Error(err) }) @@ -938,7 +1301,7 @@ func (s *QuotaCenterSuite) TestSyncMetricsFailure() { }, nil).Once() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err := quotaCenter.syncMetrics() + err := quotaCenter.collectMetrics() s.Error(err) }) } @@ -950,6 +1313,19 @@ func (s *QuotaCenterSuite) TestNodeOffline() { meta := s.meta core := s.core + call := meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i int64) (*model.Collection, error) { + return &model.Collection{CollectionID: i, DBID: 1}, nil + }).Maybe() + defer call.Unset() + + dbCall := meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return([]*model.Database{ + { + ID: 1, + Name: "default", + }, + }, nil) + defer dbCall.Unset() + metrics.RootCoordTtDelay.Reset() Params.Save(Params.QuotaConfig.TtProtectionEnabled.Key, "true") defer Params.Reset(Params.QuotaConfig.TtProtectionEnabled.Key) @@ -1020,7 +1396,7 @@ func (s *QuotaCenterSuite) TestNodeOffline() { }, nil).Once() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() s.Require().NoError(err) quotaCenter.getTimeTickDelayFactor(tsoutil.ComposeTSByTime(time.Now(), 0)) @@ -1071,7 +1447,7 @@ func (s *QuotaCenterSuite) TestNodeOffline() { Response: resp, }, nil).Once() - err = quotaCenter.syncMetrics() + err = quotaCenter.collectMetrics() s.Require().NoError(err) quotaCenter.getTimeTickDelayFactor(tsoutil.ComposeTSByTime(time.Now(), 0)) @@ -1081,3 +1457,391 @@ func (s *QuotaCenterSuite) TestNodeOffline() { func TestQuotaCenterSuite(t *testing.T) { suite.Run(t, new(QuotaCenterSuite)) } + +func TestUpdateLimiter(t *testing.T) { + t.Run("nil node", func(t *testing.T) { + updateLimiter(nil, nil, internalpb.RateScope_Database, dql) + }) + + t.Run("normal op", func(t *testing.T) { + node := interalratelimitutil.NewRateLimiterNode(internalpb.RateScope_Collection) + node.GetLimiters().Insert(internalpb.RateType_DQLSearch, ratelimitutil.NewLimiter(5, 5)) + newLimit := ratelimitutil.NewLimiter(10, 10) + updateLimiter(node, newLimit, internalpb.RateScope_Collection, dql) + + searchLimit, _ := node.GetLimiters().Get(internalpb.RateType_DQLSearch) + assert.Equal(t, Limit(10), searchLimit.Limit()) + }) +} + +func TestGetRateType(t *testing.T) { + t.Run("invalid rate type", func(t *testing.T) { + assert.Panics(t, func() { + getRateTypes(internalpb.RateScope(100), ddl) + }) + }) + + t.Run("ddl cluster scope", func(t *testing.T) { + a := getRateTypes(internalpb.RateScope_Cluster, ddl) + assert.Equal(t, 5, a.Len()) + }) +} + +func TestCalculateReadRates(t *testing.T) { + paramtable.Init() + ctx := context.Background() + + t.Run("cool off db", func(t *testing.T) { + qc := mocks.NewMockQueryCoordClient(t) + meta := mockrootcoord.NewIMetaTable(t) + pcm := proxyutil.NewMockProxyClientManager(t) + dc := mocks.NewMockDataCoordClient(t) + core, _ := NewCore(ctx, nil) + core.tsoAllocator = newMockTsoAllocator() + + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + + Params.Save(Params.QuotaConfig.ForceDenyReading.Key, "false") + defer Params.Reset(Params.QuotaConfig.ForceDenyReading.Key) + + Params.Save(Params.QuotaConfig.ResultProtectionEnabled.Key, "true") + defer Params.Reset(Params.QuotaConfig.ResultProtectionEnabled.Key) + Params.Save(Params.QuotaConfig.MaxReadResultRate.Key, "50") + defer Params.Reset(Params.QuotaConfig.MaxReadResultRate.Key) + Params.Save(Params.QuotaConfig.MaxReadResultRatePerDB.Key, "30") + defer Params.Reset(Params.QuotaConfig.MaxReadResultRatePerDB.Key) + Params.Save(Params.QuotaConfig.MaxReadResultRatePerCollection.Key, "20") + defer Params.Reset(Params.QuotaConfig.MaxReadResultRatePerCollection.Key) + Params.Save(Params.QuotaConfig.CoolOffSpeed.Key, "0.8") + defer Params.Reset(Params.QuotaConfig.CoolOffSpeed.Key) + + Params.Save(Params.QuotaConfig.DQLLimitEnabled.Key, "true") + defer Params.Reset(Params.QuotaConfig.DQLLimitEnabled.Key) + Params.Save(Params.QuotaConfig.DQLMaxSearchRate.Key, "500") + defer Params.Reset(Params.QuotaConfig.DQLMaxSearchRate.Key) + Params.Save(Params.QuotaConfig.DQLMaxSearchRatePerDB.Key, "500") + defer Params.Reset(Params.QuotaConfig.DQLMaxSearchRatePerDB.Key) + Params.Save(Params.QuotaConfig.DQLMaxSearchRatePerCollection.Key, "500") + defer Params.Reset(Params.QuotaConfig.DQLMaxSearchRatePerCollection.Key) + + quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.dbs = typeutil.NewConcurrentMap[string, int64]() + quotaCenter.collections = typeutil.NewConcurrentMap[string, int64]() + quotaCenter.collectionIDToDBID = typeutil.NewConcurrentMap[int64, int64]() + quotaCenter.dbs.Insert("default", 1) + quotaCenter.dbs.Insert("test", 2) + quotaCenter.collections.Insert("1.col1", 10) + quotaCenter.collections.Insert("2.col2", 20) + quotaCenter.collections.Insert("2.col3", 30) + quotaCenter.collectionIDToDBID.Insert(10, 1) + quotaCenter.collectionIDToDBID.Insert(20, 2) + quotaCenter.collectionIDToDBID.Insert(30, 2) + + searchLabel := internalpb.RateType_DQLSearch.String() + quotaCenter.queryNodeMetrics = map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics{} + quotaCenter.proxyMetrics = map[UniqueID]*metricsinfo.ProxyQuotaMetrics{ + 1: { + Rms: []metricsinfo.RateMetric{ + { + Label: metricsinfo.ReadResultThroughput, + Rate: 40 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetDBSubLabel("default")), + Rate: 20 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetCollectionSubLabel("default", "col1")), + Rate: 15 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetDBSubLabel("test")), + Rate: 20 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetCollectionSubLabel("test", "col2")), + Rate: 10 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetCollectionSubLabel("test", "col3")), + Rate: 10 * 1024 * 1024, + }, + { + Label: searchLabel, + Rate: 20, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetDBSubLabel("default")), + Rate: 10, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetDBSubLabel("test")), + Rate: 10, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetCollectionSubLabel("default", "col1")), + Rate: 10, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetCollectionSubLabel("test", "col2")), + Rate: 5, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetCollectionSubLabel("test", "col3")), + Rate: 5, + }, + }, + }, + 2: { + Rms: []metricsinfo.RateMetric{ + { + Label: metricsinfo.ReadResultThroughput, + Rate: 20 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetDBSubLabel("default")), + Rate: 20 * 1024 * 1024, + }, + { + Label: ratelimitutil.FormatSubLabel(metricsinfo.ReadResultThroughput, ratelimitutil.GetCollectionSubLabel("default", "col1")), + Rate: 15 * 1024 * 1024, + }, + { + Label: searchLabel, + Rate: 20, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetDBSubLabel("default")), + Rate: 20, + }, + { + Label: ratelimitutil.FormatSubLabel(searchLabel, ratelimitutil.GetCollectionSubLabel("default", "col1")), + Rate: 10, + }, + }, + }, + } + + quotaCenter.rateLimiter.GetRootLimiters().GetLimiters().Insert(internalpb.RateType_DQLSearch, ratelimitutil.NewLimiter(500, 500)) + quotaCenter.rateLimiter.GetOrCreateCollectionLimiters(1, 10, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps)) + quotaCenter.rateLimiter.GetOrCreateCollectionLimiters(2, 20, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps)) + quotaCenter.rateLimiter.GetOrCreateCollectionLimiters(2, 30, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps)) + + err := quotaCenter.calculateReadRates() + assert.NoError(t, err) + + checkRate := func(rateNode *interalratelimitutil.RateLimiterNode, expectValue float64) { + searchRate, ok := rateNode.GetLimiters().Get(internalpb.RateType_DQLSearch) + assert.True(t, ok) + assert.EqualValues(t, expectValue, searchRate.Limit()) + } + + { + checkRate(quotaCenter.rateLimiter.GetRootLimiters(), float64(32)) // (20 + 20) * 0.8 + checkRate(quotaCenter.rateLimiter.GetDatabaseLimiters(1), float64(24)) // (20 + 10) * 0.8 + checkRate(quotaCenter.rateLimiter.GetDatabaseLimiters(2), float64(500)) // not cool off + checkRate(quotaCenter.rateLimiter.GetCollectionLimiters(1, 10), float64(16)) // (10 + 10) * 0.8 + checkRate(quotaCenter.rateLimiter.GetCollectionLimiters(2, 20), float64(500)) // not cool off + checkRate(quotaCenter.rateLimiter.GetCollectionLimiters(2, 30), float64(500)) // not cool off + } + }) +} + +func TestResetAllCurrentRates(t *testing.T) { + paramtable.Init() + ctx := context.Background() + + qc := mocks.NewMockQueryCoordClient(t) + meta := mockrootcoord.NewIMetaTable(t) + pcm := proxyutil.NewMockProxyClientManager(t) + dc := mocks.NewMockDataCoordClient(t) + core, _ := NewCore(ctx, nil) + core.tsoAllocator = newMockTsoAllocator() + + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + + quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.readableCollections = map[int64]map[int64][]int64{ + 1: {}, + } + quotaCenter.writableCollections = map[int64]map[int64][]int64{ + 2: { + 100: []int64{}, + }, + } + err := quotaCenter.resetAllCurrentRates() + assert.NoError(t, err) + + db1 := quotaCenter.rateLimiter.GetDatabaseLimiters(1) + assert.NotNil(t, db1) + db2 := quotaCenter.rateLimiter.GetDatabaseLimiters(2) + assert.NotNil(t, db2) + collection := quotaCenter.rateLimiter.GetCollectionLimiters(2, 100) + assert.NotNil(t, collection) +} + +func TestCheckDiskQuota(t *testing.T) { + paramtable.Init() + ctx := context.Background() + + t.Run("disk quota check disable", func(t *testing.T) { + qc := mocks.NewMockQueryCoordClient(t) + meta := mockrootcoord.NewIMetaTable(t) + pcm := proxyutil.NewMockProxyClientManager(t) + dc := mocks.NewMockDataCoordClient(t) + core, _ := NewCore(ctx, nil) + core.tsoAllocator = newMockTsoAllocator() + quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + + Params.Save(Params.QuotaConfig.DiskProtectionEnabled.Key, "false") + defer Params.Reset(Params.QuotaConfig.DiskProtectionEnabled.Key) + err := quotaCenter.checkDiskQuota() + assert.NoError(t, err) + }) + + t.Run("disk quota check enable", func(t *testing.T) { + Params.Save(Params.QuotaConfig.DiskProtectionEnabled.Key, "true") + defer Params.Reset(Params.QuotaConfig.DiskProtectionEnabled.Key) + Params.Save(Params.QuotaConfig.DiskQuota.Key, "150") + defer Params.Reset(Params.QuotaConfig.DiskQuota.Key) + Params.Save(Params.QuotaConfig.DiskQuotaPerDB.Key, "10") + defer Params.Reset(Params.QuotaConfig.DiskQuotaPerDB.Key) + Params.Save(Params.QuotaConfig.DiskQuotaPerCollection.Key, "10") + defer Params.Reset(Params.QuotaConfig.DiskQuotaPerCollection.Key) + Params.Save(Params.QuotaConfig.DiskQuotaPerPartition.Key, "10") + defer Params.Reset(Params.QuotaConfig.DiskQuotaPerPartition.Key) + + Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true") + defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key) + Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, "10") + defer Params.Reset(Params.QuotaConfig.DMLMaxInsertRate.Key) + Params.Save(Params.QuotaConfig.DMLMaxInsertRatePerDB.Key, "10") + defer Params.Reset(Params.QuotaConfig.DMLMaxInsertRatePerDB.Key) + Params.Save(Params.QuotaConfig.DMLMaxInsertRatePerCollection.Key, "10") + defer Params.Reset(Params.QuotaConfig.DMLMaxInsertRatePerCollection.Key) + Params.Save(Params.QuotaConfig.DMLMaxInsertRatePerPartition.Key, "10") + defer Params.Reset(Params.QuotaConfig.DMLMaxInsertRatePerPartition.Key) + + qc := mocks.NewMockQueryCoordClient(t) + meta := mockrootcoord.NewIMetaTable(t) + pcm := proxyutil.NewMockProxyClientManager(t) + dc := mocks.NewMockDataCoordClient(t) + core, _ := NewCore(ctx, nil) + core.tsoAllocator = newMockTsoAllocator() + meta.EXPECT().GetCollectionByIDWithMaxTs(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + + quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + quotaCenter.rateLimiter.GetRootLimiters().GetLimiters().Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(500, 500)) + quotaCenter.rateLimiter.GetOrCreatePartitionLimiters(1, 10, 100, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps), + newParamLimiterFunc(internalpb.RateScope_Partition, allOps), + ) + quotaCenter.rateLimiter.GetOrCreatePartitionLimiters(1, 10, 101, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps), + newParamLimiterFunc(internalpb.RateScope_Partition, allOps), + ) + quotaCenter.rateLimiter.GetOrCreatePartitionLimiters(2, 20, 200, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps), + newParamLimiterFunc(internalpb.RateScope_Partition, allOps), + ) + quotaCenter.rateLimiter.GetOrCreatePartitionLimiters(2, 30, 300, + newParamLimiterFunc(internalpb.RateScope_Database, allOps), + newParamLimiterFunc(internalpb.RateScope_Collection, allOps), + newParamLimiterFunc(internalpb.RateScope_Partition, allOps), + ) + + quotaCenter.dataCoordMetrics = &metricsinfo.DataCoordQuotaMetrics{ + TotalBinlogSize: 200 * 1024 * 1024, + CollectionBinlogSize: map[int64]int64{ + 10: 15 * 1024 * 1024, + 20: 6 * 1024 * 1024, + 30: 6 * 1024 * 1024, + }, + PartitionsBinlogSize: map[int64]map[int64]int64{ + 10: { + 100: 10 * 1024 * 1024, + 101: 5 * 1024 * 1024, + }, + 20: { + 200: 6 * 1024 * 1024, + }, + 30: { + 300: 6 * 1024 * 1024, + }, + }, + } + quotaCenter.collectionIDToDBID = typeutil.NewConcurrentMap[int64, int64]() + quotaCenter.collectionIDToDBID.Insert(10, 1) + quotaCenter.collectionIDToDBID.Insert(20, 2) + quotaCenter.collectionIDToDBID.Insert(30, 2) + + checkRate := func(rateNode *interalratelimitutil.RateLimiterNode, expectValue float64) { + insertRate, ok := rateNode.GetLimiters().Get(internalpb.RateType_DMLInsert) + assert.True(t, ok) + assert.EqualValues(t, expectValue, insertRate.Limit()) + } + + configQuotaValue := float64(10 * 1024 * 1024) + + { + err := quotaCenter.checkDiskQuota() + assert.NoError(t, err) + checkRate(quotaCenter.rateLimiter.GetRootLimiters(), 0) + } + + { + Params.Save(Params.QuotaConfig.DiskQuota.Key, "999") + err := quotaCenter.checkDiskQuota() + assert.NoError(t, err) + checkRate(quotaCenter.rateLimiter.GetDatabaseLimiters(1), 0) + checkRate(quotaCenter.rateLimiter.GetDatabaseLimiters(2), 0) + checkRate(quotaCenter.rateLimiter.GetCollectionLimiters(1, 10), 0) + checkRate(quotaCenter.rateLimiter.GetCollectionLimiters(2, 20), configQuotaValue) + checkRate(quotaCenter.rateLimiter.GetCollectionLimiters(2, 30), configQuotaValue) + checkRate(quotaCenter.rateLimiter.GetPartitionLimiters(1, 10, 100), 0) + checkRate(quotaCenter.rateLimiter.GetPartitionLimiters(1, 10, 101), configQuotaValue) + checkRate(quotaCenter.rateLimiter.GetPartitionLimiters(2, 20, 200), configQuotaValue) + checkRate(quotaCenter.rateLimiter.GetPartitionLimiters(2, 30, 300), configQuotaValue) + } + }) +} + +func TestTORequestLimiter(t *testing.T) { + ctx := context.Background() + qc := mocks.NewMockQueryCoordClient(t) + meta := mockrootcoord.NewIMetaTable(t) + pcm := proxyutil.NewMockProxyClientManager(t) + dc := mocks.NewMockDataCoordClient(t) + core, _ := NewCore(ctx, nil) + core.tsoAllocator = newMockTsoAllocator() + + quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) + pcm.EXPECT().GetProxyCount().Return(2) + limitNode := interalratelimitutil.NewRateLimiterNode(internalpb.RateScope_Cluster) + a := ratelimitutil.NewLimiter(500, 500) + a.SetLimit(200) + b := ratelimitutil.NewLimiter(100, 100) + limitNode.GetLimiters().Insert(internalpb.RateType_DMLInsert, a) + limitNode.GetLimiters().Insert(internalpb.RateType_DMLDelete, b) + limitNode.GetLimiters().Insert(internalpb.RateType_DMLBulkLoad, GetInfLimiter(internalpb.RateType_DMLBulkLoad)) + limitNode.GetQuotaStates().Insert(milvuspb.QuotaState_DenyToRead, commonpb.ErrorCode_ForceDeny) + + quotaCenter.rateAllocateStrategy = Average + proxyLimit := quotaCenter.toRequestLimiter(limitNode) + assert.Equal(t, 1, len(proxyLimit.Rates)) + assert.Equal(t, internalpb.RateType_DMLInsert, proxyLimit.Rates[0].Rt) + assert.Equal(t, float64(100), proxyLimit.Rates[0].R) + assert.Equal(t, 1, len(proxyLimit.States)) + assert.Equal(t, milvuspb.QuotaState_DenyToRead, proxyLimit.States[0]) + assert.Equal(t, 1, len(proxyLimit.Codes)) + assert.Equal(t, commonpb.ErrorCode_ForceDeny, proxyLimit.Codes[0]) +} diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index a7621eb613..7be954d95b 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -2669,6 +2669,40 @@ func (c *Core) RenameCollection(ctx context.Context, req *milvuspb.RenameCollect return merr.Success(), nil } +func (c *Core) DescribeDatabase(ctx context.Context, req *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil + } + + log := log.Ctx(ctx).With(zap.String("dbName", req.GetDbName())) + log.Info("received request to describe database ") + + metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.TotalLabel).Inc() + tr := timerecord.NewTimeRecorder("DescribeDatabase") + t := &describeDBTask{ + baseTask: newBaseTask(ctx, c), + Req: req, + } + + if err := c.scheduler.AddTask(t); err != nil { + log.Warn("failed to enqueue request to describe database", zap.Error(err)) + metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.FailLabel).Inc() + return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil + } + + if err := t.WaitToFinish(); err != nil { + log.Warn("failed to describe database", zap.Uint64("ts", t.GetTs()), zap.Error(err)) + metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.FailLabel).Inc() + return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil + } + + metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.SuccessLabel).Inc() + metrics.RootCoordDDLReqLatency.WithLabelValues("DescribeDatabase").Observe(float64(tr.ElapseSpan().Milliseconds())) + + log.Info("done to describe database", zap.Uint64("ts", t.GetTs())) + return t.Rsp, nil +} + func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.CheckHealthResponse{ diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index 2061661e44..570f5d6f6b 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -1443,6 +1443,43 @@ func TestRootCoord_CheckHealth(t *testing.T) { }) } +func TestRootCoord_DescribeDatabase(t *testing.T) { + t.Run("not healthy", func(t *testing.T) { + ctx := context.Background() + c := newTestCore(withAbnormalCode()) + resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{}) + assert.NoError(t, err) + assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil)) + }) + + t.Run("add task failed", func(t *testing.T) { + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withInvalidScheduler()) + resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{}) + assert.NoError(t, err) + assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil)) + }) + + t.Run("execute task failed", func(t *testing.T) { + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withTaskFailScheduler()) + resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{}) + assert.NoError(t, err) + assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil)) + }) + + t.Run("run ok", func(t *testing.T) { + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withValidScheduler()) + resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{}) + assert.NoError(t, err) + assert.NoError(t, merr.CheckRPCCall(resp.GetStatus(), nil)) + }) +} + func TestRootCoord_RBACError(t *testing.T) { ctx := context.Background() c := newTestCore(withHealthyCode(), withInvalidMeta()) diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index 59f5b49a96..1ac9a4d4b9 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -138,13 +138,16 @@ func getCollectionRateLimitConfigDefaultValue(configKey string) float64 { return Params.QuotaConfig.DQLMinSearchRatePerCollection.GetAsFloat() case common.CollectionDiskQuotaKey: return Params.QuotaConfig.DiskQuotaPerCollection.GetAsFloat() - default: return float64(0) } } func getCollectionRateLimitConfig(properties map[string]string, configKey string) float64 { + return getRateLimitConfig(properties, configKey, getCollectionRateLimitConfigDefaultValue(configKey)) +} + +func getRateLimitConfig(properties map[string]string, configKey string, configValue float64) float64 { megaBytes2Bytes := func(v float64) float64 { return v * 1024.0 * 1024.0 } @@ -189,15 +192,15 @@ func getCollectionRateLimitConfig(properties map[string]string, configKey string log.Warn("invalid configuration for collection dml rate", zap.String("config item", configKey), zap.String("config value", v)) - return getCollectionRateLimitConfigDefaultValue(configKey) + return configValue } rateInBytes := toBytesIfNecessary(rate) if rateInBytes < 0 { - return getCollectionRateLimitConfigDefaultValue(configKey) + return configValue } return rateInBytes } - return getCollectionRateLimitConfigDefaultValue(configKey) + return configValue } diff --git a/internal/rootcoord/util_test.go b/internal/rootcoord/util_test.go index de03271400..b3a42e3178 100644 --- a/internal/rootcoord/util_test.go +++ b/internal/rootcoord/util_test.go @@ -292,3 +292,27 @@ func Test_getCollectionRateLimitConfig(t *testing.T) { }) } } + +func TestGetRateLimitConfigErr(t *testing.T) { + key := common.CollectionQueryRateMaxKey + t.Run("negative value", func(t *testing.T) { + v := getRateLimitConfig(map[string]string{ + key: "-1", + }, key, 1) + assert.EqualValues(t, 1, v) + }) + + t.Run("valid value", func(t *testing.T) { + v := getRateLimitConfig(map[string]string{ + key: "1", + }, key, 100) + assert.EqualValues(t, 1, v) + }) + + t.Run("not exist value", func(t *testing.T) { + v := getRateLimitConfig(map[string]string{ + key: "1", + }, "b", 100) + assert.EqualValues(t, 100, v) + }) +} diff --git a/internal/types/types.go b/internal/types/types.go index 8264179a6d..27acc7cac3 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -37,7 +37,7 @@ import ( // If Limit function return true, the request will be rejected. // Otherwise, the request will pass. Limit also returns limit of limiter. type Limiter interface { - Check(collectionIDs []int64, rt internalpb.RateType, n int) error + Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error } // Component is the interface all services implement diff --git a/internal/util/mock/grpc_rootcoord_client.go b/internal/util/mock/grpc_rootcoord_client.go index 4603c32c6c..e14c6da179 100644 --- a/internal/util/mock/grpc_rootcoord_client.go +++ b/internal/util/mock/grpc_rootcoord_client.go @@ -37,6 +37,10 @@ type GrpcRootCoordClient struct { Err error } +func (m *GrpcRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) { + return &rootcoordpb.DescribeDatabaseResponse{}, m.Err +} + func (m *GrpcRootCoordClient) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } diff --git a/internal/util/quota/quota_constant.go b/internal/util/quota/quota_constant.go new file mode 100644 index 0000000000..0302e1fddc --- /dev/null +++ b/internal/util/quota/quota_constant.go @@ -0,0 +1,106 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package quota + +import ( + "math" + "sync" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +var ( + initOnce sync.Once + limitConfigMap map[internalpb.RateScope]map[internalpb.RateType]*paramtable.ParamItem +) + +func initLimitConfigMaps() { + initOnce.Do(func() { + quotaConfig := ¶mtable.Get().QuotaConfig + limitConfigMap = map[internalpb.RateScope]map[internalpb.RateType]*paramtable.ParamItem{ + internalpb.RateScope_Cluster: { + internalpb.RateType_DDLCollection: "aConfig.DDLCollectionRate, + internalpb.RateType_DDLPartition: "aConfig.DDLPartitionRate, + internalpb.RateType_DDLIndex: "aConfig.MaxIndexRate, + internalpb.RateType_DDLFlush: "aConfig.MaxFlushRate, + internalpb.RateType_DDLCompaction: "aConfig.MaxCompactionRate, + internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRate, + internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRate, + internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRate, + internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRate, + internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRate, + internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRate, + }, + internalpb.RateScope_Database: { + internalpb.RateType_DDLCollection: "aConfig.DDLCollectionRatePerDB, + internalpb.RateType_DDLPartition: "aConfig.DDLPartitionRatePerDB, + internalpb.RateType_DDLIndex: "aConfig.MaxIndexRatePerDB, + internalpb.RateType_DDLFlush: "aConfig.MaxFlushRatePerDB, + internalpb.RateType_DDLCompaction: "aConfig.MaxCompactionRatePerDB, + internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRatePerDB, + internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRatePerDB, + internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRatePerDB, + internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRatePerDB, + internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRatePerDB, + internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRatePerDB, + }, + internalpb.RateScope_Collection: { + internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRatePerCollection, + internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRatePerCollection, + internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRatePerCollection, + internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRatePerCollection, + internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRatePerCollection, + internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRatePerCollection, + internalpb.RateType_DDLFlush: "aConfig.MaxFlushRatePerCollection, + }, + internalpb.RateScope_Partition: { + internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRatePerPartition, + internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRatePerPartition, + internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRatePerPartition, + internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRatePerPartition, + internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRatePerPartition, + internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRatePerPartition, + }, + } + }) +} + +func GetQuotaConfigMap(scope internalpb.RateScope) map[internalpb.RateType]*paramtable.ParamItem { + initLimitConfigMaps() + configMap, ok := limitConfigMap[scope] + if !ok { + log.Warn("Unknown rate scope", zap.Any("scope", scope)) + return make(map[internalpb.RateType]*paramtable.ParamItem) + } + return configMap +} + +func GetQuotaValue(scope internalpb.RateScope, rateType internalpb.RateType, params *paramtable.ComponentParam) float64 { + configMap := GetQuotaConfigMap(scope) + config, ok := configMap[rateType] + if !ok { + log.Warn("Unknown rate type", zap.Any("rateType", rateType)) + return math.MaxFloat64 + } + return config.GetAsFloat() +} diff --git a/internal/util/quota/quota_constant_test.go b/internal/util/quota/quota_constant_test.go new file mode 100644 index 0000000000..f8476bdf11 --- /dev/null +++ b/internal/util/quota/quota_constant_test.go @@ -0,0 +1,91 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package quota + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestGetQuotaConfigMap(t *testing.T) { + paramtable.Init() + { + m := GetQuotaConfigMap(internalpb.RateScope_Cluster) + assert.Equal(t, 11, len(m)) + } + { + m := GetQuotaConfigMap(internalpb.RateScope_Database) + assert.Equal(t, 11, len(m)) + } + { + m := GetQuotaConfigMap(internalpb.RateScope_Collection) + assert.Equal(t, 7, len(m)) + } + { + m := GetQuotaConfigMap(internalpb.RateScope_Partition) + assert.Equal(t, 6, len(m)) + } + { + m := GetQuotaConfigMap(internalpb.RateScope(1000)) + assert.Equal(t, 0, len(m)) + } +} + +func TestGetQuotaValue(t *testing.T) { + paramtable.Init() + param := paramtable.Get() + param.Save(param.QuotaConfig.DDLLimitEnabled.Key, "true") + defer param.Reset(param.QuotaConfig.DDLLimitEnabled.Key) + param.Save(param.QuotaConfig.DMLLimitEnabled.Key, "true") + defer param.Reset(param.QuotaConfig.DMLLimitEnabled.Key) + + t.Run("cluster", func(t *testing.T) { + param.Save(param.QuotaConfig.DDLCollectionRate.Key, "10") + defer param.Reset(param.QuotaConfig.DDLCollectionRate.Key) + v := GetQuotaValue(internalpb.RateScope_Cluster, internalpb.RateType_DDLCollection, param) + assert.EqualValues(t, 10, v) + }) + t.Run("database", func(t *testing.T) { + param.Save(param.QuotaConfig.DDLCollectionRatePerDB.Key, "10") + defer param.Reset(param.QuotaConfig.DDLCollectionRatePerDB.Key) + v := GetQuotaValue(internalpb.RateScope_Database, internalpb.RateType_DDLCollection, param) + assert.EqualValues(t, 10, v) + }) + t.Run("collection", func(t *testing.T) { + param.Save(param.QuotaConfig.DMLMaxInsertRatePerCollection.Key, "10") + defer param.Reset(param.QuotaConfig.DMLMaxInsertRatePerCollection.Key) + v := GetQuotaValue(internalpb.RateScope_Collection, internalpb.RateType_DMLInsert, param) + assert.EqualValues(t, 10*1024*1024, v) + }) + t.Run("partition", func(t *testing.T) { + param.Save(param.QuotaConfig.DMLMaxInsertRatePerPartition.Key, "10") + defer param.Reset(param.QuotaConfig.DMLMaxInsertRatePerPartition.Key) + v := GetQuotaValue(internalpb.RateScope_Partition, internalpb.RateType_DMLInsert, param) + assert.EqualValues(t, 10*1024*1024, v) + }) + t.Run("unknown", func(t *testing.T) { + v := GetQuotaValue(internalpb.RateScope(1000), internalpb.RateType(1000), param) + assert.EqualValues(t, math.MaxFloat64, v) + }) +} diff --git a/internal/util/ratelimitutil/rate_limiter_tree.go b/internal/util/ratelimitutil/rate_limiter_tree.go new file mode 100644 index 0000000000..a2db0eb3e0 --- /dev/null +++ b/internal/util/ratelimitutil/rate_limiter_tree.go @@ -0,0 +1,336 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimitutil + +import ( + "fmt" + "sync" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type RateLimiterNode struct { + limiters *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter] + quotaStates *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode] + level internalpb.RateScope + + // db id, collection id or partition id, cluster id is 0 for the cluster level + id int64 + + // children will be databases if current level is cluster + // children will be collections if current level is database + // children will be partitions if current level is collection + children *typeutil.ConcurrentMap[int64, *RateLimiterNode] +} + +func NewRateLimiterNode(level internalpb.RateScope) *RateLimiterNode { + rln := &RateLimiterNode{ + limiters: typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter](), + quotaStates: typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode](), + children: typeutil.NewConcurrentMap[int64, *RateLimiterNode](), + level: level, + } + return rln +} + +func (rln *RateLimiterNode) Level() internalpb.RateScope { + return rln.level +} + +// Limit returns true, the request will be rejected. +// Otherwise, the request will pass. +func (rln *RateLimiterNode) Limit(rt internalpb.RateType, n int) (bool, float64) { + limit, ok := rln.limiters.Get(rt) + if !ok { + return false, -1 + } + return !limit.AllowN(time.Now(), n), float64(limit.Limit()) +} + +func (rln *RateLimiterNode) Cancel(rt internalpb.RateType, n int) { + limit, ok := rln.limiters.Get(rt) + if !ok { + return + } + limit.Cancel(n) +} + +func (rln *RateLimiterNode) Check(rt internalpb.RateType, n int) error { + limit, rate := rln.Limit(rt, n) + if rate == 0 { + return rln.GetQuotaExceededError(rt) + } + if limit { + return rln.GetRateLimitError(rate) + } + return nil +} + +func (rln *RateLimiterNode) GetQuotaExceededError(rt internalpb.RateType) error { + switch rt { + case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad: + if errCode, ok := rln.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok { + return merr.WrapErrServiceQuotaExceeded(ratelimitutil.GetQuotaErrorString(errCode)) + } + case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery: + if errCode, ok := rln.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok { + return merr.WrapErrServiceQuotaExceeded(ratelimitutil.GetQuotaErrorString(errCode)) + } + } + return merr.WrapErrServiceQuotaExceeded(fmt.Sprintf("rate type: %s", rt.String())) +} + +func (rln *RateLimiterNode) GetRateLimitError(rate float64) error { + return merr.WrapErrServiceRateLimit(rate, "request is rejected by grpc RateLimiter middleware, please retry later") +} + +func TraverseRateLimiterTree(root *RateLimiterNode, fn1 func(internalpb.RateType, *ratelimitutil.Limiter) bool, + fn2 func(node *RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool, +) { + if fn1 != nil { + root.limiters.Range(fn1) + } + + if fn2 != nil { + root.quotaStates.Range(func(state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool { + return fn2(root, state, errCode) + }) + } + root.GetChildren().Range(func(key int64, child *RateLimiterNode) bool { + TraverseRateLimiterTree(child, fn1, fn2) + return true + }) +} + +func (rln *RateLimiterNode) AddChild(key int64, child *RateLimiterNode) { + rln.children.Insert(key, child) +} + +func (rln *RateLimiterNode) GetChild(key int64) *RateLimiterNode { + n, _ := rln.children.Get(key) + return n +} + +func (rln *RateLimiterNode) GetChildren() *typeutil.ConcurrentMap[int64, *RateLimiterNode] { + return rln.children +} + +func (rln *RateLimiterNode) GetLimiters() *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter] { + return rln.limiters +} + +func (rln *RateLimiterNode) SetLimiters(new *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]) { + rln.limiters = new +} + +func (rln *RateLimiterNode) GetQuotaStates() *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode] { + return rln.quotaStates +} + +func (rln *RateLimiterNode) SetQuotaStates(new *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]) { + rln.quotaStates = new +} + +func (rln *RateLimiterNode) GetID() int64 { + return rln.id +} + +// RateLimiterTree is implemented based on RateLimiterNode to operate multilevel rate limiters +// +// it contains the following four levels generally: +// +// -> global level +// -> database level +// -> collection level +// -> partition levelearl +type RateLimiterTree struct { + root *RateLimiterNode + mu sync.RWMutex +} + +// NewRateLimiterTree returns a new RateLimiterTree. +func NewRateLimiterTree(root *RateLimiterNode) *RateLimiterTree { + return &RateLimiterTree{root: root} +} + +// GetRootLimiters get root limiters +func (m *RateLimiterTree) GetRootLimiters() *RateLimiterNode { + return m.root +} + +func (m *RateLimiterTree) ClearInvalidLimiterNode(req *proxypb.LimiterNode) { + m.mu.Lock() + defer m.mu.Unlock() + + reqDBLimits := req.GetChildren() + removeDBLimits := make([]int64, 0) + m.GetRootLimiters().GetChildren().Range(func(key int64, _ *RateLimiterNode) bool { + if _, ok := reqDBLimits[key]; !ok { + removeDBLimits = append(removeDBLimits, key) + } + return true + }) + for _, dbID := range removeDBLimits { + m.GetRootLimiters().GetChildren().Remove(dbID) + } + + m.GetRootLimiters().GetChildren().Range(func(dbID int64, dbNode *RateLimiterNode) bool { + reqCollectionLimits := reqDBLimits[dbID].GetChildren() + removeCollectionLimits := make([]int64, 0) + dbNode.GetChildren().Range(func(key int64, _ *RateLimiterNode) bool { + if _, ok := reqCollectionLimits[key]; !ok { + removeCollectionLimits = append(removeCollectionLimits, key) + } + return true + }) + for _, collectionID := range removeCollectionLimits { + dbNode.GetChildren().Remove(collectionID) + } + return true + }) + + m.GetRootLimiters().GetChildren().Range(func(dbID int64, dbNode *RateLimiterNode) bool { + dbNode.GetChildren().Range(func(collectionID int64, collectionNode *RateLimiterNode) bool { + reqPartitionLimits := reqDBLimits[dbID].GetChildren()[collectionID].GetChildren() + removePartitionLimits := make([]int64, 0) + collectionNode.GetChildren().Range(func(key int64, _ *RateLimiterNode) bool { + if _, ok := reqPartitionLimits[key]; !ok { + removePartitionLimits = append(removePartitionLimits, key) + } + return true + }) + for _, partitionID := range removePartitionLimits { + collectionNode.GetChildren().Remove(partitionID) + } + return true + }) + return true + }) +} + +func (m *RateLimiterTree) GetDatabaseLimiters(dbID int64) *RateLimiterNode { + m.mu.RLock() + defer m.mu.RUnlock() + return m.root.GetChild(dbID) +} + +// GetOrCreateDatabaseLimiters get limiter of database level, or create a database limiter if it doesn't exist. +func (m *RateLimiterTree) GetOrCreateDatabaseLimiters(dbID int64, newDBRateLimiter func() *RateLimiterNode) *RateLimiterNode { + dbRateLimiters := m.GetDatabaseLimiters(dbID) + if dbRateLimiters != nil { + return dbRateLimiters + } + m.mu.Lock() + defer m.mu.Unlock() + if cur := m.root.GetChild(dbID); cur != nil { + return cur + } + dbRateLimiters = newDBRateLimiter() + dbRateLimiters.id = dbID + m.root.AddChild(dbID, dbRateLimiters) + return dbRateLimiters +} + +func (m *RateLimiterTree) GetCollectionLimiters(dbID, collectionID int64) *RateLimiterNode { + m.mu.RLock() + defer m.mu.RUnlock() + dbRateLimiters := m.root.GetChild(dbID) + + // database rate limiter not found + if dbRateLimiters == nil { + return nil + } + return dbRateLimiters.GetChild(collectionID) +} + +// GetOrCreateCollectionLimiters create limiter of collection level for all rate types and rate scopes. +// create a database rate limiters if db rate limiter does not exist +func (m *RateLimiterTree) GetOrCreateCollectionLimiters(dbID, collectionID int64, + newDBRateLimiter func() *RateLimiterNode, newCollectionRateLimiter func() *RateLimiterNode, +) *RateLimiterNode { + collectionRateLimiters := m.GetCollectionLimiters(dbID, collectionID) + if collectionRateLimiters != nil { + return collectionRateLimiters + } + + dbRateLimiters := m.GetOrCreateDatabaseLimiters(dbID, newDBRateLimiter) + m.mu.Lock() + defer m.mu.Unlock() + if cur := dbRateLimiters.GetChild(collectionID); cur != nil { + return cur + } + + collectionRateLimiters = newCollectionRateLimiter() + collectionRateLimiters.id = collectionID + dbRateLimiters.AddChild(collectionID, collectionRateLimiters) + return collectionRateLimiters +} + +// It checks if the rate limiters exist for the database, collection, and partition, +// returns the corresponding rate limiter tree. +func (m *RateLimiterTree) GetPartitionLimiters(dbID, collectionID, partitionID int64) *RateLimiterNode { + m.mu.RLock() + defer m.mu.RUnlock() + + dbRateLimiters := m.root.GetChild(dbID) + + // database rate limiter not found + if dbRateLimiters == nil { + return nil + } + + collectionRateLimiters := dbRateLimiters.GetChild(collectionID) + + // collection rate limiter not found + if collectionRateLimiters == nil { + return nil + } + + return collectionRateLimiters.GetChild(partitionID) +} + +// GetOrCreatePartitionLimiters create limiter of partition level for all rate types and rate scopes. +// create a database rate limiters if db rate limiter does not exist +// create a collection rate limiters if collection rate limiter does not exist +func (m *RateLimiterTree) GetOrCreatePartitionLimiters(dbID int64, collectionID int64, partitionID int64, + newDBRateLimiter func() *RateLimiterNode, newCollectionRateLimiter func() *RateLimiterNode, + newPartRateLimiter func() *RateLimiterNode, +) *RateLimiterNode { + partRateLimiters := m.GetPartitionLimiters(dbID, collectionID, partitionID) + if partRateLimiters != nil { + return partRateLimiters + } + + collectionRateLimiters := m.GetOrCreateCollectionLimiters(dbID, collectionID, newDBRateLimiter, newCollectionRateLimiter) + m.mu.Lock() + defer m.mu.Unlock() + if cur := collectionRateLimiters.GetChild(partitionID); cur != nil { + return cur + } + + partRateLimiters = newPartRateLimiter() + partRateLimiters.id = partitionID + collectionRateLimiters.AddChild(partitionID, partRateLimiters) + return partRateLimiters +} diff --git a/internal/util/ratelimitutil/rate_limiter_tree_test.go b/internal/util/ratelimitutil/rate_limiter_tree_test.go new file mode 100644 index 0000000000..0cdf8f2d7a --- /dev/null +++ b/internal/util/ratelimitutil/rate_limiter_tree_test.go @@ -0,0 +1,205 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimitutil + +import ( + "strings" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestRateLimiterNode_AddAndGetChild(t *testing.T) { + rln := NewRateLimiterNode(internalpb.RateScope_Cluster) + child := NewRateLimiterNode(internalpb.RateScope_Cluster) + + // Positive test case + rln.AddChild(1, child) + if rln.GetChild(1) != child { + t.Error("AddChild did not add the child correctly") + } + + // Negative test case + invalidChild := &RateLimiterNode{} + rln.AddChild(2, child) + if rln.GetChild(2) == invalidChild { + t.Error("AddChild added an invalid child") + } +} + +func TestTraverseRateLimiterTree(t *testing.T) { + limiters := typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]() + limiters.Insert(internalpb.RateType_DDLCollection, ratelimitutil.NewLimiter(ratelimitutil.Inf, 0)) + quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]() + quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny) + + root := NewRateLimiterNode(internalpb.RateScope_Cluster) + root.SetLimiters(limiters) + root.SetQuotaStates(quotaStates) + + // Add a child to the root node + child := NewRateLimiterNode(internalpb.RateScope_Cluster) + child.SetLimiters(limiters) + child.SetQuotaStates(quotaStates) + root.AddChild(123, child) + + // Add a child to the root node + child2 := NewRateLimiterNode(internalpb.RateScope_Cluster) + child2.SetLimiters(limiters) + child2.SetQuotaStates(quotaStates) + child.AddChild(123, child2) + + // Positive test case for fn1 + var fn1Count int + fn1 := func(rateType internalpb.RateType, limiter *ratelimitutil.Limiter) bool { + fn1Count++ + return true + } + + // Negative test case for fn2 + var fn2Count int + fn2 := func(node *RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool { + fn2Count++ + return true + } + + // Call TraverseRateLimiterTree with fn1 and fn2 + TraverseRateLimiterTree(root, fn1, fn2) + + assert.Equal(t, 3, fn1Count) + assert.Equal(t, 3, fn2Count) +} + +func TestRateLimiterNodeCancel(t *testing.T) { + t.Run("cancel not exist type", func(t *testing.T) { + limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster) + limitNode.Cancel(internalpb.RateType_DMLInsert, 10) + }) +} + +func TestRateLimiterNodeCheck(t *testing.T) { + t.Run("quota exceed", func(t *testing.T) { + limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster) + limitNode.limiters.Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(0, 0)) + limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny) + err := limitNode.Check(internalpb.RateType_DMLInsert, 10) + assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) + }) + + t.Run("rate limit", func(t *testing.T) { + limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster) + limitNode.limiters.Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(0.01, 0.01)) + { + err := limitNode.Check(internalpb.RateType_DMLInsert, 1) + assert.NoError(t, err) + } + { + err := limitNode.Check(internalpb.RateType_DMLInsert, 1) + assert.True(t, errors.Is(err, merr.ErrServiceRateLimit)) + } + }) +} + +func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) { + t.Run("write", func(t *testing.T) { + limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster) + limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny) + err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert) + assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) + // reference: ratelimitutil.GetQuotaErrorString(errCode) + assert.True(t, strings.Contains(err.Error(), "deactivated")) + }) + + t.Run("read", func(t *testing.T) { + limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster) + limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToRead, commonpb.ErrorCode_ForceDeny) + err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch) + assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) + // reference: ratelimitutil.GetQuotaErrorString(errCode) + assert.True(t, strings.Contains(err.Error(), "deactivated")) + }) + + t.Run("unknown", func(t *testing.T) { + limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster) + err := limitNode.GetQuotaExceededError(internalpb.RateType_DDLCompaction) + assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) + assert.True(t, strings.Contains(err.Error(), "rate type")) + }) +} + +func TestRateLimiterTreeClearInvalidLimiterNode(t *testing.T) { + root := NewRateLimiterNode(internalpb.RateScope_Cluster) + tree := NewRateLimiterTree(root) + + generateNodeFFunc := func(level internalpb.RateScope) func() *RateLimiterNode { + return func() *RateLimiterNode { + return NewRateLimiterNode(level) + } + } + + tree.GetOrCreatePartitionLimiters(1, 10, 100, + generateNodeFFunc(internalpb.RateScope_Database), + generateNodeFFunc(internalpb.RateScope_Collection), + generateNodeFFunc(internalpb.RateScope_Partition), + ) + tree.GetOrCreatePartitionLimiters(1, 10, 200, + generateNodeFFunc(internalpb.RateScope_Database), + generateNodeFFunc(internalpb.RateScope_Collection), + generateNodeFFunc(internalpb.RateScope_Partition), + ) + tree.GetOrCreatePartitionLimiters(1, 20, 300, + generateNodeFFunc(internalpb.RateScope_Database), + generateNodeFFunc(internalpb.RateScope_Collection), + generateNodeFFunc(internalpb.RateScope_Partition), + ) + tree.GetOrCreatePartitionLimiters(2, 30, 400, + generateNodeFFunc(internalpb.RateScope_Database), + generateNodeFFunc(internalpb.RateScope_Collection), + generateNodeFFunc(internalpb.RateScope_Partition), + ) + + assert.Equal(t, 2, root.GetChildren().Len()) + assert.Equal(t, 2, root.GetChild(1).GetChildren().Len()) + assert.Equal(t, 2, root.GetChild(1).GetChild(10).GetChildren().Len()) + + tree.ClearInvalidLimiterNode(&proxypb.LimiterNode{ + Children: map[int64]*proxypb.LimiterNode{ + 1: { + Children: map[int64]*proxypb.LimiterNode{ + 10: { + Children: map[int64]*proxypb.LimiterNode{ + 100: {}, + }, + }, + }, + }, + }, + }) + + assert.Equal(t, 1, root.GetChildren().Len()) + assert.Equal(t, 1, root.GetChild(1).GetChildren().Len()) + assert.Equal(t, 1, root.GetChild(1).GetChild(10).GetChildren().Len()) +} diff --git a/pkg/common/common.go b/pkg/common/common.go index fb8a06c92c..a1f20d7f38 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -132,6 +132,8 @@ const ( CollectionSearchRateMaxKey = "collection.searchRate.max.vps" CollectionSearchRateMinKey = "collection.searchRate.min.vps" CollectionDiskQuotaKey = "collection.diskProtection.diskQuota.mb" + + PartitionDiskQuotaKey = "partition.diskProtection.diskQuota.mb" ) // common properties diff --git a/pkg/metrics/rootcoord_metrics.go b/pkg/metrics/rootcoord_metrics.go index ddd3044078..e50c9bece0 100644 --- a/pkg/metrics/rootcoord_metrics.go +++ b/pkg/metrics/rootcoord_metrics.go @@ -167,7 +167,7 @@ var ( Help: "The quota states of cluster", }, []string{ "quota_states", - "db_name", + "name", }) // RootCoordRateLimitRatio reflects the ratio of rate limit. diff --git a/pkg/util/metricsinfo/quota_metric.go b/pkg/util/metricsinfo/quota_metric.go index 84ffddf348..44f609c29f 100644 --- a/pkg/util/metricsinfo/quota_metric.go +++ b/pkg/util/metricsinfo/quota_metric.go @@ -87,6 +87,7 @@ type QueryNodeQuotaMetrics struct { type DataCoordQuotaMetrics struct { TotalBinlogSize int64 CollectionBinlogSize map[int64]int64 + PartitionsBinlogSize map[int64]map[int64]int64 } // DataNodeQuotaMetrics are metrics of DataNode. diff --git a/pkg/util/paramtable/quota_param.go b/pkg/util/paramtable/quota_param.go index f38af39bc1..9f5e448047 100644 --- a/pkg/util/paramtable/quota_param.go +++ b/pkg/util/paramtable/quota_param.go @@ -61,6 +61,12 @@ type quotaConfig struct { CompactionLimitEnabled ParamItem `refreshable:"true"` MaxCompactionRate ParamItem `refreshable:"true"` + DDLCollectionRatePerDB ParamItem `refreshable:"true"` + DDLPartitionRatePerDB ParamItem `refreshable:"true"` + MaxIndexRatePerDB ParamItem `refreshable:"true"` + MaxFlushRatePerDB ParamItem `refreshable:"true"` + MaxCompactionRatePerDB ParamItem `refreshable:"true"` + // dml DMLLimitEnabled ParamItem `refreshable:"true"` DMLMaxInsertRate ParamItem `refreshable:"true"` @@ -71,6 +77,14 @@ type quotaConfig struct { DMLMinDeleteRate ParamItem `refreshable:"true"` DMLMaxBulkLoadRate ParamItem `refreshable:"true"` DMLMinBulkLoadRate ParamItem `refreshable:"true"` + DMLMaxInsertRatePerDB ParamItem `refreshable:"true"` + DMLMinInsertRatePerDB ParamItem `refreshable:"true"` + DMLMaxUpsertRatePerDB ParamItem `refreshable:"true"` + DMLMinUpsertRatePerDB ParamItem `refreshable:"true"` + DMLMaxDeleteRatePerDB ParamItem `refreshable:"true"` + DMLMinDeleteRatePerDB ParamItem `refreshable:"true"` + DMLMaxBulkLoadRatePerDB ParamItem `refreshable:"true"` + DMLMinBulkLoadRatePerDB ParamItem `refreshable:"true"` DMLMaxInsertRatePerCollection ParamItem `refreshable:"true"` DMLMinInsertRatePerCollection ParamItem `refreshable:"true"` DMLMaxUpsertRatePerCollection ParamItem `refreshable:"true"` @@ -79,6 +93,14 @@ type quotaConfig struct { DMLMinDeleteRatePerCollection ParamItem `refreshable:"true"` DMLMaxBulkLoadRatePerCollection ParamItem `refreshable:"true"` DMLMinBulkLoadRatePerCollection ParamItem `refreshable:"true"` + DMLMaxInsertRatePerPartition ParamItem `refreshable:"true"` + DMLMinInsertRatePerPartition ParamItem `refreshable:"true"` + DMLMaxUpsertRatePerPartition ParamItem `refreshable:"true"` + DMLMinUpsertRatePerPartition ParamItem `refreshable:"true"` + DMLMaxDeleteRatePerPartition ParamItem `refreshable:"true"` + DMLMinDeleteRatePerPartition ParamItem `refreshable:"true"` + DMLMaxBulkLoadRatePerPartition ParamItem `refreshable:"true"` + DMLMinBulkLoadRatePerPartition ParamItem `refreshable:"true"` // dql DQLLimitEnabled ParamItem `refreshable:"true"` @@ -86,10 +108,18 @@ type quotaConfig struct { DQLMinSearchRate ParamItem `refreshable:"true"` DQLMaxQueryRate ParamItem `refreshable:"true"` DQLMinQueryRate ParamItem `refreshable:"true"` + DQLMaxSearchRatePerDB ParamItem `refreshable:"true"` + DQLMinSearchRatePerDB ParamItem `refreshable:"true"` + DQLMaxQueryRatePerDB ParamItem `refreshable:"true"` + DQLMinQueryRatePerDB ParamItem `refreshable:"true"` DQLMaxSearchRatePerCollection ParamItem `refreshable:"true"` DQLMinSearchRatePerCollection ParamItem `refreshable:"true"` DQLMaxQueryRatePerCollection ParamItem `refreshable:"true"` DQLMinQueryRatePerCollection ParamItem `refreshable:"true"` + DQLMaxSearchRatePerPartition ParamItem `refreshable:"true"` + DQLMinSearchRatePerPartition ParamItem `refreshable:"true"` + DQLMaxQueryRatePerPartition ParamItem `refreshable:"true"` + DQLMinQueryRatePerPartition ParamItem `refreshable:"true"` // limits MaxCollectionNum ParamItem `refreshable:"true"` @@ -114,16 +144,20 @@ type quotaConfig struct { GrowingSegmentsSizeHighWaterLevel ParamItem `refreshable:"true"` DiskProtectionEnabled ParamItem `refreshable:"true"` DiskQuota ParamItem `refreshable:"true"` + DiskQuotaPerDB ParamItem `refreshable:"true"` DiskQuotaPerCollection ParamItem `refreshable:"true"` + DiskQuotaPerPartition ParamItem `refreshable:"true"` // limit reading - ForceDenyReading ParamItem `refreshable:"true"` - QueueProtectionEnabled ParamItem `refreshable:"true"` - NQInQueueThreshold ParamItem `refreshable:"true"` - QueueLatencyThreshold ParamItem `refreshable:"true"` - ResultProtectionEnabled ParamItem `refreshable:"true"` - MaxReadResultRate ParamItem `refreshable:"true"` - CoolOffSpeed ParamItem `refreshable:"true"` + ForceDenyReading ParamItem `refreshable:"true"` + QueueProtectionEnabled ParamItem `refreshable:"true"` + NQInQueueThreshold ParamItem `refreshable:"true"` + QueueLatencyThreshold ParamItem `refreshable:"true"` + ResultProtectionEnabled ParamItem `refreshable:"true"` + MaxReadResultRate ParamItem `refreshable:"true"` + MaxReadResultRatePerDB ParamItem `refreshable:"true"` + MaxReadResultRatePerCollection ParamItem `refreshable:"true"` + CoolOffSpeed ParamItem `refreshable:"true"` } func (p *quotaConfig) init(base *BaseTable) { @@ -185,6 +219,25 @@ seconds, (0 ~ 65536)`, } p.DDLCollectionRate.Init(base.mgr) + p.DDLCollectionRatePerDB = ParamItem{ + Key: "quotaAndLimits.ddl.db.collectionRate", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DDLLimitEnabled.GetAsBool() { + return max + } + // [0 ~ Inf) + if getAsInt(v) < 0 { + return max + } + return v + }, + Doc: "qps of db level , default no limit, rate for CreateCollection, DropCollection, LoadCollection, ReleaseCollection", + Export: true, + } + p.DDLCollectionRatePerDB.Init(base.mgr) + p.DDLPartitionRate = ParamItem{ Key: "quotaAndLimits.ddl.partitionRate", Version: "2.2.0", @@ -204,6 +257,25 @@ seconds, (0 ~ 65536)`, } p.DDLPartitionRate.Init(base.mgr) + p.DDLPartitionRatePerDB = ParamItem{ + Key: "quotaAndLimits.ddl.db.partitionRate", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DDLLimitEnabled.GetAsBool() { + return max + } + // [0 ~ Inf) + if getAsInt(v) < 0 { + return max + } + return v + }, + Doc: "qps of db level, default no limit, rate for CreatePartition, DropPartition, LoadPartition, ReleasePartition", + Export: true, + } + p.DDLPartitionRatePerDB.Init(base.mgr) + p.IndexLimitEnabled = ParamItem{ Key: "quotaAndLimits.indexRate.enabled", Version: "2.2.0", @@ -231,6 +303,25 @@ seconds, (0 ~ 65536)`, } p.MaxIndexRate.Init(base.mgr) + p.MaxIndexRatePerDB = ParamItem{ + Key: "quotaAndLimits.indexRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.IndexLimitEnabled.GetAsBool() { + return max + } + // [0 ~ Inf) + if getAsFloat(v) < 0 { + return max + } + return v + }, + Doc: "qps of db level, default no limit, rate for CreateIndex, DropIndex", + Export: true, + } + p.MaxIndexRatePerDB.Init(base.mgr) + p.FlushLimitEnabled = ParamItem{ Key: "quotaAndLimits.flushRate.enabled", Version: "2.2.0", @@ -258,6 +349,25 @@ seconds, (0 ~ 65536)`, } p.MaxFlushRate.Init(base.mgr) + p.MaxFlushRatePerDB = ParamItem{ + Key: "quotaAndLimits.flushRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.FlushLimitEnabled.GetAsBool() { + return max + } + // [0 ~ Inf) + if getAsInt(v) < 0 { + return max + } + return v + }, + Doc: "qps of db level, default no limit, rate for flush", + Export: true, + } + p.MaxFlushRatePerDB.Init(base.mgr) + p.MaxFlushRatePerCollection = ParamItem{ Key: "quotaAndLimits.flushRate.collection.max", Version: "2.3.9", @@ -304,6 +414,25 @@ seconds, (0 ~ 65536)`, } p.MaxCompactionRate.Init(base.mgr) + p.MaxCompactionRatePerDB = ParamItem{ + Key: "quotaAndLimits.compactionRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.CompactionLimitEnabled.GetAsBool() { + return max + } + // [0 ~ Inf) + if getAsInt(v) < 0 { + return max + } + return v + }, + Doc: "qps of db level, default no limit, rate for manualCompaction", + Export: true, + } + p.MaxCompactionRatePerDB.Init(base.mgr) + // dml p.DMLLimitEnabled = ParamItem{ Key: "quotaAndLimits.dml.enabled", @@ -359,6 +488,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinInsertRate.Init(base.mgr) + p.DMLMaxInsertRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.insertRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxInsertRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit", + Export: true, + } + p.DMLMaxInsertRatePerDB.Init(base.mgr) + + p.DMLMinInsertRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.insertRate.db.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxInsertRatePerDB.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinInsertRatePerDB.Init(base.mgr) + p.DMLMaxInsertRatePerCollection = ParamItem{ Key: "quotaAndLimits.dml.insertRate.collection.max", Version: "2.2.9", @@ -403,6 +576,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinInsertRatePerCollection.Init(base.mgr) + p.DMLMaxInsertRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.insertRate.partition.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxInsertRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit", + Export: true, + } + p.DMLMaxInsertRatePerPartition.Init(base.mgr) + + p.DMLMinInsertRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.insertRate.partition.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxInsertRatePerPartition.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinInsertRatePerPartition.Init(base.mgr) + p.DMLMaxUpsertRate = ParamItem{ Key: "quotaAndLimits.dml.upsertRate.max", Version: "2.3.0", @@ -447,6 +664,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinUpsertRate.Init(base.mgr) + p.DMLMaxUpsertRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.upsertRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxUpsertRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit", + Export: true, + } + p.DMLMaxUpsertRatePerDB.Init(base.mgr) + + p.DMLMinUpsertRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.upsertRate.db.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxUpsertRatePerDB.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinUpsertRatePerDB.Init(base.mgr) + p.DMLMaxUpsertRatePerCollection = ParamItem{ Key: "quotaAndLimits.dml.upsertRate.collection.max", Version: "2.3.0", @@ -491,6 +752,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinUpsertRatePerCollection.Init(base.mgr) + p.DMLMaxUpsertRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.upsertRate.partition.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxUpsertRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit", + Export: true, + } + p.DMLMaxUpsertRatePerPartition.Init(base.mgr) + + p.DMLMinUpsertRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.upsertRate.partition.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxUpsertRatePerPartition.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinUpsertRatePerPartition.Init(base.mgr) + p.DMLMaxDeleteRate = ParamItem{ Key: "quotaAndLimits.dml.deleteRate.max", Version: "2.2.0", @@ -535,6 +840,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinDeleteRate.Init(base.mgr) + p.DMLMaxDeleteRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.deleteRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxDeleteRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit", + Export: true, + } + p.DMLMaxDeleteRatePerDB.Init(base.mgr) + + p.DMLMinDeleteRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.deleteRate.db.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxDeleteRatePerDB.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinDeleteRatePerDB.Init(base.mgr) + p.DMLMaxDeleteRatePerCollection = ParamItem{ Key: "quotaAndLimits.dml.deleteRate.collection.max", Version: "2.2.9", @@ -579,6 +928,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinDeleteRatePerCollection.Init(base.mgr) + p.DMLMaxDeleteRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.deleteRate.partition.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxDeleteRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit", + Export: true, + } + p.DMLMaxDeleteRatePerPartition.Init(base.mgr) + + p.DMLMinDeleteRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.deleteRate.partition.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxDeleteRatePerPartition.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinDeleteRatePerPartition.Init(base.mgr) + p.DMLMaxBulkLoadRate = ParamItem{ Key: "quotaAndLimits.dml.bulkLoadRate.max", Version: "2.2.0", @@ -623,6 +1016,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinBulkLoadRate.Init(base.mgr) + p.DMLMaxBulkLoadRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.bulkLoadRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxBulkLoadRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit, not support yet. TODO: limit db bulkLoad rate", + Export: true, + } + p.DMLMaxBulkLoadRatePerDB.Init(base.mgr) + + p.DMLMinBulkLoadRatePerDB = ParamItem{ + Key: "quotaAndLimits.dml.bulkLoadRate.db.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxBulkLoadRatePerDB.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinBulkLoadRatePerDB.Init(base.mgr) + p.DMLMaxBulkLoadRatePerCollection = ParamItem{ Key: "quotaAndLimits.dml.bulkLoadRate.collection.max", Version: "2.2.9", @@ -667,6 +1104,50 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DMLMinBulkLoadRatePerCollection.Init(base.mgr) + p.DMLMaxBulkLoadRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.bulkLoadRate.partition.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + rate = megaBytes2Bytes(rate) + } + // [0, inf) + if rate < 0 { + return p.DMLMaxBulkLoadRate.GetValue() + } + return fmt.Sprintf("%f", rate) + }, + Doc: "MB/s, default no limit, not support yet. TODO: limit partition bulkLoad rate", + Export: true, + } + p.DMLMaxBulkLoadRatePerPartition.Init(base.mgr) + + p.DMLMinBulkLoadRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dml.bulkLoadRate.partition.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DMLLimitEnabled.GetAsBool() { + return min + } + rate := megaBytes2Bytes(getAsFloat(v)) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DMLMaxBulkLoadRatePerPartition.GetAsFloat()) { + return min + } + return fmt.Sprintf("%f", rate) + }, + } + p.DMLMinBulkLoadRatePerPartition.Init(base.mgr) + // dql p.DQLLimitEnabled = ParamItem{ Key: "quotaAndLimits.dql.enabled", @@ -718,6 +1199,46 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DQLMinSearchRate.Init(base.mgr) + p.DQLMaxSearchRatePerDB = ParamItem{ + Key: "quotaAndLimits.dql.searchRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return max + } + // [0, inf) + if getAsFloat(v) < 0 { + return p.DQLMaxSearchRate.GetValue() + } + return v + }, + Doc: "vps (vectors per second), default no limit", + Export: true, + } + p.DQLMaxSearchRatePerDB.Init(base.mgr) + + p.DQLMinSearchRatePerDB = ParamItem{ + Key: "quotaAndLimits.dql.searchRate.db.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return min + } + rate := getAsFloat(v) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DQLMaxSearchRatePerDB.GetAsFloat()) { + return min + } + return v + }, + } + p.DQLMinSearchRatePerDB.Init(base.mgr) + p.DQLMaxSearchRatePerCollection = ParamItem{ Key: "quotaAndLimits.dql.searchRate.collection.max", Version: "2.2.9", @@ -758,6 +1279,46 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DQLMinSearchRatePerCollection.Init(base.mgr) + p.DQLMaxSearchRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dql.searchRate.partition.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return max + } + // [0, inf) + if getAsFloat(v) < 0 { + return p.DQLMaxSearchRate.GetValue() + } + return v + }, + Doc: "vps (vectors per second), default no limit", + Export: true, + } + p.DQLMaxSearchRatePerPartition.Init(base.mgr) + + p.DQLMinSearchRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dql.searchRate.partition.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return min + } + rate := getAsFloat(v) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DQLMaxSearchRatePerPartition.GetAsFloat()) { + return min + } + return v + }, + } + p.DQLMinSearchRatePerPartition.Init(base.mgr) + p.DQLMaxQueryRate = ParamItem{ Key: "quotaAndLimits.dql.queryRate.max", Version: "2.2.0", @@ -798,6 +1359,46 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DQLMinQueryRate.Init(base.mgr) + p.DQLMaxQueryRatePerDB = ParamItem{ + Key: "quotaAndLimits.dql.queryRate.db.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return max + } + // [0, inf) + if getAsFloat(v) < 0 { + return p.DQLMaxQueryRate.GetValue() + } + return v + }, + Doc: "qps, default no limit", + Export: true, + } + p.DQLMaxQueryRatePerDB.Init(base.mgr) + + p.DQLMinQueryRatePerDB = ParamItem{ + Key: "quotaAndLimits.dql.queryRate.db.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return min + } + rate := getAsFloat(v) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DQLMaxQueryRatePerDB.GetAsFloat()) { + return min + } + return v + }, + } + p.DQLMinQueryRatePerDB.Init(base.mgr) + p.DQLMaxQueryRatePerCollection = ParamItem{ Key: "quotaAndLimits.dql.queryRate.collection.max", Version: "2.2.9", @@ -838,6 +1439,46 @@ The maximum rate will not be greater than ` + "max" + `.`, } p.DQLMinQueryRatePerCollection.Init(base.mgr) + p.DQLMaxQueryRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dql.queryRate.partition.max", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return max + } + // [0, inf) + if getAsFloat(v) < 0 { + return p.DQLMaxQueryRate.GetValue() + } + return v + }, + Doc: "qps, default no limit", + Export: true, + } + p.DQLMaxQueryRatePerPartition.Init(base.mgr) + + p.DQLMinQueryRatePerPartition = ParamItem{ + Key: "quotaAndLimits.dql.queryRate.partition.min", + Version: "2.4.1", + DefaultValue: min, + Formatter: func(v string) string { + if !p.DQLLimitEnabled.GetAsBool() { + return min + } + rate := getAsFloat(v) + // [0, inf) + if rate < 0 { + return min + } + if !p.checkMinMaxLegal(rate, p.DQLMaxQueryRatePerPartition.GetAsFloat()) { + return min + } + return v + }, + } + p.DQLMinQueryRatePerPartition.Init(base.mgr) + // limits p.MaxCollectionNum = ParamItem{ Key: "quotaAndLimits.limits.maxCollectionNum", @@ -1132,6 +1773,27 @@ but the rate will not be lower than minRateRatio * dmlRate.`, } p.DiskQuota.Init(base.mgr) + p.DiskQuotaPerDB = ParamItem{ + Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerDB", + Version: "2.4.1", + DefaultValue: quota, + Formatter: func(v string) string { + if !p.DiskProtectionEnabled.GetAsBool() { + return max + } + level := getAsFloat(v) + // (0, +inf) + if level <= 0 { + return p.DiskQuota.GetValue() + } + // megabytes to bytes + return fmt.Sprintf("%f", megaBytes2Bytes(level)) + }, + Doc: "MB, (0, +inf), default no limit", + Export: true, + } + p.DiskQuotaPerDB.Init(base.mgr) + p.DiskQuotaPerCollection = ParamItem{ Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerCollection", Version: "2.2.8", @@ -1153,6 +1815,27 @@ but the rate will not be lower than minRateRatio * dmlRate.`, } p.DiskQuotaPerCollection.Init(base.mgr) + p.DiskQuotaPerPartition = ParamItem{ + Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerPartition", + Version: "2.4.1", + DefaultValue: quota, + Formatter: func(v string) string { + if !p.DiskProtectionEnabled.GetAsBool() { + return max + } + level := getAsFloat(v) + // (0, +inf) + if level <= 0 { + return p.DiskQuota.GetValue() + } + // megabytes to bytes + return fmt.Sprintf("%f", megaBytes2Bytes(level)) + }, + Doc: "MB, (0, +inf), default no limit", + Export: true, + } + p.DiskQuotaPerPartition.Init(base.mgr) + // limit reading p.ForceDenyReading = ParamItem{ Key: "quotaAndLimits.limitReading.forceDeny", @@ -1253,6 +1936,50 @@ MB/s, default no limit`, } p.MaxReadResultRate.Init(base.mgr) + p.MaxReadResultRatePerDB = ParamItem{ + Key: "quotaAndLimits.limitReading.resultProtection.maxReadResultRatePerDB", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.ResultProtectionEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + return fmt.Sprintf("%f", megaBytes2Bytes(rate)) + } + // [0, inf) + if rate < 0 { + return max + } + return v + }, + Export: true, + } + p.MaxReadResultRatePerDB.Init(base.mgr) + + p.MaxReadResultRatePerCollection = ParamItem{ + Key: "quotaAndLimits.limitReading.resultProtection.maxReadResultRatePerCollection", + Version: "2.4.1", + DefaultValue: max, + Formatter: func(v string) string { + if !p.ResultProtectionEnabled.GetAsBool() { + return max + } + rate := getAsFloat(v) + if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax + return fmt.Sprintf("%f", megaBytes2Bytes(rate)) + } + // [0, inf) + if rate < 0 { + return max + } + return v + }, + Export: true, + } + p.MaxReadResultRatePerCollection.Init(base.mgr) + const defaultSpeed = "0.9" p.CoolOffSpeed = ParamItem{ Key: "quotaAndLimits.limitReading.coolOffSpeed", diff --git a/pkg/util/ratelimitutil/limiter.go b/pkg/util/ratelimitutil/limiter.go index 6528f2bf49..d2f95b31b6 100644 --- a/pkg/util/ratelimitutil/limiter.go +++ b/pkg/util/ratelimitutil/limiter.go @@ -45,12 +45,13 @@ const Inf = Limit(math.MaxFloat64) // in bucket may be negative, and the latter events would be "punished", // any event should wait for the tokens to be filled to greater or equal to 0. type Limiter struct { - mu sync.Mutex + mu sync.RWMutex limit Limit burst float64 tokens float64 // last is the last time the limiter's tokens field was updated - last time.Time + last time.Time + hasUpdated bool } // NewLimiter returns a new Limiter that allows events up to rate r. @@ -63,13 +64,20 @@ func NewLimiter(r Limit, b float64) *Limiter { // Limit returns the maximum overall event rate. func (lim *Limiter) Limit() Limit { - lim.mu.Lock() - defer lim.mu.Unlock() + lim.mu.RLock() + defer lim.mu.RUnlock() return lim.limit } // AllowN reports whether n events may happen at time now. func (lim *Limiter) AllowN(now time.Time, n int) bool { + lim.mu.RLock() + if lim.limit == Inf { + lim.mu.RUnlock() + return true + } + lim.mu.RUnlock() + lim.mu.Lock() defer lim.mu.Unlock() @@ -119,6 +127,7 @@ func (lim *Limiter) SetLimit(newLimit Limit) { // use rate as burst, because Limiter is with punishment mechanism, burst is insignificant. lim.burst = float64(newLimit) } + lim.hasUpdated = true } // Cancel the AllowN operation and refund the tokens that have already been deducted by the limiter. @@ -128,6 +137,12 @@ func (lim *Limiter) Cancel(n int) { lim.tokens += float64(n) } +func (lim *Limiter) HasUpdated() bool { + lim.mu.RLock() + defer lim.mu.RUnlock() + return lim.hasUpdated +} + // advance calculates and returns an updated state for lim resulting from the passage of time. // lim is not changed. advance requires that lim.mu is held. func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) { diff --git a/pkg/util/ratelimitutil/rate_collector.go b/pkg/util/ratelimitutil/rate_collector.go index 0608c5924d..bf943b2512 100644 --- a/pkg/util/ratelimitutil/rate_collector.go +++ b/pkg/util/ratelimitutil/rate_collector.go @@ -19,8 +19,11 @@ package ratelimitutil import ( "fmt" "math" + "strings" "sync" "time" + + "github.com/samber/lo" ) const ( @@ -34,21 +37,22 @@ const ( type RateCollector struct { sync.Mutex - window time.Duration - granularity time.Duration - position int - values map[string][]float64 + window time.Duration + granularity time.Duration + position int + values map[string][]float64 + deprecatedSubLabels []lo.Tuple2[string, string] last time.Time } // NewRateCollector is shorthand for newRateCollector(window, granularity, time.Now()). -func NewRateCollector(window time.Duration, granularity time.Duration) (*RateCollector, error) { - return newRateCollector(window, granularity, time.Now()) +func NewRateCollector(window time.Duration, granularity time.Duration, enableSubLabel bool) (*RateCollector, error) { + return newRateCollector(window, granularity, time.Now(), enableSubLabel) } // newRateCollector returns a new RateCollector with given window and granularity. -func newRateCollector(window time.Duration, granularity time.Duration, now time.Time) (*RateCollector, error) { +func newRateCollector(window time.Duration, granularity time.Duration, now time.Time, enableSubLabel bool) (*RateCollector, error) { if window == 0 || granularity == 0 { return nil, fmt.Errorf("create RateCollector failed, window or granularity cannot be 0, window = %d, granularity = %d", window, granularity) } @@ -62,9 +66,52 @@ func newRateCollector(window time.Duration, granularity time.Duration, now time. values: make(map[string][]float64), last: now, } + + if enableSubLabel { + go rc.cleanDeprecateSubLabels() + } return rc, nil } +func (r *RateCollector) cleanDeprecateSubLabels() { + tick := time.NewTicker(r.window * 2) + defer tick.Stop() + for range tick.C { + r.Lock() + for _, labelInfo := range r.deprecatedSubLabels { + r.removeSubLabel(labelInfo) + } + r.Unlock() + } +} + +func (r *RateCollector) removeSubLabel(labelInfo lo.Tuple2[string, string]) { + label := labelInfo.A + subLabel := labelInfo.B + if subLabel == "" { + return + } + removeKeys := make([]string, 1) + removeKeys[0] = FormatSubLabel(label, subLabel) + + deleteCollectionSubLabelWithPrefix := func(dbName string) { + for key := range r.values { + if strings.HasPrefix(key, FormatSubLabel(label, GetCollectionSubLabel(dbName, ""))) { + removeKeys = append(removeKeys, key) + } + } + } + + parts := strings.Split(subLabel, ".") + if strings.HasPrefix(subLabel, GetDBSubLabel("")) { + dbName := parts[1] + deleteCollectionSubLabelWithPrefix(dbName) + } + for _, key := range removeKeys { + delete(r.values, key) + } +} + // Register init values of RateCollector for specified label. func (r *RateCollector) Register(label string) { r.Lock() @@ -81,21 +128,77 @@ func (r *RateCollector) Deregister(label string) { delete(r.values, label) } +func GetDBSubLabel(dbName string) string { + return fmt.Sprintf("db.%s", dbName) +} + +func GetCollectionSubLabel(dbName, collectionName string) string { + return fmt.Sprintf("collection.%s.%s", dbName, collectionName) +} + +func FormatSubLabel(label, subLabel string) string { + return fmt.Sprintf("%s-%s", label, subLabel) +} + +func GetDBFromSubLabel(label, fullLabel string) (string, bool) { + if !strings.HasPrefix(fullLabel, FormatSubLabel(label, GetDBSubLabel(""))) { + return "", false + } + return fullLabel[len(FormatSubLabel(label, GetDBSubLabel(""))):], true +} + +func GetCollectionFromSubLabel(label, fullLabel string) (string, string, bool) { + if !strings.HasPrefix(fullLabel, FormatSubLabel(label, "")) { + return "", "", false + } + subLabels := strings.Split(fullLabel[len(FormatSubLabel(label, "")):], ".") + if len(subLabels) != 3 || subLabels[0] != "collection" { + return "", "", false + } + + return subLabels[1], subLabels[2], true +} + +func (r *RateCollector) DeregisterSubLabel(label, subLabel string) { + r.Lock() + defer r.Unlock() + r.deprecatedSubLabels = append(r.deprecatedSubLabels, lo.Tuple2[string, string]{ + A: label, + B: subLabel, + }) +} + // Add is shorthand for add(label, value, time.Now()). -func (r *RateCollector) Add(label string, value float64) { - r.add(label, value, time.Now()) +func (r *RateCollector) Add(label string, value float64, subLabels ...string) { + r.add(label, value, time.Now(), subLabels...) } // add increases the current value of specified label. -func (r *RateCollector) add(label string, value float64, now time.Time) { +func (r *RateCollector) add(label string, value float64, now time.Time, subLabels ...string) { r.Lock() defer r.Unlock() r.update(now) if _, ok := r.values[label]; ok { r.values[label][r.position] += value + for _, subLabel := range subLabels { + r.unsafeAddForSubLabels(label, subLabel, value) + } } } +func (r *RateCollector) unsafeAddForSubLabels(label, subLabel string, value float64) { + if subLabel == "" { + return + } + sub := FormatSubLabel(label, subLabel) + if _, ok := r.values[sub]; ok { + r.values[sub][r.position] += value + return + } + r.values[sub] = make([]float64, int(r.window/r.granularity)) + r.values[sub][r.position] = value +} + // Max is shorthand for max(label, time.Now()). func (r *RateCollector) Max(label string, now time.Time) (float64, error) { return r.max(label, time.Now()) @@ -145,6 +248,26 @@ func (r *RateCollector) Rate(label string, duration time.Duration) (float64, err return r.rate(label, duration, time.Now()) } +func (r *RateCollector) RateSubLabel(label string, duration time.Duration) (map[string]float64, error) { + subLabelPrefix := FormatSubLabel(label, "") + subLabels := make(map[string]float64) + r.Lock() + for s := range r.values { + if strings.HasPrefix(s, subLabelPrefix) { + subLabels[s] = 0 + } + } + r.Unlock() + for s := range subLabels { + v, err := r.rate(s, duration, time.Now()) + if err != nil { + return nil, err + } + subLabels[s] = v + } + return subLabels, nil +} + // rate returns the latest mean value of the specified duration. func (r *RateCollector) rate(label string, duration time.Duration, now time.Time) (float64, error) { if duration > r.window { diff --git a/pkg/util/ratelimitutil/rate_collector_test.go b/pkg/util/ratelimitutil/rate_collector_test.go index 039a9c4853..068b1a4fe5 100644 --- a/pkg/util/ratelimitutil/rate_collector_test.go +++ b/pkg/util/ratelimitutil/rate_collector_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/samber/lo" "github.com/stretchr/testify/assert" ) @@ -36,7 +37,7 @@ func TestRateCollector(t *testing.T) { ts100 = ts0.Add(time.Duration(100.0 * float64(time.Second))) ) - rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0) + rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0, false) assert.NoError(t, err) label := "mock_label" rc.Register(label) @@ -78,7 +79,7 @@ func TestRateCollector(t *testing.T) { ts31 = ts0.Add(time.Duration(3.1 * float64(time.Second))) ) - rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0) + rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0, false) assert.NoError(t, err) label := "mock_label" rc.Register(label) @@ -105,7 +106,7 @@ func TestRateCollector(t *testing.T) { start := tt.now() end := start.Add(testPeriod * time.Second) - rc, err := newRateCollector(DefaultWindow, DefaultGranularity, start) + rc, err := newRateCollector(DefaultWindow, DefaultGranularity, start, false) assert.NoError(t, err) label := "mock_label" rc.Register(label) @@ -138,3 +139,111 @@ func TestRateCollector(t *testing.T) { } }) } + +func TestRateSubLabel(t *testing.T) { + rateCollector, err := NewRateCollector(5*time.Second, time.Second, true) + assert.NoError(t, err) + + var ( + label = "search" + db = "hoo" + collection = "foo" + dbSubLabel = GetDBSubLabel(db) + collectionSubLabel = GetCollectionSubLabel(db, collection) + ts0 = time.Now() + ts10 = ts0.Add(time.Duration(1.0 * float64(time.Second))) + ts19 = ts0.Add(time.Duration(1.9 * float64(time.Second))) + ts20 = ts0.Add(time.Duration(2.0 * float64(time.Second))) + ts30 = ts0.Add(time.Duration(3.0 * float64(time.Second))) + ts40 = ts0.Add(time.Duration(4.0 * float64(time.Second))) + ) + + rateCollector.Register(label) + defer rateCollector.Deregister(label) + rateCollector.add(label, 10, ts0, dbSubLabel, collectionSubLabel) + rateCollector.add(label, 20, ts10, dbSubLabel, collectionSubLabel) + rateCollector.add(label, 30, ts19, dbSubLabel, collectionSubLabel) + rateCollector.add(label, 40, ts20, dbSubLabel, collectionSubLabel) + rateCollector.add(label, 50, ts30, dbSubLabel, collectionSubLabel) + rateCollector.add(label, 60, ts40, dbSubLabel, collectionSubLabel) + + time.Sleep(4 * time.Second) + + // 10 20+30 40 50 60 + { + avg, err := rateCollector.Rate(label, 3*time.Second) + assert.NoError(t, err) + assert.Equal(t, float64(50), avg) + } + { + avg, err := rateCollector.Rate(label, 5*time.Second) + assert.NoError(t, err) + assert.Equal(t, float64(42), avg) + } + { + avgs, err := rateCollector.RateSubLabel(label, 3*time.Second) + assert.NoError(t, err) + assert.Equal(t, 2, len(avgs)) + assert.Equal(t, float64(50), avgs[FormatSubLabel(label, dbSubLabel)]) + assert.Equal(t, float64(50), avgs[FormatSubLabel(label, collectionSubLabel)]) + } + + rateCollector.Add(label, 10, GetCollectionSubLabel(db, collection)) + rateCollector.Add(label, 10, GetCollectionSubLabel(db, "col2")) + + rateCollector.DeregisterSubLabel(label, GetCollectionSubLabel(db, "col2")) + rateCollector.DeregisterSubLabel(label, dbSubLabel) + + rateCollector.removeSubLabel(lo.Tuple2[string, string]{ + A: "aaa", + }) + + rateCollector.Lock() + for _, labelInfo := range rateCollector.deprecatedSubLabels { + rateCollector.removeSubLabel(labelInfo) + } + rateCollector.Unlock() + + { + _, ok := rateCollector.values[FormatSubLabel(label, dbSubLabel)] + assert.False(t, ok) + } + + { + _, ok := rateCollector.values[FormatSubLabel(label, collectionSubLabel)] + assert.False(t, ok) + } + + { + assert.Len(t, rateCollector.values, 1) + _, ok := rateCollector.values[label] + assert.True(t, ok) + } +} + +func TestLabelUtil(t *testing.T) { + assert.Equal(t, GetDBSubLabel("db"), "db.db") + assert.Equal(t, GetCollectionSubLabel("db", "collection"), "collection.db.collection") + { + db, ok := GetDBFromSubLabel("foo", FormatSubLabel("foo", GetDBSubLabel("db1"))) + assert.True(t, ok) + assert.Equal(t, "db1", db) + } + + { + _, ok := GetDBFromSubLabel("foo", "aaa") + assert.False(t, ok) + } + + { + db, col, ok := GetCollectionFromSubLabel("foo", FormatSubLabel("foo", GetCollectionSubLabel("db1", "col1"))) + assert.True(t, ok) + assert.Equal(t, "col1", col) + assert.Equal(t, "db1", db) + } + + { + _, _, ok := GetCollectionFromSubLabel("foo", "aaa") + assert.False(t, ok) + } +} diff --git a/pkg/util/ratelimitutil/utils.go b/pkg/util/ratelimitutil/utils.go new file mode 100644 index 0000000000..d72e28c22e --- /dev/null +++ b/pkg/util/ratelimitutil/utils.go @@ -0,0 +1,30 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimitutil + +import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + +var QuotaErrorString = map[commonpb.ErrorCode]string{ + commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator", + commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources", + commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources", + commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay", +} + +func GetQuotaErrorString(errCode commonpb.ErrorCode) string { + return QuotaErrorString[errCode] +} diff --git a/pkg/util/ratelimitutil/utils_test.go b/pkg/util/ratelimitutil/utils_test.go new file mode 100644 index 0000000000..4c0a7dc3ac --- /dev/null +++ b/pkg/util/ratelimitutil/utils_test.go @@ -0,0 +1,43 @@ +package ratelimitutil + +import ( + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +func TestGetQuotaErrorString(t *testing.T) { + tests := []struct { + name string + args commonpb.ErrorCode + want string + }{ + { + name: "Test ErrorCode_ForceDeny", + args: commonpb.ErrorCode_ForceDeny, + want: "the writing has been deactivated by the administrator", + }, + { + name: "Test ErrorCode_MemoryQuotaExhausted", + args: commonpb.ErrorCode_MemoryQuotaExhausted, + want: "memory quota exceeded, please allocate more resources", + }, + { + name: "Test ErrorCode_DiskQuotaExhausted", + args: commonpb.ErrorCode_DiskQuotaExhausted, + want: "disk quota exceeded, please allocate more resources", + }, + { + name: "Test ErrorCode_TimeTickLongDelay", + args: commonpb.ErrorCode_TimeTickLongDelay, + want: "time tick long delay", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetQuotaErrorString(tt.args); got != tt.want { + t.Errorf("GetQuotaErrorString() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/util/typeutil/set_test.go b/pkg/util/typeutil/set_test.go index 82bb69b81b..fafc84f975 100644 --- a/pkg/util/typeutil/set_test.go +++ b/pkg/util/typeutil/set_test.go @@ -30,6 +30,16 @@ func TestUniqueSet(t *testing.T) { assert.True(t, set.Contain(9)) assert.True(t, set.Contain(5, 7, 9)) + containFive := false + set.Range(func(i UniqueID) bool { + if i == 5 { + containFive = true + return false + } + return true + }) + assert.True(t, containFive) + set.Remove(7) assert.True(t, set.Contain(5)) assert.False(t, set.Contain(7))