From 6d4961b9789c0ee52dd19c69fc3c287594503f9c Mon Sep 17 00:00:00 2001 From: wei liu Date: Fri, 19 Sep 2025 17:46:01 +0800 Subject: [PATCH] enhance: Refactor balance checker with priority queue (#43992) issue: #43858 Refactor the balance checker implementation to use priority queues for managing collection balance operations, improving processing efficiency and order control. Changes include: - Export priority queue interfaces (Item, BaseItem, PriorityQueue) - Replace collection round-robin with priority-based queue system - Add BalanceCheckCollectionMaxCount configuration parameter - Optimize balance task generation with batch processing limits - Refactor processBalanceQueue method for different strategies - Enhance test coverage with comprehensive unit tests The new priority queue system processes collections based on row count or collection ID order, providing better control over balance operation priorities and resource utilization. --------- Signed-off-by: Wei Liu --- .../querycoordv2/balance/priority_queue.go | 32 +- .../balance/priority_queue_test.go | 38 +- .../balance/rowcount_based_balancer.go | 20 +- .../balance/score_based_balancer.go | 16 +- .../querycoordv2/checkers/balance_checker.go | 623 ++++-- .../checkers/balance_checker_test.go | 1776 ++++++++--------- pkg/util/paramtable/component_param.go | 11 + pkg/util/paramtable/component_param_test.go | 2 + 8 files changed, 1366 insertions(+), 1152 deletions(-) diff --git a/internal/querycoordv2/balance/priority_queue.go b/internal/querycoordv2/balance/priority_queue.go index 4182f0a73a..465181697f 100644 --- a/internal/querycoordv2/balance/priority_queue.go +++ b/internal/querycoordv2/balance/priority_queue.go @@ -20,24 +20,24 @@ import ( "container/heap" ) -type item interface { +type Item interface { getPriority() int setPriority(priority int) } -type baseItem struct { +type BaseItem struct { priority int } -func (b *baseItem) getPriority() int { +func (b *BaseItem) getPriority() int { return b.priority } -func (b *baseItem) setPriority(priority int) { +func (b *BaseItem) setPriority(priority int) { b.priority = priority } -type heapQueue []item +type heapQueue []Item func (hq heapQueue) Len() int { return len(hq) @@ -52,7 +52,7 @@ func (hq heapQueue) Swap(i, j int) { } func (hq *heapQueue) Push(x any) { - i := x.(item) + i := x.(Item) *hq = append(*hq, i) } @@ -64,22 +64,30 @@ func (hq *heapQueue) Pop() any { return ret } -type priorityQueue struct { +type PriorityQueue struct { heapQueue } -func newPriorityQueue() priorityQueue { +func NewPriorityQueue() PriorityQueue { hq := make(heapQueue, 0) heap.Init(&hq) - return priorityQueue{ + return PriorityQueue{ heapQueue: hq, } } -func (pq *priorityQueue) push(item item) { +func NewPriorityQueuePtr() *PriorityQueue { + hq := make(heapQueue, 0) + heap.Init(&hq) + return &PriorityQueue{ + heapQueue: hq, + } +} + +func (pq *PriorityQueue) Push(item Item) { heap.Push(&pq.heapQueue, item) } -func (pq *priorityQueue) pop() item { - return heap.Pop(&pq.heapQueue).(item) +func (pq *PriorityQueue) Pop() Item { + return heap.Pop(&pq.heapQueue).(Item) } diff --git a/internal/querycoordv2/balance/priority_queue_test.go b/internal/querycoordv2/balance/priority_queue_test.go index eb1fc53d3d..5ead1d16dd 100644 --- a/internal/querycoordv2/balance/priority_queue_test.go +++ b/internal/querycoordv2/balance/priority_queue_test.go @@ -23,74 +23,74 @@ import ( ) func TestMinPriorityQueue(t *testing.T) { - pq := newPriorityQueue() + pq := NewPriorityQueue() for i := 0; i < 5; i++ { priority := i % 3 nodeItem := newNodeItem(priority, int64(i)) - pq.push(&nodeItem) + pq.Push(&nodeItem) } - item := pq.pop() + item := pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(3)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(1)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(4)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 2) println(item.getPriority()) assert.Equal(t, item.(*nodeItem).nodeID, int64(2)) } func TestPopPriorityQueue(t *testing.T) { - pq := newPriorityQueue() + pq := NewPriorityQueue() for i := 0; i < 1; i++ { priority := 1 nodeItem := newNodeItem(priority, int64(i)) - pq.push(&nodeItem) + pq.Push(&nodeItem) } - item := pq.pop() + item := pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) - pq.push(item) + pq.Push(item) // if it's round robin, but not working - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) } func TestMaxPriorityQueue(t *testing.T) { - pq := newPriorityQueue() + pq := NewPriorityQueue() for i := 0; i < 5; i++ { priority := i % 3 nodeItem := newNodeItem(-priority, int64(i)) - pq.push(&nodeItem) + pq.Push(&nodeItem) } - item := pq.pop() + item := pq.Pop() assert.Equal(t, item.getPriority(), -2) assert.Equal(t, item.(*nodeItem).nodeID, int64(2)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), -1) assert.Equal(t, item.(*nodeItem).nodeID, int64(4)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), -1) assert.Equal(t, item.(*nodeItem).nodeID, int64(1)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(3)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) } diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 141ca6a1ff..01cd9647fb 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -57,9 +57,9 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID if len(nodeItems) == 0 { return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItems { - queue.push(item) + queue.Push(item) } sort.Slice(segments, func(i, j int) bool { @@ -70,7 +70,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID plans := make([]SegmentAssignPlan, 0, len(segments)) for _, s := range segments { // pick the node with the least row count and allocate to it. - ni := queue.pop().(*nodeItem) + ni := queue.Pop().(*nodeItem) plan := SegmentAssignPlan{ From: -1, To: ni.nodeID, @@ -82,7 +82,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID } // change node's score and push back ni.AddCurrentScoreDelta(float64(s.GetNumOfRows())) - queue.push(ni) + queue.Push(ni) } return plans } @@ -108,9 +108,9 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItems { - queue.push(item) + queue.Push(item) } plans := make([]ChannelAssignPlan, 0) @@ -125,7 +125,7 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID } if ni == nil { // pick the node with the least channel num and allocate to it. - ni = queue.pop().(*nodeItem) + ni = queue.Pop().(*nodeItem) } plan := ChannelAssignPlan{ From: -1, @@ -135,7 +135,7 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID plans = append(plans, plan) // change node's score and push back ni.AddCurrentScoreDelta(1) - queue.push(ni) + queue.Push(ni) } return plans } @@ -408,7 +408,7 @@ func NewRowCountBasedBalancer( } type nodeItem struct { - baseItem + BaseItem fmt.Stringer nodeID int64 assignedScore float64 @@ -417,7 +417,7 @@ type nodeItem struct { func newNodeItem(currentScore int, nodeID int64) nodeItem { return nodeItem{ - baseItem: baseItem{}, + BaseItem: BaseItem{}, nodeID: nodeID, currentScore: float64(currentScore), } diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 7de55f2be0..9dcf10bf70 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -79,9 +79,9 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItemsMap { - queue.push(item) + queue.Push(item) } // sort segments by segment row count, if segment has same row count, sort by node's score @@ -100,9 +100,9 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 for _, s := range segments { func(s *meta.Segment) { // for each segment, pick the node with the least score - targetNode := queue.pop().(*nodeItem) + targetNode := queue.Pop().(*nodeItem) // make sure candidate is always push back - defer queue.push(targetNode) + defer queue.Push(targetNode) scoreChanges := b.calculateSegmentScore(s) sourceNode := nodeItemsMap[s.Node] @@ -173,9 +173,9 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItemsMap { - queue.push(item) + queue.Push(item) } plans := make([]ChannelAssignPlan, 0, len(channels)) for _, ch := range channels { @@ -193,10 +193,10 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 } // for each channel, pick the node with the least score if targetNode == nil { - targetNode = queue.pop().(*nodeItem) + targetNode = queue.Pop().(*nodeItem) } // make sure candidate is always push back - defer queue.push(targetNode) + defer queue.Push(targetNode) scoreChanges := b.calculateChannelScore(ch, collectionID) sourceNode := nodeItemsMap[ch.Node] diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index f9f656754f..72dbc980d9 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -18,7 +18,6 @@ package checkers import ( "context" - "sort" "strings" "time" @@ -36,22 +35,111 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) +// balanceConfig holds all configuration parameters for balance operations. +// This configuration controls how balance tasks are generated and executed. +type balanceConfig struct { + // segmentBatchSize specifies the maximum number of segment balance tasks to generate in one round + segmentBatchSize int + // channelBatchSize specifies the maximum number of channel balance tasks to generate in one round + channelBatchSize int + // balanceOnMultipleCollections determines whether to balance multiple collections in one round. + // If false, only balance one collection at a time to avoid resource contention + balanceOnMultipleCollections bool + // maxCheckCollectionCount limits the maximum number of collections to check in one round + // to prevent long-running balance operations + maxCheckCollectionCount int + // autoBalanceInterval controls how frequently automatic balance operations are triggered + autoBalanceInterval time.Duration + // segmentTaskTimeout specifies the timeout for segment balance tasks + segmentTaskTimeout time.Duration + // channelTaskTimeout specifies the timeout for channel balance tasks + channelTaskTimeout time.Duration +} + +// This method fetches all balance-related configuration parameters from the global +// parameter table and returns a balanceConfig struct for use in balance operations. +func (b *BalanceChecker) loadBalanceConfig() balanceConfig { + return balanceConfig{ + segmentBatchSize: paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt(), + channelBatchSize: paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt(), + balanceOnMultipleCollections: paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool(), + maxCheckCollectionCount: paramtable.Get().QueryCoordCfg.BalanceCheckCollectionMaxCount.GetAsInt(), + autoBalanceInterval: paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond), + segmentTaskTimeout: paramtable.Get().QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), + channelTaskTimeout: paramtable.Get().QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), + } +} + +// collectionBalanceItem represents a collection in the balance priority queue. +// Each item contains collection metadata and is used to determine the order +// in which collections should be processed for balance operations. +type collectionBalanceItem struct { + *balance.BaseItem + balancePriority int + + // collectionID and rowCount are used to calculate the priority + collectionID int64 + rowCount int + sortOrder string +} + +// The priority is determined by the BalanceTriggerOrder configuration: +// - "byrowcount": Higher row count collections get higher priority (processed first) +// - "bycollectionid": Collections with smaller IDs get higher priority +func newCollectionBalanceItem(collectionID int64, rowCount int, sortOrder string) *collectionBalanceItem { + priority := 0 + if sortOrder == "bycollectionid" { + priority = int(collectionID) + } else { + priority = -rowCount + } + + return &collectionBalanceItem{ + BaseItem: &balance.BaseItem{}, + collectionID: collectionID, + rowCount: rowCount, + sortOrder: sortOrder, + balancePriority: priority, + } +} + +func (c *collectionBalanceItem) getPriority() int { + return c.balancePriority +} + +func (c *collectionBalanceItem) setPriority(priority int) { + c.balancePriority = priority +} + // BalanceChecker checks the cluster distribution and generates balance tasks. +// It is responsible for monitoring the load distribution across query nodes and +// generating segment/channel move tasks to maintain optimal balance. +// +// The BalanceChecker operates in two modes: +// 1. Stopping Balance: High-priority balance for nodes that are being stopped or read-only nodes +// 2. Normal Balance: Regular automatic balance operations to optimize cluster performance +// +// Both modes use priority queues to determine the order in which collections are processed. type BalanceChecker struct { *checkerActivation - meta *meta.Meta - nodeManager *session.NodeManager - scheduler task.Scheduler - targetMgr meta.TargetManagerInterface + meta *meta.Meta + nodeManager *session.NodeManager + scheduler task.Scheduler + targetMgr meta.TargetManagerInterface + // getBalancerFunc returns the appropriate balancer for generating balance plans getBalancerFunc GetBalancerFunc - normalBalanceCollectionsCurrentRound typeutil.UniqueSet - stoppingBalanceCollectionsCurrentRound typeutil.UniqueSet + // normalBalanceQueue maintains collections pending normal balance operations, + // ordered by priority (row count or collection ID) + normalBalanceQueue *balance.PriorityQueue + // stoppingBalanceQueue maintains collections pending stopping balance operations, + // used when nodes are being gracefully stopped + stoppingBalanceQueue *balance.PriorityQueue - // record auto balance ts + // autoBalanceTs records the timestamp of the last auto balance operation + // to ensure balance operations don't happen too frequently autoBalanceTs time.Time } @@ -62,14 +150,14 @@ func NewBalanceChecker(meta *meta.Meta, getBalancerFunc GetBalancerFunc, ) *BalanceChecker { return &BalanceChecker{ - checkerActivation: newCheckerActivation(), - meta: meta, - targetMgr: targetMgr, - nodeManager: nodeMgr, - normalBalanceCollectionsCurrentRound: typeutil.NewUniqueSet(), - stoppingBalanceCollectionsCurrentRound: typeutil.NewUniqueSet(), - scheduler: scheduler, - getBalancerFunc: getBalancerFunc, + checkerActivation: newCheckerActivation(), + meta: meta, + targetMgr: targetMgr, + nodeManager: nodeMgr, + normalBalanceQueue: balance.NewPriorityQueuePtr(), + stoppingBalanceQueue: balance.NewPriorityQueuePtr(), + scheduler: scheduler, + getBalancerFunc: getBalancerFunc, } } @@ -81,6 +169,12 @@ func (b *BalanceChecker) Description() string { return "BalanceChecker checks the cluster distribution and generates balance tasks" } +// readyToCheck determines if a collection is ready for balance operations. +// A collection is considered ready if: +// 1. It exists in the metadata +// 2. It has either a current target or next target defined +// +// Returns true if the collection is ready for balance operations. func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) bool { metaExist := (b.meta.GetCollection(ctx, collectionID) != nil) targetExist := b.targetMgr.IsNextTargetExist(ctx, collectionID) || b.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) @@ -88,117 +182,151 @@ func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) b return metaExist && targetExist } -func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int64 { - hasUnbalancedCollection := false - defer func() { - if !hasUnbalancedCollection { - b.stoppingBalanceCollectionsCurrentRound.Clear() - log.RatedDebug(10, "BalanceChecker has triggered stopping balance for all "+ - "collections in one round, clear collectionIDs for this round") - } - }() +type ReadyForBalanceFilter func(ctx context.Context, collectionID int64) bool +// filterCollectionForBalance filters all collections using the provided filter functions. +// Only collections that pass ALL filter criteria will be included in the result. +// This is used to select collections eligible for balance operations based on +// various conditions like load status, target readiness, etc. +// Returns a slice of collection IDs that pass all filter criteria. +func (b *BalanceChecker) filterCollectionForBalance(ctx context.Context, filter ...ReadyForBalanceFilter) []int64 { ids := b.meta.GetAll(ctx) - // Sort collections using the configured sort order - ids = b.sortCollections(ctx, ids) - - if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - for _, cid := range ids { - // if target and meta isn't ready, skip balance this collection - if !b.readyToCheck(ctx, cid) { - continue - } - if b.stoppingBalanceCollectionsCurrentRound.Contain(cid) { - continue - } - - replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) - stoppingReplicas := make([]int64, 0) - for _, replica := range replicas { - // If there are some delegator work on query node, we need to balance channel to streamingnode forcely. - channelRONodes := make([]int64, 0) - if streamingutil.IsStreamingServiceEnabled() { - _, channelRONodes = utils.GetChannelRWAndRONodesFor260(replica, b.nodeManager) - } - if replica.RONodesCount()+replica.ROSQNodesCount() > 0 || len(channelRONodes) > 0 { - stoppingReplicas = append(stoppingReplicas, replica.GetID()) - } - } - if len(stoppingReplicas) > 0 { - hasUnbalancedCollection = true - b.stoppingBalanceCollectionsCurrentRound.Insert(cid) - return stoppingReplicas + ret := make([]int64, 0) + for _, cid := range ids { + shouldInclude := true + for _, f := range filter { + if !f(ctx, cid) { + shouldInclude = false + break } } + if shouldInclude { + ret = append(ret, cid) + } } - - // finish current round for stopping balance if no unbalanced collection - hasUnbalancedCollection = false - return nil + return ret } -func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64 { - hasUnbalancedCollection := false - defer func() { - if !hasUnbalancedCollection { - b.normalBalanceCollectionsCurrentRound.Clear() - log.RatedDebug(10, "BalanceChecker has triggered normal balance for all "+ - "collections in one round, clear collectionIDs for this round") - } - }() - - // 1. no stopping balance and auto balance is disabled, return empty collections for balance - // 2. when balancer isn't active, skip auto balance - if !Params.QueryCoordCfg.AutoBalance.GetAsBool() || !b.IsActive() { - // finish current round for normal balance if normal balance isn't triggered - hasUnbalancedCollection = false - return nil +// constructStoppingBalanceQueue creates and populates the stopping balance priority queue. +// This queue contains collections that need balance operations due to nodes being stopped. +// Collections are ordered by priority (row count or collection ID based on configuration). +// +// Returns a new priority queue with all eligible collections for stopping balance. +// Note: cause stopping balance need to move out all data from the node, so we need to check all collections. +func (b *BalanceChecker) constructStoppingBalanceQueue(ctx context.Context) *balance.PriorityQueue { + sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) + if sortOrder == "" { + sortOrder = "byrowcount" // Default to ByRowCount } - ids := b.meta.GetAll(ctx) - // all replicas belonging to loading collection will be skipped - loadedCollections := lo.Filter(ids, func(cid int64, _ int) bool { + ret := b.filterCollectionForBalance(ctx, b.readyToCheck) + pq := balance.NewPriorityQueuePtr() + for _, cid := range ret { + rowCount := b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) + item := newCollectionBalanceItem(cid, int(rowCount), sortOrder) + pq.Push(item) + } + b.stoppingBalanceQueue = pq + return pq +} + +// constructNormalBalanceQueue creates and populates the normal balance priority queue. +// This queue contains loaded collections that are ready for regular balance operations. +// Collections must meet multiple criteria: +// 1. Be ready for balance operations (metadata and target exist) +// 2. Have loaded status (actively serving queries) +// 3. Have current target ready (consistent state) +// +// Returns a new priority queue with all eligible collections for normal balance. +func (b *BalanceChecker) constructNormalBalanceQueue(ctx context.Context) *balance.PriorityQueue { + filterLoadedCollections := func(ctx context.Context, cid int64) bool { collection := b.meta.GetCollection(ctx, cid) return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded - }) - - // Before performing balancing, check the CurrentTarget/LeaderView/Distribution for all collections. - // If any collection has unready info, skip the balance operation to avoid inconsistencies. - notReadyCollections := lo.Filter(loadedCollections, func(cid int64, _ int) bool { - // todo: should also check distribution and leader view in the future - return !b.targetMgr.IsCurrentTargetReady(ctx, cid) - }) - if len(notReadyCollections) > 0 { - // finish current round for normal balance if any collection isn't ready - hasUnbalancedCollection = false - log.RatedInfo(10, "skip normal balance, cause collection not ready for balance", zap.Int64s("collectionIDs", notReadyCollections)) - return nil } - // Sort collections using the configured sort order - loadedCollections = b.sortCollections(ctx, loadedCollections) - - // iterator one normal collection in one round - normalReplicasToBalance := make([]int64, 0) - for _, cid := range loadedCollections { - if b.normalBalanceCollectionsCurrentRound.Contain(cid) { - log.RatedDebug(10, "BalanceChecker is balancing this collection, skip balancing in this round", - zap.Int64("collectionID", cid)) - continue - } - hasUnbalancedCollection = true - b.normalBalanceCollectionsCurrentRound.Insert(cid) - for _, replica := range b.meta.ReplicaManager.GetByCollection(ctx, cid) { - normalReplicasToBalance = append(normalReplicasToBalance, replica.GetID()) - } - break + filterTargetReadyCollections := func(ctx context.Context, cid int64) bool { + return b.targetMgr.IsCurrentTargetReady(ctx, cid) } - return normalReplicasToBalance + + sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) + if sortOrder == "" { + sortOrder = "byrowcount" // Default to ByRowCount + } + + ret := b.filterCollectionForBalance(ctx, b.readyToCheck, filterLoadedCollections, filterTargetReadyCollections) + pq := balance.NewPriorityQueuePtr() + for _, cid := range ret { + rowCount := b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) + item := newCollectionBalanceItem(cid, int(rowCount), sortOrder) + pq.Push(item) + } + b.normalBalanceQueue = pq + return pq } -func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { +// getReplicaForStoppingBalance returns replicas that need stopping balance operations. +// A replica needs stopping balance if it has: +// 1. Read-only (RO) nodes that need to be drained +// 2. Read-only streaming query (ROSQ) nodes that need to be drained +// 3. Channel read-only nodes when streaming service is enabled +// +// These replicas need immediate attention to move data off nodes that are being stopped. +// +// Returns a slice of replica IDs that need stopping balance operations. +func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context, collectionID int64) []int64 { + filterReplicaWithRONodes := func(replica *meta.Replica, _ int) bool { + channelRONodes := make([]int64, 0) + if streamingutil.IsStreamingServiceEnabled() { + _, channelRONodes = utils.GetChannelRWAndRONodesFor260(replica, b.nodeManager) + } + return replica.RONodesCount()+replica.ROSQNodesCount() > 0 || len(channelRONodes) > 0 + } + + // filter replicas with RONodes or channelRONodes + replicas := b.meta.ReplicaManager.GetByCollection(ctx, collectionID) + ret := make([]int64, 0) + for _, replica := range replicas { + if filterReplicaWithRONodes(replica, 0) { + ret = append(ret, replica.GetID()) + } + } + return ret +} + +// getReplicaForNormalBalance returns all replicas for a collection for normal balance operations. +// Unlike stopping balance, normal balance considers all replicas regardless of their node status. +// This allows for comprehensive load balancing across the entire collection. +// +// Returns a slice of all replica IDs for the collection. +func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context, collectionID int64) []int64 { + replicas := b.meta.ReplicaManager.GetByCollection(ctx, collectionID) + return lo.Map(replicas, func(replica *meta.Replica, _ int) int64 { + return replica.GetID() + }) +} + +// generateBalanceTasksFromReplicas generates balance tasks for the given replicas. +// This method is the core of the balance operation that: +// 1. Uses the balancer to create segment and channel assignment plans +// 2. Converts these plans into executable tasks +// 3. Sets appropriate priorities and reasons for the tasks +// +// The process involves: +// - Getting balance plans from the configured balancer for each replica +// - Creating segment move tasks from segment assignment plans +// - Creating channel move tasks from channel assignment plans +// - Setting task metadata (priority, reason, timeout) +// +// Returns: +// - segmentTasks: tasks for moving segments between nodes +// - channelTasks: tasks for moving channels between nodes +func (b *BalanceChecker) generateBalanceTasksFromReplicas(ctx context.Context, replicas []int64, config balanceConfig) ([]task.Task, []task.Task) { + if len(replicas) == 0 { + return nil, nil + } + segmentPlans, channelPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0) - for _, rid := range replicaIDs { + for _, rid := range replicas { replica := b.meta.ReplicaManager.Get(ctx, rid) if replica == nil { continue @@ -210,62 +338,103 @@ func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64 balance.PrintNewBalancePlans(replica.GetCollectionID(), replica.GetID(), sPlans, cPlans) } } - return segmentPlans, channelPlans -} - -// Notice: balance checker will generate tasks for multiple collections in one round, -// so generated tasks will be submitted to scheduler directly, and return nil -func (b *BalanceChecker) Check(ctx context.Context) []task.Task { - segmentBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() - channelBatchSize := paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt() - balanceOnMultipleCollections := paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool() segmentTasks := make([]task.Task, 0) channelTasks := make([]task.Task, 0) - - generateBalanceTaskForReplicas := func(replicas []int64) { - segmentPlans, channelPlans := b.balanceReplicas(ctx, replicas) - tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans) - task.SetPriority(task.TaskPriorityLow, tasks...) - task.SetReason("segment unbalanced", tasks...) - segmentTasks = append(segmentTasks, tasks...) - - tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans) - task.SetReason("channel unbalanced", tasks...) - channelTasks = append(channelTasks, tasks...) - } - - stoppingReplicas := b.getReplicaForStoppingBalance(ctx) - if len(stoppingReplicas) > 0 { - // check for stopping balance first - generateBalanceTaskForReplicas(stoppingReplicas) - // iterate all collection to find a collection to balance - for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 { - if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) { - // if balance on multiple collections is disabled, and there are already some tasks, break - break - } - replicasToBalance := b.getReplicaForStoppingBalance(ctx) - generateBalanceTaskForReplicas(replicasToBalance) - } - } else { - // then check for auto balance - if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) { - b.autoBalanceTs = time.Now() - replicasToBalance := b.getReplicaForNormalBalance(ctx) - generateBalanceTaskForReplicas(replicasToBalance) - // iterate all collection to find a collection to balance - for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.normalBalanceCollectionsCurrentRound.Len() > 0 { - if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) { - // if balance on multiple collections is disabled, and there are already some tasks, break - break - } - replicasToBalance := b.getReplicaForNormalBalance(ctx) - generateBalanceTaskForReplicas(replicasToBalance) - } + // Create segment tasks with error handling + if len(segmentPlans) > 0 { + tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), config.segmentTaskTimeout, segmentPlans) + if len(tasks) > 0 { + task.SetPriority(task.TaskPriorityLow, tasks...) + task.SetReason("segment unbalanced", tasks...) + segmentTasks = append(segmentTasks, tasks...) } } + // Create channel tasks with error handling + if len(channelPlans) > 0 { + tasks := balance.CreateChannelTasksFromPlans(ctx, b.ID(), config.channelTaskTimeout, channelPlans) + if len(tasks) > 0 { + task.SetReason("channel unbalanced", tasks...) + channelTasks = append(channelTasks, tasks...) + } + } + + return segmentTasks, channelTasks +} + +// processBalanceQueue processes balance queue with common logic for both normal and stopping balance. +// This is a template method that implements the core queue processing algorithm while allowing +// different balance types to provide their own specific logic through function parameters. +// +// The method implements several safeguards: +// 1. Batch size limits to prevent generating too many tasks at once +// 2. Collection count limits to prevent long-running operations +// 3. Multi-collection balance control to avoid resource contention +// +// Processing flow: +// 1. Get or construct the priority queue for collections +// 2. Pop collections from queue in priority order +// 3. Get replicas that need balance for the collection +// 4. Generate balance tasks for those replicas +// 5. Accumulate tasks until batch limits are reached +// +// Parameters: +// - ctx: context for the operation +// - getReplicasFunc: function to get replicas for a collection (normal vs stopping) +// - constructQueueFunc: function to construct a new priority queue if needed +// - getQueueFunc: function to get the existing priority queue +// - config: balance configuration with batch sizes and limits +// +// Returns: +// - generatedSegmentTaskNum: number of generated segment balance tasks +// - generatedChannelTaskNum: number of generated channel balance tasks +func (b *BalanceChecker) processBalanceQueue( + ctx context.Context, + getReplicasFunc func(context.Context, int64) []int64, + constructQueueFunc func(context.Context) *balance.PriorityQueue, + getQueueFunc func() *balance.PriorityQueue, + config balanceConfig, +) (int, int) { + checkCollectionCount := 0 + pq := getQueueFunc() + if pq == nil || pq.Len() == 0 { + pq = constructQueueFunc(ctx) + } + + generatedSegmentTaskNum := 0 + generatedChannelTaskNum := 0 + + for generatedSegmentTaskNum < config.segmentBatchSize && + generatedChannelTaskNum < config.channelBatchSize && + checkCollectionCount < config.maxCheckCollectionCount && + pq.Len() > 0 { + // Break if balanceOnMultipleCollections is disabled and we already have tasks + if !config.balanceOnMultipleCollections && (generatedSegmentTaskNum > 0 || generatedChannelTaskNum > 0) { + log.Debug("Balance on multiple collections disabled, stopping after first collection") + break + } + + item := pq.Pop().(*collectionBalanceItem) + checkCollectionCount++ + + replicasToBalance := getReplicasFunc(ctx, item.collectionID) + if len(replicasToBalance) == 0 { + continue + } + + newSegmentTasks, newChannelTasks := b.generateBalanceTasksFromReplicas(ctx, replicasToBalance, config) + generatedSegmentTaskNum += len(newSegmentTasks) + generatedChannelTaskNum += len(newChannelTasks) + b.submitTasks(newSegmentTasks, newChannelTasks) + } + return generatedSegmentTaskNum, generatedChannelTaskNum +} + +// submitTasks submits the generated balance tasks to the scheduler for execution. +// This method handles the final step of the balance process by adding all +// generated tasks to the task scheduler, which will execute them asynchronously. +func (b *BalanceChecker) submitTasks(segmentTasks, channelTasks []task.Task) { for _, task := range segmentTasks { b.scheduler.Add(task) } @@ -273,45 +442,99 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { for _, task := range channelTasks { b.scheduler.Add(task) } +} +// Check is the main entry point for balance operations. +// This method implements a two-phase balance strategy with clear priorities: +// +// **Phase 1: Stopping Balance (Higher Priority)** +// - Handles nodes that are being gracefully stopped +// - Moves data off read-only nodes to active nodes +// - Critical for maintaining service availability during node shutdowns +// - Runs immediately when stopping nodes are detected +// +// **Phase 2: Normal Balance (Lower Priority)** +// - Performs regular load balancing to optimize cluster performance +// - Runs periodically based on autoBalanceInterval configuration +// - Considers all collections and distributes load evenly +// - Skipped if stopping balance tasks were generated +// +// **Key Design Decisions:** +// 1. Tasks are submitted directly to scheduler and nil is returned +// (unlike other checkers that return tasks to caller) +// 2. Stopping balance always takes precedence over normal balance +// 3. Performance monitoring alerts for operations > 100ms +// 4. Configuration is loaded fresh each time to respect dynamic updates +// +// **Return Value:** +// Always returns nil because tasks are submitted directly to the scheduler. +// This design allows the balance checker to handle multiple collections +// and large numbers of tasks efficiently. +// +// **Performance Monitoring:** +// The method tracks execution time and logs warnings for slow operations +// to help identify performance bottlenecks in large clusters. +func (b *BalanceChecker) Check(ctx context.Context) []task.Task { + // Skip balance operations if the checker is not active + if !b.IsActive() { + return nil + } + + // Performance monitoring: track execution time + start := time.Now() + defer func() { + duration := time.Since(start) + if duration > 100*time.Millisecond { + log.Info("Balance check too slow", zap.Duration("duration", duration)) + } + }() + + // Load current configuration to respect dynamic parameter changes + config := b.loadBalanceConfig() + + // Phase 1: Process stopping balance first (higher priority) + // This handles nodes that are being gracefully stopped and need immediate attention + if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { + generatedSegmentTaskNum, generatedChannelTaskNum := b.processBalanceQueue(ctx, + b.getReplicaForStoppingBalance, + b.constructStoppingBalanceQueue, + func() *balance.PriorityQueue { return b.stoppingBalanceQueue }, + config) + + if generatedSegmentTaskNum > 0 || generatedChannelTaskNum > 0 { + // clean up the normal balance queue when stopping balance generated tasks + // make sure that next time when trigger normal balance, a new normal balance round will be started + b.normalBalanceQueue = nil + + return nil + } + } + + // Phase 2: Process normal balance if no stopping balance was needed + // This handles regular load balancing operations for cluster optimization + if paramtable.Get().QueryCoordCfg.AutoBalance.GetAsBool() { + // Respect the auto balance interval to prevent too frequent operations + if time.Since(b.autoBalanceTs) <= config.autoBalanceInterval { + return nil + } + + generatedSegmentTaskNum, generatedChannelTaskNum := b.processBalanceQueue(ctx, + b.getReplicaForNormalBalance, + b.constructNormalBalanceQueue, + func() *balance.PriorityQueue { return b.normalBalanceQueue }, + config) + + // Submit normal balance tasks if any were generated + // Update the auto balance timestamp to enforce the interval + if generatedSegmentTaskNum > 0 || generatedChannelTaskNum > 0 { + b.autoBalanceTs = time.Now() + + // clean up the stopping balance queue when normal balance generated tasks + // make sure that next time when trigger stopping balance, a new stopping balance round will be started + b.stoppingBalanceQueue = nil + } + } + + // Always return nil as tasks are submitted directly to scheduler return nil } - -func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 { - sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) - if sortOrder == "" { - sortOrder = "byrowcount" // Default to ByRowCount - } - - collectionRowCountMap := make(map[int64]int64) - for _, cid := range collections { - collectionRowCountMap[cid] = b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) - } - - // Define sorting functions - sortByRowCount := func(i, j int) bool { - rowCount1 := collectionRowCountMap[collections[i]] - rowCount2 := collectionRowCountMap[collections[j]] - return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j]) - } - - sortByCollectionID := func(i, j int) bool { - return collections[i] < collections[j] - } - - // Select the appropriate sorting function - var sortFunc func(i, j int) bool - switch sortOrder { - case "byrowcount": - sortFunc = sortByRowCount - case "bycollectionid": - sortFunc = sortByCollectionID - default: - log.Warn("Invalid balance sort order configuration, using default ByRowCount", zap.String("sortOrder", sortOrder)) - sortFunc = sortByRowCount - } - - // Sort the collections - sort.Slice(collections, sortFunc) - return collections -} diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 280e9ea22b..cdda585f08 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -21,993 +21,963 @@ import ( "testing" "time" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" - "go.uber.org/atomic" + "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/querycoordv2/balance" "github.com/milvus-io/milvus/internal/querycoordv2/meta" - . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" - "github.com/milvus-io/milvus/pkg/v2/kv" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/querypb" - "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) -type BalanceCheckerTestSuite struct { - suite.Suite - kv kv.MetaKv - checker *BalanceChecker - balancer *balance.MockBalancer - meta *meta.Meta - broker *meta.MockBroker - nodeMgr *session.NodeManager - scheduler *task.MockScheduler - targetMgr meta.TargetManagerInterface +// createMockPriorityQueue creates a mock priority queue for testing +func createMockPriorityQueue() *balance.PriorityQueue { + return balance.NewPriorityQueuePtr() } -func (suite *BalanceCheckerTestSuite) SetupSuite() { - paramtable.Init() +// Helper function to create a test BalanceChecker +func createTestBalanceChecker() *BalanceChecker { + metaInstance := &meta.Meta{ + CollectionManager: meta.NewCollectionManager(nil), + } + targetMgr := meta.NewTargetManager(nil, nil) + nodeMgr := &session.NodeManager{} + scheduler := task.NewScheduler(context.Background(), nil, nil, nil, nil, nil, nil) + balancer := balance.NewScoreBasedBalancer(nil, nil, nil, nil, nil) + getBalancerFunc := func() balance.Balance { return balancer } + + return NewBalanceChecker(metaInstance, targetMgr, nodeMgr, scheduler, getBalancerFunc) } -func (suite *BalanceCheckerTestSuite) SetupTest() { - var err error - config := GenerateEtcdConfig() - cli, err := etcd.GetEtcdClient( - config.UseEmbedEtcd.GetAsBool(), - config.EtcdUseSSL.GetAsBool(), - config.Endpoints.GetAsStrings(), - config.EtcdTLSCert.GetValue(), - config.EtcdTLSKey.GetValue(), - config.EtcdTLSCACert.GetValue(), - config.EtcdTLSMinVersion.GetValue()) - suite.Require().NoError(err) - suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) +// ============================================================================= +// Basic Interface Tests +// ============================================================================= - // meta - store := querycoord.NewCatalog(suite.kv) - idAllocator := RandomIncrementIDAllocator() - suite.nodeMgr = session.NewNodeManager() - suite.meta = meta.NewMeta(idAllocator, store, suite.nodeMgr) - suite.broker = meta.NewMockBroker(suite.T()) - suite.scheduler = task.NewMockScheduler(suite.T()) - suite.scheduler.EXPECT().Add(mock.Anything).Return(nil).Maybe() - suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) +func TestBalanceChecker_ID(t *testing.T) { + checker := createTestBalanceChecker() - suite.balancer = balance.NewMockBalancer(suite.T()) - suite.checker = NewBalanceChecker(suite.meta, suite.targetMgr, suite.nodeMgr, suite.scheduler, func() balance.Balance { return suite.balancer }) + id := checker.ID() + assert.Equal(t, utils.BalanceChecker, id) } -func (suite *BalanceCheckerTestSuite) TearDownTest() { - suite.kv.Close() +func TestBalanceChecker_Description(t *testing.T) { + checker := createTestBalanceChecker() + + desc := checker.Description() + assert.Equal(t, "BalanceChecker checks the cluster distribution and generates balance tasks", desc) } -func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { +// ============================================================================= +// Configuration Tests +// ============================================================================= + +func TestBalanceChecker_LoadBalanceConfig(t *testing.T) { + checker := createTestBalanceChecker() + + // Mock paramtable.Get function + mockParamGet := mockey.Mock(paramtable.Get).Return(¶mtable.ComponentParam{}).Build() + defer mockParamGet.UnPatch() + + // Mock various param item methods + mockGetAsInt := mockey.Mock((*paramtable.ParamItem).GetAsInt).Return(5).Build() + defer mockGetAsInt.UnPatch() + + mockGetAsBool := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(true).Build() + defer mockGetAsBool.UnPatch() + + mockGetAsDuration := mockey.Mock((*paramtable.ParamItem).GetAsDuration).Return(5 * time.Second).Build() + defer mockGetAsDuration.UnPatch() + + config := checker.loadBalanceConfig() + + // Verify config structure is returned + assert.IsType(t, balanceConfig{}, config) +} + +// ============================================================================= +// Collection Balance Item Tests +// ============================================================================= + +func TestNewCollectionBalanceItem(t *testing.T) { + collectionID := int64(100) + rowCount := 1000 + sortOrder := "byrowcount" + + item := newCollectionBalanceItem(collectionID, rowCount, sortOrder) + + assert.Equal(t, collectionID, item.collectionID) + assert.Equal(t, rowCount, item.rowCount) + assert.Equal(t, sortOrder, item.sortOrder) +} + +func TestCollectionBalanceItem_GetPriority_ByRowCount(t *testing.T) { + tests := []struct { + name string + rowCount int + sortOrder string + expected int + }{ + {"ByRowCount", 1000, "byrowcount", -1000}, + {"Default", 500, "", -500}, // default to byrowcount + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + item := newCollectionBalanceItem(1, tt.rowCount, tt.sortOrder) + + priority := item.getPriority() + assert.Equal(t, tt.expected, priority) + }) + } +} + +func TestCollectionBalanceItem_GetPriority_ByCollectionID(t *testing.T) { + collectionID := int64(123) + item := newCollectionBalanceItem(collectionID, 1000, "bycollectionid") + + priority := item.getPriority() + assert.Equal(t, int(collectionID), priority) +} + +func TestCollectionBalanceItem_SetPriority(t *testing.T) { + item := newCollectionBalanceItem(1, 100, "byrowcount") + + item.setPriority(200) + + assert.Equal(t, 200, item.getPriority()) +} + +// ============================================================================= +// Collection Filtering Tests +// ============================================================================= + +func TestBalanceChecker_ReadyToCheck_Success(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // set up nodes info - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) + collectionID := int64(1) - // set collections meta - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, - } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, - } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) + // Mock meta.GetCollection to return a valid collection + mockGetCollection := mockey.Mock(mockey.GetMethod(checker.meta.CollectionManager, "GetCollection")).Return(&meta.Collection{}).Build() + defer mockGetCollection.UnPatch() - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + // Mock target manager methods + mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(true).Build() + defer mockIsNextTargetExist.UnPatch() - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) - - // test disable auto balance - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false") - suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { - return 0 - }) - replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicasToBalance) - segPlans, _ := suite.checker.balanceReplicas(ctx, replicasToBalance) - suite.Empty(segPlans) - - // test enable auto balance - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - idsToBalance := []int64{int64(replicaID1)} - replicasToBalance = suite.checker.getReplicaForNormalBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - // next round - idsToBalance = []int64{int64(replicaID2)} - replicasToBalance = suite.checker.getReplicaForNormalBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - // final round - replicasToBalance = suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicasToBalance) + result := checker.readyToCheck(ctx, collectionID) + assert.True(t, result) } -func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { +func TestBalanceChecker_ReadyToCheck_NoMeta(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // set up nodes info - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) + collectionID := int64(1) - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, - } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, - } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) + // Mock meta.GetCollection to return nil + mockGetCollection := mockey.Mock((*meta.Meta).GetCollection).Return(nil).Build() + defer mockGetCollection.UnPatch() - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + // Mock target manager methods to return false + mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(false).Build() + defer mockIsNextTargetExist.UnPatch() - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) + mockIsCurrentTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsCurrentTargetExist")).Return(false).Build() + defer mockIsCurrentTargetExist.UnPatch() - // test scheduler busy - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { - return 1 - }) - replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) - suite.Len(replicasToBalance, 1) + result := checker.readyToCheck(ctx, collectionID) + assert.False(t, result) } -func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { +func TestBalanceChecker_ReadyToCheck_NoTarget(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // set up nodes info, stopping node1 - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Stopping(int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) + collectionID := int64(1) - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, + // Mock meta.GetCollection to return a valid collection + mockGetCollection := mockey.Mock((*meta.Meta).GetCollection).Return(&meta.Collection{}).Build() + defer mockGetCollection.UnPatch() + + // Mock target manager methods to return false + mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(false).Build() + defer mockIsNextTargetExist.UnPatch() + + mockIsCurrentTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsCurrentTargetExist")).Return(false).Build() + defer mockIsCurrentTargetExist.UnPatch() + + result := checker.readyToCheck(ctx, collectionID) + assert.False(t, result) +} + +func TestBalanceChecker_FilterCollectionForBalance_Success(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock meta.GetAll to return collection IDs + collectionIDs := []int64{1, 2, 3} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() + + // Create filters that pass all collections + passAllFilter := func(ctx context.Context, collectionID int64) bool { + return true } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, + + result := checker.filterCollectionForBalance(ctx, passAllFilter) + assert.Equal(t, collectionIDs, result) +} + +func TestBalanceChecker_FilterCollectionForBalance_WithFiltering(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock meta.GetAll to return collection IDs + collectionIDs := []int64{1, 2, 3, 4} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() + + // Create filters: only even numbers pass first filter, only > 2 pass second filter + evenFilter := func(ctx context.Context, collectionID int64) bool { + return collectionID%2 == 0 } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + greaterThanTwoFilter := func(ctx context.Context, collectionID int64) bool { + return collectionID > 2 + } - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) + result := checker.filterCollectionForBalance(ctx, evenFilter, greaterThanTwoFilter) + // Only collection 4 should pass both filters (even AND > 2) + assert.Equal(t, []int64{4}, result) +} - mr1 := replica1.CopyForWrite() - mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) +func TestBalanceChecker_FilterCollectionForBalance_EmptyResult(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() - mr2 := replica2.CopyForWrite() - mr2.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + // Mock meta.GetAll to return collection IDs + collectionIDs := []int64{1, 2, 3} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() - // test stopping balance - // First round: check replica1 - idsToBalance := []int64{int64(replicaID1)} - replicasToBalance := suite.checker.getReplicaForStoppingBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid1))) - suite.False(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid2))) + // Create filter that rejects all + rejectAllFilter := func(ctx context.Context, collectionID int64) bool { + return false + } - // Second round: should skip replica1, check replica2 - idsToBalance = []int64{int64(replicaID2)} - replicasToBalance = suite.checker.getReplicaForStoppingBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid1))) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid2))) + result := checker.filterCollectionForBalance(ctx, rejectAllFilter) + assert.Empty(t, result) +} - // Third round: all collections checked, should return nil and clear the set - replicasToBalance = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Empty(replicasToBalance) - suite.False(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid1))) - suite.False(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid2))) +// ============================================================================= +// Queue Construction Tests +// ============================================================================= - // reset meta for Check test - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - mr1 = replica1.CopyForWrite() - mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) +func TestBalanceChecker_ConstructStoppingBalanceQueue(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() - // checker check - segPlans, chanPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0) - mockPlan := balance.SegmentAssignPlan{ - Segment: utils.CreateTestSegment(1, 1, 1, 1, 1, "1"), - Replica: meta.NilReplica, + // Mock filterCollectionForBalance result + collectionIDs := []int64{1, 2} + mockFilterCollections := mockey.Mock((*BalanceChecker).filterCollectionForBalance).Return(collectionIDs).Build() + defer mockFilterCollections.UnPatch() + + // Mock target manager GetCollectionRowCount + mockGetRowCount := mockey.Mock(mockey.GetMethod(checker.targetMgr, "GetCollectionRowCount")).Return(int64(100)).Build() + defer mockGetRowCount.UnPatch() + + // Mock paramtable for sort order + mockParamValue := mockey.Mock((*paramtable.ParamItem).GetValue).Return("byrowcount").Build() + defer mockParamValue.UnPatch() + + result := checker.constructStoppingBalanceQueue(ctx) + assert.Equal(t, result.Len(), 2) +} + +func TestBalanceChecker_ConstructNormalBalanceQueue(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock filterCollectionForBalance result + collectionIDs := []int64{1, 2} + mockFilterCollections := mockey.Mock((*BalanceChecker).filterCollectionForBalance).Return(collectionIDs).Build() + defer mockFilterCollections.UnPatch() + + // Mock target manager GetCollectionRowCount + mockGetRowCount := mockey.Mock(mockey.GetMethod(checker.targetMgr, "GetCollectionRowCount")).Return(int64(100)).Build() + defer mockGetRowCount.UnPatch() + + // Mock paramtable for sort order + mockParamValue := mockey.Mock((*paramtable.ParamItem).GetValue).Return("byrowcount").Build() + defer mockParamValue.UnPatch() + + result := checker.constructNormalBalanceQueue(ctx) + assert.Equal(t, result.Len(), 2) +} + +// ============================================================================= +// Replica Getting Tests +// ============================================================================= + +func TestBalanceChecker_GetReplicaForStoppingBalance_WithRONodes(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replica2 := &meta.Replica{} + replicas := []*meta.Replica{replica1, replica2} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica methods - replica1 has RO nodes, replica2 doesn't + mockRONodesCount1 := mockey.Mock((*meta.Replica).RONodesCount).Return(1).Build() + defer mockRONodesCount1.UnPatch() + + mockROSQNodesCount1 := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build() + defer mockROSQNodesCount1.UnPatch() + + mockGetID1 := mockey.Mock((*meta.Replica).GetID).Return(int64(101)).Build() + defer mockGetID1.UnPatch() + + // Skip streaming service mock for simplicity + + result := checker.getReplicaForStoppingBalance(ctx, collectionID) + // Should return replica1 ID since it has RO nodes + assert.Contains(t, result, int64(101)) +} + +func TestBalanceChecker_GetReplicaForStoppingBalance_NoRONodes(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replicas := []*meta.Replica{replica1} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica methods - no RO nodes + mockRONodesCount := mockey.Mock((*meta.Replica).RONodesCount).Return(0).Build() + defer mockRONodesCount.UnPatch() + + mockROSQNodesCount := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build() + defer mockROSQNodesCount.UnPatch() + + // Skip streaming service mock for simplicity + + result := checker.getReplicaForStoppingBalance(ctx, collectionID) + assert.Empty(t, result) +} + +func TestBalanceChecker_GetReplicaForNormalBalance(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replica2 := &meta.Replica{} + replicas := []*meta.Replica{replica1, replica2} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica GetID methods + mockGetID := mockey.Mock((*meta.Replica).GetID).Return(mockey.Sequence(101).Times(1).Then(102)).Build() + defer mockGetID.UnPatch() + + result := checker.getReplicaForNormalBalance(ctx, collectionID) + expectedIDs := []int64{101, 102} + assert.ElementsMatch(t, expectedIDs, result) +} + +// ============================================================================= +// Task Generation Tests +// ============================================================================= + +func TestBalanceChecker_GenerateBalanceTasksFromReplicas_EmptyReplicas(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + config := balanceConfig{} + + segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, []int64{}, config) + + assert.Empty(t, segmentTasks) + assert.Empty(t, channelTasks) +} + +func TestBalanceChecker_GenerateBalanceTasksFromReplicas_Success(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + config := balanceConfig{ + segmentTaskTimeout: 30 * time.Second, + channelTaskTimeout: 30 * time.Second, + } + replicaIDs := []int64{101} + + // Create mock replica + mockReplica := &meta.Replica{} + + // Mock ReplicaManager.Get + mockReplicaGet := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "Get")).Return(mockReplica).Build() + defer mockReplicaGet.UnPatch() + + // Create mock balance plans + segmentPlan := balance.SegmentAssignPlan{ + Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1}}, + Replica: mockReplica, From: 1, To: 2, } - segPlans = append(segPlans, mockPlan) - suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) - - tasks := make([]task.Task, 0) - suite.scheduler.ExpectedCalls = nil - suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(task task.Task) error { - tasks = append(tasks, task) - return nil - }) - suite.checker.Check(context.TODO()) - suite.Len(tasks, 2) -} - -func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { - ctx := context.Background() - // set up nodes info, stopping node1 - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Stopping(nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - mockTarget := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTarget - - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - - // test normal balance when one collection has unready target - mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true) - mockTarget.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false) - mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() - replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) - suite.Len(replicasToBalance, 0) - - // test stopping balance with target not ready - mockTarget.ExpectedCalls = nil - mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(false) - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true).Maybe() - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false).Maybe() - mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() - mr1 := replica1.CopyForWrite() - mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - mr2 := replica2.CopyForWrite() - mr2.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) - - idsToBalance := []int64{int64(replicaID1)} - replicasToBalance = suite.checker.getReplicaForStoppingBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) -} - -func (suite *BalanceCheckerTestSuite) TestAutoBalanceInterval() { - ctx := context.Background() - // set up nodes info - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) - - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, - { - ID: 2, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, + channelPlan := balance.ChannelAssignPlan{ + Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{ChannelName: "test"}}, + Replica: mockReplica, + From: 1, + To: 2, } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, + + mockBalancer := mockey.Mock(checker.getBalancerFunc).To(func() balance.Balance { + return balance.NewScoreBasedBalancer(nil, nil, nil, nil, nil) + }).Build() + defer mockBalancer.UnPatch() + + // Mock balancer.BalanceReplica + mockBalanceReplica := mockey.Mock((*balance.ScoreBasedBalancer).BalanceReplica).Return( + []balance.SegmentAssignPlan{segmentPlan}, + []balance.ChannelAssignPlan{channelPlan}, + ).Build() + defer mockBalanceReplica.UnPatch() + + // Mock balance.CreateSegmentTasksFromPlans + mockSegmentTask := &task.SegmentTask{} + mockCreateSegmentTasks := mockey.Mock(balance.CreateSegmentTasksFromPlans).Return([]task.Task{mockSegmentTask}).Build() + defer mockCreateSegmentTasks.UnPatch() + + // Mock balance.CreateChannelTasksFromPlans + mockChannelTask := &task.ChannelTask{} + mockCreateChannelTasks := mockey.Mock(balance.CreateChannelTasksFromPlans).Return([]task.Task{mockChannelTask}).Build() + defer mockCreateChannelTasks.UnPatch() + + // Mock task.SetPriority and task.SetReason + mockSetPriority := mockey.Mock(task.SetPriority).Return().Build() + defer mockSetPriority.UnPatch() + + mockSetReason := mockey.Mock(task.SetReason).Return().Build() + defer mockSetReason.UnPatch() + + // Mock balance.PrintNewBalancePlans + mockPrintPlans := mockey.Mock(balance.PrintNewBalancePlans).Return().Build() + defer mockPrintPlans.UnPatch() + + segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, replicaIDs, config) + + assert.Len(t, segmentTasks, 1) + assert.Len(t, channelTasks, 1) + assert.Equal(t, mockSegmentTask, segmentTasks[0]) + assert.Equal(t, mockChannelTask, channelTasks[0]) +} + +func TestBalanceChecker_GenerateBalanceTasksFromReplicas_NilReplica(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + config := balanceConfig{} + replicaIDs := []int64{101} + + // Mock ReplicaManager.Get to return nil + mockReplicaGet := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "Get")).Return(nil).Build() + defer mockReplicaGet.UnPatch() + + segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, replicaIDs, config) + + assert.Empty(t, segmentTasks) + assert.Empty(t, channelTasks) +} + +// ============================================================================= +// Task Submission Tests +// ============================================================================= + +func TestBalanceChecker_SubmitTasks(t *testing.T) { + checker := createTestBalanceChecker() + + // Create mock tasks + segmentTask := &task.SegmentTask{} + channelTask := &task.ChannelTask{} + segmentTasks := []task.Task{segmentTask} + channelTasks := []task.Task{channelTask} + + // Mock scheduler.Add + mockSchedulerAdd := mockey.Mock(mockey.GetMethod(checker.scheduler, "Add")).Return(nil).Build() + defer mockSchedulerAdd.UnPatch() + + checker.submitTasks(segmentTasks, channelTasks) + + // Verify scheduler.Add was called for both tasks + // This is implicit verification through mockey call tracking +} + +func TestBalanceChecker_SubmitTasks_EmptyTasks(t *testing.T) { + checker := createTestBalanceChecker() + + // Mock scheduler.Add - should not be called + mockSchedulerAdd := mockey.Mock(mockey.GetMethod(checker.scheduler, "Add")).Return(nil).Build() + defer mockSchedulerAdd.UnPatch() + + checker.submitTasks([]task.Task{}, []task.Task{}) + + // No assertions needed - just ensuring no panic with empty tasks +} + +// ============================================================================= +// Main Check Method Tests +// ============================================================================= + +func TestBalanceChecker_Check_InactiveChecker(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock IsActive to return false + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(false).Build() + defer mockIsActive.UnPatch() + + result := checker.Check(ctx) + assert.Nil(t, result) +} + +func TestBalanceChecker_Check_StoppingBalanceEnabled(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock IsActive to return true + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build() + defer mockIsActive.UnPatch() + + // Mock paramtable for enabling stopping balance + mockParamGet := mockey.Mock(paramtable.Get).Return(¶mtable.ComponentParam{}).Build() + defer mockParamGet.UnPatch() + + mockStoppingBalanceEnabled := mockey.Mock(mockey.GetMethod(¶mtable.ParamItem{}, "GetAsBool")).Return(true).Build() + defer mockStoppingBalanceEnabled.UnPatch() + + // Mock loadBalanceConfig + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 5, } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) + mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build() + defer mockLoadConfig.UnPatch() - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + // Mock processBalanceQueue to return tasks + mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).Return( + 1, 0, // segment tasks, channel tasks + ).Build() + defer mockProcessQueue.UnPatch() - funcCallCounter := atomic.NewInt64(0) - suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { - funcCallCounter.Inc() - return nil, nil - }) - - // first auto balance should be triggered - suite.checker.Check(ctx) - suite.Equal(funcCallCounter.Load(), int64(1)) - - // second auto balance won't be triggered due to autoBalanceInterval == 3s - suite.checker.Check(ctx) - suite.Equal(funcCallCounter.Load(), int64(1)) - - // set autoBalanceInterval == 1, sleep 1s, auto balance should be triggered - paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "1000") - paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key) - time.Sleep(1 * time.Second) - suite.checker.Check(ctx) - suite.Equal(funcCallCounter.Load(), int64(1)) + result := checker.Check(ctx) + assert.Nil(t, result) // Always returns nil as tasks are submitted directly + assert.Nil(t, checker.normalBalanceQueue) } -func (suite *BalanceCheckerTestSuite) TestBalanceOrder() { +func TestBalanceChecker_Check_NormalBalanceEnabled(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - nodeID1, nodeID2 := int64(1), int64(2) - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) + // Set autoBalanceTs to allow normal balance + checker.autoBalanceTs = time.Time{} - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) + // Mock IsActive to return true + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build() + defer mockIsActive.UnPatch() - // mock collection row count - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(200)).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + // Mock paramtable - stopping balance disabled, auto balance enabled + mockParamGet := mockey.Mock(paramtable.Get).Return(¶mtable.ComponentParam{}).Build() + defer mockParamGet.UnPatch() - // mock stopping node - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - mr2 := replica2.CopyForWrite() - mr2.AddRONode(nodeID2) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + // return false for stopping balance enabled, true for auto balance enabled + mockParams := mockey.Mock(mockey.GetMethod(¶mtable.ParamItem{}, "GetAsBool")).Return(mockey.Sequence(false).Times(1).Then(true)).Build() + defer mockParams.UnPatch() - // test normal balance order - replicas := suite.checker.getReplicaForNormalBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID2)}) + // Mock loadBalanceConfig + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 5, + autoBalanceInterval: 1 * time.Second, + } + mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build() + defer mockLoadConfig.UnPatch() - // test stopping balance order - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID2)}) + // Mock processBalanceQueue to return tasks + mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).Return( + 0, 1, // segment tasks, channel tasks + ).Build() + defer mockProcessQueue.UnPatch() - // mock collection row count - mockTargetManager.ExpectedCalls = nil - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(200)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - - // test normal balance order - replicas = suite.checker.getReplicaForNormalBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID1)}) - - // test stopping balance order - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID1)}) + result := checker.Check(ctx) + assert.Nil(t, result) // Always returns nil as tasks are submitted directly } -func (suite *BalanceCheckerTestSuite) TestSortCollections() { +// ============================================================================= +// ProcessBalanceQueue Tests +// ============================================================================= + +func TestBalanceChecker_ProcessBalanceQueue_Success(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // Set up test collections - cid1, cid2, cid3 := int64(1), int64(2), int64(3) + // Create mock balance config + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 3, + maxCheckCollectionCount: 5, + balanceOnMultipleCollections: true, + } - // Mock the target manager for row count returns - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager + // Create mock priority queue + mockQueue := createMockPriorityQueue() + // Use real priority queue for simplicity - // Collection 1: Low ID, High row count - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() + mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(2, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(3, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(4, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(5, 100, "byrowcount")) - // Collection 2: Middle ID, Low row count - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() + // Mock getReplicasFunc + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101, 102} + } - // Collection 3: High ID, Middle row count - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() + // Mock constructQueueFunc + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } - collections := []int64{cid1, cid2, cid3} + // Mock getQueueFunc + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } - // Test ByRowCount sorting (default) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") - sortedCollections := suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Collections should be sorted by row count (highest first)") + // Mock generateBalanceTasksFromReplicas + mockSegmentTask := &task.SegmentTask{} + mockChannelTask := &task.ChannelTask{} + mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return( + []task.Task{mockSegmentTask}, []task.Task{mockChannelTask}, + ).Build() + defer mockGenerateTasks.UnPatch() - // Test ByCollectionID sorting - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Collections should be sorted by collection ID (ascending)") + // mock submit tasks + mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build() + defer mockSubmitTasks.UnPatch() - // Test with empty sort order (should default to ByRowCount) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is empty") + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) - // Test with invalid sort order (should default to ByRowCount) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "InvalidOrder") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is invalid") + assert.Equal(t, 3, segmentTasks) + assert.Equal(t, 3, channelTasks) +} - // Test with mixed case (should be case-insensitive) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "bYcOlLeCtIoNiD") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Should handle case-insensitive sort order names") +func TestBalanceChecker_ProcessBalanceQueue_EmptyQueue(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 3, + } + + // Create empty mock priority queue + mockQueue := createMockPriorityQueue() + // Use real priority queue for empty queue testing + + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101} + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + assert.Equal(t, 0, segmentTasks) + assert.Equal(t, 0, channelTasks) +} + +func TestBalanceChecker_ProcessBalanceQueue_BatchSizeLimit(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Set small batch sizes to test limits + config := balanceConfig{ + segmentBatchSize: 1, // Only allow 1 segment task + channelBatchSize: 1, // Only allow 1 channel task + maxCheckCollectionCount: 10, + balanceOnMultipleCollections: true, + } + + // Test batch size limits with simplified logic + + // Create mock priority queue with 2 items + mockQueue := createMockPriorityQueue() + // Use real priority queue for batch size testing + + mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(2, 100, "byrowcount")) + + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101} + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + // Mock generateBalanceTasksFromReplicas to return multiple tasks + mockSegmentTask1 := &task.SegmentTask{} + mockSegmentTask2 := &task.SegmentTask{} + mockChannelTask1 := &task.ChannelTask{} + mockChannelTask2 := &task.ChannelTask{} + + mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return(mockey.Sequence( + []task.Task{mockSegmentTask1}, []task.Task{mockChannelTask1}, + ).Times(1).Then( + []task.Task{mockSegmentTask2}, []task.Task{mockChannelTask2}, + )).Build() + defer mockGenerateTasks.UnPatch() + + // mock submit tasks + mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build() + defer mockSubmitTasks.UnPatch() + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + // Should stop after first collection due to batch size limits + assert.Equal(t, 1, segmentTasks) + assert.Equal(t, 1, channelTasks) +} + +func TestBalanceChecker_ProcessBalanceQueue_MultiCollectionDisabled(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + config := balanceConfig{ + segmentBatchSize: 10, + channelBatchSize: 10, + maxCheckCollectionCount: 10, + balanceOnMultipleCollections: false, // Disabled + } + + mockQueue := createMockPriorityQueue() + // Use real priority queue for multi-collection testing + + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101} + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount")) + + // Mock generateBalanceTasksFromReplicas to return tasks + mockSegmentTask := &task.SegmentTask{} + mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return( + []task.Task{mockSegmentTask}, []task.Task{}, + ).Build() + defer mockGenerateTasks.UnPatch() + + // mock submit tasks + mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build() + defer mockSubmitTasks.UnPatch() + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + // Should stop after first collection due to multi-collection disabled + assert.Equal(t, 1, segmentTasks) + assert.Equal(t, 0, channelTasks) +} + +func TestBalanceChecker_ProcessBalanceQueue_NoReplicasToBalance(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 5, + maxCheckCollectionCount: 5, + balanceOnMultipleCollections: true, + } + + mockQueue := createMockPriorityQueue() + // Use real priority queue for simplicity + + // getReplicasFunc returns empty slice + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{} // No replicas + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + assert.Equal(t, 0, segmentTasks) + assert.Equal(t, 0, channelTasks) +} + +// ============================================================================= +// Performance and Edge Case Tests +// ============================================================================= + +func TestBalanceChecker_CollectionBalanceItem_EdgeCases(t *testing.T) { + // Test with zero row count + item := newCollectionBalanceItem(1, 0, "byrowcount") + assert.Equal(t, 0, item.getPriority()) + + // Test with negative collection ID + item = newCollectionBalanceItem(-1, 100, "bycollectionid") + assert.Equal(t, -1, item.getPriority()) + + // Test with very large values + item = newCollectionBalanceItem(9223372036854775807, 2147483647, "byrowcount") + assert.Equal(t, -2147483647, item.getPriority()) + + // Test with empty sort order (should default to byrowcount) + item = newCollectionBalanceItem(5, 100, "") + assert.Equal(t, -100, item.getPriority()) + + // Test with invalid sort order (should default to byrowcount) + item = newCollectionBalanceItem(5, 100, "invalid") + assert.Equal(t, -100, item.getPriority()) +} + +func TestBalanceChecker_FilterCollectionForBalance_EdgeCases(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() // Test with empty collection list - emptyCollections := []int64{} - sortedCollections = suite.checker.sortCollections(ctx, emptyCollections) - suite.Equal([]int64{}, sortedCollections, "Should handle empty collection list") -} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return([]int64{}).Build() + defer mockGetAll.UnPatch() -func (suite *BalanceCheckerTestSuite) TestSortCollectionsIntegration() { - ctx := context.Background() - - // Set up test collections and nodes - nodeID1 := int64(1) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - - // Create two collections to ensure sorting is triggered - cid1, replicaID1 := int64(1), int64(101) - collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - // Add a second collection with different characteristics - cid2, replicaID2 := int64(2), int64(102) - collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - - // Mock target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - - // Setup different row counts to test sorting - // Collection 1 has more rows than Collection 2 - var getRowCountCallCount int - mockTargetManager.On("GetCollectionRowCount", mock.Anything, mock.Anything, mock.Anything). - Return(func(ctx context.Context, collectionID int64, scope int32) int64 { - getRowCountCallCount++ - if collectionID == cid1 { - return 200 // More rows in collection 1 - } - return 100 // Fewer rows in collection 2 - }) - - mockTargetManager.On("IsCurrentTargetReady", mock.Anything, mock.Anything).Return(true) - mockTargetManager.On("IsNextTargetExist", mock.Anything, mock.Anything).Return(true) - - // Configure for testing - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") - - // Clear first to avoid previous test state - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - - // Call normal balance - _ = suite.checker.getReplicaForNormalBalance(ctx) - - // Verify GetCollectionRowCount was called at least twice (once for each collection) - // This confirms that the collections were sorted - suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during normal balance") - - // Reset counter and test stopping balance - getRowCountCallCount = 0 - - // Set up for stopping balance test - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - mr2 := replica2.CopyForWrite() - mr2.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) - - paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") - - // Call stopping balance - _ = suite.checker.getReplicaForStoppingBalance(ctx) - - // Verify GetCollectionRowCount was called at least twice during stopping balance - suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during stopping balance") -} - -func (suite *BalanceCheckerTestSuite) TestBalanceTriggerOrder() { - ctx := context.Background() - - // Set up nodes - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - // Create collections with different row counts - cid1, replicaID1 := int64(1), int64(101) - collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - cid2, replicaID2 := int64(2), int64(102) - collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - - cid3, replicaID3 := int64(3), int64(103) - collection3 := utils.CreateTestCollection(cid3, int32(replicaID3)) - collection3.Status = querypb.LoadStatus_Loaded - replica3 := utils.CreateTestReplica(replicaID3, cid3, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection3) - suite.checker.meta.ReplicaManager.Put(ctx, replica3) - - // Mock the target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - - // Set row counts: Collection 1 (highest), Collection 3 (middle), Collection 2 (lowest) - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() - - // Mark the current target as ready - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() - - // Enable auto balance - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - - // Test with ByRowCount order (default) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - - // Normal balance should pick the collection with highest row count first - replicas := suite.checker.getReplicaForNormalBalance(ctx) - suite.Contains(replicas, replicaID1, "Should balance collection with highest row count first") - - // Add stopping nodes to test stopping balance - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - mr2 := replica2.CopyForWrite() - mr2.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) - - mr3 := replica3.CopyForWrite() - mr3.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr3.IntoReplica()) - - // Enable stopping balance - paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") - - // Stopping balance should also pick the collection with highest row count first - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with highest row count") - - // Test with ByCollectionID order - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - - // Normal balance should pick the collection with lowest ID first - replicas = suite.checker.getReplicaForNormalBalance(ctx) - suite.Contains(replicas, replicaID1, "Should balance collection with lowest ID first") - - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - // Stopping balance should also pick the collection with lowest ID first - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with lowest ID") -} - -func (suite *BalanceCheckerTestSuite) TestHasUnbalancedCollectionFlag() { - ctx := context.Background() - - // Set up nodes - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - // Create collection - cid1, replicaID1 := int64(1), int64(101) - collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - // Mock the target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe() - - // 1. Test normal balance with auto balance disabled - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false") - - // The collections set should be initially empty - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return nil and keep the set empty - replicas := suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicas) - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len(), - "normalBalanceCollectionsCurrentRound should remain empty when auto balance is disabled") - - // 2. Test normal balance when targetMgr.IsCurrentTargetReady returns false - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false).Maybe() - - // The collections set should be initially empty - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return nil and keep the set empty because of not ready targets - replicas = suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicas) - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len(), - "normalBalanceCollectionsCurrentRound should remain empty when targets are not ready") - - // 3. Test stopping balance when there are no RO nodes - paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() - - // The collections set should be initially empty - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.stoppingBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return nil and keep the set empty because there are no RO nodes - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Empty(replicas) - suite.Equal(0, suite.checker.stoppingBalanceCollectionsCurrentRound.Len(), - "stoppingBalanceCollectionsCurrentRound should remain empty when there are no RO nodes") - - // 4. Test stopping balance with RO nodes - // Add a RO node to the replica - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - // The collections set should be initially empty - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.stoppingBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return the replica ID and add the collection to the set - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Equal([]int64{replicaID1}, replicas) - suite.Equal(1, suite.checker.stoppingBalanceCollectionsCurrentRound.Len()) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(cid1), - "stoppingBalanceCollectionsCurrentRound should contain the collection when it has RO nodes") -} - -func (suite *BalanceCheckerTestSuite) TestCheckBatchSizesAndMultiCollection() { - ctx := context.Background() - - // Set up nodes - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - // Create 3 collections - for i := 1; i <= 3; i++ { - cid := int64(i) - replicaID := int64(100 + i) - - collection := utils.CreateTestCollection(cid, int32(replicaID)) - collection.Status = querypb.LoadStatus_Loaded - replica := utils.CreateTestReplica(replicaID, cid, []int64{}) - mutableReplica := replica.CopyForWrite() - mutableReplica.AddRWNode(nodeID1) - mutableReplica.AddRONode(nodeID2) - - suite.checker.meta.CollectionManager.PutCollection(ctx, collection) - suite.checker.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica()) + passAllFilter := func(ctx context.Context, collectionID int64) bool { + return true } - // Mock target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager + result := checker.filterCollectionForBalance(ctx, passAllFilter) + assert.Empty(t, result) - // All collections have same row count for simplicity - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() + // Test with no filters + collectionIDs := []int64{1, 2, 3} + mockGetAll.UnPatch() + mockGetAll = mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() - // For each collection, return different segment plans - suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.AnythingOfType("*meta.Replica")).RunAndReturn( - func(ctx context.Context, replica *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { - // Create 2 segment plans and 1 channel plan per replica - collID := replica.GetCollectionID() - segPlans := make([]balance.SegmentAssignPlan, 0) - chanPlans := make([]balance.ChannelAssignPlan, 0) - - // Create 2 segment plans - for j := 1; j <= 2; j++ { - segID := collID*100 + int64(j) - segPlan := balance.SegmentAssignPlan{ - Segment: utils.CreateTestSegment(segID, collID, 1, 1, 1, "test-channel"), - Replica: replica, - From: nodeID1, - To: nodeID2, - } - segPlans = append(segPlans, segPlan) - } - - // Create 1 channel plan - chanPlan := balance.ChannelAssignPlan{ - Channel: &meta.DmChannel{ - VchannelInfo: &datapb.VchannelInfo{ - CollectionID: collID, - ChannelName: "test-channel", - }, - }, - Replica: replica, - From: nodeID1, - To: nodeID2, - } - chanPlans = append(chanPlans, chanPlan) - - return segPlans, chanPlans - }).Maybe() - - // Add tasks to check batch size limits - var addedTasks []task.Task - suite.scheduler.ExpectedCalls = nil - suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { - addedTasks = append(addedTasks, t) - return nil - }).Maybe() - - // Test 1: Balance with multiple collections disabled - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "false") - // Set batch sizes to large values to test single-collection case - paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "10") - paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "10") - - // Reset test state - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - suite.checker.autoBalanceTs = time.Time{} // Reset to trigger auto balance - addedTasks = nil - - // Run the Check method - suite.checker.Check(ctx) - - // Should have tasks for a single collection (2 segment tasks + 1 channel task) - suite.Equal(3, len(addedTasks), "Should have tasks for a single collection when multiple collections balance is disabled") - - // Test 2: Balance with multiple collections enabled - paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "true") - - // Reset test state - suite.checker.autoBalanceTs = time.Time{} - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - addedTasks = nil - - // Run the Check method - suite.checker.Check(ctx) - - // Should have tasks for all collections (3 collections * (2 segment tasks + 1 channel task) = 9 tasks) - suite.Equal(9, len(addedTasks), "Should have tasks for all collections when multiple collections balance is enabled") - - // Test 3: Batch size limits - paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "2") - paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "1") - - // Reset test state - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - addedTasks = nil - - // Run the Check method - suite.checker.Check(ctx) - - // Should respect batch size limits: 2 segment tasks + 1 channel task = 3 tasks - suite.Equal(3, len(addedTasks), "Should respect batch size limits") - - // Count segment tasks and channel tasks - segmentTaskCount := 0 - channelTaskCount := 0 - for _, t := range addedTasks { - if _, ok := t.(*task.SegmentTask); ok { - segmentTaskCount++ - } else { - channelTaskCount++ - } - } - - suite.LessOrEqual(segmentTaskCount, 2, "Should have at most 2 segment tasks due to batch size limit") - suite.LessOrEqual(channelTaskCount, 1, "Should have at most 1 channel task due to batch size limit") + result = checker.filterCollectionForBalance(ctx) + assert.Equal(t, collectionIDs, result) // No filters means all pass } -func TestBalanceCheckerSuite(t *testing.T) { - suite.Run(t, new(BalanceCheckerTestSuite)) +// ============================================================================= +// Streaming Service Tests +// ============================================================================= + +func TestBalanceChecker_GetReplicaForStoppingBalance_WithStreamingService(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replicas := []*meta.Replica{replica1} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica methods - no RO nodes but has streaming channel RO nodes + mockRONodesCount := mockey.Mock((*meta.Replica).RONodesCount).Return(0).Build() + defer mockRONodesCount.UnPatch() + + mockROSQNodesCount := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build() + defer mockROSQNodesCount.UnPatch() + + mockGetID := mockey.Mock((*meta.Replica).GetID).Return(int64(101)).Build() + defer mockGetID.UnPatch() + + // streaming service mocks for simplicity + mockIsStreamingServiceEnabled := mockey.Mock(streamingutil.IsStreamingServiceEnabled).Return(true).Build() + defer mockIsStreamingServiceEnabled.UnPatch() + mockGetChannelRWAndRONodesFor260 := mockey.Mock(utils.GetChannelRWAndRONodesFor260).Return([]int64{}, []int64{1}).Build() + defer mockGetChannelRWAndRONodesFor260.UnPatch() + + result := checker.getReplicaForStoppingBalance(ctx, collectionID) + // Should return replica1 ID since it has channel RO nodes + assert.Equal(t, []int64{101}, result) +} + +// ============================================================================= +// Error Handling Tests +// ============================================================================= + +func TestBalanceChecker_Check_TimeoutWarning(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock IsActive to return true + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build() + defer mockIsActive.UnPatch() + + mockProcessBalanceQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To( + func(ctx context.Context, + getReplicasFunc func(ctx context.Context, collectionID int64) []int64, + constructQueueFunc func(ctx context.Context) *balance.PriorityQueue, + getQueueFunc func() *balance.PriorityQueue, config balanceConfig, + ) (int, int) { + time.Sleep(150 * time.Millisecond) + return 0, 0 + }).Build() + defer mockProcessBalanceQueue.UnPatch() + + start := time.Now() + result := checker.Check(ctx) + duration := time.Since(start) + + assert.Nil(t, result) + assert.Greater(t, duration, 100*time.Millisecond) // Should trigger log } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 2735df9122..a5c3ca3a3e 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2347,6 +2347,8 @@ type queryCoordConfig struct { // query node task parallelism factor QueryNodeTaskParallelismFactor ParamItem `refreshable:"true"` + + BalanceCheckCollectionMaxCount ParamItem `refreshable:"true"` } func (p *queryCoordConfig) init(base *BaseTable) { @@ -2979,6 +2981,15 @@ If this parameter is set false, Milvus simply searches the growing segments with Export: false, } p.QueryNodeTaskParallelismFactor.Init(base.mgr) + + p.BalanceCheckCollectionMaxCount = ParamItem{ + Key: "queryCoord.balanceCheckCollectionMaxCount", + Version: "2.6.2", + DefaultValue: "100", + Doc: "the max collection count for each balance check", + Export: false, + } + p.BalanceCheckCollectionMaxCount.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index fef1f7216f..7b8746d85f 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -393,6 +393,8 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 1, Params.QueryNodeTaskParallelismFactor.GetAsInt()) params.Save("queryCoord.queryNodeTaskParallelismFactor", "2") assert.Equal(t, 2, Params.QueryNodeTaskParallelismFactor.GetAsInt()) + + assert.Equal(t, 100, Params.BalanceCheckCollectionMaxCount.GetAsInt()) }) t.Run("test queryNodeConfig", func(t *testing.T) {