mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: support balancer interface for streaming client to fetch streaming node information (#43969)
issue: #43968 - Add ListStreamingNode/GetWALDistribution to fetch streaming node info - Add SuspendRebalance/ResumeRebalance to enable or stop balance - Add FreezeNodeIDs/DefreezeNodeIDs to freeze target node Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
8e1ce15146
commit
082ca62ec1
@ -24,6 +24,7 @@ packages:
|
||||
interfaces:
|
||||
Client:
|
||||
BroadcastService:
|
||||
AssignmentService:
|
||||
github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast:
|
||||
interfaces:
|
||||
Watcher:
|
||||
|
||||
83
internal/distributed/streaming/balancer.go
Normal file
83
internal/distributed/streaming/balancer.go
Normal file
@ -0,0 +1,83 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
type balancerImpl struct {
|
||||
*walAccesserImpl
|
||||
}
|
||||
|
||||
// GetWALDistribution returns the wal distribution of the streaming node.
|
||||
func (b balancerImpl) ListStreamingNode(ctx context.Context) ([]types.StreamingNodeInfo, error) {
|
||||
assignments, err := b.streamingCoordClient.Assignment().GetLatestAssignments(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes := make([]types.StreamingNodeInfo, 0, len(assignments.Assignments))
|
||||
for _, assignment := range assignments.Assignments {
|
||||
nodes = append(nodes, assignment.NodeInfo)
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// GetWALDistribution returns the wal distribution of the streaming node.
|
||||
func (b balancerImpl) GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error) {
|
||||
assignments, err := b.streamingCoordClient.Assignment().GetLatestAssignments(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, assignment := range assignments.Assignments {
|
||||
if assignment.NodeInfo.ServerID == nodeID {
|
||||
return &assignment, nil
|
||||
}
|
||||
}
|
||||
return nil, merr.WrapErrNodeNotFound(nodeID, "streaming node not found")
|
||||
}
|
||||
|
||||
func (b balancerImpl) SuspendRebalance(ctx context.Context) error {
|
||||
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
|
||||
Config: &streamingpb.WALBalancePolicyConfig{
|
||||
AllowRebalance: false,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (b balancerImpl) ResumeRebalance(ctx context.Context) error {
|
||||
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
|
||||
Config: &streamingpb.WALBalancePolicyConfig{
|
||||
AllowRebalance: true,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (b balancerImpl) FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
|
||||
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
|
||||
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}},
|
||||
Nodes: &streamingpb.WALBalancePolicyNodes{
|
||||
FreezeNodeIds: nodeIDs,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (b balancerImpl) DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
|
||||
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
|
||||
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}},
|
||||
Nodes: &streamingpb.WALBalancePolicyNodes{
|
||||
DefreezeNodeIds: nodeIDs,
|
||||
},
|
||||
})
|
||||
}
|
||||
78
internal/distributed/streaming/balancer_test.go
Normal file
78
internal/distributed/streaming/balancer_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
func TestBalancer(t *testing.T) {
|
||||
scClient := mock_client.NewMockClient(t)
|
||||
assignmentService := mock_client.NewMockAssignmentService(t)
|
||||
scClient.EXPECT().Assignment().Return(assignmentService)
|
||||
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(&types.VersionedStreamingNodeAssignments{
|
||||
Assignments: map[int64]types.StreamingNodeAssignment{
|
||||
1: {
|
||||
NodeInfo: types.StreamingNodeInfo{ServerID: 1},
|
||||
Channels: map[string]types.PChannelInfo{
|
||||
"v1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
balancer := balancerImpl{
|
||||
walAccesserImpl: &walAccesserImpl{
|
||||
streamingCoordClient: scClient,
|
||||
},
|
||||
}
|
||||
|
||||
nodes, err := balancer.ListStreamingNode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assignment, err := balancer.GetWALDistribution(context.Background(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(assignment.Channels))
|
||||
|
||||
assignment, err = balancer.GetWALDistribution(context.Background(), 2)
|
||||
assert.True(t, errors.Is(err, merr.ErrNodeNotFound))
|
||||
assert.Nil(t, assignment)
|
||||
|
||||
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Unset()
|
||||
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(nil, errors.New("test"))
|
||||
nodes, err = balancer.ListStreamingNode(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, nodes)
|
||||
|
||||
assignment, err = balancer.GetWALDistribution(context.Background(), 1)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, assignment)
|
||||
|
||||
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(nil)
|
||||
err = balancer.SuspendRebalance(context.Background())
|
||||
assert.NoError(t, err)
|
||||
err = balancer.ResumeRebalance(context.Background())
|
||||
assert.NoError(t, err)
|
||||
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
|
||||
assert.NoError(t, err)
|
||||
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Unset()
|
||||
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(errors.New("test"))
|
||||
err = balancer.SuspendRebalance(context.Background())
|
||||
assert.Error(t, err)
|
||||
err = balancer.ResumeRebalance(context.Background())
|
||||
assert.Error(t, err)
|
||||
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
|
||||
assert.Error(t, err)
|
||||
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@ -81,8 +81,33 @@ type Scanner interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
// Balancer is the interface for managing the balancer of the wal.
|
||||
type Balancer interface {
|
||||
// ListStreamingNode lists the streaming node.
|
||||
ListStreamingNode(ctx context.Context) ([]types.StreamingNodeInfo, error)
|
||||
|
||||
// GetWALDistribution returns the wal distribution of the streaming node.
|
||||
GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error)
|
||||
|
||||
// SuspendRebalance suspends the rebalance of the wal.
|
||||
SuspendRebalance(ctx context.Context) error
|
||||
|
||||
// ResumeRebalance resumes the rebalance of the wal.
|
||||
ResumeRebalance(ctx context.Context) error
|
||||
|
||||
// FreezeNodeIDs freezes the streaming node.
|
||||
// The wal will not be assigned to the frozen nodes and the wal will be removed from the frozen nodes.
|
||||
FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error
|
||||
|
||||
// DefreezeNodeIDs defreezes the streaming node.
|
||||
DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error
|
||||
}
|
||||
|
||||
// WALAccesser is the interfaces to interact with the milvus write ahead log.
|
||||
type WALAccesser interface {
|
||||
// Balancer returns the balancer management of the wal.
|
||||
Balancer() Balancer
|
||||
|
||||
// WALName returns the name of the wal.
|
||||
WALName() string
|
||||
|
||||
|
||||
@ -51,6 +51,32 @@ func SetupNoopWALForTest() {
|
||||
singleton = &noopWALAccesser{}
|
||||
}
|
||||
|
||||
type noopBalancer struct{}
|
||||
|
||||
func (n *noopBalancer) ListStreamingNode(ctx context.Context) ([]types.StreamingNodeInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (n *noopBalancer) GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (n *noopBalancer) SuspendRebalance(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopBalancer) ResumeRebalance(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopBalancer) FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopBalancer) DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopLocal struct{}
|
||||
|
||||
func (n *noopLocal) GetLatestMVCCTimestampIfLocal(ctx context.Context, vchannel string) (uint64, error) {
|
||||
@ -108,6 +134,10 @@ func (n *noopTxn) Rollback(ctx context.Context) error {
|
||||
|
||||
type noopWALAccesser struct{}
|
||||
|
||||
func (n *noopWALAccesser) Balancer() Balancer {
|
||||
return &noopBalancer{}
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) WALName() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
@ -59,6 +59,10 @@ type walAccesserImpl struct {
|
||||
dispatchExecutionPool *conc.Pool[struct{}]
|
||||
}
|
||||
|
||||
func (w *walAccesserImpl) Balancer() Balancer {
|
||||
return balancerImpl{w}
|
||||
}
|
||||
|
||||
func (w *walAccesserImpl) WALName() string {
|
||||
return util.MustSelectWALName()
|
||||
}
|
||||
|
||||
@ -149,6 +149,53 @@ func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(c
|
||||
return _c
|
||||
}
|
||||
|
||||
// Balancer provides a mock function with no fields
|
||||
func (_m *MockWALAccesser) Balancer() streaming.Balancer {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Balancer")
|
||||
}
|
||||
|
||||
var r0 streaming.Balancer
|
||||
if rf, ok := ret.Get(0).(func() streaming.Balancer); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(streaming.Balancer)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockWALAccesser_Balancer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Balancer'
|
||||
type MockWALAccesser_Balancer_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Balancer is a helper method to define mock.On call
|
||||
func (_e *MockWALAccesser_Expecter) Balancer() *MockWALAccesser_Balancer_Call {
|
||||
return &MockWALAccesser_Balancer_Call{Call: _e.mock.On("Balancer")}
|
||||
}
|
||||
|
||||
func (_c *MockWALAccesser_Balancer_Call) Run(run func()) *MockWALAccesser_Balancer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWALAccesser_Balancer_Call) Return(_a0 streaming.Balancer) *MockWALAccesser_Balancer_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWALAccesser_Balancer_Call) RunAndReturn(run func() streaming.Balancer) *MockWALAccesser_Balancer_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Broadcast provides a mock function with no fields
|
||||
func (_m *MockWALAccesser) Broadcast() streaming.Broadcast {
|
||||
ret := _m.Called()
|
||||
|
||||
@ -0,0 +1,239 @@
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mock_client
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
)
|
||||
|
||||
// MockAssignmentService is an autogenerated mock type for the AssignmentService type
|
||||
type MockAssignmentService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockAssignmentService_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockAssignmentService) EXPECT() *MockAssignmentService_Expecter {
|
||||
return &MockAssignmentService_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AssignmentDiscover provides a mock function with given fields: ctx, cb
|
||||
func (_m *MockAssignmentService) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
|
||||
ret := _m.Called(ctx, cb)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AssignmentDiscover")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error); ok {
|
||||
r0 = rf(ctx, cb)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockAssignmentService_AssignmentDiscover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignmentDiscover'
|
||||
type MockAssignmentService_AssignmentDiscover_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AssignmentDiscover is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cb func(*types.VersionedStreamingNodeAssignments) error
|
||||
func (_e *MockAssignmentService_Expecter) AssignmentDiscover(ctx interface{}, cb interface{}) *MockAssignmentService_AssignmentDiscover_Call {
|
||||
return &MockAssignmentService_AssignmentDiscover_Call{Call: _e.mock.On("AssignmentDiscover", ctx, cb)}
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_AssignmentDiscover_Call) Run(run func(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error)) *MockAssignmentService_AssignmentDiscover_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(func(*types.VersionedStreamingNodeAssignments) error))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_AssignmentDiscover_Call) Return(_a0 error) *MockAssignmentService_AssignmentDiscover_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_AssignmentDiscover_Call) RunAndReturn(run func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error) *MockAssignmentService_AssignmentDiscover_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetLatestAssignments provides a mock function with given fields: ctx
|
||||
func (_m *MockAssignmentService) GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetLatestAssignments")
|
||||
}
|
||||
|
||||
var r0 *types.VersionedStreamingNodeAssignments
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) (*types.VersionedStreamingNodeAssignments, error)); ok {
|
||||
return rf(ctx)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context) *types.VersionedStreamingNodeAssignments); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*types.VersionedStreamingNodeAssignments)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
|
||||
r1 = rf(ctx)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockAssignmentService_GetLatestAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestAssignments'
|
||||
type MockAssignmentService_GetLatestAssignments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetLatestAssignments is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *MockAssignmentService_Expecter) GetLatestAssignments(ctx interface{}) *MockAssignmentService_GetLatestAssignments_Call {
|
||||
return &MockAssignmentService_GetLatestAssignments_Call{Call: _e.mock.On("GetLatestAssignments", ctx)}
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_GetLatestAssignments_Call) Run(run func(ctx context.Context)) *MockAssignmentService_GetLatestAssignments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_GetLatestAssignments_Call) Return(_a0 *types.VersionedStreamingNodeAssignments, _a1 error) *MockAssignmentService_GetLatestAssignments_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_GetLatestAssignments_Call) RunAndReturn(run func(context.Context) (*types.VersionedStreamingNodeAssignments, error)) *MockAssignmentService_GetLatestAssignments_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ReportAssignmentError provides a mock function with given fields: ctx, pchannel, err
|
||||
func (_m *MockAssignmentService) ReportAssignmentError(ctx context.Context, pchannel types.PChannelInfo, err error) error {
|
||||
ret := _m.Called(ctx, pchannel, err)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ReportAssignmentError")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfo, error) error); ok {
|
||||
r0 = rf(ctx, pchannel, err)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockAssignmentService_ReportAssignmentError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportAssignmentError'
|
||||
type MockAssignmentService_ReportAssignmentError_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ReportAssignmentError is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - pchannel types.PChannelInfo
|
||||
// - err error
|
||||
func (_e *MockAssignmentService_Expecter) ReportAssignmentError(ctx interface{}, pchannel interface{}, err interface{}) *MockAssignmentService_ReportAssignmentError_Call {
|
||||
return &MockAssignmentService_ReportAssignmentError_Call{Call: _e.mock.On("ReportAssignmentError", ctx, pchannel, err)}
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_ReportAssignmentError_Call) Run(run func(ctx context.Context, pchannel types.PChannelInfo, err error)) *MockAssignmentService_ReportAssignmentError_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(types.PChannelInfo), args[2].(error))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_ReportAssignmentError_Call) Return(_a0 error) *MockAssignmentService_ReportAssignmentError_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_ReportAssignmentError_Call) RunAndReturn(run func(context.Context, types.PChannelInfo, error) error) *MockAssignmentService_ReportAssignmentError_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateWALBalancePolicy provides a mock function with given fields: ctx, req
|
||||
func (_m *MockAssignmentService) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateWALBalancePolicy")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockAssignmentService_UpdateWALBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWALBalancePolicy'
|
||||
type MockAssignmentService_UpdateWALBalancePolicy_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateWALBalancePolicy is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *streamingpb.UpdateWALBalancePolicyRequest
|
||||
func (_e *MockAssignmentService_Expecter) UpdateWALBalancePolicy(ctx interface{}, req interface{}) *MockAssignmentService_UpdateWALBalancePolicy_Call {
|
||||
return &MockAssignmentService_UpdateWALBalancePolicy_Call{Call: _e.mock.On("UpdateWALBalancePolicy", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Run(run func(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest)) *MockAssignmentService_UpdateWALBalancePolicy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Return(_a0 error) *MockAssignmentService_UpdateWALBalancePolicy_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error) *MockAssignmentService_UpdateWALBalancePolicy_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockAssignmentService creates a new instance of MockAssignmentService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockAssignmentService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockAssignmentService {
|
||||
mock := &MockAssignmentService{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@ -5,9 +5,11 @@ package mock_balancer
|
||||
import (
|
||||
context "context"
|
||||
|
||||
syncutil "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
syncutil "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
|
||||
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
|
||||
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
@ -241,6 +243,53 @@ func (_c *MockBalancer_Trigger_Call) RunAndReturn(run func(context.Context) erro
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateBalancePolicy provides a mock function with given fields: ctx, req
|
||||
func (_m *MockBalancer) UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateBalancePolicy")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockBalancer_UpdateBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBalancePolicy'
|
||||
type MockBalancer_UpdateBalancePolicy_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateBalancePolicy is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *streamingpb.UpdateWALBalancePolicyRequest
|
||||
func (_e *MockBalancer_Expecter) UpdateBalancePolicy(ctx interface{}, req interface{}) *MockBalancer_UpdateBalancePolicy_Call {
|
||||
return &MockBalancer_UpdateBalancePolicy_Call{Call: _e.mock.On("UpdateBalancePolicy", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_UpdateBalancePolicy_Call) Run(run func(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest)) *MockBalancer_UpdateBalancePolicy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_UpdateBalancePolicy_Call) Return(_a0 error) *MockBalancer_UpdateBalancePolicy_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockBalancer_UpdateBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error) *MockBalancer_UpdateBalancePolicy_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// WatchChannelAssignments provides a mock function with given fields: ctx, cb
|
||||
func (_m *MockBalancer) WatchChannelAssignments(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
|
||||
ret := _m.Called(ctx, cb)
|
||||
|
||||
@ -47,6 +47,30 @@ type AssignmentServiceImpl struct {
|
||||
logger *log.MLogger
|
||||
}
|
||||
|
||||
// GetLatestAssignments returns the latest assignment discovery result.
|
||||
func (c *AssignmentServiceImpl) GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
|
||||
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
|
||||
return nil, status.NewOnShutdownError("assignment service client is closing")
|
||||
}
|
||||
defer c.lifetime.Done()
|
||||
|
||||
return c.watcher.GetLatestDiscover(ctx)
|
||||
}
|
||||
|
||||
func (c *AssignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error {
|
||||
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
|
||||
return status.NewOnShutdownError("assignment service client is closing")
|
||||
}
|
||||
defer c.lifetime.Done()
|
||||
|
||||
service, err := c.service.GetService(c.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = service.UpdateWALBalancePolicy(ctx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
// AssignmentDiscover watches the assignment discovery.
|
||||
func (c *AssignmentServiceImpl) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
|
||||
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
|
||||
|
||||
@ -24,6 +24,7 @@ func TestAssignmentService(t *testing.T) {
|
||||
s.EXPECT().GetService(mock.Anything).Return(c, nil)
|
||||
cc := mock_streamingpb.NewMockStreamingCoordAssignmentService_AssignmentDiscoverClient(t)
|
||||
c.EXPECT().AssignmentDiscover(mock.Anything).Return(cc, nil)
|
||||
c.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(&streamingpb.UpdateWALBalancePolicyResponse{}, nil)
|
||||
k := 0
|
||||
closeCh := make(chan struct{})
|
||||
cc.EXPECT().Send(mock.Anything).Return(nil)
|
||||
@ -93,6 +94,13 @@ func TestAssignmentService(t *testing.T) {
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.True(t, finalAssignments.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3}))
|
||||
|
||||
assign, err := assignmentService.GetLatestAssignments(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, assign.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3}))
|
||||
|
||||
err = assignmentService.UpdateWALBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test"))
|
||||
|
||||
// Repeated report error at the same term should be ignored.
|
||||
@ -114,4 +122,12 @@ func TestAssignmentService(t *testing.T) {
|
||||
err = assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test"))
|
||||
se = status.AsStreamingError(err)
|
||||
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code)
|
||||
|
||||
assignmentService.GetLatestAssignments(ctx)
|
||||
se = status.AsStreamingError(err)
|
||||
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code)
|
||||
|
||||
assignmentService.UpdateWALBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{})
|
||||
se = status.AsStreamingError(err)
|
||||
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code)
|
||||
}
|
||||
|
||||
@ -30,6 +30,18 @@ type watcher struct {
|
||||
lastVersionedAssignment types.VersionedStreamingNodeAssignments
|
||||
}
|
||||
|
||||
func (w *watcher) GetLatestDiscover(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
|
||||
w.cond.L.Lock()
|
||||
for w.lastVersionedAssignment.Version.Global == -1 && w.lastVersionedAssignment.Version.Local == -1 {
|
||||
if err := w.cond.Wait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
last := w.lastVersionedAssignment
|
||||
w.cond.L.Unlock()
|
||||
return &last, nil
|
||||
}
|
||||
|
||||
// AssignmentDiscover watches the assignment discovery.
|
||||
func (w *watcher) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
|
||||
w.cond.L.Lock()
|
||||
|
||||
@ -32,6 +32,12 @@ var _ Client = (*clientImpl)(nil)
|
||||
type AssignmentService interface {
|
||||
// AssignmentDiscover is used to watches the assignment discovery.
|
||||
types.AssignmentDiscoverWatcher
|
||||
|
||||
// GetLatestAssignments returns the latest assignment discovery result.
|
||||
GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error)
|
||||
|
||||
// UpdateWALBalancePolicy is used to update the WAL balance policy.
|
||||
UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error
|
||||
}
|
||||
|
||||
// BroadcastService is the interface of broadcast service.
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
@ -20,6 +21,9 @@ var (
|
||||
// Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency.
|
||||
// Balancer should be thread safe.
|
||||
type Balancer interface {
|
||||
// UpdateBalancePolicy update the balance policy.
|
||||
UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error
|
||||
|
||||
// RegisterStreamingEnabledNotifier registers a notifier into the balancer.
|
||||
// If the error is returned, the balancer is closed.
|
||||
// Otherwise, the following rules are applied:
|
||||
|
||||
@ -50,6 +50,7 @@ func RecoverBalancer(
|
||||
policy: policy,
|
||||
reqCh: make(chan *request, 5),
|
||||
backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
|
||||
freezeNodes: typeutil.NewSet[int64](),
|
||||
}
|
||||
b.SetLogger(logger)
|
||||
ready260Future, err := b.checkIfAllNodeGreaterThan260AndWatch(ctx)
|
||||
@ -71,6 +72,7 @@ type balancerImpl struct {
|
||||
policy Policy // policy is the balance policy, TODO: should be dynamic in future.
|
||||
reqCh chan *request // reqCh is the request channel, send the operation to background task.
|
||||
backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] // backgroundTaskNotifier is used to conmunicate with the background task.
|
||||
freezeNodes typeutil.Set[int64] // freezeNodes is the nodes that will be frozen, no more wal will be assigned to these nodes and wal will be removed from these nodes.
|
||||
}
|
||||
|
||||
// RegisterStreamingEnabledNotifier registers a notifier into the balancer.
|
||||
@ -95,6 +97,19 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers
|
||||
return b.channelMetaManager.WatchAssignmentResult(ctx, cb)
|
||||
}
|
||||
|
||||
// UpdateBalancePolicy update the balance policy.
|
||||
func (b *balancerImpl) UpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) error {
|
||||
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
|
||||
return status.NewOnShutdownError("balancer is closing")
|
||||
}
|
||||
defer b.lifetime.Done()
|
||||
|
||||
ctx, cancel := contextutil.MergeContext(ctx, b.ctx)
|
||||
defer cancel()
|
||||
return b.sendRequestAndWaitFinish(ctx, newOpUpdateBalancePolicy(ctx, req))
|
||||
}
|
||||
|
||||
// MarkAsUnavailable mark the pchannels as unavailable.
|
||||
func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error {
|
||||
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
|
||||
return status.NewOnShutdownError("balancer is closing")
|
||||
@ -330,9 +345,9 @@ func (b *balancerImpl) balance(ctx context.Context) (bool, error) {
|
||||
pchannelView := b.channelMetaManager.CurrentPChannelsView()
|
||||
|
||||
b.Logger().Info("collect all status...")
|
||||
nodeStatus, err := resource.Resource().StreamingNodeManagerClient().CollectAllStatus(ctx)
|
||||
nodeStatus, err := b.fetchStreamingNodeStatus(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "fail to collect all status")
|
||||
return false, err
|
||||
}
|
||||
|
||||
// call the balance strategy to generate the expected layout.
|
||||
@ -360,6 +375,29 @@ func (b *balancerImpl) balance(ctx context.Context) (bool, error) {
|
||||
return true, b.applyBalanceResultToStreamingNode(ctx, modifiedChannels)
|
||||
}
|
||||
|
||||
// fetchStreamingNodeStatus fetch the streaming node status.
|
||||
func (b *balancerImpl) fetchStreamingNodeStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) {
|
||||
nodeStatus, err := resource.Resource().StreamingNodeManagerClient().CollectAllStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fail to collect all status")
|
||||
}
|
||||
|
||||
// mark the frozen node as frozen in the node status.
|
||||
for _, node := range nodeStatus {
|
||||
if b.freezeNodes.Contain(node.ServerID) && node.IsHealthy() {
|
||||
node.Err = types.ErrFrozen
|
||||
}
|
||||
}
|
||||
// clean up the freeze node that has been removed from session.
|
||||
for serverID := range b.freezeNodes {
|
||||
if _, ok := nodeStatus[serverID]; !ok {
|
||||
b.Logger().Info("freeze node has been removed from session", zap.Int64("serverID", serverID))
|
||||
b.freezeNodes.Remove(serverID)
|
||||
}
|
||||
}
|
||||
return nodeStatus, nil
|
||||
}
|
||||
|
||||
// applyBalanceResultToStreamingNode apply the balance result to streaming node.
|
||||
func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, modifiedChannels map[types.ChannelID]*channel.PChannelMeta) error {
|
||||
b.Logger().Info("balance result need to be applied...", zap.Int("modifiedChannelCount", len(modifiedChannels)))
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"go.uber.org/atomic"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
|
||||
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_manager"
|
||||
@ -18,10 +19,10 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
|
||||
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
|
||||
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
||||
@ -30,12 +31,7 @@ import (
|
||||
|
||||
func TestBalancer(t *testing.T) {
|
||||
paramtable.Init()
|
||||
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
|
||||
assert.NoError(t, err)
|
||||
defer etcd.StopEtcdServer()
|
||||
|
||||
etcdClient, err := etcd.GetEmbedEtcdClient()
|
||||
assert.NoError(t, err)
|
||||
etcdClient, _ := kvfactory.GetEtcdAndPath()
|
||||
channel.ResetStaticPChannelStatsManager()
|
||||
channel.RecoverPChannelStatsManager([]string{})
|
||||
|
||||
@ -184,18 +180,65 @@ func TestBalancer(t *testing.T) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.False(t, f.Ready())
|
||||
|
||||
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
|
||||
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
|
||||
Config: &streamingpb.WALBalancePolicyConfig{
|
||||
AllowRebalance: false,
|
||||
},
|
||||
Nodes: &streamingpb.WALBalancePolicyNodes{
|
||||
FreezeNodeIds: []int64{1},
|
||||
DefreezeNodeIds: []int64{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
|
||||
b.Trigger(ctx)
|
||||
err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
|
||||
for _, relation := range relations {
|
||||
if relation.Node.ServerID == 1 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return doneErr
|
||||
})
|
||||
assert.ErrorIs(t, err, doneErr)
|
||||
|
||||
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
|
||||
Config: &streamingpb.WALBalancePolicyConfig{
|
||||
AllowRebalance: true,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
|
||||
},
|
||||
})
|
||||
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
|
||||
assert.NoError(t, err)
|
||||
b.Trigger(ctx)
|
||||
|
||||
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
|
||||
Config: &streamingpb.WALBalancePolicyConfig{
|
||||
AllowRebalance: false,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{},
|
||||
},
|
||||
Nodes: &streamingpb.WALBalancePolicyNodes{
|
||||
FreezeNodeIds: []int64{},
|
||||
DefreezeNodeIds: []int64{1},
|
||||
},
|
||||
})
|
||||
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
|
||||
assert.NoError(t, err)
|
||||
b.Trigger(ctx)
|
||||
|
||||
b.Close()
|
||||
assert.ErrorIs(t, f.Get(), balancer.ErrBalancerClosed)
|
||||
}
|
||||
|
||||
func TestBalancer_WithRecoveryLag(t *testing.T) {
|
||||
paramtable.Init()
|
||||
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
|
||||
assert.NoError(t, err)
|
||||
defer etcd.StopEtcdServer()
|
||||
|
||||
etcdClient, err := etcd.GetEmbedEtcdClient()
|
||||
assert.NoError(t, err)
|
||||
etcdClient, _ := kvfactory.GetEtcdAndPath()
|
||||
channel.ResetStaticPChannelStatsManager()
|
||||
channel.RecoverPChannelStatsManager([]string{})
|
||||
|
||||
|
||||
@ -2,8 +2,12 @@ package balancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
)
|
||||
|
||||
@ -17,6 +21,42 @@ type request struct {
|
||||
// requestApply is a request operation to be executed.
|
||||
type requestApply func(impl *balancerImpl)
|
||||
|
||||
// newOpUpdateBalancePolicy is a operation to update the balance policy.
|
||||
func newOpUpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) *request {
|
||||
future := syncutil.NewFuture[error]()
|
||||
return &request{
|
||||
ctx: ctx,
|
||||
apply: func(impl *balancerImpl) {
|
||||
if req.UpdateMask != nil {
|
||||
// if there's a update mask, only update the fields in the update mask.
|
||||
for _, field := range req.UpdateMask.Paths {
|
||||
switch field {
|
||||
case types.UpdateMaskPathWALBalancePolicyAllowRebalance:
|
||||
updateAllowRebalance(impl, req.GetConfig().GetAllowRebalance())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// otherwise update all fields.
|
||||
updateAllowRebalance(impl, req.GetConfig().GetAllowRebalance())
|
||||
}
|
||||
// apply the freeze streaming nodes.
|
||||
if len(req.GetNodes().GetFreezeNodeIds()) > 0 || len(req.GetNodes().GetDefreezeNodeIds()) > 0 {
|
||||
impl.Logger().Info("update freeze nodes", zap.Int64s("freezeNodeIDs", req.GetNodes().GetFreezeNodeIds()), zap.Int64s("defreezeNodeIDs", req.GetNodes().GetDefreezeNodeIds()))
|
||||
impl.freezeNodes.Insert(req.GetNodes().GetFreezeNodeIds()...)
|
||||
impl.freezeNodes.Remove(req.GetNodes().GetDefreezeNodeIds()...)
|
||||
}
|
||||
future.Set(nil)
|
||||
},
|
||||
future: future,
|
||||
}
|
||||
}
|
||||
|
||||
// updateAllowRebalance update the allow rebalance.
|
||||
func updateAllowRebalance(impl *balancerImpl, allowRebalance bool) {
|
||||
old := paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.SwapTempValue(strconv.FormatBool(allowRebalance))
|
||||
impl.Logger().Info("update allow_rebalance", zap.Bool("new", allowRebalance), zap.String("old", old))
|
||||
}
|
||||
|
||||
// newOpMarkAsUnavailable is a operation to mark some channels as unavailable.
|
||||
func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) *request {
|
||||
future := syncutil.NewFuture[error]()
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
|
||||
@ -44,3 +46,16 @@ func (s *assignmentServiceImpl) AssignmentDiscover(server streamingpb.StreamingC
|
||||
}
|
||||
return discover.NewAssignmentDiscoverServer(balancer, server).Execute()
|
||||
}
|
||||
|
||||
// UpdateWALBalancePolicy is used to update the WAL balance policy.
|
||||
func (s *assignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) {
|
||||
balancer, err := s.balancer.GetWithContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = balancer.UpdateBalancePolicy(ctx, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &streamingpb.UpdateWALBalancePolicyResponse{}, nil
|
||||
}
|
||||
|
||||
@ -98,6 +98,80 @@ func (_c *MockStreamingCoordAssignmentServiceClient_AssignmentDiscover_Call) Run
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateWALBalancePolicy provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockStreamingCoordAssignmentServiceClient) UpdateWALBalancePolicy(ctx context.Context, in *streamingpb.UpdateWALBalancePolicyRequest, opts ...grpc.CallOption) (*streamingpb.UpdateWALBalancePolicyResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateWALBalancePolicy")
|
||||
}
|
||||
|
||||
var r0 *streamingpb.UpdateWALBalancePolicyResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) (*streamingpb.UpdateWALBalancePolicyResponse, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) *streamingpb.UpdateWALBalancePolicyResponse); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*streamingpb.UpdateWALBalancePolicyResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWALBalancePolicy'
|
||||
type MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateWALBalancePolicy is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *streamingpb.UpdateWALBalancePolicyRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockStreamingCoordAssignmentServiceClient_Expecter) UpdateWALBalancePolicy(ctx interface{}, in interface{}, opts ...interface{}) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
|
||||
return &MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call{Call: _e.mock.On("UpdateWALBalancePolicy",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call) Run(run func(ctx context.Context, in *streamingpb.UpdateWALBalancePolicyRequest, opts ...grpc.CallOption)) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call) Return(_a0 *streamingpb.UpdateWALBalancePolicyResponse, _a1 error) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) (*streamingpb.UpdateWALBalancePolicyResponse, error)) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockStreamingCoordAssignmentServiceClient creates a new instance of MockStreamingCoordAssignmentServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockStreamingCoordAssignmentServiceClient(t interface {
|
||||
|
||||
@ -10,6 +10,7 @@ import "milvus.proto";
|
||||
import "schema.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
import "google/protobuf/any.proto";
|
||||
import "google/protobuf/field_mask.proto";
|
||||
|
||||
//
|
||||
// Common
|
||||
@ -143,6 +144,10 @@ message BroadcastAckResponse {
|
||||
// Server: log coord. Running on every log node.
|
||||
// Client: all log publish/consuming node.
|
||||
service StreamingCoordAssignmentService {
|
||||
// UpdateWALBalancePolicy is used to update the WAL balance policy.
|
||||
// The policy is used to control the balance of the WAL.
|
||||
rpc UpdateWALBalancePolicy(UpdateWALBalancePolicyRequest) returns (UpdateWALBalancePolicyResponse) {};
|
||||
|
||||
// AssignmentDiscover is used to discover all log nodes managed by the
|
||||
// streamingcoord. Channel assignment information will be pushed to client
|
||||
// by stream.
|
||||
@ -150,6 +155,24 @@ service StreamingCoordAssignmentService {
|
||||
returns (stream AssignmentDiscoverResponse) {}
|
||||
}
|
||||
|
||||
// UpdateWALBalancePolicyRequest is the request to update the WAL balance policy.
|
||||
message UpdateWALBalancePolicyRequest {
|
||||
WALBalancePolicyConfig config = 1;
|
||||
WALBalancePolicyNodes nodes = 2;
|
||||
google.protobuf.FieldMask update_mask = 3;
|
||||
}
|
||||
|
||||
message WALBalancePolicyConfig {
|
||||
bool allow_rebalance = 1;
|
||||
}
|
||||
|
||||
message WALBalancePolicyNodes {
|
||||
repeated int64 freeze_node_ids = 1; // nodes that will be frozen.
|
||||
repeated int64 defreeze_node_ids = 2; // nodes that will be defrozen.
|
||||
}
|
||||
|
||||
message UpdateWALBalancePolicyResponse {}
|
||||
|
||||
// AssignmentDiscoverRequest is the request of Discovery
|
||||
message AssignmentDiscoverRequest {
|
||||
oneof command {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -239,6 +239,7 @@ var StreamingCoordBroadcastService_ServiceDesc = grpc.ServiceDesc{
|
||||
}
|
||||
|
||||
const (
|
||||
StreamingCoordAssignmentService_UpdateWALBalancePolicy_FullMethodName = "/milvus.proto.streaming.StreamingCoordAssignmentService/UpdateWALBalancePolicy"
|
||||
StreamingCoordAssignmentService_AssignmentDiscover_FullMethodName = "/milvus.proto.streaming.StreamingCoordAssignmentService/AssignmentDiscover"
|
||||
)
|
||||
|
||||
@ -246,6 +247,9 @@ const (
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type StreamingCoordAssignmentServiceClient interface {
|
||||
// UpdateWALBalancePolicy is used to update the WAL balance policy.
|
||||
// The policy is used to control the balance of the WAL.
|
||||
UpdateWALBalancePolicy(ctx context.Context, in *UpdateWALBalancePolicyRequest, opts ...grpc.CallOption) (*UpdateWALBalancePolicyResponse, error)
|
||||
// AssignmentDiscover is used to discover all log nodes managed by the
|
||||
// streamingcoord. Channel assignment information will be pushed to client
|
||||
// by stream.
|
||||
@ -260,6 +264,15 @@ func NewStreamingCoordAssignmentServiceClient(cc grpc.ClientConnInterface) Strea
|
||||
return &streamingCoordAssignmentServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *streamingCoordAssignmentServiceClient) UpdateWALBalancePolicy(ctx context.Context, in *UpdateWALBalancePolicyRequest, opts ...grpc.CallOption) (*UpdateWALBalancePolicyResponse, error) {
|
||||
out := new(UpdateWALBalancePolicyResponse)
|
||||
err := c.cc.Invoke(ctx, StreamingCoordAssignmentService_UpdateWALBalancePolicy_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *streamingCoordAssignmentServiceClient) AssignmentDiscover(ctx context.Context, opts ...grpc.CallOption) (StreamingCoordAssignmentService_AssignmentDiscoverClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &StreamingCoordAssignmentService_ServiceDesc.Streams[0], StreamingCoordAssignmentService_AssignmentDiscover_FullMethodName, opts...)
|
||||
if err != nil {
|
||||
@ -295,6 +308,9 @@ func (x *streamingCoordAssignmentServiceAssignmentDiscoverClient) Recv() (*Assig
|
||||
// All implementations should embed UnimplementedStreamingCoordAssignmentServiceServer
|
||||
// for forward compatibility
|
||||
type StreamingCoordAssignmentServiceServer interface {
|
||||
// UpdateWALBalancePolicy is used to update the WAL balance policy.
|
||||
// The policy is used to control the balance of the WAL.
|
||||
UpdateWALBalancePolicy(context.Context, *UpdateWALBalancePolicyRequest) (*UpdateWALBalancePolicyResponse, error)
|
||||
// AssignmentDiscover is used to discover all log nodes managed by the
|
||||
// streamingcoord. Channel assignment information will be pushed to client
|
||||
// by stream.
|
||||
@ -305,6 +321,9 @@ type StreamingCoordAssignmentServiceServer interface {
|
||||
type UnimplementedStreamingCoordAssignmentServiceServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedStreamingCoordAssignmentServiceServer) UpdateWALBalancePolicy(context.Context, *UpdateWALBalancePolicyRequest) (*UpdateWALBalancePolicyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateWALBalancePolicy not implemented")
|
||||
}
|
||||
func (UnimplementedStreamingCoordAssignmentServiceServer) AssignmentDiscover(StreamingCoordAssignmentService_AssignmentDiscoverServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method AssignmentDiscover not implemented")
|
||||
}
|
||||
@ -320,6 +339,24 @@ func RegisterStreamingCoordAssignmentServiceServer(s grpc.ServiceRegistrar, srv
|
||||
s.RegisterService(&StreamingCoordAssignmentService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _StreamingCoordAssignmentService_UpdateWALBalancePolicy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(UpdateWALBalancePolicyRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(StreamingCoordAssignmentServiceServer).UpdateWALBalancePolicy(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: StreamingCoordAssignmentService_UpdateWALBalancePolicy_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(StreamingCoordAssignmentServiceServer).UpdateWALBalancePolicy(ctx, req.(*UpdateWALBalancePolicyRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _StreamingCoordAssignmentService_AssignmentDiscover_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(StreamingCoordAssignmentServiceServer).AssignmentDiscover(&streamingCoordAssignmentServiceAssignmentDiscoverServer{stream})
|
||||
}
|
||||
@ -352,7 +389,12 @@ func (x *streamingCoordAssignmentServiceAssignmentDiscoverServer) Recv() (*Assig
|
||||
var StreamingCoordAssignmentService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "milvus.proto.streaming.StreamingCoordAssignmentService",
|
||||
HandlerType: (*StreamingCoordAssignmentServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "UpdateWALBalancePolicy",
|
||||
Handler: _StreamingCoordAssignmentService_UpdateWALBalancePolicy_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "AssignmentDiscover",
|
||||
|
||||
9
pkg/streaming/util/types/balance_config.go
Normal file
9
pkg/streaming/util/types/balance_config.go
Normal file
@ -0,0 +1,9 @@
|
||||
package types
|
||||
|
||||
import "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
|
||||
const (
|
||||
UpdateMaskPathWALBalancePolicyAllowRebalance = "config.allow_rebalance"
|
||||
)
|
||||
|
||||
type UpdateWALBalancePolicyRequest = streamingpb.UpdateWALBalancePolicyRequest
|
||||
@ -13,6 +13,7 @@ import (
|
||||
var (
|
||||
ErrStopping = errors.New("streaming node is stopping")
|
||||
ErrNotAlive = errors.New("streaming node is not alive")
|
||||
ErrFrozen = errors.New("streaming node is frozen")
|
||||
)
|
||||
|
||||
// AssignmentDiscoverWatcher is the interface for watching the assignment discovery.
|
||||
|
||||
@ -160,7 +160,6 @@ func (pi *ParamItem) getWithRaw() (result, raw string, err error) {
|
||||
|
||||
// SetTempValue set the value for this ParamItem,
|
||||
// Once value set, ParamItem will use the value instead of underlying config manager.
|
||||
// Usage: should only use for unittest, swap empty string will remove the value.
|
||||
func (pi *ParamItem) SwapTempValue(s string) string {
|
||||
if s == "" {
|
||||
if old := pi.tempValue.Swap(nil); old != nil {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user