enhance: support balancer interface for streaming client to fetch streaming node information (#43969)

issue: #43968

- Add ListStreamingNode/GetWALDistribution to  fetch streaming node info
- Add SuspendRebalance/ResumeRebalance to enable or stop balance
- Add FreezeNodeIDs/DefreezeNodeIDs to freeze target node

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-08-21 15:55:47 +08:00 committed by GitHub
parent 8e1ce15146
commit 082ca62ec1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1745 additions and 541 deletions

View File

@ -24,6 +24,7 @@ packages:
interfaces: interfaces:
Client: Client:
BroadcastService: BroadcastService:
AssignmentService:
github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast: github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast:
interfaces: interfaces:
Watcher: Watcher:

View File

@ -0,0 +1,83 @@
package streaming
import (
"context"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type balancerImpl struct {
*walAccesserImpl
}
// GetWALDistribution returns the wal distribution of the streaming node.
func (b balancerImpl) ListStreamingNode(ctx context.Context) ([]types.StreamingNodeInfo, error) {
assignments, err := b.streamingCoordClient.Assignment().GetLatestAssignments(ctx)
if err != nil {
return nil, err
}
nodes := make([]types.StreamingNodeInfo, 0, len(assignments.Assignments))
for _, assignment := range assignments.Assignments {
nodes = append(nodes, assignment.NodeInfo)
}
return nodes, nil
}
// GetWALDistribution returns the wal distribution of the streaming node.
func (b balancerImpl) GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error) {
assignments, err := b.streamingCoordClient.Assignment().GetLatestAssignments(ctx)
if err != nil {
return nil, err
}
for _, assignment := range assignments.Assignments {
if assignment.NodeInfo.ServerID == nodeID {
return &assignment, nil
}
}
return nil, merr.WrapErrNodeNotFound(nodeID, "streaming node not found")
}
func (b balancerImpl) SuspendRebalance(ctx context.Context) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: false,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
},
})
}
func (b balancerImpl) ResumeRebalance(ctx context.Context) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: true,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
},
})
}
func (b balancerImpl) FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}},
Nodes: &streamingpb.WALBalancePolicyNodes{
FreezeNodeIds: nodeIDs,
},
})
}
func (b balancerImpl) DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}},
Nodes: &streamingpb.WALBalancePolicyNodes{
DefreezeNodeIds: nodeIDs,
},
})
}

View File

@ -0,0 +1,78 @@
package streaming
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func TestBalancer(t *testing.T) {
scClient := mock_client.NewMockClient(t)
assignmentService := mock_client.NewMockAssignmentService(t)
scClient.EXPECT().Assignment().Return(assignmentService)
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(&types.VersionedStreamingNodeAssignments{
Assignments: map[int64]types.StreamingNodeAssignment{
1: {
NodeInfo: types.StreamingNodeInfo{ServerID: 1},
Channels: map[string]types.PChannelInfo{
"v1": {},
},
},
},
}, nil)
balancer := balancerImpl{
walAccesserImpl: &walAccesserImpl{
streamingCoordClient: scClient,
},
}
nodes, err := balancer.ListStreamingNode(context.Background())
assert.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assignment, err := balancer.GetWALDistribution(context.Background(), 1)
assert.NoError(t, err)
assert.Equal(t, 1, len(assignment.Channels))
assignment, err = balancer.GetWALDistribution(context.Background(), 2)
assert.True(t, errors.Is(err, merr.ErrNodeNotFound))
assert.Nil(t, assignment)
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Unset()
assignmentService.EXPECT().GetLatestAssignments(mock.Anything).Return(nil, errors.New("test"))
nodes, err = balancer.ListStreamingNode(context.Background())
assert.Error(t, err)
assert.Nil(t, nodes)
assignment, err = balancer.GetWALDistribution(context.Background(), 1)
assert.Error(t, err)
assert.Nil(t, assignment)
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(nil)
err = balancer.SuspendRebalance(context.Background())
assert.NoError(t, err)
err = balancer.ResumeRebalance(context.Background())
assert.NoError(t, err)
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
assert.NoError(t, err)
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
assert.NoError(t, err)
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Unset()
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(errors.New("test"))
err = balancer.SuspendRebalance(context.Background())
assert.Error(t, err)
err = balancer.ResumeRebalance(context.Background())
assert.Error(t, err)
err = balancer.FreezeNodeIDs(context.Background(), []int64{1})
assert.Error(t, err)
err = balancer.DefreezeNodeIDs(context.Background(), []int64{1})
assert.Error(t, err)
}

View File

@ -81,8 +81,33 @@ type Scanner interface {
Close() Close()
} }
// Balancer is the interface for managing the balancer of the wal.
type Balancer interface {
// ListStreamingNode lists the streaming node.
ListStreamingNode(ctx context.Context) ([]types.StreamingNodeInfo, error)
// GetWALDistribution returns the wal distribution of the streaming node.
GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error)
// SuspendRebalance suspends the rebalance of the wal.
SuspendRebalance(ctx context.Context) error
// ResumeRebalance resumes the rebalance of the wal.
ResumeRebalance(ctx context.Context) error
// FreezeNodeIDs freezes the streaming node.
// The wal will not be assigned to the frozen nodes and the wal will be removed from the frozen nodes.
FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error
// DefreezeNodeIDs defreezes the streaming node.
DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error
}
// WALAccesser is the interfaces to interact with the milvus write ahead log. // WALAccesser is the interfaces to interact with the milvus write ahead log.
type WALAccesser interface { type WALAccesser interface {
// Balancer returns the balancer management of the wal.
Balancer() Balancer
// WALName returns the name of the wal. // WALName returns the name of the wal.
WALName() string WALName() string

View File

@ -51,6 +51,32 @@ func SetupNoopWALForTest() {
singleton = &noopWALAccesser{} singleton = &noopWALAccesser{}
} }
type noopBalancer struct{}
func (n *noopBalancer) ListStreamingNode(ctx context.Context) ([]types.StreamingNodeInfo, error) {
return nil, nil
}
func (n *noopBalancer) GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error) {
return nil, nil
}
func (n *noopBalancer) SuspendRebalance(ctx context.Context) error {
return nil
}
func (n *noopBalancer) ResumeRebalance(ctx context.Context) error {
return nil
}
func (n *noopBalancer) FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
return nil
}
func (n *noopBalancer) DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
return nil
}
type noopLocal struct{} type noopLocal struct{}
func (n *noopLocal) GetLatestMVCCTimestampIfLocal(ctx context.Context, vchannel string) (uint64, error) { func (n *noopLocal) GetLatestMVCCTimestampIfLocal(ctx context.Context, vchannel string) (uint64, error) {
@ -108,6 +134,10 @@ func (n *noopTxn) Rollback(ctx context.Context) error {
type noopWALAccesser struct{} type noopWALAccesser struct{}
func (n *noopWALAccesser) Balancer() Balancer {
return &noopBalancer{}
}
func (n *noopWALAccesser) WALName() string { func (n *noopWALAccesser) WALName() string {
return "noop" return "noop"
} }

View File

@ -59,6 +59,10 @@ type walAccesserImpl struct {
dispatchExecutionPool *conc.Pool[struct{}] dispatchExecutionPool *conc.Pool[struct{}]
} }
func (w *walAccesserImpl) Balancer() Balancer {
return balancerImpl{w}
}
func (w *walAccesserImpl) WALName() string { func (w *walAccesserImpl) WALName() string {
return util.MustSelectWALName() return util.MustSelectWALName()
} }

View File

@ -149,6 +149,53 @@ func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(c
return _c return _c
} }
// Balancer provides a mock function with no fields
func (_m *MockWALAccesser) Balancer() streaming.Balancer {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Balancer")
}
var r0 streaming.Balancer
if rf, ok := ret.Get(0).(func() streaming.Balancer); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(streaming.Balancer)
}
}
return r0
}
// MockWALAccesser_Balancer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Balancer'
type MockWALAccesser_Balancer_Call struct {
*mock.Call
}
// Balancer is a helper method to define mock.On call
func (_e *MockWALAccesser_Expecter) Balancer() *MockWALAccesser_Balancer_Call {
return &MockWALAccesser_Balancer_Call{Call: _e.mock.On("Balancer")}
}
func (_c *MockWALAccesser_Balancer_Call) Run(run func()) *MockWALAccesser_Balancer_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWALAccesser_Balancer_Call) Return(_a0 streaming.Balancer) *MockWALAccesser_Balancer_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWALAccesser_Balancer_Call) RunAndReturn(run func() streaming.Balancer) *MockWALAccesser_Balancer_Call {
_c.Call.Return(run)
return _c
}
// Broadcast provides a mock function with no fields // Broadcast provides a mock function with no fields
func (_m *MockWALAccesser) Broadcast() streaming.Broadcast { func (_m *MockWALAccesser) Broadcast() streaming.Broadcast {
ret := _m.Called() ret := _m.Called()

View File

@ -0,0 +1,239 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_client
import (
context "context"
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
mock "github.com/stretchr/testify/mock"
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
// MockAssignmentService is an autogenerated mock type for the AssignmentService type
type MockAssignmentService struct {
mock.Mock
}
type MockAssignmentService_Expecter struct {
mock *mock.Mock
}
func (_m *MockAssignmentService) EXPECT() *MockAssignmentService_Expecter {
return &MockAssignmentService_Expecter{mock: &_m.Mock}
}
// AssignmentDiscover provides a mock function with given fields: ctx, cb
func (_m *MockAssignmentService) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
ret := _m.Called(ctx, cb)
if len(ret) == 0 {
panic("no return value specified for AssignmentDiscover")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error); ok {
r0 = rf(ctx, cb)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAssignmentService_AssignmentDiscover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignmentDiscover'
type MockAssignmentService_AssignmentDiscover_Call struct {
*mock.Call
}
// AssignmentDiscover is a helper method to define mock.On call
// - ctx context.Context
// - cb func(*types.VersionedStreamingNodeAssignments) error
func (_e *MockAssignmentService_Expecter) AssignmentDiscover(ctx interface{}, cb interface{}) *MockAssignmentService_AssignmentDiscover_Call {
return &MockAssignmentService_AssignmentDiscover_Call{Call: _e.mock.On("AssignmentDiscover", ctx, cb)}
}
func (_c *MockAssignmentService_AssignmentDiscover_Call) Run(run func(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error)) *MockAssignmentService_AssignmentDiscover_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(func(*types.VersionedStreamingNodeAssignments) error))
})
return _c
}
func (_c *MockAssignmentService_AssignmentDiscover_Call) Return(_a0 error) *MockAssignmentService_AssignmentDiscover_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockAssignmentService_AssignmentDiscover_Call) RunAndReturn(run func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error) *MockAssignmentService_AssignmentDiscover_Call {
_c.Call.Return(run)
return _c
}
// GetLatestAssignments provides a mock function with given fields: ctx
func (_m *MockAssignmentService) GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetLatestAssignments")
}
var r0 *types.VersionedStreamingNodeAssignments
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (*types.VersionedStreamingNodeAssignments, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) *types.VersionedStreamingNodeAssignments); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.VersionedStreamingNodeAssignments)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockAssignmentService_GetLatestAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestAssignments'
type MockAssignmentService_GetLatestAssignments_Call struct {
*mock.Call
}
// GetLatestAssignments is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockAssignmentService_Expecter) GetLatestAssignments(ctx interface{}) *MockAssignmentService_GetLatestAssignments_Call {
return &MockAssignmentService_GetLatestAssignments_Call{Call: _e.mock.On("GetLatestAssignments", ctx)}
}
func (_c *MockAssignmentService_GetLatestAssignments_Call) Run(run func(ctx context.Context)) *MockAssignmentService_GetLatestAssignments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockAssignmentService_GetLatestAssignments_Call) Return(_a0 *types.VersionedStreamingNodeAssignments, _a1 error) *MockAssignmentService_GetLatestAssignments_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockAssignmentService_GetLatestAssignments_Call) RunAndReturn(run func(context.Context) (*types.VersionedStreamingNodeAssignments, error)) *MockAssignmentService_GetLatestAssignments_Call {
_c.Call.Return(run)
return _c
}
// ReportAssignmentError provides a mock function with given fields: ctx, pchannel, err
func (_m *MockAssignmentService) ReportAssignmentError(ctx context.Context, pchannel types.PChannelInfo, err error) error {
ret := _m.Called(ctx, pchannel, err)
if len(ret) == 0 {
panic("no return value specified for ReportAssignmentError")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfo, error) error); ok {
r0 = rf(ctx, pchannel, err)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAssignmentService_ReportAssignmentError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportAssignmentError'
type MockAssignmentService_ReportAssignmentError_Call struct {
*mock.Call
}
// ReportAssignmentError is a helper method to define mock.On call
// - ctx context.Context
// - pchannel types.PChannelInfo
// - err error
func (_e *MockAssignmentService_Expecter) ReportAssignmentError(ctx interface{}, pchannel interface{}, err interface{}) *MockAssignmentService_ReportAssignmentError_Call {
return &MockAssignmentService_ReportAssignmentError_Call{Call: _e.mock.On("ReportAssignmentError", ctx, pchannel, err)}
}
func (_c *MockAssignmentService_ReportAssignmentError_Call) Run(run func(ctx context.Context, pchannel types.PChannelInfo, err error)) *MockAssignmentService_ReportAssignmentError_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(types.PChannelInfo), args[2].(error))
})
return _c
}
func (_c *MockAssignmentService_ReportAssignmentError_Call) Return(_a0 error) *MockAssignmentService_ReportAssignmentError_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockAssignmentService_ReportAssignmentError_Call) RunAndReturn(run func(context.Context, types.PChannelInfo, error) error) *MockAssignmentService_ReportAssignmentError_Call {
_c.Call.Return(run)
return _c
}
// UpdateWALBalancePolicy provides a mock function with given fields: ctx, req
func (_m *MockAssignmentService) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error {
ret := _m.Called(ctx, req)
if len(ret) == 0 {
panic("no return value specified for UpdateWALBalancePolicy")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAssignmentService_UpdateWALBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWALBalancePolicy'
type MockAssignmentService_UpdateWALBalancePolicy_Call struct {
*mock.Call
}
// UpdateWALBalancePolicy is a helper method to define mock.On call
// - ctx context.Context
// - req *streamingpb.UpdateWALBalancePolicyRequest
func (_e *MockAssignmentService_Expecter) UpdateWALBalancePolicy(ctx interface{}, req interface{}) *MockAssignmentService_UpdateWALBalancePolicy_Call {
return &MockAssignmentService_UpdateWALBalancePolicy_Call{Call: _e.mock.On("UpdateWALBalancePolicy", ctx, req)}
}
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Run(run func(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest)) *MockAssignmentService_UpdateWALBalancePolicy_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest))
})
return _c
}
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Return(_a0 error) *MockAssignmentService_UpdateWALBalancePolicy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error) *MockAssignmentService_UpdateWALBalancePolicy_Call {
_c.Call.Return(run)
return _c
}
// NewMockAssignmentService creates a new instance of MockAssignmentService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockAssignmentService(t interface {
mock.TestingT
Cleanup(func())
}) *MockAssignmentService {
mock := &MockAssignmentService{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -5,9 +5,11 @@ package mock_balancer
import ( import (
context "context" context "context"
syncutil "github.com/milvus-io/milvus/pkg/v2/util/syncutil" streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
syncutil "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil" typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
@ -241,6 +243,53 @@ func (_c *MockBalancer_Trigger_Call) RunAndReturn(run func(context.Context) erro
return _c return _c
} }
// UpdateBalancePolicy provides a mock function with given fields: ctx, req
func (_m *MockBalancer) UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error {
ret := _m.Called(ctx, req)
if len(ret) == 0 {
panic("no return value specified for UpdateBalancePolicy")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBalancer_UpdateBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBalancePolicy'
type MockBalancer_UpdateBalancePolicy_Call struct {
*mock.Call
}
// UpdateBalancePolicy is a helper method to define mock.On call
// - ctx context.Context
// - req *streamingpb.UpdateWALBalancePolicyRequest
func (_e *MockBalancer_Expecter) UpdateBalancePolicy(ctx interface{}, req interface{}) *MockBalancer_UpdateBalancePolicy_Call {
return &MockBalancer_UpdateBalancePolicy_Call{Call: _e.mock.On("UpdateBalancePolicy", ctx, req)}
}
func (_c *MockBalancer_UpdateBalancePolicy_Call) Run(run func(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest)) *MockBalancer_UpdateBalancePolicy_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest))
})
return _c
}
func (_c *MockBalancer_UpdateBalancePolicy_Call) Return(_a0 error) *MockBalancer_UpdateBalancePolicy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBalancer_UpdateBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error) *MockBalancer_UpdateBalancePolicy_Call {
_c.Call.Return(run)
return _c
}
// WatchChannelAssignments provides a mock function with given fields: ctx, cb // WatchChannelAssignments provides a mock function with given fields: ctx, cb
func (_m *MockBalancer) WatchChannelAssignments(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { func (_m *MockBalancer) WatchChannelAssignments(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
ret := _m.Called(ctx, cb) ret := _m.Called(ctx, cb)

View File

@ -47,6 +47,30 @@ type AssignmentServiceImpl struct {
logger *log.MLogger logger *log.MLogger
} }
// GetLatestAssignments returns the latest assignment discovery result.
func (c *AssignmentServiceImpl) GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
return nil, status.NewOnShutdownError("assignment service client is closing")
}
defer c.lifetime.Done()
return c.watcher.GetLatestDiscover(ctx)
}
func (c *AssignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("assignment service client is closing")
}
defer c.lifetime.Done()
service, err := c.service.GetService(c.ctx)
if err != nil {
return err
}
_, err = service.UpdateWALBalancePolicy(ctx, req)
return err
}
// AssignmentDiscover watches the assignment discovery. // AssignmentDiscover watches the assignment discovery.
func (c *AssignmentServiceImpl) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error { func (c *AssignmentServiceImpl) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) { if !c.lifetime.Add(typeutil.LifetimeStateWorking) {

View File

@ -24,6 +24,7 @@ func TestAssignmentService(t *testing.T) {
s.EXPECT().GetService(mock.Anything).Return(c, nil) s.EXPECT().GetService(mock.Anything).Return(c, nil)
cc := mock_streamingpb.NewMockStreamingCoordAssignmentService_AssignmentDiscoverClient(t) cc := mock_streamingpb.NewMockStreamingCoordAssignmentService_AssignmentDiscoverClient(t)
c.EXPECT().AssignmentDiscover(mock.Anything).Return(cc, nil) c.EXPECT().AssignmentDiscover(mock.Anything).Return(cc, nil)
c.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(&streamingpb.UpdateWALBalancePolicyResponse{}, nil)
k := 0 k := 0
closeCh := make(chan struct{}) closeCh := make(chan struct{})
cc.EXPECT().Send(mock.Anything).Return(nil) cc.EXPECT().Send(mock.Anything).Return(nil)
@ -93,6 +94,13 @@ func TestAssignmentService(t *testing.T) {
assert.ErrorIs(t, err, context.DeadlineExceeded) assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.True(t, finalAssignments.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3})) assert.True(t, finalAssignments.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3}))
assign, err := assignmentService.GetLatestAssignments(ctx)
assert.NoError(t, err)
assert.True(t, assign.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3}))
err = assignmentService.UpdateWALBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{})
assert.NoError(t, err)
assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test")) assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test"))
// Repeated report error at the same term should be ignored. // Repeated report error at the same term should be ignored.
@ -114,4 +122,12 @@ func TestAssignmentService(t *testing.T) {
err = assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test")) err = assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test"))
se = status.AsStreamingError(err) se = status.AsStreamingError(err)
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code) assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code)
assignmentService.GetLatestAssignments(ctx)
se = status.AsStreamingError(err)
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code)
assignmentService.UpdateWALBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{})
se = status.AsStreamingError(err)
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, se.Code)
} }

View File

@ -30,6 +30,18 @@ type watcher struct {
lastVersionedAssignment types.VersionedStreamingNodeAssignments lastVersionedAssignment types.VersionedStreamingNodeAssignments
} }
func (w *watcher) GetLatestDiscover(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
w.cond.L.Lock()
for w.lastVersionedAssignment.Version.Global == -1 && w.lastVersionedAssignment.Version.Local == -1 {
if err := w.cond.Wait(ctx); err != nil {
return nil, err
}
}
last := w.lastVersionedAssignment
w.cond.L.Unlock()
return &last, nil
}
// AssignmentDiscover watches the assignment discovery. // AssignmentDiscover watches the assignment discovery.
func (w *watcher) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error { func (w *watcher) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
w.cond.L.Lock() w.cond.L.Lock()

View File

@ -32,6 +32,12 @@ var _ Client = (*clientImpl)(nil)
type AssignmentService interface { type AssignmentService interface {
// AssignmentDiscover is used to watches the assignment discovery. // AssignmentDiscover is used to watches the assignment discovery.
types.AssignmentDiscoverWatcher types.AssignmentDiscoverWatcher
// GetLatestAssignments returns the latest assignment discovery result.
GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error)
// UpdateWALBalancePolicy is used to update the WAL balance policy.
UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error
} }
// BroadcastService is the interface of broadcast service. // BroadcastService is the interface of broadcast service.

View File

@ -5,6 +5,7 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
@ -20,6 +21,9 @@ var (
// Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency. // Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency.
// Balancer should be thread safe. // Balancer should be thread safe.
type Balancer interface { type Balancer interface {
// UpdateBalancePolicy update the balance policy.
UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error
// RegisterStreamingEnabledNotifier registers a notifier into the balancer. // RegisterStreamingEnabledNotifier registers a notifier into the balancer.
// If the error is returned, the balancer is closed. // If the error is returned, the balancer is closed.
// Otherwise, the following rules are applied: // Otherwise, the following rules are applied:

View File

@ -50,6 +50,7 @@ func RecoverBalancer(
policy: policy, policy: policy,
reqCh: make(chan *request, 5), reqCh: make(chan *request, 5),
backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
freezeNodes: typeutil.NewSet[int64](),
} }
b.SetLogger(logger) b.SetLogger(logger)
ready260Future, err := b.checkIfAllNodeGreaterThan260AndWatch(ctx) ready260Future, err := b.checkIfAllNodeGreaterThan260AndWatch(ctx)
@ -71,6 +72,7 @@ type balancerImpl struct {
policy Policy // policy is the balance policy, TODO: should be dynamic in future. policy Policy // policy is the balance policy, TODO: should be dynamic in future.
reqCh chan *request // reqCh is the request channel, send the operation to background task. reqCh chan *request // reqCh is the request channel, send the operation to background task.
backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] // backgroundTaskNotifier is used to conmunicate with the background task. backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] // backgroundTaskNotifier is used to conmunicate with the background task.
freezeNodes typeutil.Set[int64] // freezeNodes is the nodes that will be frozen, no more wal will be assigned to these nodes and wal will be removed from these nodes.
} }
// RegisterStreamingEnabledNotifier registers a notifier into the balancer. // RegisterStreamingEnabledNotifier registers a notifier into the balancer.
@ -95,6 +97,19 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers
return b.channelMetaManager.WatchAssignmentResult(ctx, cb) return b.channelMetaManager.WatchAssignmentResult(ctx, cb)
} }
// UpdateBalancePolicy update the balance policy.
func (b *balancerImpl) UpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("balancer is closing")
}
defer b.lifetime.Done()
ctx, cancel := contextutil.MergeContext(ctx, b.ctx)
defer cancel()
return b.sendRequestAndWaitFinish(ctx, newOpUpdateBalancePolicy(ctx, req))
}
// MarkAsUnavailable mark the pchannels as unavailable.
func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error { func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) { if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("balancer is closing") return status.NewOnShutdownError("balancer is closing")
@ -330,9 +345,9 @@ func (b *balancerImpl) balance(ctx context.Context) (bool, error) {
pchannelView := b.channelMetaManager.CurrentPChannelsView() pchannelView := b.channelMetaManager.CurrentPChannelsView()
b.Logger().Info("collect all status...") b.Logger().Info("collect all status...")
nodeStatus, err := resource.Resource().StreamingNodeManagerClient().CollectAllStatus(ctx) nodeStatus, err := b.fetchStreamingNodeStatus(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "fail to collect all status") return false, err
} }
// call the balance strategy to generate the expected layout. // call the balance strategy to generate the expected layout.
@ -360,6 +375,29 @@ func (b *balancerImpl) balance(ctx context.Context) (bool, error) {
return true, b.applyBalanceResultToStreamingNode(ctx, modifiedChannels) return true, b.applyBalanceResultToStreamingNode(ctx, modifiedChannels)
} }
// fetchStreamingNodeStatus fetch the streaming node status.
func (b *balancerImpl) fetchStreamingNodeStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) {
nodeStatus, err := resource.Resource().StreamingNodeManagerClient().CollectAllStatus(ctx)
if err != nil {
return nil, errors.Wrap(err, "fail to collect all status")
}
// mark the frozen node as frozen in the node status.
for _, node := range nodeStatus {
if b.freezeNodes.Contain(node.ServerID) && node.IsHealthy() {
node.Err = types.ErrFrozen
}
}
// clean up the freeze node that has been removed from session.
for serverID := range b.freezeNodes {
if _, ok := nodeStatus[serverID]; !ok {
b.Logger().Info("freeze node has been removed from session", zap.Int64("serverID", serverID))
b.freezeNodes.Remove(serverID)
}
}
return nodeStatus, nil
}
// applyBalanceResultToStreamingNode apply the balance result to streaming node. // applyBalanceResultToStreamingNode apply the balance result to streaming node.
func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, modifiedChannels map[types.ChannelID]*channel.PChannelMeta) error { func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, modifiedChannels map[types.ChannelID]*channel.PChannelMeta) error {
b.Logger().Info("balance result need to be applied...", zap.Int("modifiedChannelCount", len(modifiedChannels))) b.Logger().Info("balance result need to be applied...", zap.Int("modifiedChannelCount", len(modifiedChannels)))

View File

@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"go.uber.org/atomic" "go.uber.org/atomic"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore" "github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_manager" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_manager"
@ -18,10 +19,10 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" _ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
@ -30,12 +31,7 @@ import (
func TestBalancer(t *testing.T) { func TestBalancer(t *testing.T) {
paramtable.Init() paramtable.Init()
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info") etcdClient, _ := kvfactory.GetEtcdAndPath()
assert.NoError(t, err)
defer etcd.StopEtcdServer()
etcdClient, err := etcd.GetEmbedEtcdClient()
assert.NoError(t, err)
channel.ResetStaticPChannelStatsManager() channel.ResetStaticPChannelStatsManager()
channel.RecoverPChannelStatsManager([]string{}) channel.RecoverPChannelStatsManager([]string{})
@ -184,18 +180,65 @@ func TestBalancer(t *testing.T) {
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
assert.False(t, f.Ready()) assert.False(t, f.Ready())
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: false,
},
Nodes: &streamingpb.WALBalancePolicyNodes{
FreezeNodeIds: []int64{1},
DefreezeNodeIds: []int64{},
},
})
assert.NoError(t, err)
assert.False(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
b.Trigger(ctx)
err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
for _, relation := range relations {
if relation.Node.ServerID == 1 {
return nil
}
}
return doneErr
})
assert.ErrorIs(t, err, doneErr)
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: true,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
},
})
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
assert.NoError(t, err)
b.Trigger(ctx)
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: false,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{},
},
Nodes: &streamingpb.WALBalancePolicyNodes{
FreezeNodeIds: []int64{},
DefreezeNodeIds: []int64{1},
},
})
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
assert.NoError(t, err)
b.Trigger(ctx)
b.Close() b.Close()
assert.ErrorIs(t, f.Get(), balancer.ErrBalancerClosed) assert.ErrorIs(t, f.Get(), balancer.ErrBalancerClosed)
} }
func TestBalancer_WithRecoveryLag(t *testing.T) { func TestBalancer_WithRecoveryLag(t *testing.T) {
paramtable.Init() paramtable.Init()
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
assert.NoError(t, err)
defer etcd.StopEtcdServer()
etcdClient, err := etcd.GetEmbedEtcdClient() etcdClient, _ := kvfactory.GetEtcdAndPath()
assert.NoError(t, err)
channel.ResetStaticPChannelStatsManager() channel.ResetStaticPChannelStatsManager()
channel.RecoverPChannelStatsManager([]string{}) channel.RecoverPChannelStatsManager([]string{})

View File

@ -2,8 +2,12 @@ package balancer
import ( import (
"context" "context"
"strconv"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
) )
@ -17,6 +21,42 @@ type request struct {
// requestApply is a request operation to be executed. // requestApply is a request operation to be executed.
type requestApply func(impl *balancerImpl) type requestApply func(impl *balancerImpl)
// newOpUpdateBalancePolicy is a operation to update the balance policy.
func newOpUpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) *request {
future := syncutil.NewFuture[error]()
return &request{
ctx: ctx,
apply: func(impl *balancerImpl) {
if req.UpdateMask != nil {
// if there's a update mask, only update the fields in the update mask.
for _, field := range req.UpdateMask.Paths {
switch field {
case types.UpdateMaskPathWALBalancePolicyAllowRebalance:
updateAllowRebalance(impl, req.GetConfig().GetAllowRebalance())
}
}
} else {
// otherwise update all fields.
updateAllowRebalance(impl, req.GetConfig().GetAllowRebalance())
}
// apply the freeze streaming nodes.
if len(req.GetNodes().GetFreezeNodeIds()) > 0 || len(req.GetNodes().GetDefreezeNodeIds()) > 0 {
impl.Logger().Info("update freeze nodes", zap.Int64s("freezeNodeIDs", req.GetNodes().GetFreezeNodeIds()), zap.Int64s("defreezeNodeIDs", req.GetNodes().GetDefreezeNodeIds()))
impl.freezeNodes.Insert(req.GetNodes().GetFreezeNodeIds()...)
impl.freezeNodes.Remove(req.GetNodes().GetDefreezeNodeIds()...)
}
future.Set(nil)
},
future: future,
}
}
// updateAllowRebalance update the allow rebalance.
func updateAllowRebalance(impl *balancerImpl, allowRebalance bool) {
old := paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.SwapTempValue(strconv.FormatBool(allowRebalance))
impl.Logger().Info("update allow_rebalance", zap.Bool("new", allowRebalance), zap.String("old", old))
}
// newOpMarkAsUnavailable is a operation to mark some channels as unavailable. // newOpMarkAsUnavailable is a operation to mark some channels as unavailable.
func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) *request { func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) *request {
future := syncutil.NewFuture[error]() future := syncutil.NewFuture[error]()

View File

@ -1,6 +1,8 @@
package service package service
import ( import (
"context"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
@ -44,3 +46,16 @@ func (s *assignmentServiceImpl) AssignmentDiscover(server streamingpb.StreamingC
} }
return discover.NewAssignmentDiscoverServer(balancer, server).Execute() return discover.NewAssignmentDiscoverServer(balancer, server).Execute()
} }
// UpdateWALBalancePolicy is used to update the WAL balance policy.
func (s *assignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) {
balancer, err := s.balancer.GetWithContext(ctx)
if err != nil {
return nil, err
}
if err = balancer.UpdateBalancePolicy(ctx, req); err != nil {
return nil, err
}
return &streamingpb.UpdateWALBalancePolicyResponse{}, nil
}

View File

@ -98,6 +98,80 @@ func (_c *MockStreamingCoordAssignmentServiceClient_AssignmentDiscover_Call) Run
return _c return _c
} }
// UpdateWALBalancePolicy provides a mock function with given fields: ctx, in, opts
func (_m *MockStreamingCoordAssignmentServiceClient) UpdateWALBalancePolicy(ctx context.Context, in *streamingpb.UpdateWALBalancePolicyRequest, opts ...grpc.CallOption) (*streamingpb.UpdateWALBalancePolicyResponse, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for UpdateWALBalancePolicy")
}
var r0 *streamingpb.UpdateWALBalancePolicyResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) (*streamingpb.UpdateWALBalancePolicyResponse, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) *streamingpb.UpdateWALBalancePolicyResponse); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*streamingpb.UpdateWALBalancePolicyResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWALBalancePolicy'
type MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call struct {
*mock.Call
}
// UpdateWALBalancePolicy is a helper method to define mock.On call
// - ctx context.Context
// - in *streamingpb.UpdateWALBalancePolicyRequest
// - opts ...grpc.CallOption
func (_e *MockStreamingCoordAssignmentServiceClient_Expecter) UpdateWALBalancePolicy(ctx interface{}, in interface{}, opts ...interface{}) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
return &MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call{Call: _e.mock.On("UpdateWALBalancePolicy",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call) Run(run func(ctx context.Context, in *streamingpb.UpdateWALBalancePolicyRequest, opts ...grpc.CallOption)) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest), variadicArgs...)
})
return _c
}
func (_c *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call) Return(_a0 *streamingpb.UpdateWALBalancePolicyResponse, _a1 error) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest, ...grpc.CallOption) (*streamingpb.UpdateWALBalancePolicyResponse, error)) *MockStreamingCoordAssignmentServiceClient_UpdateWALBalancePolicy_Call {
_c.Call.Return(run)
return _c
}
// NewMockStreamingCoordAssignmentServiceClient creates a new instance of MockStreamingCoordAssignmentServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // NewMockStreamingCoordAssignmentServiceClient creates a new instance of MockStreamingCoordAssignmentServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value. // The first argument is typically a *testing.T value.
func NewMockStreamingCoordAssignmentServiceClient(t interface { func NewMockStreamingCoordAssignmentServiceClient(t interface {

View File

@ -10,6 +10,7 @@ import "milvus.proto";
import "schema.proto"; import "schema.proto";
import "google/protobuf/empty.proto"; import "google/protobuf/empty.proto";
import "google/protobuf/any.proto"; import "google/protobuf/any.proto";
import "google/protobuf/field_mask.proto";
// //
// Common // Common
@ -88,8 +89,8 @@ enum BroadcastTaskState {
// BroadcastTask is the task to broadcast the messake. // BroadcastTask is the task to broadcast the messake.
message BroadcastTask { message BroadcastTask {
messages.Message message = 1; // message to be broadcast. messages.Message message = 1; // message to be broadcast.
BroadcastTaskState state = 2; // state of the task. BroadcastTaskState state = 2; // state of the task.
bytes acked_vchannel_bitmap = 3; // given vchannels that have been acked, the size of bitmap is same with message.BroadcastHeader().VChannels. bytes acked_vchannel_bitmap = 3; // given vchannels that have been acked, the size of bitmap is same with message.BroadcastHeader().VChannels.
} }
@ -143,6 +144,10 @@ message BroadcastAckResponse {
// Server: log coord. Running on every log node. // Server: log coord. Running on every log node.
// Client: all log publish/consuming node. // Client: all log publish/consuming node.
service StreamingCoordAssignmentService { service StreamingCoordAssignmentService {
// UpdateWALBalancePolicy is used to update the WAL balance policy.
// The policy is used to control the balance of the WAL.
rpc UpdateWALBalancePolicy(UpdateWALBalancePolicyRequest) returns (UpdateWALBalancePolicyResponse) {};
// AssignmentDiscover is used to discover all log nodes managed by the // AssignmentDiscover is used to discover all log nodes managed by the
// streamingcoord. Channel assignment information will be pushed to client // streamingcoord. Channel assignment information will be pushed to client
// by stream. // by stream.
@ -150,6 +155,24 @@ service StreamingCoordAssignmentService {
returns (stream AssignmentDiscoverResponse) {} returns (stream AssignmentDiscoverResponse) {}
} }
// UpdateWALBalancePolicyRequest is the request to update the WAL balance policy.
message UpdateWALBalancePolicyRequest {
WALBalancePolicyConfig config = 1;
WALBalancePolicyNodes nodes = 2;
google.protobuf.FieldMask update_mask = 3;
}
message WALBalancePolicyConfig {
bool allow_rebalance = 1;
}
message WALBalancePolicyNodes {
repeated int64 freeze_node_ids = 1; // nodes that will be frozen.
repeated int64 defreeze_node_ids = 2; // nodes that will be defrozen.
}
message UpdateWALBalancePolicyResponse {}
// AssignmentDiscoverRequest is the request of Discovery // AssignmentDiscoverRequest is the request of Discovery
message AssignmentDiscoverRequest { message AssignmentDiscoverRequest {
oneof command { oneof command {

File diff suppressed because it is too large Load Diff

View File

@ -239,13 +239,17 @@ var StreamingCoordBroadcastService_ServiceDesc = grpc.ServiceDesc{
} }
const ( const (
StreamingCoordAssignmentService_AssignmentDiscover_FullMethodName = "/milvus.proto.streaming.StreamingCoordAssignmentService/AssignmentDiscover" StreamingCoordAssignmentService_UpdateWALBalancePolicy_FullMethodName = "/milvus.proto.streaming.StreamingCoordAssignmentService/UpdateWALBalancePolicy"
StreamingCoordAssignmentService_AssignmentDiscover_FullMethodName = "/milvus.proto.streaming.StreamingCoordAssignmentService/AssignmentDiscover"
) )
// StreamingCoordAssignmentServiceClient is the client API for StreamingCoordAssignmentService service. // StreamingCoordAssignmentServiceClient is the client API for StreamingCoordAssignmentService service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type StreamingCoordAssignmentServiceClient interface { type StreamingCoordAssignmentServiceClient interface {
// UpdateWALBalancePolicy is used to update the WAL balance policy.
// The policy is used to control the balance of the WAL.
UpdateWALBalancePolicy(ctx context.Context, in *UpdateWALBalancePolicyRequest, opts ...grpc.CallOption) (*UpdateWALBalancePolicyResponse, error)
// AssignmentDiscover is used to discover all log nodes managed by the // AssignmentDiscover is used to discover all log nodes managed by the
// streamingcoord. Channel assignment information will be pushed to client // streamingcoord. Channel assignment information will be pushed to client
// by stream. // by stream.
@ -260,6 +264,15 @@ func NewStreamingCoordAssignmentServiceClient(cc grpc.ClientConnInterface) Strea
return &streamingCoordAssignmentServiceClient{cc} return &streamingCoordAssignmentServiceClient{cc}
} }
func (c *streamingCoordAssignmentServiceClient) UpdateWALBalancePolicy(ctx context.Context, in *UpdateWALBalancePolicyRequest, opts ...grpc.CallOption) (*UpdateWALBalancePolicyResponse, error) {
out := new(UpdateWALBalancePolicyResponse)
err := c.cc.Invoke(ctx, StreamingCoordAssignmentService_UpdateWALBalancePolicy_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *streamingCoordAssignmentServiceClient) AssignmentDiscover(ctx context.Context, opts ...grpc.CallOption) (StreamingCoordAssignmentService_AssignmentDiscoverClient, error) { func (c *streamingCoordAssignmentServiceClient) AssignmentDiscover(ctx context.Context, opts ...grpc.CallOption) (StreamingCoordAssignmentService_AssignmentDiscoverClient, error) {
stream, err := c.cc.NewStream(ctx, &StreamingCoordAssignmentService_ServiceDesc.Streams[0], StreamingCoordAssignmentService_AssignmentDiscover_FullMethodName, opts...) stream, err := c.cc.NewStream(ctx, &StreamingCoordAssignmentService_ServiceDesc.Streams[0], StreamingCoordAssignmentService_AssignmentDiscover_FullMethodName, opts...)
if err != nil { if err != nil {
@ -295,6 +308,9 @@ func (x *streamingCoordAssignmentServiceAssignmentDiscoverClient) Recv() (*Assig
// All implementations should embed UnimplementedStreamingCoordAssignmentServiceServer // All implementations should embed UnimplementedStreamingCoordAssignmentServiceServer
// for forward compatibility // for forward compatibility
type StreamingCoordAssignmentServiceServer interface { type StreamingCoordAssignmentServiceServer interface {
// UpdateWALBalancePolicy is used to update the WAL balance policy.
// The policy is used to control the balance of the WAL.
UpdateWALBalancePolicy(context.Context, *UpdateWALBalancePolicyRequest) (*UpdateWALBalancePolicyResponse, error)
// AssignmentDiscover is used to discover all log nodes managed by the // AssignmentDiscover is used to discover all log nodes managed by the
// streamingcoord. Channel assignment information will be pushed to client // streamingcoord. Channel assignment information will be pushed to client
// by stream. // by stream.
@ -305,6 +321,9 @@ type StreamingCoordAssignmentServiceServer interface {
type UnimplementedStreamingCoordAssignmentServiceServer struct { type UnimplementedStreamingCoordAssignmentServiceServer struct {
} }
func (UnimplementedStreamingCoordAssignmentServiceServer) UpdateWALBalancePolicy(context.Context, *UpdateWALBalancePolicyRequest) (*UpdateWALBalancePolicyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method UpdateWALBalancePolicy not implemented")
}
func (UnimplementedStreamingCoordAssignmentServiceServer) AssignmentDiscover(StreamingCoordAssignmentService_AssignmentDiscoverServer) error { func (UnimplementedStreamingCoordAssignmentServiceServer) AssignmentDiscover(StreamingCoordAssignmentService_AssignmentDiscoverServer) error {
return status.Errorf(codes.Unimplemented, "method AssignmentDiscover not implemented") return status.Errorf(codes.Unimplemented, "method AssignmentDiscover not implemented")
} }
@ -320,6 +339,24 @@ func RegisterStreamingCoordAssignmentServiceServer(s grpc.ServiceRegistrar, srv
s.RegisterService(&StreamingCoordAssignmentService_ServiceDesc, srv) s.RegisterService(&StreamingCoordAssignmentService_ServiceDesc, srv)
} }
func _StreamingCoordAssignmentService_UpdateWALBalancePolicy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(UpdateWALBalancePolicyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(StreamingCoordAssignmentServiceServer).UpdateWALBalancePolicy(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: StreamingCoordAssignmentService_UpdateWALBalancePolicy_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(StreamingCoordAssignmentServiceServer).UpdateWALBalancePolicy(ctx, req.(*UpdateWALBalancePolicyRequest))
}
return interceptor(ctx, in, info, handler)
}
func _StreamingCoordAssignmentService_AssignmentDiscover_Handler(srv interface{}, stream grpc.ServerStream) error { func _StreamingCoordAssignmentService_AssignmentDiscover_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(StreamingCoordAssignmentServiceServer).AssignmentDiscover(&streamingCoordAssignmentServiceAssignmentDiscoverServer{stream}) return srv.(StreamingCoordAssignmentServiceServer).AssignmentDiscover(&streamingCoordAssignmentServiceAssignmentDiscoverServer{stream})
} }
@ -352,7 +389,12 @@ func (x *streamingCoordAssignmentServiceAssignmentDiscoverServer) Recv() (*Assig
var StreamingCoordAssignmentService_ServiceDesc = grpc.ServiceDesc{ var StreamingCoordAssignmentService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "milvus.proto.streaming.StreamingCoordAssignmentService", ServiceName: "milvus.proto.streaming.StreamingCoordAssignmentService",
HandlerType: (*StreamingCoordAssignmentServiceServer)(nil), HandlerType: (*StreamingCoordAssignmentServiceServer)(nil),
Methods: []grpc.MethodDesc{}, Methods: []grpc.MethodDesc{
{
MethodName: "UpdateWALBalancePolicy",
Handler: _StreamingCoordAssignmentService_UpdateWALBalancePolicy_Handler,
},
},
Streams: []grpc.StreamDesc{ Streams: []grpc.StreamDesc{
{ {
StreamName: "AssignmentDiscover", StreamName: "AssignmentDiscover",

View File

@ -0,0 +1,9 @@
package types
import "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
const (
UpdateMaskPathWALBalancePolicyAllowRebalance = "config.allow_rebalance"
)
type UpdateWALBalancePolicyRequest = streamingpb.UpdateWALBalancePolicyRequest

View File

@ -13,6 +13,7 @@ import (
var ( var (
ErrStopping = errors.New("streaming node is stopping") ErrStopping = errors.New("streaming node is stopping")
ErrNotAlive = errors.New("streaming node is not alive") ErrNotAlive = errors.New("streaming node is not alive")
ErrFrozen = errors.New("streaming node is frozen")
) )
// AssignmentDiscoverWatcher is the interface for watching the assignment discovery. // AssignmentDiscoverWatcher is the interface for watching the assignment discovery.

View File

@ -160,7 +160,6 @@ func (pi *ParamItem) getWithRaw() (result, raw string, err error) {
// SetTempValue set the value for this ParamItem, // SetTempValue set the value for this ParamItem,
// Once value set, ParamItem will use the value instead of underlying config manager. // Once value set, ParamItem will use the value instead of underlying config manager.
// Usage: should only use for unittest, swap empty string will remove the value.
func (pi *ParamItem) SwapTempValue(s string) string { func (pi *ParamItem) SwapTempValue(s string) string {
if s == "" { if s == "" {
if old := pi.tempValue.Swap(nil); old != nil { if old := pi.tempValue.Swap(nil); old != nil {