mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
79 lines
2.7 KiB
Go
79 lines
2.7 KiB
Go
package streaming
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
|
|
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client"
|
|
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
func TestBalancer(t *testing.T) {
|
|
scClient := mock_client.NewMockClient(t)
|
|
assignmentService := mock_client.NewMockAssignmentService(t)
|
|
scClient.EXPECT().Assignment().Return(assignmentService)
|
|
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(&types.VersionedStreamingNodeAssignments{
|
|
Assignments: map[int64]types.StreamingNodeAssignment{
|
|
1: {
|
|
NodeInfo: types.StreamingNodeInfo{ServerID: 1},
|
|
Channels: map[string]types.PChannelInfo{
|
|
"v1": {},
|
|
},
|
|
},
|
|
},
|
|
}, nil)
|
|
|
|
balancer := balancerImpl{
|
|
walAccesserImpl: &walAccesserImpl{
|
|
streamingCoordClient: scClient,
|
|
},
|
|
}
|
|
|
|
nodes, err := balancer.ListStreamingNode(context.Background())
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 1, len(nodes))
|
|
assignment, err := balancer.GetWALDistribution(context.Background(), 1)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 1, len(assignment.Channels))
|
|
|
|
assignment, err = balancer.GetWALDistribution(context.Background(), 2)
|
|
assert.True(t, errors.Is(err, merr.ErrNodeNotFound))
|
|
assert.Nil(t, assignment)
|
|
|
|
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Unset()
|
|
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(nil, errors.New("test"))
|
|
nodes, err = balancer.ListStreamingNode(context.Background())
|
|
assert.Error(t, err)
|
|
assert.Nil(t, nodes)
|
|
|
|
assignment, err = balancer.GetWALDistribution(context.Background(), 1)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, assignment)
|
|
|
|
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(&types.UpdateWALBalancePolicyResponse{}, nil)
|
|
err = balancer.SuspendRebalance(context.Background())
|
|
assert.NoError(t, err)
|
|
err = balancer.ResumeRebalance(context.Background())
|
|
assert.NoError(t, err)
|
|
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
|
|
assert.NoError(t, err)
|
|
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
|
|
assert.NoError(t, err)
|
|
|
|
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Unset()
|
|
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
|
|
err = balancer.SuspendRebalance(context.Background())
|
|
assert.Error(t, err)
|
|
err = balancer.ResumeRebalance(context.Background())
|
|
assert.Error(t, err)
|
|
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
|
|
assert.Error(t, err)
|
|
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
|
|
assert.Error(t, err)
|
|
}
|