Zhen Ye d0e3a33c37
enhance: add IsRebalanceSuspended interface for wal balancer (#44026)
issue: #43968

Signed-off-by: chyezh <chyezh@outlook.com>
2025-08-24 09:19:47 +08:00

79 lines
2.7 KiB
Go

package streaming
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func TestBalancer(t *testing.T) {
scClient := mock_client.NewMockClient(t)
assignmentService := mock_client.NewMockAssignmentService(t)
scClient.EXPECT().Assignment().Return(assignmentService)
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(&types.VersionedStreamingNodeAssignments{
Assignments: map[int64]types.StreamingNodeAssignment{
1: {
NodeInfo: types.StreamingNodeInfo{ServerID: 1},
Channels: map[string]types.PChannelInfo{
"v1": {},
},
},
},
}, nil)
balancer := balancerImpl{
walAccesserImpl: &walAccesserImpl{
streamingCoordClient: scClient,
},
}
nodes, err := balancer.ListStreamingNode(context.Background())
assert.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assignment, err := balancer.GetWALDistribution(context.Background(), 1)
assert.NoError(t, err)
assert.Equal(t, 1, len(assignment.Channels))
assignment, err = balancer.GetWALDistribution(context.Background(), 2)
assert.True(t, errors.Is(err, merr.ErrNodeNotFound))
assert.Nil(t, assignment)
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Unset()
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(nil, errors.New("test"))
nodes, err = balancer.ListStreamingNode(context.Background())
assert.Error(t, err)
assert.Nil(t, nodes)
assignment, err = balancer.GetWALDistribution(context.Background(), 1)
assert.Error(t, err)
assert.Nil(t, assignment)
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(&types.UpdateWALBalancePolicyResponse{}, nil)
err = balancer.SuspendRebalance(context.Background())
assert.NoError(t, err)
err = balancer.ResumeRebalance(context.Background())
assert.NoError(t, err)
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
assert.NoError(t, err)
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
assert.NoError(t, err)
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Unset()
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
err = balancer.SuspendRebalance(context.Background())
assert.Error(t, err)
err = balancer.ResumeRebalance(context.Background())
assert.Error(t, err)
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
assert.Error(t, err)
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
assert.Error(t, err)
}