diff --git a/internal/proxy/simple_rate_limiter.go b/internal/proxy/simple_rate_limiter.go index d9bbc46f6d..10ca38dcf3 100644 --- a/internal/proxy/simple_rate_limiter.go +++ b/internal/proxy/simple_rate_limiter.go @@ -194,13 +194,13 @@ func (m *SimpleLimiter) SetRates(rootLimiter *proxypb.LimiterNode) error { collectionConfigs = getDefaultLimiterConfig(internalpb.RateScope_Collection) partitionConfigs = getDefaultLimiterConfig(internalpb.RateScope_Partition) ) - initLimiter(m.rateLimiter.GetRootLimiters(), clusterConfigs) - m.rateLimiter.GetRootLimiters().GetChildren().Range(func(_ int64, dbLimiter *rlinternal.RateLimiterNode) bool { - initLimiter(dbLimiter, databaseConfigs) - dbLimiter.GetChildren().Range(func(_ int64, collLimiter *rlinternal.RateLimiterNode) bool { - initLimiter(collLimiter, collectionConfigs) - collLimiter.GetChildren().Range(func(_ int64, partitionLimiter *rlinternal.RateLimiterNode) bool { - initLimiter(partitionLimiter, partitionConfigs) + initLimiter("cluster", m.rateLimiter.GetRootLimiters(), clusterConfigs) + m.rateLimiter.GetRootLimiters().GetChildren().Range(func(dbID int64, dbLimiter *rlinternal.RateLimiterNode) bool { + initLimiter(fmt.Sprintf("db-%d", dbID), dbLimiter, databaseConfigs) + dbLimiter.GetChildren().Range(func(collectionID int64, collLimiter *rlinternal.RateLimiterNode) bool { + initLimiter(fmt.Sprintf("collection-%d", collectionID), collLimiter, collectionConfigs) + collLimiter.GetChildren().Range(func(partitionID int64, partitionLimiter *rlinternal.RateLimiterNode) bool { + initLimiter(fmt.Sprintf("partition-%d", partitionID), partitionLimiter, partitionConfigs) return true }) return true @@ -216,7 +216,7 @@ func (m *SimpleLimiter) SetRates(rootLimiter *proxypb.LimiterNode) error { return nil } -func initLimiter(rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[internalpb.RateType]*paramtable.ParamItem) { +func initLimiter(source string, rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[internalpb.RateType]*paramtable.ParamItem) { for rt, p := range rateLimiterConfigs { newLimit := ratelimitutil.Limit(p.GetAsFloat()) burst := p.GetAsFloat() // use rate as burst, because SimpleLimiter is with punishment mechanism, burst is insignificant. @@ -233,6 +233,7 @@ func initLimiter(rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[interna } if updated { log.Debug("RateLimiter register for rateType", + zap.String("source", source), zap.String("rateType", internalpb.RateType_name[(int32(rt))]), zap.String("rateLimit", newLimit.String()), zap.String("burst", fmt.Sprintf("%v", burst))) @@ -246,28 +247,28 @@ func initLimiter(rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[interna func newClusterLimiter() *rlinternal.RateLimiterNode { clusterRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster) clusterLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Cluster) - initLimiter(clusterRateLimiters, clusterLimiterConfigs) + initLimiter(internalpb.RateScope_Cluster.String(), clusterRateLimiters, clusterLimiterConfigs) return clusterRateLimiters } func newDatabaseLimiter() *rlinternal.RateLimiterNode { dbRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Database) databaseLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Database) - initLimiter(dbRateLimiters, databaseLimiterConfigs) + initLimiter(internalpb.RateScope_Database.String(), dbRateLimiters, databaseLimiterConfigs) return dbRateLimiters } func newCollectionLimiters() *rlinternal.RateLimiterNode { collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Collection) collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Collection) - initLimiter(collectionRateLimiters, collectionLimiterConfigs) + initLimiter(internalpb.RateScope_Collection.String(), collectionRateLimiters, collectionLimiterConfigs) return collectionRateLimiters } func newPartitionLimiters() *rlinternal.RateLimiterNode { partRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Partition) partitionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Partition) - initLimiter(partRateLimiters, partitionLimiterConfigs) + initLimiter(internalpb.RateScope_Partition.String(), partRateLimiters, partitionLimiterConfigs) return partRateLimiters } diff --git a/internal/util/ratelimitutil/rate_limiter_tree.go b/internal/util/ratelimitutil/rate_limiter_tree.go index a2db0eb3e0..083b42ae39 100644 --- a/internal/util/ratelimitutil/rate_limiter_tree.go +++ b/internal/util/ratelimitutil/rate_limiter_tree.go @@ -156,6 +156,8 @@ func (rln *RateLimiterNode) GetID() int64 { return rln.id } +const clearInvalidNodeInterval = 1 * time.Minute + // RateLimiterTree is implemented based on RateLimiterNode to operate multilevel rate limiters // // it contains the following four levels generally: @@ -167,11 +169,13 @@ func (rln *RateLimiterNode) GetID() int64 { type RateLimiterTree struct { root *RateLimiterNode mu sync.RWMutex + + lastClearTime time.Time } // NewRateLimiterTree returns a new RateLimiterTree. func NewRateLimiterTree(root *RateLimiterNode) *RateLimiterTree { - return &RateLimiterTree{root: root} + return &RateLimiterTree{root: root, lastClearTime: time.Now()} } // GetRootLimiters get root limiters @@ -183,6 +187,13 @@ func (m *RateLimiterTree) ClearInvalidLimiterNode(req *proxypb.LimiterNode) { m.mu.Lock() defer m.mu.Unlock() + if time.Since(m.lastClearTime) < clearInvalidNodeInterval { + return + } + defer func() { + m.lastClearTime = time.Now() + }() + reqDBLimits := req.GetChildren() removeDBLimits := make([]int64, 0) m.GetRootLimiters().GetChildren().Range(func(key int64, _ *RateLimiterNode) bool { diff --git a/internal/util/ratelimitutil/rate_limiter_tree_test.go b/internal/util/ratelimitutil/rate_limiter_tree_test.go index 85d7b2e53e..383cf07a81 100644 --- a/internal/util/ratelimitutil/rate_limiter_tree_test.go +++ b/internal/util/ratelimitutil/rate_limiter_tree_test.go @@ -19,6 +19,7 @@ package ratelimitutil import ( "strings" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -153,6 +154,7 @@ func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) { func TestRateLimiterTreeClearInvalidLimiterNode(t *testing.T) { root := NewRateLimiterNode(internalpb.RateScope_Cluster) tree := NewRateLimiterTree(root) + tree.lastClearTime = time.Now().Add(-1 * clearInvalidNodeInterval * 2) generateNodeFFunc := func(level internalpb.RateScope) func() *RateLimiterNode { return func() *RateLimiterNode { diff --git a/tests/integration/ratelimit/flush_test.go b/tests/integration/ratelimit/flush_test.go new file mode 100644 index 0000000000..04233ee329 --- /dev/null +++ b/tests/integration/ratelimit/flush_test.go @@ -0,0 +1,125 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type FlushSuite struct { + integration.MiniClusterSuite + + indexType string + metricType string + vecType schemapb.DataType +} + +func (s *FlushSuite) TestFlush() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := s.Cluster + + const ( + dim = 128 + dbName = "" + rowNum = 3000 + ) + + s.indexType = integration.IndexFaissIvfFlat + s.metricType = metric.L2 + s.vecType = schemapb.DataType_FloatVector + + collectionName := "TestFlush_" + funcutil.GenRandomStr() + + schema := integration.ConstructSchemaOfVecDataType(collectionName, dim, true, s.vecType) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.MilvusClient.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + err = merr.CheckRPCCall(createCollectionStatus, err) + s.NoError(err) + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := c.MilvusClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + err = merr.CheckRPCCall(showCollectionsResp.GetStatus(), err) + s.NoError(err) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + var fVecColumn *schemapb.FieldData + if s.vecType == schemapb.DataType_SparseFloatVector { + fVecColumn = integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum) + } else { + fVecColumn = integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + } + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.MilvusClient.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + err = merr.CheckRPCCall(insertResult.GetStatus(), err) + s.NoError(err) + + // flush 1 + flushResp, err := c.MilvusClient.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + err = merr.CheckRPCCall(flushResp.GetStatus(), err) + s.NoError(err) + + // flush 2 + flushResp, err = c.MilvusClient.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + s.True(merr.ErrServiceRateLimit.Is(merr.Error(flushResp.GetStatus()))) + + status, err := c.MilvusClient.DropCollection(ctx, &milvuspb.DropCollectionRequest{ + CollectionName: collectionName, + }) + err = merr.CheckRPCCall(status, err) + s.NoError(err) + + log.Info("TestFlush succeed") +} + +func TestFlush(t *testing.T) { + suite.Run(t, new(FlushSuite)) +}