enhance: support balancing multiple collections in single trigger (#41875)

issue: #41874
- Optimize balance_checker to support balancing multiple collections
simultaneously
- Add new parameters for segment and channel balancing batch sizes
- Add enableBalanceOnMultipleCollections parameter
- Update tests for balance checker

This change improves resource utilization by allowing the system to
balance multiple collections in a single trigger with configurable batch
sizes.

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-05-21 21:38:25 +08:00 committed by GitHub
parent 9f866dd7c3
commit 4e1208f4f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 259 additions and 31 deletions

View File

@ -92,7 +92,7 @@ func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int
return cnt1+delta1 < cnt2+delta2
})
balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
ret := make([]SegmentAssignPlan, 0, len(segments))
for i, s := range segments {
plan := SegmentAssignPlan{

View File

@ -65,7 +65,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID
return segments[i].GetNumOfRows() > segments[j].GetNumOfRows()
})
balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
plans := make([]SegmentAssignPlan, 0, len(segments))
for _, s := range segments {
// pick the node with the least row count and allocate to it.

View File

@ -69,7 +69,7 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64
}
return normalNode
})
balanceBatchSize = paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
}
// calculate each node's score
@ -163,7 +163,7 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64
}
return normalNode
})
balanceBatchSize = paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.GetAsInt()
balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt()
}
// calculate each node's score
@ -653,7 +653,7 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo
channelDist[node] = b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
}
balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
// find the segment from the node which has more score than the average
channelsToMove := make([]*meta.DmChannel, 0)
for node, channels := range channelDist {

View File

@ -1371,8 +1371,8 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnDifferentQN() {
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 2)
utils.RecoverAllCollection(balancer.meta)
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key, "10")
defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key)
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.Key, "10")
defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.Key)
// test balance channel on same query node
_, channelPlans = suite.getCollectionBalancePlans(balancer, collectionID)

View File

@ -108,10 +108,9 @@ func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int
continue
}
if b.stoppingBalanceCollectionsCurrentRound.Contain(cid) {
log.RatedDebug(10, "BalanceChecker is balancing this collection, skip balancing in this round",
zap.Int64("collectionID", cid))
continue
}
replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid)
stoppingReplicas := make([]int64, 0)
for _, replica := range replicas {
@ -208,42 +207,70 @@ func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64
return segmentPlans, channelPlans
}
// Notice: balance checker will generate tasks for multiple collections in one round,
// so generated tasks will be submitted to scheduler directly, and return nil
func (b *BalanceChecker) Check(ctx context.Context) []task.Task {
var segmentPlans []balance.SegmentAssignPlan
var channelPlans []balance.ChannelAssignPlan
segmentBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
channelBatchSize := paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt()
balanceOnMultipleCollections := paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool()
segmentTasks := make([]task.Task, 0)
channelTasks := make([]task.Task, 0)
generateBalanceTaskForReplicas := func(replicas []int64) {
segmentPlans, channelPlans := b.balanceReplicas(ctx, replicas)
tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans)
task.SetPriority(task.TaskPriorityLow, tasks...)
task.SetReason("segment unbalanced", tasks...)
segmentTasks = append(segmentTasks, tasks...)
tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans)
task.SetReason("channel unbalanced", tasks...)
channelTasks = append(channelTasks, tasks...)
}
stoppingReplicas := b.getReplicaForStoppingBalance(ctx)
if len(stoppingReplicas) > 0 {
// check for stopping balance first
segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas)
generateBalanceTaskForReplicas(stoppingReplicas)
// iterate all collection to find a collection to balance
for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 {
replicasToBalance := b.getReplicaForStoppingBalance(ctx)
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 {
if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) {
// if balance on multiple collections is disabled, and there are already some tasks, break
break
}
if len(channelTasks) < channelBatchSize {
replicasToBalance := b.getReplicaForStoppingBalance(ctx)
generateBalanceTaskForReplicas(replicasToBalance)
}
}
} else {
// then check for auto balance
if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) {
b.autoBalanceTs = time.Now()
replicasToBalance := b.getReplicaForNormalBalance(ctx)
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
generateBalanceTaskForReplicas(replicasToBalance)
// iterate all collection to find a collection to balance
for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 {
for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.normalBalanceCollectionsCurrentRound.Len() > 0 {
if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) {
// if balance on multiple collections is disabled, and there are already some tasks, break
break
}
replicasToBalance := b.getReplicaForNormalBalance(ctx)
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
generateBalanceTaskForReplicas(replicasToBalance)
}
}
}
ret := make([]task.Task, 0)
tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans)
task.SetPriority(task.TaskPriorityLow, tasks...)
task.SetReason("segment unbalanced", tasks...)
ret = append(ret, tasks...)
for _, task := range segmentTasks {
b.scheduler.Add(task)
}
tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans)
task.SetReason("channel unbalanced", tasks...)
ret = append(ret, tasks...)
return ret
for _, task := range channelTasks {
b.scheduler.Add(task)
}
return nil
}
func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 {
@ -252,10 +279,15 @@ func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int6
sortOrder = "byrowcount" // Default to ByRowCount
}
collectionRowCountMap := make(map[int64]int64)
for _, cid := range collections {
collectionRowCountMap[cid] = b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst)
}
// Define sorting functions
sortByRowCount := func(i, j int) bool {
rowCount1 := b.targetMgr.GetCollectionRowCount(ctx, collections[i], meta.CurrentTargetFirst)
rowCount2 := b.targetMgr.GetCollectionRowCount(ctx, collections[j], meta.CurrentTargetFirst)
rowCount1 := collectionRowCountMap[collections[i]]
rowCount2 := collectionRowCountMap[collections[j]]
return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j])
}

View File

@ -77,6 +77,7 @@ func (suite *BalanceCheckerTestSuite) SetupTest() {
suite.meta = meta.NewMeta(idAllocator, store, suite.nodeMgr)
suite.broker = meta.NewMockBroker(suite.T())
suite.scheduler = task.NewMockScheduler(suite.T())
suite.scheduler.EXPECT().Add(mock.Anything).Return(nil).Maybe()
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
suite.balancer = balance.NewMockBalancer(suite.T())
@ -326,8 +327,15 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
}
segPlans = append(segPlans, mockPlan)
suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans)
tasks := suite.checker.Check(context.TODO())
suite.Len(tasks, 1)
tasks := make([]task.Task, 0)
suite.scheduler.ExpectedCalls = nil
suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(task task.Task) error {
tasks = append(tasks, task)
return nil
})
suite.checker.Check(context.TODO())
suite.Len(tasks, 2)
}
func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
@ -850,6 +858,156 @@ func (suite *BalanceCheckerTestSuite) TestHasUnbalancedCollectionFlag() {
"stoppingBalanceCollectionsCurrentRound should contain the collection when it has RO nodes")
}
func (suite *BalanceCheckerTestSuite) TestCheckBatchSizesAndMultiCollection() {
ctx := context.Background()
// Set up nodes
nodeID1, nodeID2 := int64(1), int64(2)
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID1,
Address: "localhost",
Hostname: "localhost",
}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID2,
Address: "localhost",
Hostname: "localhost",
}))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1)
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2)
// Create 3 collections
for i := 1; i <= 3; i++ {
cid := int64(i)
replicaID := int64(100 + i)
collection := utils.CreateTestCollection(cid, int32(replicaID))
collection.Status = querypb.LoadStatus_Loaded
replica := utils.CreateTestReplica(replicaID, cid, []int64{})
mutableReplica := replica.CopyForWrite()
mutableReplica.AddRWNode(nodeID1)
mutableReplica.AddRONode(nodeID2)
suite.checker.meta.CollectionManager.PutCollection(ctx, collection)
suite.checker.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica())
}
// Mock target manager
mockTargetManager := meta.NewMockTargetManager(suite.T())
suite.checker.targetMgr = mockTargetManager
// All collections have same row count for simplicity
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe()
mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe()
mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe()
mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe()
// For each collection, return different segment plans
suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.AnythingOfType("*meta.Replica")).RunAndReturn(
func(ctx context.Context, replica *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) {
// Create 2 segment plans and 1 channel plan per replica
collID := replica.GetCollectionID()
segPlans := make([]balance.SegmentAssignPlan, 0)
chanPlans := make([]balance.ChannelAssignPlan, 0)
// Create 2 segment plans
for j := 1; j <= 2; j++ {
segID := collID*100 + int64(j)
segPlan := balance.SegmentAssignPlan{
Segment: utils.CreateTestSegment(segID, collID, 1, 1, 1, "test-channel"),
Replica: replica,
From: nodeID1,
To: nodeID2,
}
segPlans = append(segPlans, segPlan)
}
// Create 1 channel plan
chanPlan := balance.ChannelAssignPlan{
Channel: &meta.DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: collID,
ChannelName: "test-channel",
},
},
Replica: replica,
From: nodeID1,
To: nodeID2,
}
chanPlans = append(chanPlans, chanPlan)
return segPlans, chanPlans
}).Maybe()
// Add tasks to check batch size limits
var addedTasks []task.Task
suite.scheduler.ExpectedCalls = nil
suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
addedTasks = append(addedTasks, t)
return nil
}).Maybe()
// Test 1: Balance with multiple collections disabled
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true")
paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "false")
// Set batch sizes to large values to test single-collection case
paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "10")
paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "10")
// Reset test state
suite.checker.stoppingBalanceCollectionsCurrentRound.Clear()
suite.checker.autoBalanceTs = time.Time{} // Reset to trigger auto balance
addedTasks = nil
// Run the Check method
suite.checker.Check(ctx)
// Should have tasks for a single collection (2 segment tasks + 1 channel task)
suite.Equal(3, len(addedTasks), "Should have tasks for a single collection when multiple collections balance is disabled")
// Test 2: Balance with multiple collections enabled
paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "true")
// Reset test state
suite.checker.autoBalanceTs = time.Time{}
suite.checker.stoppingBalanceCollectionsCurrentRound.Clear()
addedTasks = nil
// Run the Check method
suite.checker.Check(ctx)
// Should have tasks for all collections (3 collections * (2 segment tasks + 1 channel task) = 9 tasks)
suite.Equal(9, len(addedTasks), "Should have tasks for all collections when multiple collections balance is enabled")
// Test 3: Batch size limits
paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "2")
paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "1")
// Reset test state
suite.checker.stoppingBalanceCollectionsCurrentRound.Clear()
addedTasks = nil
// Run the Check method
suite.checker.Check(ctx)
// Should respect batch size limits: 2 segment tasks + 1 channel task = 3 tasks
suite.Equal(3, len(addedTasks), "Should respect batch size limits")
// Count segment tasks and channel tasks
segmentTaskCount := 0
channelTaskCount := 0
for _, t := range addedTasks {
if _, ok := t.(*task.SegmentTask); ok {
segmentTaskCount++
} else {
channelTaskCount++
}
}
suite.LessOrEqual(segmentTaskCount, 2, "Should have at most 2 segment tasks due to batch size limit")
suite.LessOrEqual(channelTaskCount, 1, "Should have at most 1 channel task due to batch size limit")
}
func TestBalanceCheckerSuite(t *testing.T) {
suite.Run(t, new(BalanceCheckerTestSuite))
}

View File

@ -2091,6 +2091,11 @@ type queryCoordConfig struct {
UpdateCollectionLoadStatusInterval ParamItem `refreshable:"false"`
ClusterLevelLoadReplicaNumber ParamItem `refreshable:"true"`
ClusterLevelLoadResourceGroups ParamItem `refreshable:"true"`
// balance batch size in one trigger
BalanceSegmentBatchSize ParamItem `refreshable:"true"`
BalanceChannelBatchSize ParamItem `refreshable:"true"`
EnableBalanceOnMultipleCollections ParamItem `refreshable:"true"`
}
func (p *queryCoordConfig) init(base *BaseTable) {
@ -2685,6 +2690,35 @@ If this parameter is set false, Milvus simply searches the growing segments with
Export: true,
}
p.AutoBalanceInterval.Init(base.mgr)
p.BalanceSegmentBatchSize = ParamItem{
Key: "queryCoord.balanceSegmentBatchSize",
FallbackKeys: []string{"queryCoord.collectionBalanceSegmentBatchSize"},
Version: "2.5.14",
DefaultValue: "5",
Doc: "the max balance task number for segment at each round, which is used for queryCoord to trigger balance on multiple collections",
Export: false,
}
p.BalanceSegmentBatchSize.Init(base.mgr)
p.BalanceChannelBatchSize = ParamItem{
Key: "queryCoord.balanceChannelBatchSize",
FallbackKeys: []string{"queryCoord.collectionBalanceChannelBatchSize"},
Version: "2.5.14",
DefaultValue: "1",
Doc: "the max balance task number for channel at each round, which is used for queryCoord to trigger balance on multiple collections",
Export: false,
}
p.BalanceChannelBatchSize.Init(base.mgr)
p.EnableBalanceOnMultipleCollections = ParamItem{
Key: "queryCoord.enableBalanceOnMultipleCollections",
Version: "2.5.14",
DefaultValue: "true",
Doc: "whether enable trigger balance on multiple collections at one time",
Export: false,
}
p.EnableBalanceOnMultipleCollections.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////

View File

@ -381,6 +381,10 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, 10, Params.CollectionChannelCountFactor.GetAsInt())
assert.Equal(t, 3000, Params.AutoBalanceInterval.GetAsInt())
assert.Equal(t, 5, Params.BalanceSegmentBatchSize.GetAsInt())
assert.Equal(t, 1, Params.BalanceChannelBatchSize.GetAsInt())
assert.Equal(t, true, Params.EnableBalanceOnMultipleCollections.GetAsBool())
})
t.Run("test queryNodeConfig", func(t *testing.T) {