enhance: add IsRebalanceSuspended interface for wal balancer (#44026)

issue: #43968

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-08-24 09:19:47 +08:00 committed by GitHub
parent d6b78193cb
commit d0e3a33c37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 900 additions and 801 deletions

View File

@ -42,8 +42,21 @@ func (b balancerImpl) GetWALDistribution(ctx context.Context, nodeID int64) (*ty
return nil, merr.WrapErrNodeNotFound(nodeID, "streaming node not found") return nil, merr.WrapErrNodeNotFound(nodeID, "streaming node not found")
} }
// IsRebalanceSuspended returns whether the rebalance of the wal is suspended.
func (b balancerImpl) IsRebalanceSuspended(ctx context.Context) (bool, error) {
// Update nothing, just fetch the current policy back.
policy, err := b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{},
UpdateMask: &fieldmaskpb.FieldMask{},
})
if err != nil {
return false, err
}
return !policy.GetConfig().GetAllowRebalance(), nil
}
func (b balancerImpl) SuspendRebalance(ctx context.Context) error { func (b balancerImpl) SuspendRebalance(ctx context.Context) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{ _, err := b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{ Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: false, AllowRebalance: false,
}, },
@ -51,10 +64,11 @@ func (b balancerImpl) SuspendRebalance(ctx context.Context) error {
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance}, Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
}, },
}) })
return err
} }
func (b balancerImpl) ResumeRebalance(ctx context.Context) error { func (b balancerImpl) ResumeRebalance(ctx context.Context) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{ _, err := b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{ Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: true, AllowRebalance: true,
}, },
@ -62,22 +76,25 @@ func (b balancerImpl) ResumeRebalance(ctx context.Context) error {
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance}, Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
}, },
}) })
return err
} }
func (b balancerImpl) FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error { func (b balancerImpl) FreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{ _, err := b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}}, UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}},
Nodes: &streamingpb.WALBalancePolicyNodes{ Nodes: &streamingpb.WALBalancePolicyNodes{
FreezeNodeIds: nodeIDs, FreezeNodeIds: nodeIDs,
}, },
}) })
return err
} }
func (b balancerImpl) DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error { func (b balancerImpl) DefreezeNodeIDs(ctx context.Context, nodeIDs []int64) error {
return b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{ _, err := b.streamingCoordClient.Assignment().UpdateWALBalancePolicy(ctx, &types.UpdateWALBalancePolicyRequest{
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}}, UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{}},
Nodes: &streamingpb.WALBalancePolicyNodes{ Nodes: &streamingpb.WALBalancePolicyNodes{
DefreezeNodeIds: nodeIDs, DefreezeNodeIds: nodeIDs,
}, },
}) })
return err
} }

View File

@ -55,7 +55,7 @@ func TestBalancer(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, assignment) assert.Nil(t, assignment)
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(nil) assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(&types.UpdateWALBalancePolicyResponse{}, nil)
err = balancer.SuspendRebalance(context.Background()) err = balancer.SuspendRebalance(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
err = balancer.ResumeRebalance(context.Background()) err = balancer.ResumeRebalance(context.Background())
@ -66,7 +66,7 @@ func TestBalancer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Unset() assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Unset()
assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(errors.New("test")) assignmentService.EXPECT().UpdateWALBalancePolicy(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
err = balancer.SuspendRebalance(context.Background()) err = balancer.SuspendRebalance(context.Background())
assert.Error(t, err) assert.Error(t, err)
err = balancer.ResumeRebalance(context.Background()) err = balancer.ResumeRebalance(context.Background())

View File

@ -89,6 +89,9 @@ type Balancer interface {
// GetWALDistribution returns the wal distribution of the streaming node. // GetWALDistribution returns the wal distribution of the streaming node.
GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error) GetWALDistribution(ctx context.Context, nodeID int64) (*types.StreamingNodeAssignment, error)
// IsRebalanceSuspended returns whether the rebalance of the wal is suspended.
IsRebalanceSuspended(ctx context.Context) (bool, error)
// SuspendRebalance suspends the rebalance of the wal. // SuspendRebalance suspends the rebalance of the wal.
SuspendRebalance(ctx context.Context) error SuspendRebalance(ctx context.Context) error

View File

@ -61,6 +61,10 @@ func (n *noopBalancer) GetWALDistribution(ctx context.Context, nodeID int64) (*t
return nil, nil return nil, nil
} }
func (n *noopBalancer) IsRebalanceSuspended(ctx context.Context) (bool, error) {
return false, nil
}
func (n *noopBalancer) SuspendRebalance(ctx context.Context) error { func (n *noopBalancer) SuspendRebalance(ctx context.Context) error {
return nil return nil
} }

View File

@ -5,10 +5,8 @@ package mock_client
import ( import (
context "context" context "context"
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
mock "github.com/stretchr/testify/mock"
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
mock "github.com/stretchr/testify/mock"
) )
// MockAssignmentService is an autogenerated mock type for the AssignmentService type // MockAssignmentService is an autogenerated mock type for the AssignmentService type
@ -178,21 +176,33 @@ func (_c *MockAssignmentService_ReportAssignmentError_Call) RunAndReturn(run fun
} }
// UpdateWALBalancePolicy provides a mock function with given fields: ctx, req // UpdateWALBalancePolicy provides a mock function with given fields: ctx, req
func (_m *MockAssignmentService) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error { func (_m *MockAssignmentService) UpdateWALBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for UpdateWALBalancePolicy") panic("no return value specified for UpdateWALBalancePolicy")
} }
var r0 error var r0 *types.UpdateWALBalancePolicyResponse
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok { var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateWALBalancePolicyRequest) *types.UpdateWALBalancePolicyResponse); ok {
r0 = rf(ctx, req) r0 = rf(ctx, req)
} else { } else {
r0 = ret.Error(0) if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.UpdateWALBalancePolicyResponse)
}
} }
return r0 if rf, ok := ret.Get(1).(func(context.Context, *types.UpdateWALBalancePolicyRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
// MockAssignmentService_UpdateWALBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWALBalancePolicy' // MockAssignmentService_UpdateWALBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWALBalancePolicy'
@ -202,24 +212,24 @@ type MockAssignmentService_UpdateWALBalancePolicy_Call struct {
// UpdateWALBalancePolicy is a helper method to define mock.On call // UpdateWALBalancePolicy is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *streamingpb.UpdateWALBalancePolicyRequest // - req *types.UpdateWALBalancePolicyRequest
func (_e *MockAssignmentService_Expecter) UpdateWALBalancePolicy(ctx interface{}, req interface{}) *MockAssignmentService_UpdateWALBalancePolicy_Call { func (_e *MockAssignmentService_Expecter) UpdateWALBalancePolicy(ctx interface{}, req interface{}) *MockAssignmentService_UpdateWALBalancePolicy_Call {
return &MockAssignmentService_UpdateWALBalancePolicy_Call{Call: _e.mock.On("UpdateWALBalancePolicy", ctx, req)} return &MockAssignmentService_UpdateWALBalancePolicy_Call{Call: _e.mock.On("UpdateWALBalancePolicy", ctx, req)}
} }
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Run(run func(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest)) *MockAssignmentService_UpdateWALBalancePolicy_Call { func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Run(run func(ctx context.Context, req *types.UpdateWALBalancePolicyRequest)) *MockAssignmentService_UpdateWALBalancePolicy_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*streamingpb.UpdateWALBalancePolicyRequest)) run(args[0].(context.Context), args[1].(*types.UpdateWALBalancePolicyRequest))
}) })
return _c return _c
} }
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Return(_a0 error) *MockAssignmentService_UpdateWALBalancePolicy_Call { func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) Return(_a0 *types.UpdateWALBalancePolicyResponse, _a1 error) *MockAssignmentService_UpdateWALBalancePolicy_Call {
_c.Call.Return(_a0) _c.Call.Return(_a0, _a1)
return _c return _c
} }
func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error) *MockAssignmentService_UpdateWALBalancePolicy_Call { func (_c *MockAssignmentService_UpdateWALBalancePolicy_Call) RunAndReturn(run func(context.Context, *types.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error)) *MockAssignmentService_UpdateWALBalancePolicy_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -244,21 +244,33 @@ func (_c *MockBalancer_Trigger_Call) RunAndReturn(run func(context.Context) erro
} }
// UpdateBalancePolicy provides a mock function with given fields: ctx, req // UpdateBalancePolicy provides a mock function with given fields: ctx, req
func (_m *MockBalancer) UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error { func (_m *MockBalancer) UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for UpdateBalancePolicy") panic("no return value specified for UpdateBalancePolicy")
} }
var r0 error var r0 *streamingpb.UpdateWALBalancePolicyResponse
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok { var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) *streamingpb.UpdateWALBalancePolicyResponse); ok {
r0 = rf(ctx, req) r0 = rf(ctx, req)
} else { } else {
r0 = ret.Error(0) if ret.Get(0) != nil {
r0 = ret.Get(0).(*streamingpb.UpdateWALBalancePolicyResponse)
}
} }
return r0 if rf, ok := ret.Get(1).(func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
// MockBalancer_UpdateBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBalancePolicy' // MockBalancer_UpdateBalancePolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBalancePolicy'
@ -280,12 +292,12 @@ func (_c *MockBalancer_UpdateBalancePolicy_Call) Run(run func(ctx context.Contex
return _c return _c
} }
func (_c *MockBalancer_UpdateBalancePolicy_Call) Return(_a0 error) *MockBalancer_UpdateBalancePolicy_Call { func (_c *MockBalancer_UpdateBalancePolicy_Call) Return(_a0 *streamingpb.UpdateWALBalancePolicyResponse, _a1 error) *MockBalancer_UpdateBalancePolicy_Call {
_c.Call.Return(_a0) _c.Call.Return(_a0, _a1)
return _c return _c
} }
func (_c *MockBalancer_UpdateBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) error) *MockBalancer_UpdateBalancePolicy_Call { func (_c *MockBalancer_UpdateBalancePolicy_Call) RunAndReturn(run func(context.Context, *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error)) *MockBalancer_UpdateBalancePolicy_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -57,18 +57,17 @@ func (c *AssignmentServiceImpl) GetLatestAssignments(ctx context.Context) (*type
return c.watcher.GetLatestDiscover(ctx) return c.watcher.GetLatestDiscover(ctx)
} }
func (c *AssignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error { func (c *AssignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error) {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) { if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("assignment service client is closing") return nil, status.NewOnShutdownError("assignment service client is closing")
} }
defer c.lifetime.Done() defer c.lifetime.Done()
service, err := c.service.GetService(c.ctx) service, err := c.service.GetService(c.ctx)
if err != nil { if err != nil {
return err return nil, err
} }
_, err = service.UpdateWALBalancePolicy(ctx, req) return service.UpdateWALBalancePolicy(ctx, req)
return err
} }
// AssignmentDiscover watches the assignment discovery. // AssignmentDiscover watches the assignment discovery.

View File

@ -98,8 +98,9 @@ func TestAssignmentService(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, assign.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3})) assert.True(t, assign.Version.EQ(typeutil.VersionInt64Pair{Global: 2, Local: 3}))
err = assignmentService.UpdateWALBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{}) resp, err := assignmentService.UpdateWALBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp)
assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test")) assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test"))

View File

@ -37,7 +37,8 @@ type AssignmentService interface {
GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error)
// UpdateWALBalancePolicy is used to update the WAL balance policy. // UpdateWALBalancePolicy is used to update the WAL balance policy.
UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error // Return the WAL balance policy after the update.
UpdateWALBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error)
} }
// BroadcastService is the interface of broadcast service. // BroadcastService is the interface of broadcast service.

View File

@ -22,7 +22,7 @@ var (
// Balancer should be thread safe. // Balancer should be thread safe.
type Balancer interface { type Balancer interface {
// UpdateBalancePolicy update the balance policy. // UpdateBalancePolicy update the balance policy.
UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) error UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error)
// RegisterStreamingEnabledNotifier registers a notifier into the balancer. // RegisterStreamingEnabledNotifier registers a notifier into the balancer.
// If the error is returned, the balancer is closed. // If the error is returned, the balancer is closed.

View File

@ -98,15 +98,19 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers
} }
// UpdateBalancePolicy update the balance policy. // UpdateBalancePolicy update the balance policy.
func (b *balancerImpl) UpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) error { func (b *balancerImpl) UpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error) {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) { if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("balancer is closing") return nil, status.NewOnShutdownError("balancer is closing")
} }
defer b.lifetime.Done() defer b.lifetime.Done()
ctx, cancel := contextutil.MergeContext(ctx, b.ctx) ctx, cancel := contextutil.MergeContext(ctx, b.ctx)
defer cancel() defer cancel()
return b.sendRequestAndWaitFinish(ctx, newOpUpdateBalancePolicy(ctx, req)) resp, err := b.sendRequestAndWaitFinish(ctx, newOpUpdateBalancePolicy(ctx, req))
if err != nil {
return nil, err
}
return resp.(*types.UpdateWALBalancePolicyResponse), nil
} }
// MarkAsUnavailable mark the pchannels as unavailable. // MarkAsUnavailable mark the pchannels as unavailable.
@ -118,7 +122,8 @@ func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types.
ctx, cancel := contextutil.MergeContext(ctx, b.ctx) ctx, cancel := contextutil.MergeContext(ctx, b.ctx)
defer cancel() defer cancel()
return b.sendRequestAndWaitFinish(ctx, newOpMarkAsUnavailable(ctx, pChannels)) _, err := b.sendRequestAndWaitFinish(ctx, newOpMarkAsUnavailable(ctx, pChannels))
return err
} }
// Trigger trigger a re-balance. // Trigger trigger a re-balance.
@ -130,17 +135,19 @@ func (b *balancerImpl) Trigger(ctx context.Context) error {
ctx, cancel := contextutil.MergeContext(ctx, b.ctx) ctx, cancel := contextutil.MergeContext(ctx, b.ctx)
defer cancel() defer cancel()
return b.sendRequestAndWaitFinish(ctx, newOpTrigger(ctx)) _, err := b.sendRequestAndWaitFinish(ctx, newOpTrigger(ctx))
return err
} }
// sendRequestAndWaitFinish send a request to the background task and wait for it to finish. // sendRequestAndWaitFinish send a request to the background task and wait for it to finish.
func (b *balancerImpl) sendRequestAndWaitFinish(ctx context.Context, newReq *request) error { func (b *balancerImpl) sendRequestAndWaitFinish(ctx context.Context, newReq *request) (any, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return nil, ctx.Err()
case b.reqCh <- newReq: case b.reqCh <- newReq:
} }
return newReq.future.Get() resp := newReq.future.Get()
return resp.resp, resp.err
} }
// Close close the balancer. // Close close the balancer.

View File

@ -181,7 +181,7 @@ func TestBalancer(t *testing.T) {
assert.False(t, f.Ready()) assert.False(t, f.Ready())
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool()) assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{ resp, err := b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{ Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: false, AllowRebalance: false,
}, },
@ -191,6 +191,8 @@ func TestBalancer(t *testing.T) {
}, },
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.ElementsMatch(t, []int64{1}, resp.FreezeNodeIds)
assert.False(t, resp.Config.AllowRebalance)
assert.False(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool()) assert.False(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
b.Trigger(ctx) b.Trigger(ctx)
err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
@ -203,7 +205,7 @@ func TestBalancer(t *testing.T) {
}) })
assert.ErrorIs(t, err, doneErr) assert.ErrorIs(t, err, doneErr)
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{ resp, err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{ Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: true, AllowRebalance: true,
}, },
@ -211,11 +213,12 @@ func TestBalancer(t *testing.T) {
Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance}, Paths: []string{types.UpdateMaskPathWALBalancePolicyAllowRebalance},
}, },
}) })
assert.True(t, resp.Config.AllowRebalance)
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool()) assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
assert.NoError(t, err) assert.NoError(t, err)
b.Trigger(ctx) b.Trigger(ctx)
err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{ resp, err = b.UpdateBalancePolicy(ctx, &streamingpb.UpdateWALBalancePolicyRequest{
Config: &streamingpb.WALBalancePolicyConfig{ Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: false, AllowRebalance: false,
}, },
@ -227,6 +230,8 @@ func TestBalancer(t *testing.T) {
DefreezeNodeIds: []int64{1}, DefreezeNodeIds: []int64{1},
}, },
}) })
assert.True(t, resp.Config.AllowRebalance)
assert.Empty(t, resp.FreezeNodeIds)
assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool()) assert.True(t, paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool())
assert.NoError(t, err) assert.NoError(t, err)
b.Trigger(ctx) b.Trigger(ctx)

View File

@ -6,16 +6,22 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
) )
type response struct {
resp any
err error
}
// request is a operation request. // request is a operation request.
type request struct { type request struct {
ctx context.Context ctx context.Context
apply requestApply apply requestApply
future *syncutil.Future[error] future *syncutil.Future[response]
} }
// requestApply is a request operation to be executed. // requestApply is a request operation to be executed.
@ -23,7 +29,7 @@ type requestApply func(impl *balancerImpl)
// newOpUpdateBalancePolicy is a operation to update the balance policy. // newOpUpdateBalancePolicy is a operation to update the balance policy.
func newOpUpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) *request { func newOpUpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) *request {
future := syncutil.NewFuture[error]() future := syncutil.NewFuture[response]()
return &request{ return &request{
ctx: ctx, ctx: ctx,
apply: func(impl *balancerImpl) { apply: func(impl *balancerImpl) {
@ -45,7 +51,12 @@ func newOpUpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePo
impl.freezeNodes.Insert(req.GetNodes().GetFreezeNodeIds()...) impl.freezeNodes.Insert(req.GetNodes().GetFreezeNodeIds()...)
impl.freezeNodes.Remove(req.GetNodes().GetDefreezeNodeIds()...) impl.freezeNodes.Remove(req.GetNodes().GetDefreezeNodeIds()...)
} }
future.Set(nil) future.Set(response{resp: &types.UpdateWALBalancePolicyResponse{
Config: &streamingpb.WALBalancePolicyConfig{
AllowRebalance: paramtable.Get().StreamingCfg.WALBalancerPolicyAllowRebalance.GetAsBool(),
},
FreezeNodeIds: impl.freezeNodes.Collect(),
}, err: nil})
}, },
future: future, future: future,
} }
@ -59,11 +70,12 @@ func updateAllowRebalance(impl *balancerImpl, allowRebalance bool) {
// newOpMarkAsUnavailable is a operation to mark some channels as unavailable. // newOpMarkAsUnavailable is a operation to mark some channels as unavailable.
func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) *request { func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) *request {
future := syncutil.NewFuture[error]() future := syncutil.NewFuture[response]()
return &request{ return &request{
ctx: ctx, ctx: ctx,
apply: func(impl *balancerImpl) { apply: func(impl *balancerImpl) {
future.Set(impl.channelMetaManager.MarkAsUnavailable(ctx, pChannels)) err := impl.channelMetaManager.MarkAsUnavailable(ctx, pChannels)
future.Set(response{err: err})
}, },
future: future, future: future,
} }
@ -71,11 +83,11 @@ func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo)
// newOpTrigger is a operation to trigger a re-balance operation. // newOpTrigger is a operation to trigger a re-balance operation.
func newOpTrigger(ctx context.Context) *request { func newOpTrigger(ctx context.Context) *request {
future := syncutil.NewFuture[error]() future := syncutil.NewFuture[response]()
return &request{ return &request{
ctx: ctx, ctx: ctx,
apply: func(impl *balancerImpl) { apply: func(impl *balancerImpl) {
future.Set(nil) future.Set(response{})
}, },
future: future, future: future,
} }

View File

@ -54,8 +54,5 @@ func (s *assignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req
return nil, err return nil, err
} }
if err = balancer.UpdateBalancePolicy(ctx, req); err != nil { return balancer.UpdateBalancePolicy(ctx, req)
return nil, err
}
return &streamingpb.UpdateWALBalancePolicyResponse{}, nil
} }

View File

@ -171,7 +171,10 @@ message WALBalancePolicyNodes {
repeated int64 defreeze_node_ids = 2; // nodes that will be defrozen. repeated int64 defreeze_node_ids = 2; // nodes that will be defrozen.
} }
message UpdateWALBalancePolicyResponse {} message UpdateWALBalancePolicyResponse {
WALBalancePolicyConfig config = 1; // return current configuration of WAL balance policy.
repeated int64 freeze_node_ids = 2; // nodes that are frozen.
}
// AssignmentDiscoverRequest is the request of Discovery // AssignmentDiscoverRequest is the request of Discovery
message AssignmentDiscoverRequest { message AssignmentDiscoverRequest {

File diff suppressed because it is too large Load Diff

View File

@ -6,4 +6,7 @@ const (
UpdateMaskPathWALBalancePolicyAllowRebalance = "config.allow_rebalance" UpdateMaskPathWALBalancePolicyAllowRebalance = "config.allow_rebalance"
) )
type UpdateWALBalancePolicyRequest = streamingpb.UpdateWALBalancePolicyRequest type (
UpdateWALBalancePolicyRequest = streamingpb.UpdateWALBalancePolicyRequest
UpdateWALBalancePolicyResponse = streamingpb.UpdateWALBalancePolicyResponse
)