milvus/internal/querycoordv2/checkers/balance_checker_test.go
wei liu 38833b0e1d
fix: Fix deactivate balance checker also stops stopping balance (#44834)
issue: #43858
Fix the issue introduced in PR #43992 where deactivating the balance
checker incorrectly stops stopping balance operations.

Changes:
- Move IsActive() check after stopping balance logic
- Only skip normal balance when checker is inactive
- Allow stopping balance to proceed regardless of checker state

This ensures stopping balance can execute even when the balance checker
is deactivated.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
2025-10-15 15:50:04 +08:00

1249 lines
44 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checkers
import (
"context"
"testing"
"time"
"github.com/bytedance/mockey"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
// createMockPriorityQueue creates a mock priority queue for testing
func createMockPriorityQueue() *balance.PriorityQueue {
return balance.NewPriorityQueuePtr()
}
// Helper function to create a test BalanceChecker
func createTestBalanceChecker() *BalanceChecker {
metaInstance := &meta.Meta{
CollectionManager: meta.NewCollectionManager(nil),
}
targetMgr := meta.NewTargetManager(nil, nil)
nodeMgr := &session.NodeManager{}
scheduler := task.NewScheduler(context.Background(), nil, nil, nil, nil, nil, nil)
balancer := balance.NewScoreBasedBalancer(nil, nil, nil, nil, nil)
getBalancerFunc := func() balance.Balance { return balancer }
return NewBalanceChecker(metaInstance, targetMgr, nodeMgr, scheduler, getBalancerFunc)
}
// =============================================================================
// Basic Interface Tests
// =============================================================================
func TestBalanceChecker_ID(t *testing.T) {
checker := createTestBalanceChecker()
id := checker.ID()
assert.Equal(t, utils.BalanceChecker, id)
}
func TestBalanceChecker_Description(t *testing.T) {
checker := createTestBalanceChecker()
desc := checker.Description()
assert.Equal(t, "BalanceChecker checks the cluster distribution and generates balance tasks", desc)
}
// =============================================================================
// Configuration Tests
// =============================================================================
func TestBalanceChecker_LoadBalanceConfig(t *testing.T) {
checker := createTestBalanceChecker()
// Mock paramtable.Get function
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// Mock various param item methods
mockGetAsInt := mockey.Mock((*paramtable.ParamItem).GetAsInt).Return(5).Build()
defer mockGetAsInt.UnPatch()
mockGetAsBool := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(true).Build()
defer mockGetAsBool.UnPatch()
mockGetAsDuration := mockey.Mock((*paramtable.ParamItem).GetAsDuration).Return(5 * time.Second).Build()
defer mockGetAsDuration.UnPatch()
config := checker.loadBalanceConfig()
// Verify config structure is returned
assert.IsType(t, balanceConfig{}, config)
}
// =============================================================================
// Collection Balance Item Tests
// =============================================================================
func TestNewCollectionBalanceItem(t *testing.T) {
collectionID := int64(100)
rowCount := 1000
sortOrder := "byrowcount"
item := newCollectionBalanceItem(collectionID, rowCount, sortOrder)
assert.Equal(t, collectionID, item.collectionID)
assert.Equal(t, rowCount, item.rowCount)
assert.Equal(t, sortOrder, item.sortOrder)
}
func TestCollectionBalanceItem_GetPriority_ByRowCount(t *testing.T) {
tests := []struct {
name string
rowCount int
sortOrder string
expected int
}{
{"ByRowCount", 1000, "byrowcount", -1000},
{"Default", 500, "", -500}, // default to byrowcount
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
item := newCollectionBalanceItem(1, tt.rowCount, tt.sortOrder)
priority := item.getPriority()
assert.Equal(t, tt.expected, priority)
})
}
}
func TestCollectionBalanceItem_GetPriority_ByCollectionID(t *testing.T) {
collectionID := int64(123)
item := newCollectionBalanceItem(collectionID, 1000, "bycollectionid")
priority := item.getPriority()
assert.Equal(t, int(collectionID), priority)
}
func TestCollectionBalanceItem_SetPriority(t *testing.T) {
item := newCollectionBalanceItem(1, 100, "byrowcount")
item.setPriority(200)
assert.Equal(t, 200, item.getPriority())
}
// =============================================================================
// Collection Filtering Tests
// =============================================================================
func TestBalanceChecker_ReadyToCheck_Success(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Mock meta.GetCollection to return a valid collection
mockGetCollection := mockey.Mock(mockey.GetMethod(checker.meta.CollectionManager, "GetCollection")).Return(&meta.Collection{}).Build()
defer mockGetCollection.UnPatch()
// Mock target manager methods
mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(true).Build()
defer mockIsNextTargetExist.UnPatch()
result := checker.readyToCheck(ctx, collectionID)
assert.True(t, result)
}
func TestBalanceChecker_ReadyToCheck_NoMeta(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Mock meta.GetCollection to return nil
mockGetCollection := mockey.Mock((*meta.Meta).GetCollection).Return(nil).Build()
defer mockGetCollection.UnPatch()
// Mock target manager methods to return false
mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(false).Build()
defer mockIsNextTargetExist.UnPatch()
mockIsCurrentTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsCurrentTargetExist")).Return(false).Build()
defer mockIsCurrentTargetExist.UnPatch()
result := checker.readyToCheck(ctx, collectionID)
assert.False(t, result)
}
func TestBalanceChecker_ReadyToCheck_NoTarget(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Mock meta.GetCollection to return a valid collection
mockGetCollection := mockey.Mock((*meta.Meta).GetCollection).Return(&meta.Collection{}).Build()
defer mockGetCollection.UnPatch()
// Mock target manager methods to return false
mockIsNextTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsNextTargetExist")).Return(false).Build()
defer mockIsNextTargetExist.UnPatch()
mockIsCurrentTargetExist := mockey.Mock(mockey.GetMethod(checker.targetMgr, "IsCurrentTargetExist")).Return(false).Build()
defer mockIsCurrentTargetExist.UnPatch()
result := checker.readyToCheck(ctx, collectionID)
assert.False(t, result)
}
func TestBalanceChecker_FilterCollectionForBalance_Success(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock meta.GetAll to return collection IDs
collectionIDs := []int64{1, 2, 3}
mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build()
defer mockGetAll.UnPatch()
// Create filters that pass all collections
passAllFilter := func(ctx context.Context, collectionID int64) bool {
return true
}
result := checker.filterCollectionForBalance(ctx, passAllFilter)
assert.Equal(t, collectionIDs, result)
}
func TestBalanceChecker_FilterCollectionForBalance_WithFiltering(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock meta.GetAll to return collection IDs
collectionIDs := []int64{1, 2, 3, 4}
mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build()
defer mockGetAll.UnPatch()
// Create filters: only even numbers pass first filter, only > 2 pass second filter
evenFilter := func(ctx context.Context, collectionID int64) bool {
return collectionID%2 == 0
}
greaterThanTwoFilter := func(ctx context.Context, collectionID int64) bool {
return collectionID > 2
}
result := checker.filterCollectionForBalance(ctx, evenFilter, greaterThanTwoFilter)
// Only collection 4 should pass both filters (even AND > 2)
assert.Equal(t, []int64{4}, result)
}
func TestBalanceChecker_FilterCollectionForBalance_EmptyResult(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock meta.GetAll to return collection IDs
collectionIDs := []int64{1, 2, 3}
mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build()
defer mockGetAll.UnPatch()
// Create filter that rejects all
rejectAllFilter := func(ctx context.Context, collectionID int64) bool {
return false
}
result := checker.filterCollectionForBalance(ctx, rejectAllFilter)
assert.Empty(t, result)
}
// =============================================================================
// Queue Construction Tests
// =============================================================================
func TestBalanceChecker_ConstructStoppingBalanceQueue(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock filterCollectionForBalance result
collectionIDs := []int64{1, 2}
mockFilterCollections := mockey.Mock((*BalanceChecker).filterCollectionForBalance).Return(collectionIDs).Build()
defer mockFilterCollections.UnPatch()
// Mock target manager GetCollectionRowCount
mockGetRowCount := mockey.Mock(mockey.GetMethod(checker.targetMgr, "GetCollectionRowCount")).Return(int64(100)).Build()
defer mockGetRowCount.UnPatch()
// Mock paramtable for sort order
mockParamValue := mockey.Mock((*paramtable.ParamItem).GetValue).Return("byrowcount").Build()
defer mockParamValue.UnPatch()
result := checker.constructStoppingBalanceQueue(ctx)
assert.Equal(t, result.Len(), 2)
}
func TestBalanceChecker_ConstructNormalBalanceQueue(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock filterCollectionForBalance result
collectionIDs := []int64{1, 2}
mockFilterCollections := mockey.Mock((*BalanceChecker).filterCollectionForBalance).Return(collectionIDs).Build()
defer mockFilterCollections.UnPatch()
// Mock target manager GetCollectionRowCount
mockGetRowCount := mockey.Mock(mockey.GetMethod(checker.targetMgr, "GetCollectionRowCount")).Return(int64(100)).Build()
defer mockGetRowCount.UnPatch()
// Mock paramtable for sort order
mockParamValue := mockey.Mock((*paramtable.ParamItem).GetValue).Return("byrowcount").Build()
defer mockParamValue.UnPatch()
result := checker.constructNormalBalanceQueue(ctx)
assert.Equal(t, result.Len(), 2)
}
// =============================================================================
// Replica Getting Tests
// =============================================================================
func TestBalanceChecker_GetReplicaForStoppingBalance_WithRONodes(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Create mock replicas
replica1 := &meta.Replica{}
replica2 := &meta.Replica{}
replicas := []*meta.Replica{replica1, replica2}
// Mock ReplicaManager.GetByCollection
mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build()
defer mockGetByCollection.UnPatch()
// Mock replica methods - replica1 has RO nodes, replica2 doesn't
mockRONodesCount1 := mockey.Mock((*meta.Replica).RONodesCount).Return(1).Build()
defer mockRONodesCount1.UnPatch()
mockROSQNodesCount1 := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build()
defer mockROSQNodesCount1.UnPatch()
mockGetID1 := mockey.Mock((*meta.Replica).GetID).Return(int64(101)).Build()
defer mockGetID1.UnPatch()
// Skip streaming service mock for simplicity
result := checker.getReplicaForStoppingBalance(ctx, collectionID)
// Should return replica1 ID since it has RO nodes
assert.Contains(t, result, int64(101))
}
func TestBalanceChecker_GetReplicaForStoppingBalance_NoRONodes(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Create mock replicas
replica1 := &meta.Replica{}
replicas := []*meta.Replica{replica1}
// Mock ReplicaManager.GetByCollection
mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build()
defer mockGetByCollection.UnPatch()
// Mock replica methods - no RO nodes
mockRONodesCount := mockey.Mock((*meta.Replica).RONodesCount).Return(0).Build()
defer mockRONodesCount.UnPatch()
mockROSQNodesCount := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build()
defer mockROSQNodesCount.UnPatch()
// Skip streaming service mock for simplicity
result := checker.getReplicaForStoppingBalance(ctx, collectionID)
assert.Empty(t, result)
}
func TestBalanceChecker_GetReplicaForNormalBalance(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Create mock replicas
replica1 := &meta.Replica{}
replica2 := &meta.Replica{}
replicas := []*meta.Replica{replica1, replica2}
// Mock ReplicaManager.GetByCollection
mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build()
defer mockGetByCollection.UnPatch()
// Mock replica GetID methods
mockGetID := mockey.Mock((*meta.Replica).GetID).Return(mockey.Sequence(101).Times(1).Then(102)).Build()
defer mockGetID.UnPatch()
result := checker.getReplicaForNormalBalance(ctx, collectionID)
expectedIDs := []int64{101, 102}
assert.ElementsMatch(t, expectedIDs, result)
}
// =============================================================================
// Task Generation Tests
// =============================================================================
func TestBalanceChecker_GenerateBalanceTasksFromReplicas_EmptyReplicas(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
config := balanceConfig{}
segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, []int64{}, config)
assert.Empty(t, segmentTasks)
assert.Empty(t, channelTasks)
}
func TestBalanceChecker_GenerateBalanceTasksFromReplicas_Success(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
config := balanceConfig{
segmentTaskTimeout: 30 * time.Second,
channelTaskTimeout: 30 * time.Second,
}
replicaIDs := []int64{101}
// Create mock replica
mockReplica := &meta.Replica{}
// Mock ReplicaManager.Get
mockReplicaGet := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "Get")).Return(mockReplica).Build()
defer mockReplicaGet.UnPatch()
// Create mock balance plans
segmentPlan := balance.SegmentAssignPlan{
Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1}},
Replica: mockReplica,
From: 1,
To: 2,
}
channelPlan := balance.ChannelAssignPlan{
Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{ChannelName: "test"}},
Replica: mockReplica,
From: 1,
To: 2,
}
mockBalancer := mockey.Mock(checker.getBalancerFunc).To(func() balance.Balance {
return balance.NewScoreBasedBalancer(nil, nil, nil, nil, nil)
}).Build()
defer mockBalancer.UnPatch()
// Mock balancer.BalanceReplica
mockBalanceReplica := mockey.Mock((*balance.ScoreBasedBalancer).BalanceReplica).Return(
[]balance.SegmentAssignPlan{segmentPlan},
[]balance.ChannelAssignPlan{channelPlan},
).Build()
defer mockBalanceReplica.UnPatch()
// Mock balance.CreateSegmentTasksFromPlans
mockSegmentTask := &task.SegmentTask{}
mockCreateSegmentTasks := mockey.Mock(balance.CreateSegmentTasksFromPlans).Return([]task.Task{mockSegmentTask}).Build()
defer mockCreateSegmentTasks.UnPatch()
// Mock balance.CreateChannelTasksFromPlans
mockChannelTask := &task.ChannelTask{}
mockCreateChannelTasks := mockey.Mock(balance.CreateChannelTasksFromPlans).Return([]task.Task{mockChannelTask}).Build()
defer mockCreateChannelTasks.UnPatch()
// Mock task.SetPriority and task.SetReason
mockSetPriority := mockey.Mock(task.SetPriority).Return().Build()
defer mockSetPriority.UnPatch()
mockSetReason := mockey.Mock(task.SetReason).Return().Build()
defer mockSetReason.UnPatch()
// Mock balance.PrintNewBalancePlans
mockPrintPlans := mockey.Mock(balance.PrintNewBalancePlans).Return().Build()
defer mockPrintPlans.UnPatch()
segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, replicaIDs, config)
assert.Len(t, segmentTasks, 1)
assert.Len(t, channelTasks, 1)
assert.Equal(t, mockSegmentTask, segmentTasks[0])
assert.Equal(t, mockChannelTask, channelTasks[0])
}
func TestBalanceChecker_GenerateBalanceTasksFromReplicas_NilReplica(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
config := balanceConfig{}
replicaIDs := []int64{101}
// Mock ReplicaManager.Get to return nil
mockReplicaGet := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "Get")).Return(nil).Build()
defer mockReplicaGet.UnPatch()
segmentTasks, channelTasks := checker.generateBalanceTasksFromReplicas(ctx, replicaIDs, config)
assert.Empty(t, segmentTasks)
assert.Empty(t, channelTasks)
}
// =============================================================================
// Task Submission Tests
// =============================================================================
func TestBalanceChecker_SubmitTasks(t *testing.T) {
checker := createTestBalanceChecker()
// Create mock tasks
segmentTask := &task.SegmentTask{}
channelTask := &task.ChannelTask{}
segmentTasks := []task.Task{segmentTask}
channelTasks := []task.Task{channelTask}
// Mock scheduler.Add
mockSchedulerAdd := mockey.Mock(mockey.GetMethod(checker.scheduler, "Add")).Return(nil).Build()
defer mockSchedulerAdd.UnPatch()
checker.submitTasks(segmentTasks, channelTasks)
// Verify scheduler.Add was called for both tasks
// This is implicit verification through mockey call tracking
}
func TestBalanceChecker_SubmitTasks_EmptyTasks(t *testing.T) {
checker := createTestBalanceChecker()
// Mock scheduler.Add - should not be called
mockSchedulerAdd := mockey.Mock(mockey.GetMethod(checker.scheduler, "Add")).Return(nil).Build()
defer mockSchedulerAdd.UnPatch()
checker.submitTasks([]task.Task{}, []task.Task{})
// No assertions needed - just ensuring no panic with empty tasks
}
// =============================================================================
// Main Check Method Tests
// =============================================================================
func TestBalanceChecker_Check_InactiveChecker(t *testing.T) {
t.Run("StoppingBalanceRunsWhenInactive", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock IsActive to return false - checker is inactive
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(false).Build()
defer mockIsActive.UnPatch()
// Mock paramtable to enable stopping balance
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// First call returns true for EnableStoppingBalance
mockGetAsBool := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(true).Build()
defer mockGetAsBool.UnPatch()
// Mock loadBalanceConfig
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track whether processBalanceQueue was called for stopping balance
stoppingBalanceCalled := false
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
stoppingBalanceCalled = true
return 1, 0 // Return some tasks generated
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify stopping balance was executed even though checker is inactive
assert.True(t, stoppingBalanceCalled, "Stopping balance should run even when checker is inactive")
assert.Nil(t, result)
assert.Nil(t, checker.normalBalanceQueue, "Normal balance queue should be cleared when stopping balance generates tasks")
})
t.Run("NormalBalanceSkippedWhenInactive", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock IsActive to return false - checker is inactive
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(false).Build()
defer mockIsActive.UnPatch()
// Mock paramtable
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// First call returns false for EnableStoppingBalance, so we skip to normal balance check
mockGetAsBool := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(false).Build()
defer mockGetAsBool.UnPatch()
// Mock loadBalanceConfig
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track whether processBalanceQueue was called
processQueueCalled := false
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
processQueueCalled = true
return 0, 0
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify normal balance was NOT executed because checker is inactive
// processBalanceQueue should not be called at all since stopping balance is disabled
// and IsActive check blocks normal balance
assert.False(t, processQueueCalled, "Normal balance should not run when checker is inactive")
assert.Nil(t, result)
})
}
func TestBalanceChecker_Check_StoppingBalanceEnabled(t *testing.T) {
t.Run("StoppingBalanceGeneratesTasksAndClearsNormalQueue", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Pre-populate normal balance queue to verify it gets cleared
checker.normalBalanceQueue = createMockPriorityQueue()
checker.normalBalanceQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount"))
// Mock IsActive to return true
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build()
defer mockIsActive.UnPatch()
// Mock paramtable for enabling stopping balance
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
mockStoppingBalanceEnabled := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(true).Build()
defer mockStoppingBalanceEnabled.UnPatch()
// Mock loadBalanceConfig
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track which balance type was called
stoppingBalanceCalled := false
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
// Verify this is stopping balance by checking the function pointers
stoppingBalanceCalled = true
return 1, 0 // Generate stopping balance tasks
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify stopping balance was executed
assert.True(t, stoppingBalanceCalled, "Stopping balance should have been called")
assert.Nil(t, result, "Check should always return nil")
assert.Nil(t, checker.normalBalanceQueue, "Normal balance queue should be cleared when stopping balance generates tasks")
})
t.Run("StoppingBalanceNoTasksAllowsNormalBalance", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
checker.autoBalanceTs = time.Time{} // Allow normal balance
// Mock IsActive to return true
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build()
defer mockIsActive.UnPatch()
// Mock paramtable - stopping balance enabled, auto balance enabled
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// Return true for both EnableStoppingBalance and AutoBalance
mockGetAsBool := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(true).Build()
defer mockGetAsBool.UnPatch()
// Mock loadBalanceConfig
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
autoBalanceInterval: 1 * time.Second,
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track how many times processBalanceQueue is called
callCount := 0
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
callCount++
if callCount == 1 {
// First call: stopping balance generates no tasks
return 0, 0
}
// Second call: normal balance generates tasks
return 0, 1
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify both stopping and normal balance were attempted
assert.Equal(t, 2, callCount, "Both stopping balance and normal balance should be called")
assert.Nil(t, result)
assert.Nil(t, checker.stoppingBalanceQueue, "Stopping balance queue should be cleared when normal balance generates tasks")
})
}
func TestBalanceChecker_Check_NormalBalanceEnabled(t *testing.T) {
t.Run("NormalBalanceGeneratesTasks", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Set autoBalanceTs to allow normal balance
checker.autoBalanceTs = time.Time{}
// Pre-populate stopping balance queue to verify it gets cleared
checker.stoppingBalanceQueue = createMockPriorityQueue()
checker.stoppingBalanceQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount"))
// Mock IsActive to return true
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build()
defer mockIsActive.UnPatch()
// Mock paramtable - stopping balance disabled, auto balance enabled
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// return false for stopping balance enabled, true for auto balance enabled
mockParams := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(mockey.Sequence(false).Times(1).Then(true)).Build()
defer mockParams.UnPatch()
// Mock loadBalanceConfig
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
autoBalanceInterval: 1 * time.Second,
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track normal balance call and timestamp update
normalBalanceCalled := false
originalTs := checker.autoBalanceTs
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
normalBalanceCalled = true
return 0, 1 // Generate normal balance tasks
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify normal balance was executed
assert.True(t, normalBalanceCalled, "Normal balance should have been called")
assert.Nil(t, result, "Check should always return nil")
assert.Nil(t, checker.stoppingBalanceQueue, "Stopping balance queue should be cleared when normal balance generates tasks")
assert.True(t, checker.autoBalanceTs.After(originalTs), "autoBalanceTs should be updated when tasks are generated")
})
t.Run("NormalBalanceRespects IntervalThrottle", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Set autoBalanceTs to recent time to trigger throttling
checker.autoBalanceTs = time.Now()
// Mock IsActive to return true
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build()
defer mockIsActive.UnPatch()
// Mock paramtable - stopping balance disabled, auto balance enabled
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// return false for stopping balance enabled, true for auto balance enabled
mockParams := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(mockey.Sequence(false).Times(1).Then(true)).Build()
defer mockParams.UnPatch()
// Mock loadBalanceConfig with a large interval
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
autoBalanceInterval: 10 * time.Second, // Long interval
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track whether processBalanceQueue was called
normalBalanceCalled := false
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
normalBalanceCalled = true
return 0, 1
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify normal balance was NOT executed due to interval throttle
assert.False(t, normalBalanceCalled, "Normal balance should respect autoBalanceInterval throttle")
assert.Nil(t, result)
})
t.Run("NormalBalanceSkippedWhenAutoBalanceDisabled", func(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Set autoBalanceTs to allow normal balance
checker.autoBalanceTs = time.Time{}
// Mock IsActive to return true
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build()
defer mockIsActive.UnPatch()
// Mock paramtable - stopping balance disabled, auto balance also disabled
mockParamGet := mockey.Mock(paramtable.Get).Return(&paramtable.ComponentParam{}).Build()
defer mockParamGet.UnPatch()
// return false for both stopping balance and auto balance
mockParams := mockey.Mock((*paramtable.ParamItem).GetAsBool).Return(false).Build()
defer mockParams.UnPatch()
// Mock loadBalanceConfig
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
autoBalanceInterval: 1 * time.Second,
}
mockLoadConfig := mockey.Mock((*BalanceChecker).loadBalanceConfig).Return(config).Build()
defer mockLoadConfig.UnPatch()
// Track whether processBalanceQueue was called
normalBalanceCalled := false
mockProcessQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(context.Context, int64) []int64,
constructQueueFunc func(context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
normalBalanceCalled = true
return 0, 1
}).Build()
defer mockProcessQueue.UnPatch()
result := checker.Check(ctx)
// Verify normal balance was NOT executed because auto balance is disabled
assert.False(t, normalBalanceCalled, "Normal balance should be skipped when AutoBalance is disabled")
assert.Nil(t, result)
})
}
// =============================================================================
// ProcessBalanceQueue Tests
// =============================================================================
func TestBalanceChecker_ProcessBalanceQueue_Success(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Create mock balance config
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 3,
maxCheckCollectionCount: 5,
balanceOnMultipleCollections: true,
}
// Create mock priority queue
mockQueue := createMockPriorityQueue()
// Use real priority queue for simplicity
mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount"))
mockQueue.Push(newCollectionBalanceItem(2, 100, "byrowcount"))
mockQueue.Push(newCollectionBalanceItem(3, 100, "byrowcount"))
mockQueue.Push(newCollectionBalanceItem(4, 100, "byrowcount"))
mockQueue.Push(newCollectionBalanceItem(5, 100, "byrowcount"))
// Mock getReplicasFunc
getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 {
return []int64{101, 102}
}
// Mock constructQueueFunc
constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue {
return mockQueue
}
// Mock getQueueFunc
getQueueFunc := func() *balance.PriorityQueue {
return mockQueue
}
// Mock generateBalanceTasksFromReplicas
mockSegmentTask := &task.SegmentTask{}
mockChannelTask := &task.ChannelTask{}
mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return(
[]task.Task{mockSegmentTask}, []task.Task{mockChannelTask},
).Build()
defer mockGenerateTasks.UnPatch()
// mock submit tasks
mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build()
defer mockSubmitTasks.UnPatch()
segmentTasks, channelTasks := checker.processBalanceQueue(
ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config,
)
assert.Equal(t, 3, segmentTasks)
assert.Equal(t, 3, channelTasks)
}
func TestBalanceChecker_ProcessBalanceQueue_EmptyQueue(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 3,
}
// Create empty mock priority queue
mockQueue := createMockPriorityQueue()
// Use real priority queue for empty queue testing
getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 {
return []int64{101}
}
constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue {
return mockQueue
}
getQueueFunc := func() *balance.PriorityQueue {
return mockQueue
}
segmentTasks, channelTasks := checker.processBalanceQueue(
ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config,
)
assert.Equal(t, 0, segmentTasks)
assert.Equal(t, 0, channelTasks)
}
func TestBalanceChecker_ProcessBalanceQueue_BatchSizeLimit(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Set small batch sizes to test limits
config := balanceConfig{
segmentBatchSize: 1, // Only allow 1 segment task
channelBatchSize: 1, // Only allow 1 channel task
maxCheckCollectionCount: 10,
balanceOnMultipleCollections: true,
}
// Test batch size limits with simplified logic
// Create mock priority queue with 2 items
mockQueue := createMockPriorityQueue()
// Use real priority queue for batch size testing
mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount"))
mockQueue.Push(newCollectionBalanceItem(2, 100, "byrowcount"))
getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 {
return []int64{101}
}
constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue {
return mockQueue
}
getQueueFunc := func() *balance.PriorityQueue {
return mockQueue
}
// Mock generateBalanceTasksFromReplicas to return multiple tasks
mockSegmentTask1 := &task.SegmentTask{}
mockSegmentTask2 := &task.SegmentTask{}
mockChannelTask1 := &task.ChannelTask{}
mockChannelTask2 := &task.ChannelTask{}
mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return(mockey.Sequence(
[]task.Task{mockSegmentTask1}, []task.Task{mockChannelTask1},
).Times(1).Then(
[]task.Task{mockSegmentTask2}, []task.Task{mockChannelTask2},
)).Build()
defer mockGenerateTasks.UnPatch()
// mock submit tasks
mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build()
defer mockSubmitTasks.UnPatch()
segmentTasks, channelTasks := checker.processBalanceQueue(
ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config,
)
// Should stop after first collection due to batch size limits
assert.Equal(t, 1, segmentTasks)
assert.Equal(t, 1, channelTasks)
}
func TestBalanceChecker_ProcessBalanceQueue_MultiCollectionDisabled(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
config := balanceConfig{
segmentBatchSize: 10,
channelBatchSize: 10,
maxCheckCollectionCount: 10,
balanceOnMultipleCollections: false, // Disabled
}
mockQueue := createMockPriorityQueue()
// Use real priority queue for multi-collection testing
getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 {
return []int64{101}
}
constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue {
return mockQueue
}
getQueueFunc := func() *balance.PriorityQueue {
return mockQueue
}
mockQueue.Push(newCollectionBalanceItem(1, 100, "byrowcount"))
// Mock generateBalanceTasksFromReplicas to return tasks
mockSegmentTask := &task.SegmentTask{}
mockGenerateTasks := mockey.Mock((*BalanceChecker).generateBalanceTasksFromReplicas).Return(
[]task.Task{mockSegmentTask}, []task.Task{},
).Build()
defer mockGenerateTasks.UnPatch()
// mock submit tasks
mockSubmitTasks := mockey.Mock((*BalanceChecker).submitTasks).Return().Build()
defer mockSubmitTasks.UnPatch()
segmentTasks, channelTasks := checker.processBalanceQueue(
ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config,
)
// Should stop after first collection due to multi-collection disabled
assert.Equal(t, 1, segmentTasks)
assert.Equal(t, 0, channelTasks)
}
func TestBalanceChecker_ProcessBalanceQueue_NoReplicasToBalance(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
config := balanceConfig{
segmentBatchSize: 5,
channelBatchSize: 5,
maxCheckCollectionCount: 5,
balanceOnMultipleCollections: true,
}
mockQueue := createMockPriorityQueue()
// Use real priority queue for simplicity
// getReplicasFunc returns empty slice
getReplicasFunc := func(ctx context.Context, collectionID int64) []int64 {
return []int64{} // No replicas
}
constructQueueFunc := func(ctx context.Context) *balance.PriorityQueue {
return mockQueue
}
getQueueFunc := func() *balance.PriorityQueue {
return mockQueue
}
segmentTasks, channelTasks := checker.processBalanceQueue(
ctx, getReplicasFunc, constructQueueFunc, getQueueFunc, config,
)
assert.Equal(t, 0, segmentTasks)
assert.Equal(t, 0, channelTasks)
}
// =============================================================================
// Performance and Edge Case Tests
// =============================================================================
func TestBalanceChecker_CollectionBalanceItem_EdgeCases(t *testing.T) {
// Test with zero row count
item := newCollectionBalanceItem(1, 0, "byrowcount")
assert.Equal(t, 0, item.getPriority())
// Test with negative collection ID
item = newCollectionBalanceItem(-1, 100, "bycollectionid")
assert.Equal(t, -1, item.getPriority())
// Test with very large values
item = newCollectionBalanceItem(9223372036854775807, 2147483647, "byrowcount")
assert.Equal(t, -2147483647, item.getPriority())
// Test with empty sort order (should default to byrowcount)
item = newCollectionBalanceItem(5, 100, "")
assert.Equal(t, -100, item.getPriority())
// Test with invalid sort order (should default to byrowcount)
item = newCollectionBalanceItem(5, 100, "invalid")
assert.Equal(t, -100, item.getPriority())
}
func TestBalanceChecker_FilterCollectionForBalance_EdgeCases(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Test with empty collection list
mockGetAll := mockey.Mock((*meta.CollectionManager).GetAll).Return([]int64{}).Build()
defer mockGetAll.UnPatch()
passAllFilter := func(ctx context.Context, collectionID int64) bool {
return true
}
result := checker.filterCollectionForBalance(ctx, passAllFilter)
assert.Empty(t, result)
// Test with no filters
collectionIDs := []int64{1, 2, 3}
mockGetAll.UnPatch()
mockGetAll = mockey.Mock((*meta.CollectionManager).GetAll).Return(collectionIDs).Build()
defer mockGetAll.UnPatch()
result = checker.filterCollectionForBalance(ctx)
assert.Equal(t, collectionIDs, result) // No filters means all pass
}
// =============================================================================
// Streaming Service Tests
// =============================================================================
func TestBalanceChecker_GetReplicaForStoppingBalance_WithStreamingService(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
collectionID := int64(1)
// Create mock replicas
replica1 := &meta.Replica{}
replicas := []*meta.Replica{replica1}
// Mock ReplicaManager.GetByCollection
mockGetByCollection := mockey.Mock(mockey.GetMethod(checker.meta.ReplicaManager, "GetByCollection")).Return(replicas).Build()
defer mockGetByCollection.UnPatch()
// Mock replica methods - no RO nodes but has streaming channel RO nodes
mockRONodesCount := mockey.Mock((*meta.Replica).RONodesCount).Return(0).Build()
defer mockRONodesCount.UnPatch()
mockROSQNodesCount := mockey.Mock((*meta.Replica).ROSQNodesCount).Return(0).Build()
defer mockROSQNodesCount.UnPatch()
mockGetID := mockey.Mock((*meta.Replica).GetID).Return(int64(101)).Build()
defer mockGetID.UnPatch()
// streaming service mocks for simplicity
mockIsStreamingServiceEnabled := mockey.Mock(streamingutil.IsStreamingServiceEnabled).Return(true).Build()
defer mockIsStreamingServiceEnabled.UnPatch()
mockGetChannelRWAndRONodesFor260 := mockey.Mock(utils.GetChannelRWAndRONodesFor260).Return([]int64{}, []int64{1}).Build()
defer mockGetChannelRWAndRONodesFor260.UnPatch()
result := checker.getReplicaForStoppingBalance(ctx, collectionID)
// Should return replica1 ID since it has channel RO nodes
assert.Equal(t, []int64{101}, result)
}
// =============================================================================
// Error Handling Tests
// =============================================================================
func TestBalanceChecker_Check_TimeoutWarning(t *testing.T) {
checker := createTestBalanceChecker()
ctx := context.Background()
// Mock IsActive to return true
mockIsActive := mockey.Mock((*checkerActivation).IsActive).Return(true).Build()
defer mockIsActive.UnPatch()
mockProcessBalanceQueue := mockey.Mock((*BalanceChecker).processBalanceQueue).To(
func(ctx context.Context,
getReplicasFunc func(ctx context.Context, collectionID int64) []int64,
constructQueueFunc func(ctx context.Context) *balance.PriorityQueue,
getQueueFunc func() *balance.PriorityQueue, config balanceConfig,
) (int, int) {
time.Sleep(150 * time.Millisecond)
return 0, 0
}).Build()
defer mockProcessBalanceQueue.UnPatch()
start := time.Now()
result := checker.Check(ctx)
duration := time.Since(start)
assert.Nil(t, result)
assert.Greater(t, duration, 100*time.Millisecond) // Should trigger log
}