diff --git a/internal/querycoordv2/balance/priority_queue.go b/internal/querycoordv2/balance/priority_queue.go index 4182f0a73a..465181697f 100644 --- a/internal/querycoordv2/balance/priority_queue.go +++ b/internal/querycoordv2/balance/priority_queue.go @@ -20,24 +20,24 @@ import ( "container/heap" ) -type item interface { +type Item interface { getPriority() int setPriority(priority int) } -type baseItem struct { +type BaseItem struct { priority int } -func (b *baseItem) getPriority() int { +func (b *BaseItem) getPriority() int { return b.priority } -func (b *baseItem) setPriority(priority int) { +func (b *BaseItem) setPriority(priority int) { b.priority = priority } -type heapQueue []item +type heapQueue []Item func (hq heapQueue) Len() int { return len(hq) @@ -52,7 +52,7 @@ func (hq heapQueue) Swap(i, j int) { } func (hq *heapQueue) Push(x any) { - i := x.(item) + i := x.(Item) *hq = append(*hq, i) } @@ -64,22 +64,30 @@ func (hq *heapQueue) Pop() any { return ret } -type priorityQueue struct { +type PriorityQueue struct { heapQueue } -func newPriorityQueue() priorityQueue { +func NewPriorityQueue() PriorityQueue { hq := make(heapQueue, 0) heap.Init(&hq) - return priorityQueue{ + return PriorityQueue{ heapQueue: hq, } } -func (pq *priorityQueue) push(item item) { +func NewPriorityQueuePtr() *PriorityQueue { + hq := make(heapQueue, 0) + heap.Init(&hq) + return &PriorityQueue{ + heapQueue: hq, + } +} + +func (pq *PriorityQueue) Push(item Item) { heap.Push(&pq.heapQueue, item) } -func (pq *priorityQueue) pop() item { - return heap.Pop(&pq.heapQueue).(item) +func (pq *PriorityQueue) Pop() Item { + return heap.Pop(&pq.heapQueue).(Item) } diff --git a/internal/querycoordv2/balance/priority_queue_test.go b/internal/querycoordv2/balance/priority_queue_test.go index eb1fc53d3d..5ead1d16dd 100644 --- a/internal/querycoordv2/balance/priority_queue_test.go +++ b/internal/querycoordv2/balance/priority_queue_test.go @@ -23,74 +23,74 @@ import ( ) func TestMinPriorityQueue(t *testing.T) { - pq := newPriorityQueue() + pq := NewPriorityQueue() for i := 0; i < 5; i++ { priority := i % 3 nodeItem := newNodeItem(priority, int64(i)) - pq.push(&nodeItem) + pq.Push(&nodeItem) } - item := pq.pop() + item := pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(3)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(1)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(4)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 2) println(item.getPriority()) assert.Equal(t, item.(*nodeItem).nodeID, int64(2)) } func TestPopPriorityQueue(t *testing.T) { - pq := newPriorityQueue() + pq := NewPriorityQueue() for i := 0; i < 1; i++ { priority := 1 nodeItem := newNodeItem(priority, int64(i)) - pq.push(&nodeItem) + pq.Push(&nodeItem) } - item := pq.pop() + item := pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) - pq.push(item) + pq.Push(item) // if it's round robin, but not working - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 1) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) } func TestMaxPriorityQueue(t *testing.T) { - pq := newPriorityQueue() + pq := NewPriorityQueue() for i := 0; i < 5; i++ { priority := i % 3 nodeItem := newNodeItem(-priority, int64(i)) - pq.push(&nodeItem) + pq.Push(&nodeItem) } - item := pq.pop() + item := pq.Pop() assert.Equal(t, item.getPriority(), -2) assert.Equal(t, item.(*nodeItem).nodeID, int64(2)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), -1) assert.Equal(t, item.(*nodeItem).nodeID, int64(4)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), -1) assert.Equal(t, item.(*nodeItem).nodeID, int64(1)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(3)) - item = pq.pop() + item = pq.Pop() assert.Equal(t, item.getPriority(), 0) assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) } diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 141ca6a1ff..01cd9647fb 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -57,9 +57,9 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID if len(nodeItems) == 0 { return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItems { - queue.push(item) + queue.Push(item) } sort.Slice(segments, func(i, j int) bool { @@ -70,7 +70,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID plans := make([]SegmentAssignPlan, 0, len(segments)) for _, s := range segments { // pick the node with the least row count and allocate to it. - ni := queue.pop().(*nodeItem) + ni := queue.Pop().(*nodeItem) plan := SegmentAssignPlan{ From: -1, To: ni.nodeID, @@ -82,7 +82,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID } // change node's score and push back ni.AddCurrentScoreDelta(float64(s.GetNumOfRows())) - queue.push(ni) + queue.Push(ni) } return plans } @@ -108,9 +108,9 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItems { - queue.push(item) + queue.Push(item) } plans := make([]ChannelAssignPlan, 0) @@ -125,7 +125,7 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID } if ni == nil { // pick the node with the least channel num and allocate to it. - ni = queue.pop().(*nodeItem) + ni = queue.Pop().(*nodeItem) } plan := ChannelAssignPlan{ From: -1, @@ -135,7 +135,7 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID plans = append(plans, plan) // change node's score and push back ni.AddCurrentScoreDelta(1) - queue.push(ni) + queue.Push(ni) } return plans } @@ -408,7 +408,7 @@ func NewRowCountBasedBalancer( } type nodeItem struct { - baseItem + BaseItem fmt.Stringer nodeID int64 assignedScore float64 @@ -417,7 +417,7 @@ type nodeItem struct { func newNodeItem(currentScore int, nodeID int64) nodeItem { return nodeItem{ - baseItem: baseItem{}, + BaseItem: BaseItem{}, nodeID: nodeID, currentScore: float64(currentScore), } diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 7de55f2be0..9dcf10bf70 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -79,9 +79,9 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItemsMap { - queue.push(item) + queue.Push(item) } // sort segments by segment row count, if segment has same row count, sort by node's score @@ -100,9 +100,9 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 for _, s := range segments { func(s *meta.Segment) { // for each segment, pick the node with the least score - targetNode := queue.pop().(*nodeItem) + targetNode := queue.Pop().(*nodeItem) // make sure candidate is always push back - defer queue.push(targetNode) + defer queue.Push(targetNode) scoreChanges := b.calculateSegmentScore(s) sourceNode := nodeItemsMap[s.Node] @@ -173,9 +173,9 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 return nil } - queue := newPriorityQueue() + queue := NewPriorityQueue() for _, item := range nodeItemsMap { - queue.push(item) + queue.Push(item) } plans := make([]ChannelAssignPlan, 0, len(channels)) for _, ch := range channels { @@ -193,10 +193,10 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 } // for each channel, pick the node with the least score if targetNode == nil { - targetNode = queue.pop().(*nodeItem) + targetNode = queue.Pop().(*nodeItem) } // make sure candidate is always push back - defer queue.push(targetNode) + defer queue.Push(targetNode) scoreChanges := b.calculateChannelScore(ch, collectionID) sourceNode := nodeItemsMap[ch.Node] diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index f9f656754f..72dbc980d9 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -18,7 +18,6 @@ package checkers import ( "context" - "sort" "strings" "time" @@ -36,22 +35,111 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) +// balanceConfig holds all configuration parameters for balance operations. +// This configuration controls how balance tasks are generated and executed. +type balanceConfig struct { + // segmentBatchSize specifies the maximum number of segment balance tasks to generate in one round + segmentBatchSize int + // channelBatchSize specifies the maximum number of channel balance tasks to generate in one round + channelBatchSize int + // balanceOnMultipleCollections determines whether to balance multiple collections in one round. + // If false, only balance one collection at a time to avoid resource contention + balanceOnMultipleCollections bool + // maxCheckCollectionCount limits the maximum number of collections to check in one round + // to prevent long-running balance operations + maxCheckCollectionCount int + // autoBalanceInterval controls how frequently automatic balance operations are triggered + autoBalanceInterval time.Duration + // segmentTaskTimeout specifies the timeout for segment balance tasks + segmentTaskTimeout time.Duration + // channelTaskTimeout specifies the timeout for channel balance tasks + channelTaskTimeout time.Duration +} + +// This method fetches all balance-related configuration parameters from the global +// parameter table and returns a balanceConfig struct for use in balance operations. +func (b *BalanceChecker) loadBalanceConfig() balanceConfig { + return balanceConfig{ + segmentBatchSize: paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt(), + channelBatchSize: paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt(), + balanceOnMultipleCollections: paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool(), + maxCheckCollectionCount: paramtable.Get().QueryCoordCfg.BalanceCheckCollectionMaxCount.GetAsInt(), + autoBalanceInterval: paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond), + segmentTaskTimeout: paramtable.Get().QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), + channelTaskTimeout: paramtable.Get().QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), + } +} + +// collectionBalanceItem represents a collection in the balance priority queue. +// Each item contains collection metadata and is used to determine the order +// in which collections should be processed for balance operations. +type collectionBalanceItem struct { + *balance.BaseItem + balancePriority int + + // collectionID and rowCount are used to calculate the priority + collectionID int64 + rowCount int + sortOrder string +} + +// The priority is determined by the BalanceTriggerOrder configuration: +// - "byrowcount": Higher row count collections get higher priority (processed first) +// - "bycollectionid": Collections with smaller IDs get higher priority +func newCollectionBalanceItem(collectionID int64, rowCount int, sortOrder string) *collectionBalanceItem { + priority := 0 + if sortOrder == "bycollectionid" { + priority = int(collectionID) + } else { + priority = -rowCount + } + + return &collectionBalanceItem{ + BaseItem: &balance.BaseItem{}, + collectionID: collectionID, + rowCount: rowCount, + sortOrder: sortOrder, + balancePriority: priority, + } +} + +func (c *collectionBalanceItem) getPriority() int { + return c.balancePriority +} + +func (c *collectionBalanceItem) setPriority(priority int) { + c.balancePriority = priority +} + // BalanceChecker checks the cluster distribution and generates balance tasks. +// It is responsible for monitoring the load distribution across query nodes and +// generating segment/channel move tasks to maintain optimal balance. +// +// The BalanceChecker operates in two modes: +// 1. Stopping Balance: High-priority balance for nodes that are being stopped or read-only nodes +// 2. Normal Balance: Regular automatic balance operations to optimize cluster performance +// +// Both modes use priority queues to determine the order in which collections are processed. type BalanceChecker struct { *checkerActivation - meta *meta.Meta - nodeManager *session.NodeManager - scheduler task.Scheduler - targetMgr meta.TargetManagerInterface + meta *meta.Meta + nodeManager *session.NodeManager + scheduler task.Scheduler + targetMgr meta.TargetManagerInterface + // getBalancerFunc returns the appropriate balancer for generating balance plans getBalancerFunc GetBalancerFunc - normalBalanceCollectionsCurrentRound typeutil.UniqueSet - stoppingBalanceCollectionsCurrentRound typeutil.UniqueSet + // normalBalanceQueue maintains collections pending normal balance operations, + // ordered by priority (row count or collection ID) + normalBalanceQueue *balance.PriorityQueue + // stoppingBalanceQueue maintains collections pending stopping balance operations, + // used when nodes are being gracefully stopped + stoppingBalanceQueue *balance.PriorityQueue - // record auto balance ts + // autoBalanceTs records the timestamp of the last auto balance operation + // to ensure balance operations don't happen too frequently autoBalanceTs time.Time } @@ -62,14 +150,14 @@ func NewBalanceChecker(meta *meta.Meta, getBalancerFunc GetBalancerFunc, ) *BalanceChecker { return &BalanceChecker{ - checkerActivation: newCheckerActivation(), - meta: meta, - targetMgr: targetMgr, - nodeManager: nodeMgr, - normalBalanceCollectionsCurrentRound: typeutil.NewUniqueSet(), - stoppingBalanceCollectionsCurrentRound: typeutil.NewUniqueSet(), - scheduler: scheduler, - getBalancerFunc: getBalancerFunc, + checkerActivation: newCheckerActivation(), + meta: meta, + targetMgr: targetMgr, + nodeManager: nodeMgr, + normalBalanceQueue: balance.NewPriorityQueuePtr(), + stoppingBalanceQueue: balance.NewPriorityQueuePtr(), + scheduler: scheduler, + getBalancerFunc: getBalancerFunc, } } @@ -81,6 +169,12 @@ func (b *BalanceChecker) Description() string { return "BalanceChecker checks the cluster distribution and generates balance tasks" } +// readyToCheck determines if a collection is ready for balance operations. +// A collection is considered ready if: +// 1. It exists in the metadata +// 2. It has either a current target or next target defined +// +// Returns true if the collection is ready for balance operations. func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) bool { metaExist := (b.meta.GetCollection(ctx, collectionID) != nil) targetExist := b.targetMgr.IsNextTargetExist(ctx, collectionID) || b.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) @@ -88,117 +182,151 @@ func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) b return metaExist && targetExist } -func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int64 { - hasUnbalancedCollection := false - defer func() { - if !hasUnbalancedCollection { - b.stoppingBalanceCollectionsCurrentRound.Clear() - log.RatedDebug(10, "BalanceChecker has triggered stopping balance for all "+ - "collections in one round, clear collectionIDs for this round") - } - }() +type ReadyForBalanceFilter func(ctx context.Context, collectionID int64) bool +// filterCollectionForBalance filters all collections using the provided filter functions. +// Only collections that pass ALL filter criteria will be included in the result. +// This is used to select collections eligible for balance operations based on +// various conditions like load status, target readiness, etc. +// Returns a slice of collection IDs that pass all filter criteria. +func (b *BalanceChecker) filterCollectionForBalance(ctx context.Context, filter ...ReadyForBalanceFilter) []int64 { ids := b.meta.GetAll(ctx) - // Sort collections using the configured sort order - ids = b.sortCollections(ctx, ids) - - if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - for _, cid := range ids { - // if target and meta isn't ready, skip balance this collection - if !b.readyToCheck(ctx, cid) { - continue - } - if b.stoppingBalanceCollectionsCurrentRound.Contain(cid) { - continue - } - - replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) - stoppingReplicas := make([]int64, 0) - for _, replica := range replicas { - // If there are some delegator work on query node, we need to balance channel to streamingnode forcely. - channelRONodes := make([]int64, 0) - if streamingutil.IsStreamingServiceEnabled() { - _, channelRONodes = utils.GetChannelRWAndRONodesFor260(replica, b.nodeManager) - } - if replica.RONodesCount()+replica.ROSQNodesCount() > 0 || len(channelRONodes) > 0 { - stoppingReplicas = append(stoppingReplicas, replica.GetID()) - } - } - if len(stoppingReplicas) > 0 { - hasUnbalancedCollection = true - b.stoppingBalanceCollectionsCurrentRound.Insert(cid) - return stoppingReplicas + ret := make([]int64, 0) + for _, cid := range ids { + shouldInclude := true + for _, f := range filter { + if !f(ctx, cid) { + shouldInclude = false + break } } + if shouldInclude { + ret = append(ret, cid) + } } - - // finish current round for stopping balance if no unbalanced collection - hasUnbalancedCollection = false - return nil + return ret } -func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64 { - hasUnbalancedCollection := false - defer func() { - if !hasUnbalancedCollection { - b.normalBalanceCollectionsCurrentRound.Clear() - log.RatedDebug(10, "BalanceChecker has triggered normal balance for all "+ - "collections in one round, clear collectionIDs for this round") - } - }() - - // 1. no stopping balance and auto balance is disabled, return empty collections for balance - // 2. when balancer isn't active, skip auto balance - if !Params.QueryCoordCfg.AutoBalance.GetAsBool() || !b.IsActive() { - // finish current round for normal balance if normal balance isn't triggered - hasUnbalancedCollection = false - return nil +// constructStoppingBalanceQueue creates and populates the stopping balance priority queue. +// This queue contains collections that need balance operations due to nodes being stopped. +// Collections are ordered by priority (row count or collection ID based on configuration). +// +// Returns a new priority queue with all eligible collections for stopping balance. +// Note: cause stopping balance need to move out all data from the node, so we need to check all collections. +func (b *BalanceChecker) constructStoppingBalanceQueue(ctx context.Context) *balance.PriorityQueue { + sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) + if sortOrder == "" { + sortOrder = "byrowcount" // Default to ByRowCount } - ids := b.meta.GetAll(ctx) - // all replicas belonging to loading collection will be skipped - loadedCollections := lo.Filter(ids, func(cid int64, _ int) bool { + ret := b.filterCollectionForBalance(ctx, b.readyToCheck) + pq := balance.NewPriorityQueuePtr() + for _, cid := range ret { + rowCount := b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) + item := newCollectionBalanceItem(cid, int(rowCount), sortOrder) + pq.Push(item) + } + b.stoppingBalanceQueue = pq + return pq +} + +// constructNormalBalanceQueue creates and populates the normal balance priority queue. +// This queue contains loaded collections that are ready for regular balance operations. +// Collections must meet multiple criteria: +// 1. Be ready for balance operations (metadata and target exist) +// 2. Have loaded status (actively serving queries) +// 3. Have current target ready (consistent state) +// +// Returns a new priority queue with all eligible collections for normal balance. +func (b *BalanceChecker) constructNormalBalanceQueue(ctx context.Context) *balance.PriorityQueue { + filterLoadedCollections := func(ctx context.Context, cid int64) bool { collection := b.meta.GetCollection(ctx, cid) return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded - }) - - // Before performing balancing, check the CurrentTarget/LeaderView/Distribution for all collections. - // If any collection has unready info, skip the balance operation to avoid inconsistencies. - notReadyCollections := lo.Filter(loadedCollections, func(cid int64, _ int) bool { - // todo: should also check distribution and leader view in the future - return !b.targetMgr.IsCurrentTargetReady(ctx, cid) - }) - if len(notReadyCollections) > 0 { - // finish current round for normal balance if any collection isn't ready - hasUnbalancedCollection = false - log.RatedInfo(10, "skip normal balance, cause collection not ready for balance", zap.Int64s("collectionIDs", notReadyCollections)) - return nil } - // Sort collections using the configured sort order - loadedCollections = b.sortCollections(ctx, loadedCollections) - - // iterator one normal collection in one round - normalReplicasToBalance := make([]int64, 0) - for _, cid := range loadedCollections { - if b.normalBalanceCollectionsCurrentRound.Contain(cid) { - log.RatedDebug(10, "BalanceChecker is balancing this collection, skip balancing in this round", - zap.Int64("collectionID", cid)) - continue - } - hasUnbalancedCollection = true - b.normalBalanceCollectionsCurrentRound.Insert(cid) - for _, replica := range b.meta.ReplicaManager.GetByCollection(ctx, cid) { - normalReplicasToBalance = append(normalReplicasToBalance, replica.GetID()) - } - break + filterTargetReadyCollections := func(ctx context.Context, cid int64) bool { + return b.targetMgr.IsCurrentTargetReady(ctx, cid) } - return normalReplicasToBalance + + sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) + if sortOrder == "" { + sortOrder = "byrowcount" // Default to ByRowCount + } + + ret := b.filterCollectionForBalance(ctx, b.readyToCheck, filterLoadedCollections, filterTargetReadyCollections) + pq := balance.NewPriorityQueuePtr() + for _, cid := range ret { + rowCount := b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) + item := newCollectionBalanceItem(cid, int(rowCount), sortOrder) + pq.Push(item) + } + b.normalBalanceQueue = pq + return pq } -func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { +// getReplicaForStoppingBalance returns replicas that need stopping balance operations. +// A replica needs stopping balance if it has: +// 1. Read-only (RO) nodes that need to be drained +// 2. Read-only streaming query (ROSQ) nodes that need to be drained +// 3. Channel read-only nodes when streaming service is enabled +// +// These replicas need immediate attention to move data off nodes that are being stopped. +// +// Returns a slice of replica IDs that need stopping balance operations. +func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context, collectionID int64) []int64 { + filterReplicaWithRONodes := func(replica *meta.Replica, _ int) bool { + channelRONodes := make([]int64, 0) + if streamingutil.IsStreamingServiceEnabled() { + _, channelRONodes = utils.GetChannelRWAndRONodesFor260(replica, b.nodeManager) + } + return replica.RONodesCount()+replica.ROSQNodesCount() > 0 || len(channelRONodes) > 0 + } + + // filter replicas with RONodes or channelRONodes + replicas := b.meta.ReplicaManager.GetByCollection(ctx, collectionID) + ret := make([]int64, 0) + for _, replica := range replicas { + if filterReplicaWithRONodes(replica, 0) { + ret = append(ret, replica.GetID()) + } + } + return ret +} + +// getReplicaForNormalBalance returns all replicas for a collection for normal balance operations. +// Unlike stopping balance, normal balance considers all replicas regardless of their node status. +// This allows for comprehensive load balancing across the entire collection. +// +// Returns a slice of all replica IDs for the collection. +func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context, collectionID int64) []int64 { + replicas := b.meta.ReplicaManager.GetByCollection(ctx, collectionID) + return lo.Map(replicas, func(replica *meta.Replica, _ int) int64 { + return replica.GetID() + }) +} + +// generateBalanceTasksFromReplicas generates balance tasks for the given replicas. +// This method is the core of the balance operation that: +// 1. Uses the balancer to create segment and channel assignment plans +// 2. Converts these plans into executable tasks +// 3. Sets appropriate priorities and reasons for the tasks +// +// The process involves: +// - Getting balance plans from the configured balancer for each replica +// - Creating segment move tasks from segment assignment plans +// - Creating channel move tasks from channel assignment plans +// - Setting task metadata (priority, reason, timeout) +// +// Returns: +// - segmentTasks: tasks for moving segments between nodes +// - channelTasks: tasks for moving channels between nodes +func (b *BalanceChecker) generateBalanceTasksFromReplicas(ctx context.Context, replicas []int64, config balanceConfig) ([]task.Task, []task.Task) { + if len(replicas) == 0 { + return nil, nil + } + segmentPlans, channelPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0) - for _, rid := range replicaIDs { + for _, rid := range replicas { replica := b.meta.ReplicaManager.Get(ctx, rid) if replica == nil { continue @@ -210,62 +338,103 @@ func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64 balance.PrintNewBalancePlans(replica.GetCollectionID(), replica.GetID(), sPlans, cPlans) } } - return segmentPlans, channelPlans -} - -// Notice: balance checker will generate tasks for multiple collections in one round, -// so generated tasks will be submitted to scheduler directly, and return nil -func (b *BalanceChecker) Check(ctx context.Context) []task.Task { - segmentBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() - channelBatchSize := paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt() - balanceOnMultipleCollections := paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool() segmentTasks := make([]task.Task, 0) channelTasks := make([]task.Task, 0) - - generateBalanceTaskForReplicas := func(replicas []int64) { - segmentPlans, channelPlans := b.balanceReplicas(ctx, replicas) - tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans) - task.SetPriority(task.TaskPriorityLow, tasks...) - task.SetReason("segment unbalanced", tasks...) - segmentTasks = append(segmentTasks, tasks...) - - tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans) - task.SetReason("channel unbalanced", tasks...) - channelTasks = append(channelTasks, tasks...) - } - - stoppingReplicas := b.getReplicaForStoppingBalance(ctx) - if len(stoppingReplicas) > 0 { - // check for stopping balance first - generateBalanceTaskForReplicas(stoppingReplicas) - // iterate all collection to find a collection to balance - for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 { - if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) { - // if balance on multiple collections is disabled, and there are already some tasks, break - break - } - replicasToBalance := b.getReplicaForStoppingBalance(ctx) - generateBalanceTaskForReplicas(replicasToBalance) - } - } else { - // then check for auto balance - if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) { - b.autoBalanceTs = time.Now() - replicasToBalance := b.getReplicaForNormalBalance(ctx) - generateBalanceTaskForReplicas(replicasToBalance) - // iterate all collection to find a collection to balance - for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.normalBalanceCollectionsCurrentRound.Len() > 0 { - if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) { - // if balance on multiple collections is disabled, and there are already some tasks, break - break - } - replicasToBalance := b.getReplicaForNormalBalance(ctx) - generateBalanceTaskForReplicas(replicasToBalance) - } + // Create segment tasks with error handling + if len(segmentPlans) > 0 { + tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), config.segmentTaskTimeout, segmentPlans) + if len(tasks) > 0 { + task.SetPriority(task.TaskPriorityLow, tasks...) + task.SetReason("segment unbalanced", tasks...) + segmentTasks = append(segmentTasks, tasks...) } } + // Create channel tasks with error handling + if len(channelPlans) > 0 { + tasks := balance.CreateChannelTasksFromPlans(ctx, b.ID(), config.channelTaskTimeout, channelPlans) + if len(tasks) > 0 { + task.SetReason("channel unbalanced", tasks...) + channelTasks = append(channelTasks, tasks...) + } + } + + return segmentTasks, channelTasks +} + +// processBalanceQueue processes balance queue with common logic for both normal and stopping balance. +// This is a template method that implements the core queue processing algorithm while allowing +// different balance types to provide their own specific logic through function parameters. +// +// The method implements several safeguards: +// 1. Batch size limits to prevent generating too many tasks at once +// 2. Collection count limits to prevent long-running operations +// 3. Multi-collection balance control to avoid resource contention +// +// Processing flow: +// 1. Get or construct the priority queue for collections +// 2. Pop collections from queue in priority order +// 3. Get replicas that need balance for the collection +// 4. Generate balance tasks for those replicas +// 5. Accumulate tasks until batch limits are reached +// +// Parameters: +// - ctx: context for the operation +// - getReplicasFunc: function to get replicas for a collection (normal vs stopping) +// - constructQueueFunc: function to construct a new priority queue if needed +// - getQueueFunc: function to get the existing priority queue +// - config: balance configuration with batch sizes and limits +// +// Returns: +// - generatedSegmentTaskNum: number of generated segment balance tasks +// - generatedChannelTaskNum: number of generated channel balance tasks +func (b *BalanceChecker) processBalanceQueue( + ctx context.Context, + getReplicasFunc func(context.Context, int64) []int64, + constructQueueFunc func(context.Context) *balance.PriorityQueue, + getQueueFunc func() *balance.PriorityQueue, + config balanceConfig, +) (int, int) { + checkCollectionCount := 0 + pq := getQueueFunc() + if pq == nil || pq.Len() == 0 { + pq = constructQueueFunc(ctx) + } + + generatedSegmentTaskNum := 0 + generatedChannelTaskNum := 0 + + for generatedSegmentTaskNum < config.segmentBatchSize && + generatedChannelTaskNum < config.channelBatchSize && + checkCollectionCount < config.maxCheckCollectionCount && + pq.Len() > 0 { + // Break if balanceOnMultipleCollections is disabled and we already have tasks + if !config.balanceOnMultipleCollections && (generatedSegmentTaskNum > 0 || generatedChannelTaskNum > 0) { + log.Debug("Balance on multiple collections disabled, stopping after first collection") + break + } + + item := pq.Pop().(*collectionBalanceItem) + checkCollectionCount++ + + replicasToBalance := getReplicasFunc(ctx, item.collectionID) + if len(replicasToBalance) == 0 { + continue + } + + newSegmentTasks, newChannelTasks := b.generateBalanceTasksFromReplicas(ctx, replicasToBalance, config) + generatedSegmentTaskNum += len(newSegmentTasks) + generatedChannelTaskNum += len(newChannelTasks) + b.submitTasks(newSegmentTasks, newChannelTasks) + } + return generatedSegmentTaskNum, generatedChannelTaskNum +} + +// submitTasks submits the generated balance tasks to the scheduler for execution. +// This method handles the final step of the balance process by adding all +// generated tasks to the task scheduler, which will execute them asynchronously. +func (b *BalanceChecker) submitTasks(segmentTasks, channelTasks []task.Task) { for _, task := range segmentTasks { b.scheduler.Add(task) } @@ -273,45 +442,99 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { for _, task := range channelTasks { b.scheduler.Add(task) } +} +// Check is the main entry point for balance operations. +// This method implements a two-phase balance strategy with clear priorities: +// +// **Phase 1: Stopping Balance (Higher Priority)** +// - Handles nodes that are being gracefully stopped +// - Moves data off read-only nodes to active nodes +// - Critical for maintaining service availability during node shutdowns +// - Runs immediately when stopping nodes are detected +// +// **Phase 2: Normal Balance (Lower Priority)** +// - Performs regular load balancing to optimize cluster performance +// - Runs periodically based on autoBalanceInterval configuration +// - Considers all collections and distributes load evenly +// - Skipped if stopping balance tasks were generated +// +// **Key Design Decisions:** +// 1. Tasks are submitted directly to scheduler and nil is returned +// (unlike other checkers that return tasks to caller) +// 2. Stopping balance always takes precedence over normal balance +// 3. Performance monitoring alerts for operations > 100ms +// 4. Configuration is loaded fresh each time to respect dynamic updates +// +// **Return Value:** +// Always returns nil because tasks are submitted directly to the scheduler. +// This design allows the balance checker to handle multiple collections +// and large numbers of tasks efficiently. +// +// **Performance Monitoring:** +// The method tracks execution time and logs warnings for slow operations +// to help identify performance bottlenecks in large clusters. +func (b *BalanceChecker) Check(ctx context.Context) []task.Task { + // Skip balance operations if the checker is not active + if !b.IsActive() { + return nil + } + + // Performance monitoring: track execution time + start := time.Now() + defer func() { + duration := time.Since(start) + if duration > 100*time.Millisecond { + log.Info("Balance check too slow", zap.Duration("duration", duration)) + } + }() + + // Load current configuration to respect dynamic parameter changes + config := b.loadBalanceConfig() + + // Phase 1: Process stopping balance first (higher priority) + // This handles nodes that are being gracefully stopped and need immediate attention + if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { + generatedSegmentTaskNum, generatedChannelTaskNum := b.processBalanceQueue(ctx, + b.getReplicaForStoppingBalance, + b.constructStoppingBalanceQueue, + func() *balance.PriorityQueue { return b.stoppingBalanceQueue }, + config) + + if generatedSegmentTaskNum > 0 || generatedChannelTaskNum > 0 { + // clean up the normal balance queue when stopping balance generated tasks + // make sure that next time when trigger normal balance, a new normal balance round will be started + b.normalBalanceQueue = nil + + return nil + } + } + + // Phase 2: Process normal balance if no stopping balance was needed + // This handles regular load balancing operations for cluster optimization + if paramtable.Get().QueryCoordCfg.AutoBalance.GetAsBool() { + // Respect the auto balance interval to prevent too frequent operations + if time.Since(b.autoBalanceTs) <= config.autoBalanceInterval { + return nil + } + + generatedSegmentTaskNum, generatedChannelTaskNum := b.processBalanceQueue(ctx, + b.getReplicaForNormalBalance, + b.constructNormalBalanceQueue, + func() *balance.PriorityQueue { return b.normalBalanceQueue }, + config) + + // Submit normal balance tasks if any were generated + // Update the auto balance timestamp to enforce the interval + if generatedSegmentTaskNum > 0 || generatedChannelTaskNum > 0 { + b.autoBalanceTs = time.Now() + + // clean up the stopping balance queue when normal balance generated tasks + // make sure that next time when trigger stopping balance, a new stopping balance round will be started + b.stoppingBalanceQueue = nil + } + } + + // Always return nil as tasks are submitted directly to scheduler return nil } - -func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 { - sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) - if sortOrder == "" { - sortOrder = "byrowcount" // Default to ByRowCount - } - - collectionRowCountMap := make(map[int64]int64) - for _, cid := range collections { - collectionRowCountMap[cid] = b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst) - } - - // Define sorting functions - sortByRowCount := func(i, j int) bool { - rowCount1 := collectionRowCountMap[collections[i]] - rowCount2 := collectionRowCountMap[collections[j]] - return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j]) - } - - sortByCollectionID := func(i, j int) bool { - return collections[i] < collections[j] - } - - // Select the appropriate sorting function - var sortFunc func(i, j int) bool - switch sortOrder { - case "byrowcount": - sortFunc = sortByRowCount - case "bycollectionid": - sortFunc = sortByCollectionID - default: - log.Warn("Invalid balance sort order configuration, using default ByRowCount", zap.String("sortOrder", sortOrder)) - sortFunc = sortByRowCount - } - - // Sort the collections - sort.Slice(collections, sortFunc) - return collections -} diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 280e9ea22b..cdda585f08 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -21,993 +21,963 @@ import ( "testing" "time" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" - "go.uber.org/atomic" + "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/querycoordv2/balance" "github.com/milvus-io/milvus/internal/querycoordv2/meta" - . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" - "github.com/milvus-io/milvus/pkg/v2/kv" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/querypb" - "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) -type BalanceCheckerTestSuite struct { - suite.Suite - kv kv.MetaKv - checker *BalanceChecker - balancer *balance.MockBalancer - meta *meta.Meta - broker *meta.MockBroker - nodeMgr *session.NodeManager - scheduler *task.MockScheduler - targetMgr meta.TargetManagerInterface +// createMockPriorityQueue creates a mock priority queue for testing +func createMockPriorityQueue() *balance.PriorityQueue { + return balance.NewPriorityQueuePtr() } -func (suite *BalanceCheckerTestSuite) SetupSuite() { - paramtable.Init() +// Helper function to create a test BalanceChecker +func createTestBalanceChecker() *BalanceChecker { + metaInstance := &meta.Meta{ + CollectionManager: meta.NewCollectionManager(nil), + } + targetMgr := meta.NewTargetManager(nil, nil) + nodeMgr := &session.NodeManager{} + scheduler := task.NewScheduler(context.Background(), nil, nil, nil, nil, nil, nil) + balancer := balance.NewScoreBasedBalancer(nil, nil, nil, nil, nil) + getBalancerFunc := func() balance.Balance { return balancer } + + return NewBalanceChecker(metaInstance, targetMgr, nodeMgr, scheduler, getBalancerFunc) } -func (suite *BalanceCheckerTestSuite) SetupTest() { - var err error - config := GenerateEtcdConfig() - cli, err := etcd.GetEtcdClient( - config.UseEmbedEtcd.GetAsBool(), - config.EtcdUseSSL.GetAsBool(), - config.Endpoints.GetAsStrings(), - config.EtcdTLSCert.GetValue(), - config.EtcdTLSKey.GetValue(), - config.EtcdTLSCACert.GetValue(), - config.EtcdTLSMinVersion.GetValue()) - suite.Require().NoError(err) - suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) +// ============================================================================= +// Basic Interface Tests +// ============================================================================= - // meta - store := querycoord.NewCatalog(suite.kv) - idAllocator := RandomIncrementIDAllocator() - suite.nodeMgr = session.NewNodeManager() - suite.meta = meta.NewMeta(idAllocator, store, suite.nodeMgr) - suite.broker = meta.NewMockBroker(suite.T()) - suite.scheduler = task.NewMockScheduler(suite.T()) - suite.scheduler.EXPECT().Add(mock.Anything).Return(nil).Maybe() - suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) +func TestBalanceChecker_ID(t *testing.T) { + checker := createTestBalanceChecker() - suite.balancer = balance.NewMockBalancer(suite.T()) - suite.checker = NewBalanceChecker(suite.meta, suite.targetMgr, suite.nodeMgr, suite.scheduler, func() balance.Balance { return suite.balancer }) + id := checker.ID() + assert.Equal(t, utils.BalanceChecker, id) } -func (suite *BalanceCheckerTestSuite) TearDownTest() { - suite.kv.Close() +func TestBalanceChecker_Description(t *testing.T) { + checker := createTestBalanceChecker() + + desc := checker.Description() + assert.Equal(t, "BalanceChecker checks the cluster distribution and generates balance tasks", desc) } -func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { +// ============================================================================= +// Configuration Tests +// ============================================================================= + +func TestBalanceChecker_LoadBalanceConfig(t *testing.T) { + checker := createTestBalanceChecker() + + // Mock paramtable.Get function + mockParamGet := mockey.Mock(paramtable.Get).Return(¶mtable.ComponentParam{}).Build() + defer mockParamGet.UnPatch() + + // Mock various param item methods + mockGetAsInt := mockey.Mock((*paramtable.ParamItem).GetAsInt).Return(5).Build() + defer mockGetAsInt.UnPatch() + + mockGetAsBool := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(true).Build() + defer mockGetAsBool.UnPatch() + + mockGetAsDuration := mockey.Mock((*paramtable.ParamItem).GetAsDuration).Return(5 * time.Second).Build() + defer mockGetAsDuration.UnPatch() + + config := checker.loadBalanceConfig() + + // Verify config structure is returned + assert.IsType(t, balanceConfig{}, config) +} + +// ============================================================================= +// Collection Balance Item Tests +// ============================================================================= + +func TestNewCollectionBalanceItem(t *testing.T) { + collectionID := int64(100) + rowCount := 1000 + sortOrder := "byrowcount" + + item := newCollectionBalanceItem(collectionID, rowCount, sortOrder) + + assert.Equal(t, collectionID, item.collectionID) + assert.Equal(t, rowCount, item.rowCount) + assert.Equal(t, sortOrder, item.sortOrder) +} + +func TestCollectionBalanceItem_GetPriority_ByRowCount(t *testing.T) { + tests := []struct { + name string + rowCount int + sortOrder string + expected int + }{ + {"ByRowCount", 1000, "byrowcount", -1000}, + {"Default", 500, "", -500}, // default to byrowcount + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + item := newCollectionBalanceItem(1, tt.rowCount, tt.sortOrder) + + priority := item.getPriority() + assert.Equal(t, tt.expected, priority) + }) + } +} + +func TestCollectionBalanceItem_GetPriority_ByCollectionID(t *testing.T) { + collectionID := int64(123) + item := newCollectionBalanceItem(collectionID, 1000, "bycollectionid") + + priority := item.getPriority() + assert.Equal(t, int(collectionID), priority) +} + +func TestCollectionBalanceItem_SetPriority(t *testing.T) { + item := newCollectionBalanceItem(1, 100, "byrowcount") + + item.setPriority(200) + + assert.Equal(t, 200, item.getPriority()) +} + +// ============================================================================= +// Collection Filtering Tests +// ============================================================================= + +func TestBalanceChecker_ReadyToCheck_Success(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // set up nodes info - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) + collectionID := int64(1) - // set collections meta - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, - } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, - } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) + // Mock meta.GetCollection to return a valid collection + mockGetCollection := mockey.Mock(mockey.GetMethod(checker.meta.CollectionManager, "GetCollection")).Return(&meta.Collection{}).Build() + defer mockGetCollection.UnPatch() - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + // Mock target manager methods + mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(true).Build() + defer mockIsNextTargetExist.UnPatch() - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) - - // test disable auto balance - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false") - suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { - return 0 - }) - replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicasToBalance) - segPlans, _ := suite.checker.balanceReplicas(ctx, replicasToBalance) - suite.Empty(segPlans) - - // test enable auto balance - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - idsToBalance := []int64{int64(replicaID1)} - replicasToBalance = suite.checker.getReplicaForNormalBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - // next round - idsToBalance = []int64{int64(replicaID2)} - replicasToBalance = suite.checker.getReplicaForNormalBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - // final round - replicasToBalance = suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicasToBalance) + result := checker.readyToCheck(ctx, collectionID) + assert.True(t, result) } -func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { +func TestBalanceChecker_ReadyToCheck_NoMeta(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // set up nodes info - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) + collectionID := int64(1) - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, - } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, - } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) + // Mock meta.GetCollection to return nil + mockGetCollection := mockey.Mock((*meta.Meta).GetCollection).Return(nil).Build() + defer mockGetCollection.UnPatch() - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + // Mock target manager methods to return false + mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(false).Build() + defer mockIsNextTargetExist.UnPatch() - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) + mockIsCurrentTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsCurrentTargetExist")).Return(false).Build() + defer mockIsCurrentTargetExist.UnPatch() - // test scheduler busy - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { - return 1 - }) - replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) - suite.Len(replicasToBalance, 1) + result := checker.readyToCheck(ctx, collectionID) + assert.False(t, result) } -func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { +func TestBalanceChecker_ReadyToCheck_NoTarget(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // set up nodes info, stopping node1 - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Stopping(int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) + collectionID := int64(1) - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, + // Mock meta.GetCollection to return a valid collection + mockGetCollection := mockey.Mock((*meta.Meta).GetCollection).Return(&meta.Collection{}).Build() + defer mockGetCollection.UnPatch() + + // Mock target manager methods to return false + mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(false).Build() + defer mockIsNextTargetExist.UnPatch() + + mockIsCurrentTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsCurrentTargetExist")).Return(false).Build() + defer mockIsCurrentTargetExist.UnPatch() + + result := checker.readyToCheck(ctx, collectionID) + assert.False(t, result) +} + +func TestBalanceChecker_FilterCollectionForBalance_Success(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock meta.GetAll to return collection IDs + collectionIDs := []int64{1, 2, 3} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() + + // Create filters that pass all collections + passAllFilter := func(ctx context.Context, collectionID int64) bool { + return true } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, + + result := checker.filterCollectionForBalance(ctx, passAllFilter) + assert.Equal(t, collectionIDs, result) +} + +func TestBalanceChecker_FilterCollectionForBalance_WithFiltering(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock meta.GetAll to return collection IDs + collectionIDs := []int64{1, 2, 3, 4} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() + + // Create filters: only even numbers pass first filter, only > 2 pass second filter + evenFilter := func(ctx context.Context, collectionID int64) bool { + return collectionID%2 == 0 } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + greaterThanTwoFilter := func(ctx context.Context, collectionID int64) bool { + return collectionID > 2 + } - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) + result := checker.filterCollectionForBalance(ctx, evenFilter, greaterThanTwoFilter) + // Only collection 4 should pass both filters (even AND > 2) + assert.Equal(t, []int64{4}, result) +} - mr1 := replica1.CopyForWrite() - mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) +func TestBalanceChecker_FilterCollectionForBalance_EmptyResult(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() - mr2 := replica2.CopyForWrite() - mr2.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + // Mock meta.GetAll to return collection IDs + collectionIDs := []int64{1, 2, 3} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() - // test stopping balance - // First round: check replica1 - idsToBalance := []int64{int64(replicaID1)} - replicasToBalance := suite.checker.getReplicaForStoppingBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid1))) - suite.False(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid2))) + // Create filter that rejects all + rejectAllFilter := func(ctx context.Context, collectionID int64) bool { + return false + } - // Second round: should skip replica1, check replica2 - idsToBalance = []int64{int64(replicaID2)} - replicasToBalance = suite.checker.getReplicaForStoppingBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid1))) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid2))) + result := checker.filterCollectionForBalance(ctx, rejectAllFilter) + assert.Empty(t, result) +} - // Third round: all collections checked, should return nil and clear the set - replicasToBalance = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Empty(replicasToBalance) - suite.False(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid1))) - suite.False(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(int64(cid2))) +// ============================================================================= +// Queue Construction Tests +// ============================================================================= - // reset meta for Check test - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - mr1 = replica1.CopyForWrite() - mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) +func TestBalanceChecker_ConstructStoppingBalanceQueue(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() - // checker check - segPlans, chanPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0) - mockPlan := balance.SegmentAssignPlan{ - Segment: utils.CreateTestSegment(1, 1, 1, 1, 1, "1"), - Replica: meta.NilReplica, + // Mock filterCollectionForBalance result + collectionIDs := []int64{1, 2} + mockFilterCollections := mockey.Mock((*BalanceChecker).filterCollectionForBalance).Return(collectionIDs).Build() + defer mockFilterCollections.UnPatch() + + // Mock target manager GetCollectionRowCount + mockGetRowCount := mockey.Mock(mockey.GetMethod(checker.targetMgr, "GetCollectionRowCount")).Return(int64(100)).Build() + defer mockGetRowCount.UnPatch() + + // Mock paramtable for sort order + mockParamValue := mockey.Mock((*paramtable.ParamItem).GetValue).Return("byrowcount").Build() + defer mockParamValue.UnPatch() + + result := checker.constructStoppingBalanceQueue(ctx) + assert.Equal(t, result.Len(), 2) +} + +func TestBalanceChecker_ConstructNormalBalanceQueue(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock filterCollectionForBalance result + collectionIDs := []int64{1, 2} + mockFilterCollections := mockey.Mock((*BalanceChecker).filterCollectionForBalance).Return(collectionIDs).Build() + defer mockFilterCollections.UnPatch() + + // Mock target manager GetCollectionRowCount + mockGetRowCount := mockey.Mock(mockey.GetMethod(checker.targetMgr, "GetCollectionRowCount")).Return(int64(100)).Build() + defer mockGetRowCount.UnPatch() + + // Mock paramtable for sort order + mockParamValue := mockey.Mock((*paramtable.ParamItem).GetValue).Return("byrowcount").Build() + defer mockParamValue.UnPatch() + + result := checker.constructNormalBalanceQueue(ctx) + assert.Equal(t, result.Len(), 2) +} + +// ============================================================================= +// Replica Getting Tests +// ============================================================================= + +func TestBalanceChecker_GetReplicaForStoppingBalance_WithRONodes(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replica2 := &meta.Replica{} + replicas := []*meta.Replica{replica1, replica2} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica methods - replica1 has RO nodes, replica2 doesn't + mockRONodesCount1 := mockey.Mock((*meta.Replica).RONodesCount).Return(1).Build() + defer mockRONodesCount1.UnPatch() + + mockROSQNodesCount1 := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build() + defer mockROSQNodesCount1.UnPatch() + + mockGetID1 := mockey.Mock((*meta.Replica).GetID).Return(int64(101)).Build() + defer mockGetID1.UnPatch() + + // Skip streaming service mock for simplicity + + result := checker.getReplicaForStoppingBalance(ctx, collectionID) + // Should return replica1 ID since it has RO nodes + assert.Contains(t, result, int64(101)) +} + +func TestBalanceChecker_GetReplicaForStoppingBalance_NoRONodes(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replicas := []*meta.Replica{replica1} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica methods - no RO nodes + mockRONodesCount := mockey.Mock((*meta.Replica).RONodesCount).Return(0).Build() + defer mockRONodesCount.UnPatch() + + mockROSQNodesCount := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build() + defer mockROSQNodesCount.UnPatch() + + // Skip streaming service mock for simplicity + + result := checker.getReplicaForStoppingBalance(ctx, collectionID) + assert.Empty(t, result) +} + +func TestBalanceChecker_GetReplicaForNormalBalance(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replica2 := &meta.Replica{} + replicas := []*meta.Replica{replica1, replica2} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica GetID methods + mockGetID := mockey.Mock((*meta.Replica).GetID).Return(mockey.Sequence(101).Times(1).Then(102)).Build() + defer mockGetID.UnPatch() + + result := checker.getReplicaForNormalBalance(ctx, collectionID) + expectedIDs := []int64{101, 102} + assert.ElementsMatch(t, expectedIDs, result) +} + +// ============================================================================= +// Task Generation Tests +// ============================================================================= + +func TestBalanceChecker_GenerateBalanceTasksFromReplicas_EmptyReplicas(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + config := balanceConfig{} + + segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, []int64{}, config) + + assert.Empty(t, segmentTasks) + assert.Empty(t, channelTasks) +} + +func TestBalanceChecker_GenerateBalanceTasksFromReplicas_Success(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + config := balanceConfig{ + segmentTaskTimeout: 30 * time.Second, + channelTaskTimeout: 30 * time.Second, + } + replicaIDs := []int64{101} + + // Create mock replica + mockReplica := &meta.Replica{} + + // Mock ReplicaManager.Get + mockReplicaGet := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "Get")).Return(mockReplica).Build() + defer mockReplicaGet.UnPatch() + + // Create mock balance plans + segmentPlan := balance.SegmentAssignPlan{ + Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1}}, + Replica: mockReplica, From: 1, To: 2, } - segPlans = append(segPlans, mockPlan) - suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) - - tasks := make([]task.Task, 0) - suite.scheduler.ExpectedCalls = nil - suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(task task.Task) error { - tasks = append(tasks, task) - return nil - }) - suite.checker.Check(context.TODO()) - suite.Len(tasks, 2) -} - -func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { - ctx := context.Background() - // set up nodes info, stopping node1 - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Stopping(nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - mockTarget := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTarget - - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - - // test normal balance when one collection has unready target - mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true) - mockTarget.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false) - mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() - replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) - suite.Len(replicasToBalance, 0) - - // test stopping balance with target not ready - mockTarget.ExpectedCalls = nil - mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(false) - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true).Maybe() - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false).Maybe() - mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() - mr1 := replica1.CopyForWrite() - mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - mr2 := replica2.CopyForWrite() - mr2.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) - - idsToBalance := []int64{int64(replicaID1)} - replicasToBalance = suite.checker.getReplicaForStoppingBalance(ctx) - suite.ElementsMatch(idsToBalance, replicasToBalance) -} - -func (suite *BalanceCheckerTestSuite) TestAutoBalanceInterval() { - ctx := context.Background() - // set up nodes info - nodeID1, nodeID2 := 1, 2 - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID1), - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: int64(nodeID2), - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) - - segments := []*datapb.SegmentInfo{ - { - ID: 1, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, - { - ID: 2, - PartitionID: 1, - InsertChannel: "test-insert-channel", - }, + channelPlan := balance.ChannelAssignPlan{ + Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{ChannelName: "test"}}, + Replica: mockReplica, + From: 1, + To: 2, } - channels := []*datapb.VchannelInfo{ - { - CollectionID: 1, - ChannelName: "test-insert-channel", - }, + + mockBalancer := mockey.Mock(checker.getBalancerFunc).To(func() balance.Balance { + return balance.NewScoreBasedBalancer(nil, nil, nil, nil, nil) + }).Build() + defer mockBalancer.UnPatch() + + // Mock balancer.BalanceReplica + mockBalanceReplica := mockey.Mock((*balance.ScoreBasedBalancer).BalanceReplica).Return( + []balance.SegmentAssignPlan{segmentPlan}, + []balance.ChannelAssignPlan{channelPlan}, + ).Build() + defer mockBalanceReplica.UnPatch() + + // Mock balance.CreateSegmentTasksFromPlans + mockSegmentTask := &task.SegmentTask{} + mockCreateSegmentTasks := mockey.Mock(balance.CreateSegmentTasksFromPlans).Return([]task.Task{mockSegmentTask}).Build() + defer mockCreateSegmentTasks.UnPatch() + + // Mock balance.CreateChannelTasksFromPlans + mockChannelTask := &task.ChannelTask{} + mockCreateChannelTasks := mockey.Mock(balance.CreateChannelTasksFromPlans).Return([]task.Task{mockChannelTask}).Build() + defer mockCreateChannelTasks.UnPatch() + + // Mock task.SetPriority and task.SetReason + mockSetPriority := mockey.Mock(task.SetPriority).Return().Build() + defer mockSetPriority.UnPatch() + + mockSetReason := mockey.Mock(task.SetReason).Return().Build() + defer mockSetReason.UnPatch() + + // Mock balance.PrintNewBalancePlans + mockPrintPlans := mockey.Mock(balance.PrintNewBalancePlans).Return().Build() + defer mockPrintPlans.UnPatch() + + segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, replicaIDs, config) + + assert.Len(t, segmentTasks, 1) + assert.Len(t, channelTasks, 1) + assert.Equal(t, mockSegmentTask, segmentTasks[0]) + assert.Equal(t, mockChannelTask, channelTasks[0]) +} + +func TestBalanceChecker_GenerateBalanceTasksFromReplicas_NilReplica(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + config := balanceConfig{} + replicaIDs := []int64{101} + + // Mock ReplicaManager.Get to return nil + mockReplicaGet := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "Get")).Return(nil).Build() + defer mockReplicaGet.UnPatch() + + segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, replicaIDs, config) + + assert.Empty(t, segmentTasks) + assert.Empty(t, channelTasks) +} + +// ============================================================================= +// Task Submission Tests +// ============================================================================= + +func TestBalanceChecker_SubmitTasks(t *testing.T) { + checker := createTestBalanceChecker() + + // Create mock tasks + segmentTask := &task.SegmentTask{} + channelTask := &task.ChannelTask{} + segmentTasks := []task.Task{segmentTask} + channelTasks := []task.Task{channelTask} + + // Mock scheduler.Add + mockSchedulerAdd := mockey.Mock(mockey.GetMethod(checker.scheduler, "Add")).Return(nil).Build() + defer mockSchedulerAdd.UnPatch() + + checker.submitTasks(segmentTasks, channelTasks) + + // Verify scheduler.Add was called for both tasks + // This is implicit verification through mockey call tracking +} + +func TestBalanceChecker_SubmitTasks_EmptyTasks(t *testing.T) { + checker := createTestBalanceChecker() + + // Mock scheduler.Add - should not be called + mockSchedulerAdd := mockey.Mock(mockey.GetMethod(checker.scheduler, "Add")).Return(nil).Build() + defer mockSchedulerAdd.UnPatch() + + checker.submitTasks([]task.Task{}, []task.Task{}) + + // No assertions needed - just ensuring no panic with empty tasks +} + +// ============================================================================= +// Main Check Method Tests +// ============================================================================= + +func TestBalanceChecker_Check_InactiveChecker(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock IsActive to return false + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(false).Build() + defer mockIsActive.UnPatch() + + result := checker.Check(ctx) + assert.Nil(t, result) +} + +func TestBalanceChecker_Check_StoppingBalanceEnabled(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock IsActive to return true + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build() + defer mockIsActive.UnPatch() + + // Mock paramtable for enabling stopping balance + mockParamGet := mockey.Mock(paramtable.Get).Return(¶mtable.ComponentParam{}).Build() + defer mockParamGet.UnPatch() + + mockStoppingBalanceEnabled := mockey.Mock(mockey.GetMethod(¶mtable.ParamItem{}, "GetAsBool")).Return(true).Build() + defer mockStoppingBalanceEnabled.UnPatch() + + // Mock loadBalanceConfig + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 5, } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(channels, segments, nil) + mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build() + defer mockLoadConfig.UnPatch() - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) + // Mock processBalanceQueue to return tasks + mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).Return( + 1, 0, // segment tasks, channel tasks + ).Build() + defer mockProcessQueue.UnPatch() - funcCallCounter := atomic.NewInt64(0) - suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { - funcCallCounter.Inc() - return nil, nil - }) - - // first auto balance should be triggered - suite.checker.Check(ctx) - suite.Equal(funcCallCounter.Load(), int64(1)) - - // second auto balance won't be triggered due to autoBalanceInterval == 3s - suite.checker.Check(ctx) - suite.Equal(funcCallCounter.Load(), int64(1)) - - // set autoBalanceInterval == 1, sleep 1s, auto balance should be triggered - paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "1000") - paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key) - time.Sleep(1 * time.Second) - suite.checker.Check(ctx) - suite.Equal(funcCallCounter.Load(), int64(1)) + result := checker.Check(ctx) + assert.Nil(t, result) // Always returns nil as tasks are submitted directly + assert.Nil(t, checker.normalBalanceQueue) } -func (suite *BalanceCheckerTestSuite) TestBalanceOrder() { +func TestBalanceChecker_Check_NormalBalanceEnabled(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - nodeID1, nodeID2 := int64(1), int64(2) - // set collections meta - cid1, replicaID1, partitionID1 := 1, 1, 1 - collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) - partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) + // Set autoBalanceTs to allow normal balance + checker.autoBalanceTs = time.Time{} - cid2, replicaID2, partitionID2 := 2, 2, 2 - collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) - partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) + // Mock IsActive to return true + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build() + defer mockIsActive.UnPatch() - // mock collection row count - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(200)).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + // Mock paramtable - stopping balance disabled, auto balance enabled + mockParamGet := mockey.Mock(paramtable.Get).Return(¶mtable.ComponentParam{}).Build() + defer mockParamGet.UnPatch() - // mock stopping node - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - mr2 := replica2.CopyForWrite() - mr2.AddRONode(nodeID2) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + // return false for stopping balance enabled, true for auto balance enabled + mockParams := mockey.Mock(mockey.GetMethod(¶mtable.ParamItem{}, "GetAsBool")).Return(mockey.Sequence(false).Times(1).Then(true)).Build() + defer mockParams.UnPatch() - // test normal balance order - replicas := suite.checker.getReplicaForNormalBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID2)}) + // Mock loadBalanceConfig + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 5, + autoBalanceInterval: 1 * time.Second, + } + mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build() + defer mockLoadConfig.UnPatch() - // test stopping balance order - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID2)}) + // Mock processBalanceQueue to return tasks + mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).Return( + 0, 1, // segment tasks, channel tasks + ).Build() + defer mockProcessQueue.UnPatch() - // mock collection row count - mockTargetManager.ExpectedCalls = nil - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(200)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - - // test normal balance order - replicas = suite.checker.getReplicaForNormalBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID1)}) - - // test stopping balance order - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Equal(replicas, []int64{int64(replicaID1)}) + result := checker.Check(ctx) + assert.Nil(t, result) // Always returns nil as tasks are submitted directly } -func (suite *BalanceCheckerTestSuite) TestSortCollections() { +// ============================================================================= +// ProcessBalanceQueue Tests +// ============================================================================= + +func TestBalanceChecker_ProcessBalanceQueue_Success(t *testing.T) { + checker := createTestBalanceChecker() ctx := context.Background() - // Set up test collections - cid1, cid2, cid3 := int64(1), int64(2), int64(3) + // Create mock balance config + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 3, + maxCheckCollectionCount: 5, + balanceOnMultipleCollections: true, + } - // Mock the target manager for row count returns - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager + // Create mock priority queue + mockQueue := createMockPriorityQueue() + // Use real priority queue for simplicity - // Collection 1: Low ID, High row count - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() + mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(2, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(3, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(4, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(5, 100, "byrowcount")) - // Collection 2: Middle ID, Low row count - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() + // Mock getReplicasFunc + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101, 102} + } - // Collection 3: High ID, Middle row count - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() + // Mock constructQueueFunc + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } - collections := []int64{cid1, cid2, cid3} + // Mock getQueueFunc + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } - // Test ByRowCount sorting (default) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") - sortedCollections := suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Collections should be sorted by row count (highest first)") + // Mock generateBalanceTasksFromReplicas + mockSegmentTask := &task.SegmentTask{} + mockChannelTask := &task.ChannelTask{} + mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return( + []task.Task{mockSegmentTask}, []task.Task{mockChannelTask}, + ).Build() + defer mockGenerateTasks.UnPatch() - // Test ByCollectionID sorting - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Collections should be sorted by collection ID (ascending)") + // mock submit tasks + mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build() + defer mockSubmitTasks.UnPatch() - // Test with empty sort order (should default to ByRowCount) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is empty") + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) - // Test with invalid sort order (should default to ByRowCount) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "InvalidOrder") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is invalid") + assert.Equal(t, 3, segmentTasks) + assert.Equal(t, 3, channelTasks) +} - // Test with mixed case (should be case-insensitive) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "bYcOlLeCtIoNiD") - sortedCollections = suite.checker.sortCollections(ctx, collections) - suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Should handle case-insensitive sort order names") +func TestBalanceChecker_ProcessBalanceQueue_EmptyQueue(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 3, + } + + // Create empty mock priority queue + mockQueue := createMockPriorityQueue() + // Use real priority queue for empty queue testing + + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101} + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + assert.Equal(t, 0, segmentTasks) + assert.Equal(t, 0, channelTasks) +} + +func TestBalanceChecker_ProcessBalanceQueue_BatchSizeLimit(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Set small batch sizes to test limits + config := balanceConfig{ + segmentBatchSize: 1, // Only allow 1 segment task + channelBatchSize: 1, // Only allow 1 channel task + maxCheckCollectionCount: 10, + balanceOnMultipleCollections: true, + } + + // Test batch size limits with simplified logic + + // Create mock priority queue with 2 items + mockQueue := createMockPriorityQueue() + // Use real priority queue for batch size testing + + mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount")) + mockQueue.Push(newCollectionBalanceItem(2, 100, "byrowcount")) + + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101} + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + // Mock generateBalanceTasksFromReplicas to return multiple tasks + mockSegmentTask1 := &task.SegmentTask{} + mockSegmentTask2 := &task.SegmentTask{} + mockChannelTask1 := &task.ChannelTask{} + mockChannelTask2 := &task.ChannelTask{} + + mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return(mockey.Sequence( + []task.Task{mockSegmentTask1}, []task.Task{mockChannelTask1}, + ).Times(1).Then( + []task.Task{mockSegmentTask2}, []task.Task{mockChannelTask2}, + )).Build() + defer mockGenerateTasks.UnPatch() + + // mock submit tasks + mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build() + defer mockSubmitTasks.UnPatch() + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + // Should stop after first collection due to batch size limits + assert.Equal(t, 1, segmentTasks) + assert.Equal(t, 1, channelTasks) +} + +func TestBalanceChecker_ProcessBalanceQueue_MultiCollectionDisabled(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + config := balanceConfig{ + segmentBatchSize: 10, + channelBatchSize: 10, + maxCheckCollectionCount: 10, + balanceOnMultipleCollections: false, // Disabled + } + + mockQueue := createMockPriorityQueue() + // Use real priority queue for multi-collection testing + + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{101} + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount")) + + // Mock generateBalanceTasksFromReplicas to return tasks + mockSegmentTask := &task.SegmentTask{} + mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return( + []task.Task{mockSegmentTask}, []task.Task{}, + ).Build() + defer mockGenerateTasks.UnPatch() + + // mock submit tasks + mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build() + defer mockSubmitTasks.UnPatch() + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + // Should stop after first collection due to multi-collection disabled + assert.Equal(t, 1, segmentTasks) + assert.Equal(t, 0, channelTasks) +} + +func TestBalanceChecker_ProcessBalanceQueue_NoReplicasToBalance(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + config := balanceConfig{ + segmentBatchSize: 5, + channelBatchSize: 5, + maxCheckCollectionCount: 5, + balanceOnMultipleCollections: true, + } + + mockQueue := createMockPriorityQueue() + // Use real priority queue for simplicity + + // getReplicasFunc returns empty slice + getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 { + return []int64{} // No replicas + } + + constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue { + return mockQueue + } + + getQueueFunc := func() *balance.PriorityQueue { + return mockQueue + } + + segmentTasks, channelTasks := checker.processBalanceQueue( + ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config, + ) + + assert.Equal(t, 0, segmentTasks) + assert.Equal(t, 0, channelTasks) +} + +// ============================================================================= +// Performance and Edge Case Tests +// ============================================================================= + +func TestBalanceChecker_CollectionBalanceItem_EdgeCases(t *testing.T) { + // Test with zero row count + item := newCollectionBalanceItem(1, 0, "byrowcount") + assert.Equal(t, 0, item.getPriority()) + + // Test with negative collection ID + item = newCollectionBalanceItem(-1, 100, "bycollectionid") + assert.Equal(t, -1, item.getPriority()) + + // Test with very large values + item = newCollectionBalanceItem(9223372036854775807, 2147483647, "byrowcount") + assert.Equal(t, -2147483647, item.getPriority()) + + // Test with empty sort order (should default to byrowcount) + item = newCollectionBalanceItem(5, 100, "") + assert.Equal(t, -100, item.getPriority()) + + // Test with invalid sort order (should default to byrowcount) + item = newCollectionBalanceItem(5, 100, "invalid") + assert.Equal(t, -100, item.getPriority()) +} + +func TestBalanceChecker_FilterCollectionForBalance_EdgeCases(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() // Test with empty collection list - emptyCollections := []int64{} - sortedCollections = suite.checker.sortCollections(ctx, emptyCollections) - suite.Equal([]int64{}, sortedCollections, "Should handle empty collection list") -} + mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return([]int64{}).Build() + defer mockGetAll.UnPatch() -func (suite *BalanceCheckerTestSuite) TestSortCollectionsIntegration() { - ctx := context.Background() - - // Set up test collections and nodes - nodeID1 := int64(1) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - - // Create two collections to ensure sorting is triggered - cid1, replicaID1 := int64(1), int64(101) - collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - // Add a second collection with different characteristics - cid2, replicaID2 := int64(2), int64(102) - collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - - // Mock target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - - // Setup different row counts to test sorting - // Collection 1 has more rows than Collection 2 - var getRowCountCallCount int - mockTargetManager.On("GetCollectionRowCount", mock.Anything, mock.Anything, mock.Anything). - Return(func(ctx context.Context, collectionID int64, scope int32) int64 { - getRowCountCallCount++ - if collectionID == cid1 { - return 200 // More rows in collection 1 - } - return 100 // Fewer rows in collection 2 - }) - - mockTargetManager.On("IsCurrentTargetReady", mock.Anything, mock.Anything).Return(true) - mockTargetManager.On("IsNextTargetExist", mock.Anything, mock.Anything).Return(true) - - // Configure for testing - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") - - // Clear first to avoid previous test state - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - - // Call normal balance - _ = suite.checker.getReplicaForNormalBalance(ctx) - - // Verify GetCollectionRowCount was called at least twice (once for each collection) - // This confirms that the collections were sorted - suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during normal balance") - - // Reset counter and test stopping balance - getRowCountCallCount = 0 - - // Set up for stopping balance test - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - mr2 := replica2.CopyForWrite() - mr2.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) - - paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") - - // Call stopping balance - _ = suite.checker.getReplicaForStoppingBalance(ctx) - - // Verify GetCollectionRowCount was called at least twice during stopping balance - suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during stopping balance") -} - -func (suite *BalanceCheckerTestSuite) TestBalanceTriggerOrder() { - ctx := context.Background() - - // Set up nodes - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - // Create collections with different row counts - cid1, replicaID1 := int64(1), int64(101) - collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - cid2, replicaID2 := int64(2), int64(102) - collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) - collection2.Status = querypb.LoadStatus_Loaded - replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) - suite.checker.meta.ReplicaManager.Put(ctx, replica2) - - cid3, replicaID3 := int64(3), int64(103) - collection3 := utils.CreateTestCollection(cid3, int32(replicaID3)) - collection3.Status = querypb.LoadStatus_Loaded - replica3 := utils.CreateTestReplica(replicaID3, cid3, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection3) - suite.checker.meta.ReplicaManager.Put(ctx, replica3) - - // Mock the target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - - // Set row counts: Collection 1 (highest), Collection 3 (middle), Collection 2 (lowest) - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() - - // Mark the current target as ready - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() - - // Enable auto balance - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - - // Test with ByRowCount order (default) - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - - // Normal balance should pick the collection with highest row count first - replicas := suite.checker.getReplicaForNormalBalance(ctx) - suite.Contains(replicas, replicaID1, "Should balance collection with highest row count first") - - // Add stopping nodes to test stopping balance - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - mr2 := replica2.CopyForWrite() - mr2.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) - - mr3 := replica3.CopyForWrite() - mr3.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr3.IntoReplica()) - - // Enable stopping balance - paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") - - // Stopping balance should also pick the collection with highest row count first - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with highest row count") - - // Test with ByCollectionID order - paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - - // Normal balance should pick the collection with lowest ID first - replicas = suite.checker.getReplicaForNormalBalance(ctx) - suite.Contains(replicas, replicaID1, "Should balance collection with lowest ID first") - - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - // Stopping balance should also pick the collection with lowest ID first - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with lowest ID") -} - -func (suite *BalanceCheckerTestSuite) TestHasUnbalancedCollectionFlag() { - ctx := context.Background() - - // Set up nodes - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - // Create collection - cid1, replicaID1 := int64(1), int64(101) - collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) - collection1.Status = querypb.LoadStatus_Loaded - replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2}) - suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) - suite.checker.meta.ReplicaManager.Put(ctx, replica1) - - // Mock the target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe() - - // 1. Test normal balance with auto balance disabled - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false") - - // The collections set should be initially empty - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return nil and keep the set empty - replicas := suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicas) - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len(), - "normalBalanceCollectionsCurrentRound should remain empty when auto balance is disabled") - - // 2. Test normal balance when targetMgr.IsCurrentTargetReady returns false - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false).Maybe() - - // The collections set should be initially empty - suite.checker.normalBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return nil and keep the set empty because of not ready targets - replicas = suite.checker.getReplicaForNormalBalance(ctx) - suite.Empty(replicas) - suite.Equal(0, suite.checker.normalBalanceCollectionsCurrentRound.Len(), - "normalBalanceCollectionsCurrentRound should remain empty when targets are not ready") - - // 3. Test stopping balance when there are no RO nodes - paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() - - // The collections set should be initially empty - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.stoppingBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return nil and keep the set empty because there are no RO nodes - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Empty(replicas) - suite.Equal(0, suite.checker.stoppingBalanceCollectionsCurrentRound.Len(), - "stoppingBalanceCollectionsCurrentRound should remain empty when there are no RO nodes") - - // 4. Test stopping balance with RO nodes - // Add a RO node to the replica - mr1 := replica1.CopyForWrite() - mr1.AddRONode(nodeID1) - suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) - - // The collections set should be initially empty - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - suite.Equal(0, suite.checker.stoppingBalanceCollectionsCurrentRound.Len()) - - // Get replicas - should return the replica ID and add the collection to the set - replicas = suite.checker.getReplicaForStoppingBalance(ctx) - suite.Equal([]int64{replicaID1}, replicas) - suite.Equal(1, suite.checker.stoppingBalanceCollectionsCurrentRound.Len()) - suite.True(suite.checker.stoppingBalanceCollectionsCurrentRound.Contain(cid1), - "stoppingBalanceCollectionsCurrentRound should contain the collection when it has RO nodes") -} - -func (suite *BalanceCheckerTestSuite) TestCheckBatchSizesAndMultiCollection() { - ctx := context.Background() - - // Set up nodes - nodeID1, nodeID2 := int64(1), int64(2) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID1, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: nodeID2, - Address: "localhost", - Hostname: "localhost", - })) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) - - // Create 3 collections - for i := 1; i <= 3; i++ { - cid := int64(i) - replicaID := int64(100 + i) - - collection := utils.CreateTestCollection(cid, int32(replicaID)) - collection.Status = querypb.LoadStatus_Loaded - replica := utils.CreateTestReplica(replicaID, cid, []int64{}) - mutableReplica := replica.CopyForWrite() - mutableReplica.AddRWNode(nodeID1) - mutableReplica.AddRONode(nodeID2) - - suite.checker.meta.CollectionManager.PutCollection(ctx, collection) - suite.checker.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica()) + passAllFilter := func(ctx context.Context, collectionID int64) bool { + return true } - // Mock target manager - mockTargetManager := meta.NewMockTargetManager(suite.T()) - suite.checker.targetMgr = mockTargetManager + result := checker.filterCollectionForBalance(ctx, passAllFilter) + assert.Empty(t, result) - // All collections have same row count for simplicity - mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() - mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() + // Test with no filters + collectionIDs := []int64{1, 2, 3} + mockGetAll.UnPatch() + mockGetAll = mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build() + defer mockGetAll.UnPatch() - // For each collection, return different segment plans - suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.AnythingOfType("*meta.Replica")).RunAndReturn( - func(ctx context.Context, replica *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { - // Create 2 segment plans and 1 channel plan per replica - collID := replica.GetCollectionID() - segPlans := make([]balance.SegmentAssignPlan, 0) - chanPlans := make([]balance.ChannelAssignPlan, 0) - - // Create 2 segment plans - for j := 1; j <= 2; j++ { - segID := collID*100 + int64(j) - segPlan := balance.SegmentAssignPlan{ - Segment: utils.CreateTestSegment(segID, collID, 1, 1, 1, "test-channel"), - Replica: replica, - From: nodeID1, - To: nodeID2, - } - segPlans = append(segPlans, segPlan) - } - - // Create 1 channel plan - chanPlan := balance.ChannelAssignPlan{ - Channel: &meta.DmChannel{ - VchannelInfo: &datapb.VchannelInfo{ - CollectionID: collID, - ChannelName: "test-channel", - }, - }, - Replica: replica, - From: nodeID1, - To: nodeID2, - } - chanPlans = append(chanPlans, chanPlan) - - return segPlans, chanPlans - }).Maybe() - - // Add tasks to check batch size limits - var addedTasks []task.Task - suite.scheduler.ExpectedCalls = nil - suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { - addedTasks = append(addedTasks, t) - return nil - }).Maybe() - - // Test 1: Balance with multiple collections disabled - paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") - paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "false") - // Set batch sizes to large values to test single-collection case - paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "10") - paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "10") - - // Reset test state - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - suite.checker.autoBalanceTs = time.Time{} // Reset to trigger auto balance - addedTasks = nil - - // Run the Check method - suite.checker.Check(ctx) - - // Should have tasks for a single collection (2 segment tasks + 1 channel task) - suite.Equal(3, len(addedTasks), "Should have tasks for a single collection when multiple collections balance is disabled") - - // Test 2: Balance with multiple collections enabled - paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "true") - - // Reset test state - suite.checker.autoBalanceTs = time.Time{} - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - addedTasks = nil - - // Run the Check method - suite.checker.Check(ctx) - - // Should have tasks for all collections (3 collections * (2 segment tasks + 1 channel task) = 9 tasks) - suite.Equal(9, len(addedTasks), "Should have tasks for all collections when multiple collections balance is enabled") - - // Test 3: Batch size limits - paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "2") - paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "1") - - // Reset test state - suite.checker.stoppingBalanceCollectionsCurrentRound.Clear() - addedTasks = nil - - // Run the Check method - suite.checker.Check(ctx) - - // Should respect batch size limits: 2 segment tasks + 1 channel task = 3 tasks - suite.Equal(3, len(addedTasks), "Should respect batch size limits") - - // Count segment tasks and channel tasks - segmentTaskCount := 0 - channelTaskCount := 0 - for _, t := range addedTasks { - if _, ok := t.(*task.SegmentTask); ok { - segmentTaskCount++ - } else { - channelTaskCount++ - } - } - - suite.LessOrEqual(segmentTaskCount, 2, "Should have at most 2 segment tasks due to batch size limit") - suite.LessOrEqual(channelTaskCount, 1, "Should have at most 1 channel task due to batch size limit") + result = checker.filterCollectionForBalance(ctx) + assert.Equal(t, collectionIDs, result) // No filters means all pass } -func TestBalanceCheckerSuite(t *testing.T) { - suite.Run(t, new(BalanceCheckerTestSuite)) +// ============================================================================= +// Streaming Service Tests +// ============================================================================= + +func TestBalanceChecker_GetReplicaForStoppingBalance_WithStreamingService(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + collectionID := int64(1) + + // Create mock replicas + replica1 := &meta.Replica{} + replicas := []*meta.Replica{replica1} + + // Mock ReplicaManager.GetByCollection + mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build() + defer mockGetByCollection.UnPatch() + + // Mock replica methods - no RO nodes but has streaming channel RO nodes + mockRONodesCount := mockey.Mock((*meta.Replica).RONodesCount).Return(0).Build() + defer mockRONodesCount.UnPatch() + + mockROSQNodesCount := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build() + defer mockROSQNodesCount.UnPatch() + + mockGetID := mockey.Mock((*meta.Replica).GetID).Return(int64(101)).Build() + defer mockGetID.UnPatch() + + // streaming service mocks for simplicity + mockIsStreamingServiceEnabled := mockey.Mock(streamingutil.IsStreamingServiceEnabled).Return(true).Build() + defer mockIsStreamingServiceEnabled.UnPatch() + mockGetChannelRWAndRONodesFor260 := mockey.Mock(utils.GetChannelRWAndRONodesFor260).Return([]int64{}, []int64{1}).Build() + defer mockGetChannelRWAndRONodesFor260.UnPatch() + + result := checker.getReplicaForStoppingBalance(ctx, collectionID) + // Should return replica1 ID since it has channel RO nodes + assert.Equal(t, []int64{101}, result) +} + +// ============================================================================= +// Error Handling Tests +// ============================================================================= + +func TestBalanceChecker_Check_TimeoutWarning(t *testing.T) { + checker := createTestBalanceChecker() + ctx := context.Background() + + // Mock IsActive to return true + mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build() + defer mockIsActive.UnPatch() + + mockProcessBalanceQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To( + func(ctx context.Context, + getReplicasFunc func(ctx context.Context, collectionID int64) []int64, + constructQueueFunc func(ctx context.Context) *balance.PriorityQueue, + getQueueFunc func() *balance.PriorityQueue, config balanceConfig, + ) (int, int) { + time.Sleep(150 * time.Millisecond) + return 0, 0 + }).Build() + defer mockProcessBalanceQueue.UnPatch() + + start := time.Now() + result := checker.Check(ctx) + duration := time.Since(start) + + assert.Nil(t, result) + assert.Greater(t, duration, 100*time.Millisecond) // Should trigger log } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 2735df9122..a5c3ca3a3e 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2347,6 +2347,8 @@ type queryCoordConfig struct { // query node task parallelism factor QueryNodeTaskParallelismFactor ParamItem `refreshable:"true"` + + BalanceCheckCollectionMaxCount ParamItem `refreshable:"true"` } func (p *queryCoordConfig) init(base *BaseTable) { @@ -2979,6 +2981,15 @@ If this parameter is set false, Milvus simply searches the growing segments with Export: false, } p.QueryNodeTaskParallelismFactor.Init(base.mgr) + + p.BalanceCheckCollectionMaxCount = ParamItem{ + Key: "queryCoord.balanceCheckCollectionMaxCount", + Version: "2.6.2", + DefaultValue: "100", + Doc: "the max collection count for each balance check", + Export: false, + } + p.BalanceCheckCollectionMaxCount.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index fef1f7216f..7b8746d85f 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -393,6 +393,8 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 1, Params.QueryNodeTaskParallelismFactor.GetAsInt()) params.Save("queryCoord.queryNodeTaskParallelismFactor", "2") assert.Equal(t, 2, Params.QueryNodeTaskParallelismFactor.GetAsInt()) + + assert.Equal(t, 100, Params.BalanceCheckCollectionMaxCount.GetAsInt()) }) t.Run("test queryNodeConfig", func(t *testing.T) {