enhance: Balance the collection with the largest row count first (#40297)

issue: #37651
this PR enable to balance the collection with largest row count first,
to avoid temporary migration of small table data to new nodes during
their onboarding, only to be moved out again after the large table
balance, which would cause unnecessary load.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-03-31 16:00:19 +08:00 committed by GitHub
parent 15ec7bae4d
commit c02892e9fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 433 additions and 14 deletions

View File

@ -19,6 +19,7 @@ package checkers
import ( import (
"context" "context"
"sort" "sort"
"strings"
"time" "time"
"github.com/samber/lo" "github.com/samber/lo"
@ -85,26 +86,28 @@ func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) b
func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int64 { func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int64 {
ids := b.meta.GetAll(ctx) ids := b.meta.GetAll(ctx)
// Sort collections using the configured sort order
ids = b.sortCollections(ctx, ids)
if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() {
// balance collections influenced by stopping nodes
stoppingReplicas := make([]int64, 0)
for _, cid := range ids { for _, cid := range ids {
// if target and meta isn't ready, skip balance this collection // if target and meta isn't ready, skip balance this collection
if !b.readyToCheck(ctx, cid) { if !b.readyToCheck(ctx, cid) {
continue continue
} }
replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid)
stoppingReplicas := make([]int64, 0)
for _, replica := range replicas { for _, replica := range replicas {
if replica.RONodesCount() > 0 { if replica.RONodesCount() > 0 {
stoppingReplicas = append(stoppingReplicas, replica.GetID()) stoppingReplicas = append(stoppingReplicas, replica.GetID())
} }
} }
}
// do stopping balance only in this round
if len(stoppingReplicas) > 0 { if len(stoppingReplicas) > 0 {
return stoppingReplicas return stoppingReplicas
} }
} }
}
return nil return nil
} }
@ -123,9 +126,6 @@ func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64
collection := b.meta.GetCollection(ctx, cid) collection := b.meta.GetCollection(ctx, cid)
return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded
}) })
sort.Slice(loadedCollections, func(i, j int) bool {
return loadedCollections[i] < loadedCollections[j]
})
// Before performing balancing, check the CurrentTarget/LeaderView/Distribution for all collections. // Before performing balancing, check the CurrentTarget/LeaderView/Distribution for all collections.
// If any collection has unready info, skip the balance operation to avoid inconsistencies. // If any collection has unready info, skip the balance operation to avoid inconsistencies.
@ -138,6 +138,9 @@ func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64
return nil return nil
} }
// Sort collections using the configured sort order
loadedCollections = b.sortCollections(ctx, loadedCollections)
// iterator one normal collection in one round // iterator one normal collection in one round
normalReplicasToBalance := make([]int64, 0) normalReplicasToBalance := make([]int64, 0)
hasUnbalancedCollection := false hasUnbalancedCollection := false
@ -187,6 +190,11 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task {
if len(stoppingReplicas) > 0 { if len(stoppingReplicas) > 0 {
// check for stopping balance first // check for stopping balance first
segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas) segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas)
// iterate all collection to find a collection to balance
for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 {
replicasToBalance := b.getReplicaForStoppingBalance(ctx)
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
}
} else { } else {
// then check for auto balance // then check for auto balance
if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) { if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) {
@ -212,3 +220,37 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task {
ret = append(ret, tasks...) ret = append(ret, tasks...)
return ret return ret
} }
func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 {
sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue())
if sortOrder == "" {
sortOrder = "byrowcount" // Default to ByRowCount
}
// Define sorting functions
sortByRowCount := func(i, j int) bool {
rowCount1 := b.targetMgr.GetCollectionRowCount(ctx, collections[i], meta.CurrentTargetFirst)
rowCount2 := b.targetMgr.GetCollectionRowCount(ctx, collections[j], meta.CurrentTargetFirst)
return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j])
}
sortByCollectionID := func(i, j int) bool {
return collections[i] < collections[j]
}
// Select the appropriate sorting function
var sortFunc func(i, j int) bool
switch sortOrder {
case "byrowcount":
sortFunc = sortByRowCount
case "bycollectionid":
sortFunc = sortByCollectionID
default:
log.Warn("Invalid balance sort order configuration, using default ByRowCount", zap.String("sortOrder", sortOrder))
sortFunc = sortByRowCount
}
// Sort the collections
sort.Slice(collections, sortFunc)
return collections
}

View File

@ -49,7 +49,7 @@ type BalanceCheckerTestSuite struct {
broker *meta.MockBroker broker *meta.MockBroker
nodeMgr *session.NodeManager nodeMgr *session.NodeManager
scheduler *task.MockScheduler scheduler *task.MockScheduler
targetMgr *meta.TargetManager targetMgr meta.TargetManagerInterface
} }
func (suite *BalanceCheckerTestSuite) SetupSuite() { func (suite *BalanceCheckerTestSuite) SetupSuite() {
@ -290,7 +290,7 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica())
// test stopping balance // test stopping balance
idsToBalance := []int64{int64(replicaID1), int64(replicaID2)} idsToBalance := []int64{int64(replicaID1)}
replicasToBalance := suite.checker.getReplicaForStoppingBalance(ctx) replicasToBalance := suite.checker.getReplicaForStoppingBalance(ctx)
suite.ElementsMatch(idsToBalance, replicasToBalance) suite.ElementsMatch(idsToBalance, replicasToBalance)
@ -305,7 +305,7 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
segPlans = append(segPlans, mockPlan) segPlans = append(segPlans, mockPlan)
suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans)
tasks := suite.checker.Check(context.TODO()) tasks := suite.checker.Check(context.TODO())
suite.Len(tasks, 2) suite.Len(tasks, 1)
} }
func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
@ -349,14 +349,16 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
// test normal balance when one collection has unready target // test normal balance when one collection has unready target
mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true) mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true)
mockTarget.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false) mockTarget.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false)
mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe()
replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx)
suite.Len(replicasToBalance, 0) suite.Len(replicasToBalance, 0)
// test stopping balance with target not ready // test stopping balance with target not ready
mockTarget.ExpectedCalls = nil mockTarget.ExpectedCalls = nil
mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(false) mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(false)
mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true) mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true).Maybe()
mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false) mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false).Maybe()
mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe()
mr1 := replica1.CopyForWrite() mr1 := replica1.CopyForWrite()
mr1.AddRONode(1) mr1.AddRONode(1)
suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica())
@ -440,6 +442,299 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceInterval() {
suite.Equal(funcCallCounter.Load(), int64(1)) suite.Equal(funcCallCounter.Load(), int64(1))
} }
func (suite *BalanceCheckerTestSuite) TestBalanceOrder() {
ctx := context.Background()
nodeID1, nodeID2 := int64(1), int64(2)
// set collections meta
cid1, replicaID1, partitionID1 := 1, 1, 1
collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1))
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2})
partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
cid2, replicaID2, partitionID2 := 2, 2, 2
collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2})
partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2))
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
// mock collection row count
mockTargetManager := meta.NewMockTargetManager(suite.T())
suite.checker.targetMgr = mockTargetManager
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(100)).Maybe()
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(200)).Maybe()
mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe()
mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe()
// mock stopping node
mr1 := replica1.CopyForWrite()
mr1.AddRONode(nodeID1)
suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica())
mr2 := replica2.CopyForWrite()
mr2.AddRONode(nodeID2)
suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica())
// test normal balance order
replicas := suite.checker.getReplicaForNormalBalance(ctx)
suite.Equal(replicas, []int64{int64(replicaID2)})
// test stopping balance order
replicas = suite.checker.getReplicaForStoppingBalance(ctx)
suite.Equal(replicas, []int64{int64(replicaID2)})
// mock collection row count
mockTargetManager.ExpectedCalls = nil
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(200)).Maybe()
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(100)).Maybe()
mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe()
mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe()
// test normal balance order
replicas = suite.checker.getReplicaForNormalBalance(ctx)
suite.Equal(replicas, []int64{int64(replicaID1)})
// test stopping balance order
replicas = suite.checker.getReplicaForStoppingBalance(ctx)
suite.Equal(replicas, []int64{int64(replicaID1)})
}
func (suite *BalanceCheckerTestSuite) TestSortCollections() {
ctx := context.Background()
// Set up test collections
cid1, cid2, cid3 := int64(1), int64(2), int64(3)
// Mock the target manager for row count returns
mockTargetManager := meta.NewMockTargetManager(suite.T())
suite.checker.targetMgr = mockTargetManager
// Collection 1: Low ID, High row count
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe()
// Collection 2: Middle ID, Low row count
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe()
// Collection 3: High ID, Middle row count
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe()
collections := []int64{cid1, cid2, cid3}
// Test ByRowCount sorting (default)
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount")
sortedCollections := suite.checker.sortCollections(ctx, collections)
suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Collections should be sorted by row count (highest first)")
// Test ByCollectionID sorting
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID")
sortedCollections = suite.checker.sortCollections(ctx, collections)
suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Collections should be sorted by collection ID (ascending)")
// Test with empty sort order (should default to ByRowCount)
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "")
sortedCollections = suite.checker.sortCollections(ctx, collections)
suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is empty")
// Test with invalid sort order (should default to ByRowCount)
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "InvalidOrder")
sortedCollections = suite.checker.sortCollections(ctx, collections)
suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is invalid")
// Test with mixed case (should be case-insensitive)
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "bYcOlLeCtIoNiD")
sortedCollections = suite.checker.sortCollections(ctx, collections)
suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Should handle case-insensitive sort order names")
// Test with empty collection list
emptyCollections := []int64{}
sortedCollections = suite.checker.sortCollections(ctx, emptyCollections)
suite.Equal([]int64{}, sortedCollections, "Should handle empty collection list")
}
func (suite *BalanceCheckerTestSuite) TestSortCollectionsIntegration() {
ctx := context.Background()
// Set up test collections and nodes
nodeID1 := int64(1)
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID1,
Address: "localhost",
Hostname: "localhost",
}))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1)
// Create two collections to ensure sorting is triggered
cid1, replicaID1 := int64(1), int64(101)
collection1 := utils.CreateTestCollection(cid1, int32(replicaID1))
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1})
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
// Add a second collection with different characteristics
cid2, replicaID2 := int64(2), int64(102)
collection2 := utils.CreateTestCollection(cid2, int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1})
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
// Mock target manager
mockTargetManager := meta.NewMockTargetManager(suite.T())
suite.checker.targetMgr = mockTargetManager
// Setup different row counts to test sorting
// Collection 1 has more rows than Collection 2
var getRowCountCallCount int
mockTargetManager.On("GetCollectionRowCount", mock.Anything, mock.Anything, mock.Anything).
Return(func(ctx context.Context, collectionID int64, scope int32) int64 {
getRowCountCallCount++
if collectionID == cid1 {
return 200 // More rows in collection 1
}
return 100 // Fewer rows in collection 2
})
mockTargetManager.On("IsCurrentTargetReady", mock.Anything, mock.Anything).Return(true)
mockTargetManager.On("IsNextTargetExist", mock.Anything, mock.Anything).Return(true)
// Configure for testing
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true")
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount")
// Clear first to avoid previous test state
suite.checker.normalBalanceCollectionsCurrentRound.Clear()
// Call normal balance
_ = suite.checker.getReplicaForNormalBalance(ctx)
// Verify GetCollectionRowCount was called at least twice (once for each collection)
// This confirms that the collections were sorted
suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during normal balance")
// Reset counter and test stopping balance
getRowCountCallCount = 0
// Set up for stopping balance test
mr1 := replica1.CopyForWrite()
mr1.AddRONode(nodeID1)
suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica())
mr2 := replica2.CopyForWrite()
mr2.AddRONode(nodeID1)
suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica())
paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true")
// Call stopping balance
_ = suite.checker.getReplicaForStoppingBalance(ctx)
// Verify GetCollectionRowCount was called at least twice during stopping balance
suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during stopping balance")
}
func (suite *BalanceCheckerTestSuite) TestBalanceTriggerOrder() {
ctx := context.Background()
// Set up nodes
nodeID1, nodeID2 := int64(1), int64(2)
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID1,
Address: "localhost",
Hostname: "localhost",
}))
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID2,
Address: "localhost",
Hostname: "localhost",
}))
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1)
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2)
// Create collections with different row counts
cid1, replicaID1 := int64(1), int64(101)
collection1 := utils.CreateTestCollection(cid1, int32(replicaID1))
collection1.Status = querypb.LoadStatus_Loaded
replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2})
suite.checker.meta.CollectionManager.PutCollection(ctx, collection1)
suite.checker.meta.ReplicaManager.Put(ctx, replica1)
cid2, replicaID2 := int64(2), int64(102)
collection2 := utils.CreateTestCollection(cid2, int32(replicaID2))
collection2.Status = querypb.LoadStatus_Loaded
replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1, nodeID2})
suite.checker.meta.CollectionManager.PutCollection(ctx, collection2)
suite.checker.meta.ReplicaManager.Put(ctx, replica2)
cid3, replicaID3 := int64(3), int64(103)
collection3 := utils.CreateTestCollection(cid3, int32(replicaID3))
collection3.Status = querypb.LoadStatus_Loaded
replica3 := utils.CreateTestReplica(replicaID3, cid3, []int64{nodeID1, nodeID2})
suite.checker.meta.CollectionManager.PutCollection(ctx, collection3)
suite.checker.meta.ReplicaManager.Put(ctx, replica3)
// Mock the target manager
mockTargetManager := meta.NewMockTargetManager(suite.T())
suite.checker.targetMgr = mockTargetManager
// Set row counts: Collection 1 (highest), Collection 3 (middle), Collection 2 (lowest)
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe()
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe()
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe()
// Mark the current target as ready
mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe()
mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe()
mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe()
// Enable auto balance
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true")
// Test with ByRowCount order (default)
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount")
suite.checker.normalBalanceCollectionsCurrentRound.Clear()
// Normal balance should pick the collection with highest row count first
replicas := suite.checker.getReplicaForNormalBalance(ctx)
suite.Contains(replicas, replicaID1, "Should balance collection with highest row count first")
// Add stopping nodes to test stopping balance
mr1 := replica1.CopyForWrite()
mr1.AddRONode(nodeID1)
suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica())
mr2 := replica2.CopyForWrite()
mr2.AddRONode(nodeID1)
suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica())
mr3 := replica3.CopyForWrite()
mr3.AddRONode(nodeID1)
suite.checker.meta.ReplicaManager.Put(ctx, mr3.IntoReplica())
// Enable stopping balance
paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true")
// Stopping balance should also pick the collection with highest row count first
replicas = suite.checker.getReplicaForStoppingBalance(ctx)
suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with highest row count")
// Test with ByCollectionID order
paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID")
suite.checker.normalBalanceCollectionsCurrentRound.Clear()
// Normal balance should pick the collection with lowest ID first
replicas = suite.checker.getReplicaForNormalBalance(ctx)
suite.Contains(replicas, replicaID1, "Should balance collection with lowest ID first")
// Stopping balance should also pick the collection with lowest ID first
replicas = suite.checker.getReplicaForStoppingBalance(ctx)
suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with lowest ID")
}
func TestBalanceCheckerSuite(t *testing.T) { func TestBalanceCheckerSuite(t *testing.T) {
suite.Run(t, new(BalanceCheckerTestSuite)) suite.Run(t, new(BalanceCheckerTestSuite))
} }

View File

@ -74,6 +74,54 @@ func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(contex
return _c return _c
} }
// GetCollectionRowCount provides a mock function with given fields: ctx, collectionID, scope
func (_m *MockTargetManager) GetCollectionRowCount(ctx context.Context, collectionID int64, scope int32) int64 {
ret := _m.Called(ctx, collectionID, scope)
if len(ret) == 0 {
panic("no return value specified for GetCollectionRowCount")
}
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, int64, int32) int64); ok {
r0 = rf(ctx, collectionID, scope)
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockTargetManager_GetCollectionRowCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionRowCount'
type MockTargetManager_GetCollectionRowCount_Call struct {
*mock.Call
}
// GetCollectionRowCount is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - scope int32
func (_e *MockTargetManager_Expecter) GetCollectionRowCount(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionRowCount_Call {
return &MockTargetManager_GetCollectionRowCount_Call{Call: _e.mock.On("GetCollectionRowCount", ctx, collectionID, scope)}
}
func (_c *MockTargetManager_GetCollectionRowCount_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetCollectionRowCount_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(int32))
})
return _c
}
func (_c *MockTargetManager_GetCollectionRowCount_Call) Return(_a0 int64) *MockTargetManager_GetCollectionRowCount_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockTargetManager_GetCollectionRowCount_Call) RunAndReturn(run func(context.Context, int64, int32) int64) *MockTargetManager_GetCollectionRowCount_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionTargetVersion provides a mock function with given fields: ctx, collectionID, scope // GetCollectionTargetVersion provides a mock function with given fields: ctx, collectionID, scope
func (_m *MockTargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope int32) int64 { func (_m *MockTargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope int32) int64 {
ret := _m.Called(ctx, collectionID, scope) ret := _m.Called(ctx, collectionID, scope)

View File

@ -42,11 +42,15 @@ type CollectionTarget struct {
// record target status, if target has been save before milvus v2.4.19, then the target will lack of segment info. // record target status, if target has been save before milvus v2.4.19, then the target will lack of segment info.
lackSegmentInfo bool lackSegmentInfo bool
// cache collection total row count
totalRowCount int64
} }
func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[string]*DmChannel, partitionIDs []int64) *CollectionTarget { func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[string]*DmChannel, partitionIDs []int64) *CollectionTarget {
channel2Segments := make(map[string][]*datapb.SegmentInfo, len(dmChannels)) channel2Segments := make(map[string][]*datapb.SegmentInfo, len(dmChannels))
partition2Segments := make(map[int64][]*datapb.SegmentInfo, len(partitionIDs)) partition2Segments := make(map[int64][]*datapb.SegmentInfo, len(partitionIDs))
totalRowCount := int64(0)
for _, segment := range segments { for _, segment := range segments {
channel := segment.GetInsertChannel() channel := segment.GetInsertChannel()
if _, ok := channel2Segments[channel]; !ok { if _, ok := channel2Segments[channel]; !ok {
@ -58,6 +62,7 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[
partition2Segments[partitionID] = make([]*datapb.SegmentInfo, 0) partition2Segments[partitionID] = make([]*datapb.SegmentInfo, 0)
} }
partition2Segments[partitionID] = append(partition2Segments[partitionID], segment) partition2Segments[partitionID] = append(partition2Segments[partitionID], segment)
totalRowCount += segment.GetNumOfRows()
} }
return &CollectionTarget{ return &CollectionTarget{
segments: segments, segments: segments,
@ -66,6 +71,7 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[
dmChannels: dmChannels, dmChannels: dmChannels,
partitions: typeutil.NewSet(partitionIDs...), partitions: typeutil.NewSet(partitionIDs...),
version: time.Now().UnixNano(), version: time.Now().UnixNano(),
totalRowCount: totalRowCount,
} }
} }
@ -77,6 +83,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget
var partitions []int64 var partitions []int64
lackSegmentInfo := false lackSegmentInfo := false
totalRowCount := int64(0)
for _, t := range target.GetChannelTargets() { for _, t := range target.GetChannelTargets() {
if _, ok := channel2Segments[t.GetChannelName()]; !ok { if _, ok := channel2Segments[t.GetChannelName()]; !ok {
channel2Segments[t.GetChannelName()] = make([]*datapb.SegmentInfo, 0) channel2Segments[t.GetChannelName()] = make([]*datapb.SegmentInfo, 0)
@ -100,6 +107,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget
segments[segment.GetID()] = info segments[segment.GetID()] = info
channel2Segments[t.GetChannelName()] = append(channel2Segments[t.GetChannelName()], info) channel2Segments[t.GetChannelName()] = append(channel2Segments[t.GetChannelName()], info)
partition2Segments[partition.GetPartitionID()] = append(partition2Segments[partition.GetPartitionID()], info) partition2Segments[partition.GetPartitionID()] = append(partition2Segments[partition.GetPartitionID()], info)
totalRowCount += segment.GetNumOfRows()
} }
partitions = append(partitions, partition.GetPartitionID()) partitions = append(partitions, partition.GetPartitionID())
} }
@ -128,6 +136,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget
partitions: typeutil.NewSet(partitions...), partitions: typeutil.NewSet(partitions...),
version: target.GetVersion(), version: target.GetVersion(),
lackSegmentInfo: lackSegmentInfo, lackSegmentInfo: lackSegmentInfo,
totalRowCount: totalRowCount,
} }
} }
@ -222,6 +231,10 @@ func (p *CollectionTarget) Ready() bool {
return !p.lackSegmentInfo return !p.lackSegmentInfo
} }
func (p *CollectionTarget) GetRowCount() int64 {
return p.totalRowCount
}
type target struct { type target struct {
keyLock *lock.KeyLock[int64] // guards updateCollectionTarget keyLock *lock.KeyLock[int64] // guards updateCollectionTarget
// just maintain target at collection level // just maintain target at collection level

View File

@ -72,6 +72,7 @@ type TargetManagerInterface interface {
GetTargetJSON(ctx context.Context, scope TargetScope, collectionID int64) string GetTargetJSON(ctx context.Context, scope TargetScope, collectionID int64) string
GetPartitions(ctx context.Context, collectionID int64, scope TargetScope) ([]int64, error) GetPartitions(ctx context.Context, collectionID int64, scope TargetScope) ([]int64, error)
IsCurrentTargetReady(ctx context.Context, collectionID int64) bool IsCurrentTargetReady(ctx context.Context, collectionID int64) bool
GetCollectionRowCount(ctx context.Context, collectionID int64, scope TargetScope) int64
} }
type TargetManager struct { type TargetManager struct {
@ -609,3 +610,11 @@ func (mgr *TargetManager) IsCurrentTargetReady(ctx context.Context, collectionID
return target.Ready() return target.Ready()
} }
func (mgr *TargetManager) GetCollectionRowCount(ctx context.Context, collectionID int64, scope TargetScope) int64 {
target := mgr.getCollectionTarget(scope, collectionID)
if len(target) == 0 {
return 0
}
return target[0].GetRowCount()
}

View File

@ -620,6 +620,7 @@ func (suite *TargetManagerSuite) TestRecover() {
suite.Equal(int64(100), segment.GetNumOfRows()) suite.Equal(int64(100), segment.GetNumOfRows())
} }
suite.True(target.Ready()) suite.True(target.Ready())
suite.Equal(int64(200), target.GetRowCount())
// after recover, target info should be cleaned up // after recover, target info should be cleaned up
targets, err := suite.catalog.GetCollectionTargets(ctx) targets, err := suite.catalog.GetCollectionTargets(ctx)

View File

@ -1969,6 +1969,7 @@ type queryCoordConfig struct {
AutoBalance ParamItem `refreshable:"true"` AutoBalance ParamItem `refreshable:"true"`
AutoBalanceChannel ParamItem `refreshable:"true"` AutoBalanceChannel ParamItem `refreshable:"true"`
Balancer ParamItem `refreshable:"true"` Balancer ParamItem `refreshable:"true"`
BalanceTriggerOrder ParamItem `refreshable:"true"`
GlobalRowCountFactor ParamItem `refreshable:"true"` GlobalRowCountFactor ParamItem `refreshable:"true"`
ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"` ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"`
ReverseUnbalanceTolerationFactor ParamItem `refreshable:"true"` ReverseUnbalanceTolerationFactor ParamItem `refreshable:"true"`
@ -2106,6 +2107,16 @@ If this parameter is set false, Milvus simply searches the growing segments with
} }
p.Balancer.Init(base.mgr) p.Balancer.Init(base.mgr)
p.BalanceTriggerOrder = ParamItem{
Key: "queryCoord.balanceTriggerOrder",
Version: "2.5.8",
DefaultValue: "ByRowCount",
PanicIfEmpty: false,
Doc: "sorting order for collection balancing, options: ByRowCount, ByCollectionID",
Export: false,
}
p.BalanceTriggerOrder.Init(base.mgr)
p.GlobalRowCountFactor = ParamItem{ p.GlobalRowCountFactor = ParamItem{
Key: "queryCoord.globalRowCountFactor", Key: "queryCoord.globalRowCountFactor",
Version: "2.0.0", Version: "2.0.0",