diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index eca2cc1e45..9c076e67ce 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -92,7 +92,7 @@ func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int return cnt1+delta1 < cnt2+delta2 }) - balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt() + balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() ret := make([]SegmentAssignPlan, 0, len(segments)) for i, s := range segments { plan := SegmentAssignPlan{ diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index abfdaa5e04..7e41e31048 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -65,7 +65,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID return segments[i].GetNumOfRows() > segments[j].GetNumOfRows() }) - balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt() + balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() plans := make([]SegmentAssignPlan, 0, len(segments)) for _, s := range segments { // pick the node with the least row count and allocate to it. diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index e233ca59be..77a46e4b91 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -69,7 +69,7 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 } return normalNode }) - balanceBatchSize = paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt() + balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() } // calculate each node's score @@ -163,7 +163,7 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 } return normalNode }) - balanceBatchSize = paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.GetAsInt() + balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt() } // calculate each node's score @@ -653,7 +653,7 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo channelDist[node] = b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node)) } - balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt() + balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() // find the segment from the node which has more score than the average channelsToMove := make([]*meta.DmChannel, 0) for node, channels := range channelDist { diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index ec4a32c226..70638086b1 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -1371,8 +1371,8 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnDifferentQN() { suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 2) utils.RecoverAllCollection(balancer.meta) - paramtable.Get().Save(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key, "10") - defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key) + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.Key, "10") + defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.Key) // test balance channel on same query node _, channelPlans = suite.getCollectionBalancePlans(balancer, collectionID) diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index de398fad32..5c2aa04944 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -108,10 +108,9 @@ func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int continue } if b.stoppingBalanceCollectionsCurrentRound.Contain(cid) { - log.RatedDebug(10, "BalanceChecker is balancing this collection, skip balancing in this round", - zap.Int64("collectionID", cid)) continue } + replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) stoppingReplicas := make([]int64, 0) for _, replica := range replicas { @@ -208,42 +207,70 @@ func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64 return segmentPlans, channelPlans } +// Notice: balance checker will generate tasks for multiple collections in one round, +// so generated tasks will be submitted to scheduler directly, and return nil func (b *BalanceChecker) Check(ctx context.Context) []task.Task { - var segmentPlans []balance.SegmentAssignPlan - var channelPlans []balance.ChannelAssignPlan + segmentBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() + channelBatchSize := paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt() + balanceOnMultipleCollections := paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool() + + segmentTasks := make([]task.Task, 0) + channelTasks := make([]task.Task, 0) + + generateBalanceTaskForReplicas := func(replicas []int64) { + segmentPlans, channelPlans := b.balanceReplicas(ctx, replicas) + tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans) + task.SetPriority(task.TaskPriorityLow, tasks...) + task.SetReason("segment unbalanced", tasks...) + segmentTasks = append(segmentTasks, tasks...) + + tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans) + task.SetReason("channel unbalanced", tasks...) + channelTasks = append(channelTasks, tasks...) + } + stoppingReplicas := b.getReplicaForStoppingBalance(ctx) if len(stoppingReplicas) > 0 { // check for stopping balance first - segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas) + generateBalanceTaskForReplicas(stoppingReplicas) // iterate all collection to find a collection to balance - for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 { - replicasToBalance := b.getReplicaForStoppingBalance(ctx) - segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance) + for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 { + if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) { + // if balance on multiple collections is disabled, and there are already some tasks, break + break + } + if len(channelTasks) < channelBatchSize { + replicasToBalance := b.getReplicaForStoppingBalance(ctx) + generateBalanceTaskForReplicas(replicasToBalance) + } } } else { // then check for auto balance if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) { b.autoBalanceTs = time.Now() replicasToBalance := b.getReplicaForNormalBalance(ctx) - segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance) + generateBalanceTaskForReplicas(replicasToBalance) // iterate all collection to find a collection to balance - for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 { + for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.normalBalanceCollectionsCurrentRound.Len() > 0 { + if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) { + // if balance on multiple collections is disabled, and there are already some tasks, break + break + } replicasToBalance := b.getReplicaForNormalBalance(ctx) - segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance) + generateBalanceTaskForReplicas(replicasToBalance) } } } - ret := make([]task.Task, 0) - tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans) - task.SetPriority(task.TaskPriorityLow, tasks...) - task.SetReason("segment unbalanced", tasks...) - ret = append(ret, tasks...) + for _, task := range segmentTasks { + b.scheduler.Add(task) + } - tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans) - task.SetReason("channel unbalanced", tasks...) - ret = append(ret, tasks...) - return ret + for _, task := range channelTasks { + b.scheduler.Add(task) + } + + return nil } func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 { @@ -252,10 +279,15 @@ func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int6 sortOrder = "byrowcount" // Default to ByRowCount } + collectionRowCountMap := make(map[int64]int64) + for _, cid := range collections { + collectionRowCountMap[cid] = b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) + } + // Define sorting functions sortByRowCount := func(i, j int) bool { - rowCount1 := b.targetMgr.GetCollectionRowCount(ctx, collections[i], meta.CurrentTargetFirst) - rowCount2 := b.targetMgr.GetCollectionRowCount(ctx, collections[j], meta.CurrentTargetFirst) + rowCount1 := collectionRowCountMap[collections[i]] + rowCount2 := collectionRowCountMap[collections[j]] return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j]) } diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 729f9760a6..280e9ea22b 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -77,6 +77,7 @@ func (suite *BalanceCheckerTestSuite) SetupTest() { suite.meta = meta.NewMeta(idAllocator, store, suite.nodeMgr) suite.broker = meta.NewMockBroker(suite.T()) suite.scheduler = task.NewMockScheduler(suite.T()) + suite.scheduler.EXPECT().Add(mock.Anything).Return(nil).Maybe() suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) suite.balancer = balance.NewMockBalancer(suite.T()) @@ -326,8 +327,15 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { } segPlans = append(segPlans, mockPlan) suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) - tasks := suite.checker.Check(context.TODO()) - suite.Len(tasks, 1) + + tasks := make([]task.Task, 0) + suite.scheduler.ExpectedCalls = nil + suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(task task.Task) error { + tasks = append(tasks, task) + return nil + }) + suite.checker.Check(context.TODO()) + suite.Len(tasks, 2) } func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { @@ -850,6 +858,156 @@ func (suite *BalanceCheckerTestSuite) TestHasUnbalancedCollectionFlag() { "stoppingBalanceCollectionsCurrentRound should contain the collection when it has RO nodes") } +func (suite *BalanceCheckerTestSuite) TestCheckBatchSizesAndMultiCollection() { + ctx := context.Background() + + // Set up nodes + nodeID1, nodeID2 := int64(1), int64(2) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID1, + Address: "localhost", + Hostname: "localhost", + })) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID2, + Address: "localhost", + Hostname: "localhost", + })) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) + + // Create 3 collections + for i := 1; i <= 3; i++ { + cid := int64(i) + replicaID := int64(100 + i) + + collection := utils.CreateTestCollection(cid, int32(replicaID)) + collection.Status = querypb.LoadStatus_Loaded + replica := utils.CreateTestReplica(replicaID, cid, []int64{}) + mutableReplica := replica.CopyForWrite() + mutableReplica.AddRWNode(nodeID1) + mutableReplica.AddRONode(nodeID2) + + suite.checker.meta.CollectionManager.PutCollection(ctx, collection) + suite.checker.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica()) + } + + // Mock target manager + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // All collections have same row count for simplicity + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() + + // For each collection, return different segment plans + suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.AnythingOfType("*meta.Replica")).RunAndReturn( + func(ctx context.Context, replica *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { + // Create 2 segment plans and 1 channel plan per replica + collID := replica.GetCollectionID() + segPlans := make([]balance.SegmentAssignPlan, 0) + chanPlans := make([]balance.ChannelAssignPlan, 0) + + // Create 2 segment plans + for j := 1; j <= 2; j++ { + segID := collID*100 + int64(j) + segPlan := balance.SegmentAssignPlan{ + Segment: utils.CreateTestSegment(segID, collID, 1, 1, 1, "test-channel"), + Replica: replica, + From: nodeID1, + To: nodeID2, + } + segPlans = append(segPlans, segPlan) + } + + // Create 1 channel plan + chanPlan := balance.ChannelAssignPlan{ + Channel: &meta.DmChannel{ + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collID, + ChannelName: "test-channel", + }, + }, + Replica: replica, + From: nodeID1, + To: nodeID2, + } + chanPlans = append(chanPlans, chanPlan) + + return segPlans, chanPlans + }).Maybe() + + // Add tasks to check batch size limits + var addedTasks []task.Task + suite.scheduler.ExpectedCalls = nil + suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + addedTasks = append(addedTasks, t) + return nil + }).Maybe() + + // Test 1: Balance with multiple collections disabled + paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") + paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "false") + // Set batch sizes to large values to test single-collection case + paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "10") + paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "10") + + // Reset test state + suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() + suite.checker.autoBalanceTs = time.Time{} // Reset to trigger auto balance + addedTasks = nil + + // Run the Check method + suite.checker.Check(ctx) + + // Should have tasks for a single collection (2 segment tasks + 1 channel task) + suite.Equal(3, len(addedTasks), "Should have tasks for a single collection when multiple collections balance is disabled") + + // Test 2: Balance with multiple collections enabled + paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "true") + + // Reset test state + suite.checker.autoBalanceTs = time.Time{} + suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() + addedTasks = nil + + // Run the Check method + suite.checker.Check(ctx) + + // Should have tasks for all collections (3 collections * (2 segment tasks + 1 channel task) = 9 tasks) + suite.Equal(9, len(addedTasks), "Should have tasks for all collections when multiple collections balance is enabled") + + // Test 3: Batch size limits + paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "2") + paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "1") + + // Reset test state + suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() + addedTasks = nil + + // Run the Check method + suite.checker.Check(ctx) + + // Should respect batch size limits: 2 segment tasks + 1 channel task = 3 tasks + suite.Equal(3, len(addedTasks), "Should respect batch size limits") + + // Count segment tasks and channel tasks + segmentTaskCount := 0 + channelTaskCount := 0 + for _, t := range addedTasks { + if _, ok := t.(*task.SegmentTask); ok { + segmentTaskCount++ + } else { + channelTaskCount++ + } + } + + suite.LessOrEqual(segmentTaskCount, 2, "Should have at most 2 segment tasks due to batch size limit") + suite.LessOrEqual(channelTaskCount, 1, "Should have at most 1 channel task due to batch size limit") +} + func TestBalanceCheckerSuite(t *testing.T) { suite.Run(t, new(BalanceCheckerTestSuite)) } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index df808215b2..d36c0ee72d 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2091,6 +2091,11 @@ type queryCoordConfig struct { UpdateCollectionLoadStatusInterval ParamItem `refreshable:"false"` ClusterLevelLoadReplicaNumber ParamItem `refreshable:"true"` ClusterLevelLoadResourceGroups ParamItem `refreshable:"true"` + + // balance batch size in one trigger + BalanceSegmentBatchSize ParamItem `refreshable:"true"` + BalanceChannelBatchSize ParamItem `refreshable:"true"` + EnableBalanceOnMultipleCollections ParamItem `refreshable:"true"` } func (p *queryCoordConfig) init(base *BaseTable) { @@ -2685,6 +2690,35 @@ If this parameter is set false, Milvus simply searches the growing segments with Export: true, } p.AutoBalanceInterval.Init(base.mgr) + + p.BalanceSegmentBatchSize = ParamItem{ + Key: "queryCoord.balanceSegmentBatchSize", + FallbackKeys: []string{"queryCoord.collectionBalanceSegmentBatchSize"}, + Version: "2.5.14", + DefaultValue: "5", + Doc: "the max balance task number for segment at each round, which is used for queryCoord to trigger balance on multiple collections", + Export: false, + } + p.BalanceSegmentBatchSize.Init(base.mgr) + + p.BalanceChannelBatchSize = ParamItem{ + Key: "queryCoord.balanceChannelBatchSize", + FallbackKeys: []string{"queryCoord.collectionBalanceChannelBatchSize"}, + Version: "2.5.14", + DefaultValue: "1", + Doc: "the max balance task number for channel at each round, which is used for queryCoord to trigger balance on multiple collections", + Export: false, + } + p.BalanceChannelBatchSize.Init(base.mgr) + + p.EnableBalanceOnMultipleCollections = ParamItem{ + Key: "queryCoord.enableBalanceOnMultipleCollections", + Version: "2.5.14", + DefaultValue: "true", + Doc: "whether enable trigger balance on multiple collections at one time", + Export: false, + } + p.EnableBalanceOnMultipleCollections.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 813cbd47f7..4b2ead31d6 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -381,6 +381,10 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 10, Params.CollectionChannelCountFactor.GetAsInt()) assert.Equal(t, 3000, Params.AutoBalanceInterval.GetAsInt()) + + assert.Equal(t, 5, Params.BalanceSegmentBatchSize.GetAsInt()) + assert.Equal(t, 1, Params.BalanceChannelBatchSize.GetAsInt()) + assert.Equal(t, true, Params.EnableBalanceOnMultipleCollections.GetAsBool()) }) t.Run("test queryNodeConfig", func(t *testing.T) {