enhance: broadcaster will lock resource until message acked (#44508)

issue: #43897

- Return LastConfirmedMessageID when wal append operation.
- Add resource-key-based locker for broadcast-ack operation to protect
the coord state when executing ddl.
- Resource-key-based locker is held until the broadcast operation is
acked.
- ResourceKey support shared and exclusive lock.
- Add FastAck execute ack right away after the broadcast done to speed
up ddl.
- Ack callback will support broadcast message result now.
- Add tombstone for broadcaster to avoid to repeatedly commit DDL and
ABA issue.

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-09-24 20:58:05 +08:00 committed by GitHub
parent 1b20e956be
commit 19e5e9f910
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 2159 additions and 546 deletions

View File

@ -20,6 +20,7 @@ packages:
github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster:
interfaces:
Broadcaster:
BroadcastAPI:
AppendOperator:
Watcher:
github.com/milvus-io/milvus/internal/streamingcoord/client:

View File

@ -8,6 +8,7 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -23,7 +24,6 @@ var ErrStreamingServiceNotReady = errors.New("streaming service is not ready, ma
func newStreamingNodeManager() *StreamingNodeManager {
snm := &StreamingNodeManager{
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
balancer: syncutil.NewFuture[balancer.Balancer](),
cond: syncutil.NewContextCond(&sync.Mutex{}),
latestAssignments: make(map[string]types.PChannelInfoAssigned),
nodeChangedNotifier: syncutil.NewVersionedNotifier(),
@ -64,8 +64,6 @@ func (s *StreamingReadyNotifier) IsReady() bool {
// StreamingNodeManager is exclusive with ResourceManager.
type StreamingNodeManager struct {
notifier *syncutil.AsyncTaskNotifier[struct{}]
balancer *syncutil.Future[balancer.Balancer]
// The coord is merged after 2.6, so we don't need to make distribution safe.
cond *syncutil.ContextCond
latestAssignments map[string]types.PChannelInfoAssigned // The latest assignments info got from streaming coord balance module.
nodeChangedNotifier *syncutil.VersionedNotifier // used to notify that node in streaming node manager has been changed.
@ -73,14 +71,18 @@ type StreamingNodeManager struct {
// GetBalancer returns the balancer of the streaming node manager.
func (s *StreamingNodeManager) GetBalancer() balancer.Balancer {
return s.balancer.Get()
b, err := balance.GetWithContext(context.Background())
if err != nil {
panic(err)
}
return b
}
// GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located.
// Return -1 and error if the vchannel is not found or context is canceled.
func (s *StreamingNodeManager) GetLatestWALLocated(ctx context.Context, vchannel string) (int64, error) {
pchannel := funcutil.ToPhysicalChannel(vchannel)
balancer, err := s.balancer.GetWithContext(ctx)
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return -1, err
}
@ -107,7 +109,7 @@ func (s *StreamingNodeManager) CheckIfStreamingServiceReady(ctx context.Context)
// RegisterStreamingEnabledNotifier registers a notifier into the balancer.
func (s *StreamingNodeManager) RegisterStreamingEnabledListener(ctx context.Context, notifier *StreamingReadyNotifier) error {
balancer, err := s.balancer.GetWithContext(ctx)
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return err
}
@ -134,7 +136,7 @@ func (s *StreamingNodeManager) GetWALLocated(vChannel string) int64 {
// GetStreamingQueryNodeIDs returns the server ids of the streaming query nodes.
func (s *StreamingNodeManager) GetStreamingQueryNodeIDs() typeutil.UniqueSet {
balancer, err := s.balancer.GetWithContext(context.Background())
balancer, err := balance.GetWithContext(context.Background())
if err != nil {
panic(err)
}
@ -154,15 +156,10 @@ func (s *StreamingNodeManager) ListenNodeChanged() *syncutil.VersionedListener {
return s.nodeChangedNotifier.Listen(syncutil.VersionedListenAtEarliest)
}
// SetBalancerReady set the balancer ready for the streaming node manager from streamingcoord initialization.
func (s *StreamingNodeManager) SetBalancerReady(b balancer.Balancer) {
s.balancer.Set(b)
}
func (s *StreamingNodeManager) execute() (err error) {
defer s.notifier.Finish(struct{}{})
b, err := s.balancer.GetWithContext(s.notifier.Context())
b, err := balance.GetWithContext(s.notifier.Context())
if err != nil {
return errors.Wrap(err, "failed to wait balancer ready")
}
@ -182,3 +179,8 @@ func (s *StreamingNodeManager) execute() (err error) {
}
}
}
func (s *StreamingNodeManager) Close() {
s.notifier.Cancel()
s.notifier.BlockUntilFinish()
}

View File

@ -9,6 +9,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
@ -20,17 +21,18 @@ type pChannelInfoAssigned struct {
}
func TestStreamingNodeManager(t *testing.T) {
StaticStreamingNodeManager.Close()
m := newStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
ch := make(chan pChannelInfoAssigned, 1)
b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Run(
func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) {
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
for {
select {
case <-ctx.Done():
return
return ctx.Err()
case p := <-ch:
cb(balancer.WatchChannelAssignmentsCallbackParam{
Version: p.version,
@ -41,7 +43,7 @@ func TestStreamingNodeManager(t *testing.T) {
}
})
b.EXPECT().RegisterStreamingEnabledNotifier(mock.Anything).Return()
m.SetBalancerReady(b)
balance.Register(b)
streamingNodes := m.GetStreamingQueryNodeIDs()
assert.Empty(t, streamingNodes)

View File

@ -11,10 +11,13 @@ import (
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
func ResetStreamingNodeManager() {
StaticStreamingNodeManager.Close()
balance.ResetBalancer()
StaticStreamingNodeManager = newStreamingNodeManager()
}
@ -26,5 +29,5 @@ func ResetDoNothingStreamingNodeManager(t *testing.T) {
return ctx.Err()
}).Maybe()
b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil).Maybe()
StaticStreamingNodeManager.SetBalancerReady(b)
balance.Register(b)
}

View File

@ -353,17 +353,24 @@ func (s *Server) initDataCoord() error {
// initMessageCallback initializes the message callback.
// TODO: we should build a ddl framework to handle the message ack callback for ddl messages
func (s *Server) initMessageCallback() {
registry.RegisterDropPartitionV1AckCallback(func(ctx context.Context, msg message.ImmutableDropPartitionMessageV1) error {
return s.NotifyDropPartition(ctx, msg.VChannel(), []int64{msg.Header().PartitionId})
registry.RegisterDropPartitionV1AckCallback(func(ctx context.Context, result message.BroadcastResultDropPartitionMessageV1) error {
partitionID := result.Message.Header().PartitionId
for _, vchannel := range result.GetVChannelsWithoutControlChannel() {
if err := s.NotifyDropPartition(ctx, vchannel, []int64{partitionID}); err != nil {
return err
}
}
return nil
})
registry.RegisterImportV1AckCallback(func(ctx context.Context, msg message.ImmutableImportMessageV1) error {
body := msg.MustBody()
registry.RegisterImportV1AckCallback(func(ctx context.Context, result message.BroadcastResultImportMessageV1) error {
body := result.Message.MustBody()
vchannels := result.GetVChannelsWithoutControlChannel()
importResp, err := s.ImportV2(ctx, &internalpb.ImportRequestInternal{
CollectionID: body.GetCollectionID(),
CollectionName: body.GetCollectionName(),
PartitionIDs: body.GetPartitionIDs(),
ChannelNames: []string{msg.VChannel()},
ChannelNames: vchannels,
Schema: body.GetSchema(),
Files: lo.Map(body.GetFiles(), func(file *msgpb.ImportFile, _ int) *internalpb.ImportFile {
return &internalpb.ImportFile{

View File

@ -61,7 +61,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -2887,8 +2887,7 @@ func TestServer_InitMessageCallback(t *testing.T) {
server.initMessageCallback()
// Test DropPartition message callback
dropPartitionMsg, err := message.NewDropPartitionMessageBuilderV1().
WithVChannel("test_channel").
dropPartitionMsg := message.NewDropPartitionMessageBuilderV1().
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: 1,
PartitionId: 1,
@ -2898,9 +2897,15 @@ func TestServer_InitMessageCallback(t *testing.T) {
MsgType: commonpb.MsgType_DropPartition,
},
}).
BuildMutable()
assert.NoError(t, err)
err = registry.CallMessageAckCallback(ctx, dropPartitionMsg.IntoImmutableMessage(rmq.NewRmqID(1)))
WithBroadcast([]string{"test_channel"}, message.NewImportJobIDResourceKey(1)).
MustBuildBroadcast()
err := registry.CallMessageAckCallback(ctx, dropPartitionMsg, map[string]*message.AppendResult{
"test_channel": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
})
assert.Error(t, err) // server not healthy
// Test Import message check callback
@ -2918,16 +2923,22 @@ func TestServer_InitMessageCallback(t *testing.T) {
assert.NoError(t, err)
// Test Import message ack callback
importMsg, err := message.NewImportMessageBuilderV1().
WithVChannel("test_channel").
importMsg := message.NewImportMessageBuilderV1().
WithHeader(&message.ImportMessageHeader{}).
WithBody(&msgpb.ImportMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Import,
},
}).
BuildMutable()
assert.NoError(t, err)
err = registry.CallMessageAckCallback(ctx, importMsg.IntoImmutableMessage(rmq.NewRmqID(1)))
WithBroadcast([]string{"test_channel"}, resourceKey).
MustBuildBroadcast()
err = registry.CallMessageAckCallback(ctx, importMsg, map[string]*message.AppendResult{
"test_channel": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
},
)
assert.Error(t, err) // server not healthy
}

View File

@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -54,7 +55,7 @@ func TestBalancer(t *testing.T) {
})
snmanager.ResetStreamingNodeManager()
snmanager.StaticStreamingNodeManager.SetBalancerReady(sbalancer)
balance.Register(sbalancer)
balancer := balancerImpl{
walAccesserImpl: &walAccesserImpl{},

View File

@ -36,7 +36,7 @@ import (
// │   └── cluster-2-pchannel-2
func NewCataLog(metaKV kv.MetaKv) metastore.StreamingCoordCataLog {
return &catalog{
metaKV: metaKV,
metaKV: kv.NewReliableWriteMetaKv(metaKV),
}
}

View File

@ -13,6 +13,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/mocks/mock_kv"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func TestCatalog(t *testing.T) {
@ -128,18 +129,6 @@ func TestCatalog(t *testing.T) {
tasks, err = catalog.ListBroadcastTask(context.Background())
assert.Error(t, err)
assert.Nil(t, tasks)
kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Unset()
kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(errors.New("save error"))
kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Unset()
kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("save error"))
err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{{
Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1},
Node: &streamingpb.StreamingNodeInfo{ServerId: 1},
}})
assert.Error(t, err)
err = catalog.SaveBroadcastTask(context.Background(), 1, &streamingpb.BroadcastTask{})
assert.Error(t, err)
}
func TestCatalog_ReplicationCatalog(t *testing.T) {
@ -255,4 +244,11 @@ func TestCatalog_ReplicationCatalog(t *testing.T) {
assert.Equal(t, infos[0].GetSourceChannelName(), "source-channel-2")
assert.Equal(t, infos[0].GetTargetChannelName(), "target-channel-2")
assert.Equal(t, infos[0].GetTargetCluster().GetClusterId(), "target-cluster")
kv.EXPECT().Load(mock.Anything, mock.Anything).Unset()
kv.EXPECT().Load(mock.Anything, mock.Anything).Return("", merr.ErrIoKeyNotFound)
cfg, err = catalog.GetReplicateConfiguration(context.Background())
assert.NoError(t, err)
assert.Nil(t, cfg)
}

View File

@ -11,6 +11,8 @@ import (
mock "github.com/stretchr/testify/mock"
replicateutil "github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
syncutil "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
@ -315,6 +317,51 @@ func (_c *MockBalancer_RegisterStreamingEnabledNotifier_Call) RunAndReturn(run f
return _c
}
// ReplicateRole provides a mock function with no fields
func (_m *MockBalancer) ReplicateRole() replicateutil.Role {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ReplicateRole")
}
var r0 replicateutil.Role
if rf, ok := ret.Get(0).(func() replicateutil.Role); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(replicateutil.Role)
}
return r0
}
// MockBalancer_ReplicateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReplicateRole'
type MockBalancer_ReplicateRole_Call struct {
*mock.Call
}
// ReplicateRole is a helper method to define mock.On call
func (_e *MockBalancer_Expecter) ReplicateRole() *MockBalancer_ReplicateRole_Call {
return &MockBalancer_ReplicateRole_Call{Call: _e.mock.On("ReplicateRole")}
}
func (_c *MockBalancer_ReplicateRole_Call) Run(run func()) *MockBalancer_ReplicateRole_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBalancer_ReplicateRole_Call) Return(_a0 replicateutil.Role) *MockBalancer_ReplicateRole_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBalancer_ReplicateRole_Call) RunAndReturn(run func() replicateutil.Role) *MockBalancer_ReplicateRole_Call {
_c.Call.Return(run)
return _c
}
// Trigger provides a mock function with given fields: ctx
func (_m *MockBalancer) Trigger(ctx context.Context) error {
ret := _m.Called(ctx)

View File

@ -0,0 +1,130 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_broadcaster
import (
context "context"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
// MockBroadcastAPI is an autogenerated mock type for the BroadcastAPI type
type MockBroadcastAPI struct {
mock.Mock
}
type MockBroadcastAPI_Expecter struct {
mock *mock.Mock
}
func (_m *MockBroadcastAPI) EXPECT() *MockBroadcastAPI_Expecter {
return &MockBroadcastAPI_Expecter{mock: &_m.Mock}
}
// Broadcast provides a mock function with given fields: ctx, msg
func (_m *MockBroadcastAPI) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for Broadcast")
}
var r0 *types.BroadcastAppendResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok {
return rf(ctx, msg)
}
if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok {
r0 = rf(ctx, msg)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.BroadcastAppendResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok {
r1 = rf(ctx, msg)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroadcastAPI_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast'
type MockBroadcastAPI_Broadcast_Call struct {
*mock.Call
}
// Broadcast is a helper method to define mock.On call
// - ctx context.Context
// - msg message.BroadcastMutableMessage
func (_e *MockBroadcastAPI_Expecter) Broadcast(ctx interface{}, msg interface{}) *MockBroadcastAPI_Broadcast_Call {
return &MockBroadcastAPI_Broadcast_Call{Call: _e.mock.On("Broadcast", ctx, msg)}
}
func (_c *MockBroadcastAPI_Broadcast_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockBroadcastAPI_Broadcast_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage))
})
return _c
}
func (_c *MockBroadcastAPI_Broadcast_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockBroadcastAPI_Broadcast_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroadcastAPI_Broadcast_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockBroadcastAPI_Broadcast_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with no fields
func (_m *MockBroadcastAPI) Close() {
_m.Called()
}
// MockBroadcastAPI_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockBroadcastAPI_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockBroadcastAPI_Expecter) Close() *MockBroadcastAPI_Close_Call {
return &MockBroadcastAPI_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockBroadcastAPI_Close_Call) Run(run func()) *MockBroadcastAPI_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastAPI_Close_Call) Return() *MockBroadcastAPI_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockBroadcastAPI_Close_Call) RunAndReturn(run func()) *MockBroadcastAPI_Close_Call {
_c.Run(run)
return _c
}
// NewMockBroadcastAPI creates a new instance of MockBroadcastAPI. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockBroadcastAPI(t interface {
mock.TestingT
Cleanup(func())
}) *MockBroadcastAPI {
mock := &MockBroadcastAPI{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -5,10 +5,11 @@ package mock_broadcaster
import (
context "context"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
broadcaster "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
)
// MockBroadcaster is an autogenerated mock type for the Broadcaster type
@ -71,65 +72,6 @@ func (_c *MockBroadcaster_Ack_Call) RunAndReturn(run func(context.Context, messa
return _c
}
// Broadcast provides a mock function with given fields: ctx, msg
func (_m *MockBroadcaster) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for Broadcast")
}
var r0 *types.BroadcastAppendResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok {
return rf(ctx, msg)
}
if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok {
r0 = rf(ctx, msg)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.BroadcastAppendResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok {
r1 = rf(ctx, msg)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroadcaster_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast'
type MockBroadcaster_Broadcast_Call struct {
*mock.Call
}
// Broadcast is a helper method to define mock.On call
// - ctx context.Context
// - msg message.BroadcastMutableMessage
func (_e *MockBroadcaster_Expecter) Broadcast(ctx interface{}, msg interface{}) *MockBroadcaster_Broadcast_Call {
return &MockBroadcaster_Broadcast_Call{Call: _e.mock.On("Broadcast", ctx, msg)}
}
func (_c *MockBroadcaster_Broadcast_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockBroadcaster_Broadcast_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage))
})
return _c
}
func (_c *MockBroadcaster_Broadcast_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockBroadcaster_Broadcast_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroadcaster_Broadcast_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockBroadcaster_Broadcast_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with no fields
func (_m *MockBroadcaster) Close() {
_m.Called()
@ -210,6 +152,79 @@ func (_c *MockBroadcaster_LegacyAck_Call) RunAndReturn(run func(context.Context,
return _c
}
// WithResourceKeys provides a mock function with given fields: ctx, resourceKeys
func (_m *MockBroadcaster) WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (broadcaster.BroadcastAPI, error) {
_va := make([]interface{}, len(resourceKeys))
for _i := range resourceKeys {
_va[_i] = resourceKeys[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for WithResourceKeys")
}
var r0 broadcaster.BroadcastAPI
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, ...message.ResourceKey) (broadcaster.BroadcastAPI, error)); ok {
return rf(ctx, resourceKeys...)
}
if rf, ok := ret.Get(0).(func(context.Context, ...message.ResourceKey) broadcaster.BroadcastAPI); ok {
r0 = rf(ctx, resourceKeys...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(broadcaster.BroadcastAPI)
}
}
if rf, ok := ret.Get(1).(func(context.Context, ...message.ResourceKey) error); ok {
r1 = rf(ctx, resourceKeys...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroadcaster_WithResourceKeys_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithResourceKeys'
type MockBroadcaster_WithResourceKeys_Call struct {
*mock.Call
}
// WithResourceKeys is a helper method to define mock.On call
// - ctx context.Context
// - resourceKeys ...message.ResourceKey
func (_e *MockBroadcaster_Expecter) WithResourceKeys(ctx interface{}, resourceKeys ...interface{}) *MockBroadcaster_WithResourceKeys_Call {
return &MockBroadcaster_WithResourceKeys_Call{Call: _e.mock.On("WithResourceKeys",
append([]interface{}{ctx}, resourceKeys...)...)}
}
func (_c *MockBroadcaster_WithResourceKeys_Call) Run(run func(ctx context.Context, resourceKeys ...message.ResourceKey)) *MockBroadcaster_WithResourceKeys_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]message.ResourceKey, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(message.ResourceKey)
}
}
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
func (_c *MockBroadcaster_WithResourceKeys_Call) Return(_a0 broadcaster.BroadcastAPI, _a1 error) *MockBroadcaster_WithResourceKeys_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroadcaster_WithResourceKeys_Call) RunAndReturn(run func(context.Context, ...message.ResourceKey) (broadcaster.BroadcastAPI, error)) *MockBroadcaster_WithResourceKeys_Call {
_c.Call.Return(run)
return _c
}
// NewMockBroadcaster creates a new instance of MockBroadcaster. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockBroadcaster(t interface {

View File

@ -7,11 +7,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
@ -50,7 +50,7 @@ func TestAssignChannelToWALLocatedFirst(t *testing.T) {
<-ctx.Done()
return context.Cause(ctx)
})
snmanager.StaticStreamingNodeManager.SetBalancerReady(b)
balance.Register(b)
channels := []*meta.DmChannel{
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel_v1"}},

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
@ -242,7 +243,7 @@ func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() {
return pchans[0], nil
}
})
snmanager.StaticStreamingNodeManager.SetBalancerReady(b)
balance.Register(b)
suite.observer = NewReplicaObserver(suite.meta, suite.distMgr)
suite.observer.Start()

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
mocktso "github.com/milvus-io/milvus/internal/tso/mocks"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
@ -551,14 +552,15 @@ func TestGcPartitionData(t *testing.T) {
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Run(
func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) {
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
})
b.EXPECT().RegisterStreamingEnabledNotifier(mock.Anything).Run(func(notifier *syncutil.AsyncTaskNotifier[struct{}]) {
notifier.Cancel()
})
snmanager.StaticStreamingNodeManager.SetBalancerReady(b)
balance.Register(b)
wal := mock_streaming.NewMockWALAccesser(t)
broadcast := mock_streaming.NewMockBroadcast(t)

View File

@ -0,0 +1,25 @@
package balance
import (
"context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
var singleton = syncutil.NewFuture[balancer.Balancer]()
func Register(balancer balancer.Balancer) {
singleton.Set(balancer)
}
func GetWithContext(ctx context.Context) (balancer.Balancer, error) {
return singleton.GetWithContext(ctx)
}
func Release() {
if !singleton.Ready() {
return
}
singleton.Get().Close()
}

View File

@ -0,0 +1,13 @@
//go:build test
// +build test
package balance
import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func ResetBalancer() {
singleton = syncutil.NewFuture[balancer.Balancer]()
}

View File

@ -9,6 +9,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
@ -36,6 +37,9 @@ type Balancer interface {
// UpdateBalancePolicy update the balance policy.
UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error)
// ReplicateRole returns the replicate role of the balancer.
ReplicateRole() replicateutil.Role
// RegisterStreamingEnabledNotifier registers a notifier into the balancer.
// If the error is returned, the balancer is closed.
// Otherwise, the following rules are applied:

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -93,6 +94,11 @@ func (b *balancerImpl) GetLatestChannelAssignment() (*WatchChannelAssignmentsCal
return b.channelMetaManager.GetLatestChannelAssignment()
}
// ReplicateRole returns the replicate role of the balancer.
func (b *balancerImpl) ReplicateRole() replicateutil.Role {
return b.channelMetaManager.ReplicateRole()
}
// GetAllStreamingNodes fetches all streaming node info.
func (b *balancerImpl) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) {
return resource.Resource().StreamingNodeManagerClient().GetAllStreamingNodes(ctx)

View File

@ -171,6 +171,17 @@ func (cm *ChannelManager) IsStreamingEnabledOnce() bool {
return cm.streamingVersion != nil
}
// ReplicateRole returns the replicate role of the channel manager.
func (cm *ChannelManager) ReplicateRole() replicateutil.Role {
cm.cond.L.Lock()
defer cm.cond.L.Unlock()
if cm.replicateConfig == nil {
return replicateutil.RolePrimary
}
return cm.replicateConfig.GetCurrentCluster().Role()
}
// TriggerWatchUpdate triggers the watch update.
// Because current watch must see new incoming streaming node right away,
// so a watch updating trigger will be called if there's new incoming streaming node.

View File

@ -0,0 +1,194 @@
package broadcaster
import (
"context"
"sort"
"time"
"github.com/cenkalti/backoff/v4"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// newAckCallbackScheduler creates a new ack callback scheduler.
func newAckCallbackScheduler(logger *log.MLogger) *ackCallbackScheduler {
s := &ackCallbackScheduler{
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
pending: make(chan *broadcastTask, 16),
triggerChan: make(chan struct{}, 1),
rkLocker: newResourceKeyLocker(newBroadcasterMetrics()),
tombstoneScheduler: newTombstoneScheduler(logger),
}
s.SetLogger(logger)
return s
}
type ackCallbackScheduler struct {
log.Binder
notifier *syncutil.AsyncTaskNotifier[struct{}]
pending chan *broadcastTask
triggerChan chan struct{}
tombstoneScheduler *tombstoneScheduler
pendingAckedTasks []*broadcastTask // should already sorted by the broadcastID
// For the task that hold the conflicted resource-key (which is protected by the resource-key lock),
// broadcastID is always increasing,
// the task which broadcastID is smaller happens before the task which broadcastID is larger.
// Meanwhile the timetick order of any vchannel of those two tasks are same with the order of broadcastID,
// so the smaller broadcastID task is always acked before the larger broadcastID task.
// so we can exeucte the tasks by the order of the broadcastID to promise the ack order is same with wal order.
rkLocker *resourceKeyLocker // it is used to lock the resource-key of ack operation.
// it is not same instance with the resourceKeyLocker in the broadcastTaskManager.
// because it is just used to check if the resource-key is locked when acked.
// For primary milvus cluster, it makes no sense, because the execution order is already protected by the broadcastTaskManager.
// But for secondary milvus cluster, it is necessary to use this rkLocker to protect the resource-key when acked to avoid the execution order broken.
}
// Initialize initializes the ack scheduler with a list of broadcast tasks.
func (s *ackCallbackScheduler) Initialize(tasks []*broadcastTask, tombstoneIDs []uint64, bm *broadcastTaskManager) {
// when initializing, the tasks in recovery info may be out of order, so we need to sort them by the broadcastID.
sortByBroadcastID(tasks)
s.tombstoneScheduler.Initialize(bm, tombstoneIDs)
s.pendingAckedTasks = tasks
go s.background()
}
// AddTask adds a new broadcast task into the ack scheduler.
func (s *ackCallbackScheduler) AddTask(task *broadcastTask) {
select {
case <-s.notifier.Context().Done():
panic("unreachable: ack scheduler is closing when adding new task")
case s.pending <- task:
}
}
// Close closes the ack scheduler.
func (s *ackCallbackScheduler) Close() {
s.notifier.Cancel()
s.notifier.BlockUntilFinish()
// close the tombstone scheduler after the ack scheduler is closed.
s.tombstoneScheduler.Close()
}
// background is the background task of the ack scheduler.
func (s *ackCallbackScheduler) background() {
defer func() {
s.notifier.Finish(struct{}{})
s.Logger().Info("ack scheduler background exit")
}()
s.Logger().Info("ack scheduler background start")
for {
s.triggerAckCallback()
select {
case <-s.notifier.Context().Done():
return
case task := <-s.pending:
s.addBroadcastTask(task)
case <-s.triggerChan:
}
}
}
// addBroadcastTask adds a broadcast task into the pending acked tasks.
func (s *ackCallbackScheduler) addBroadcastTask(task *broadcastTask) error {
s.pendingAckedTasks = append(s.pendingAckedTasks, task)
sortByBroadcastID(s.pendingAckedTasks) // It's a redundant operation,
// once at runtime, the tasks are coming with the order of the broadcastID if they have the conflict resource-key.
return nil
}
// triggerAckCallback triggers the ack callback.
func (s *ackCallbackScheduler) triggerAckCallback() {
pendingTasks := make([]*broadcastTask, 0, len(s.pendingAckedTasks))
for _, task := range s.pendingAckedTasks {
if task.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING &&
task.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK &&
task.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED {
s.Logger().Info("task cannot be acked, skip the ack callback", zap.Uint64("broadcastID", task.Header().BroadcastID))
continue
}
g, err := s.rkLocker.FastLock(task.Header().ResourceKeys.Collect()...)
if err != nil {
s.Logger().Warn("lock is occupied, delay the ack callback", zap.Uint64("broadcastID", task.Header().BroadcastID), zap.Error(err))
pendingTasks = append(pendingTasks, task)
continue
}
// Execute the ack callback in background.
go s.doAckCallback(task, g)
}
s.pendingAckedTasks = pendingTasks
}
// doAckCallback executes the ack callback.
func (s *ackCallbackScheduler) doAckCallback(bt *broadcastTask, g *lockGuards) (err error) {
defer func() {
g.Unlock()
s.triggerChan <- struct{}{}
if err == nil {
s.Logger().Info("execute ack callback done", zap.Uint64("broadcastID", bt.Header().BroadcastID))
} else {
s.Logger().Warn("execute ack callback failed", zap.Uint64("broadcastID", bt.Header().BroadcastID), zap.Error(err))
}
}()
s.Logger().Info("start to execute ack callback", zap.Uint64("broadcastID", bt.Header().BroadcastID))
msg, result := bt.BroadcastResult()
makeMap := make(map[string]*message.AppendResult, len(result))
for vchannel, result := range result {
makeMap[vchannel] = &message.AppendResult{
MessageID: result.MessageID,
LastConfirmedMessageID: result.LastConfirmedMessageID,
TimeTick: result.TimeTick,
}
}
// call the ack callback until done.
if err := s.callMessageAckCallbackUntilDone(s.notifier.Context(), msg, makeMap); err != nil {
return err
}
if err := bt.MarkAckCallbackDone(s.notifier.Context()); err != nil {
// The catalog is reliable to write, so we can mark the ack callback done without retrying.
return err
}
s.tombstoneScheduler.AddPending(bt.Header().BroadcastID)
return nil
}
// callMessageAckCallbackUntilDone calls the message ack callback until done.
func (s *ackCallbackScheduler) callMessageAckCallbackUntilDone(ctx context.Context, msg message.BroadcastMutableMessage, result map[string]*message.AppendResult) error {
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = 10 * time.Millisecond
backoff.MaxInterval = 10 * time.Second
backoff.MaxElapsedTime = 0
backoff.Reset()
for {
err := registry.CallMessageAckCallback(ctx, msg, result)
if err == nil {
return nil
}
nextInterval := backoff.NextBackOff()
s.Logger().Warn("failed to call message ack callback, wait for retry...",
log.FieldMessage(msg),
zap.Duration("nextInterval", nextInterval),
zap.Error(err))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(nextInterval):
}
}
}
func sortByBroadcastID(tasks []*broadcastTask) {
sort.Slice(tasks, func(i, j int) bool {
return tasks[i].Header().BroadcastID < tasks[j].Header().BroadcastID
})
}

View File

@ -0,0 +1,38 @@
package broadcast
import (
"context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
var singleton = syncutil.NewFuture[broadcaster.Broadcaster]()
// Register registers the broadcaster.
func Register(broadcaster broadcaster.Broadcaster) {
singleton.Set(broadcaster)
}
// GetWithContext gets the broadcaster with context.
func GetWithContext(ctx context.Context) (broadcaster.Broadcaster, error) {
return singleton.GetWithContext(ctx)
}
// StartBroadcastWithResourceKeys starts a broadcast with resource keys.
func StartBroadcastWithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (broadcaster.BroadcastAPI, error) {
broadcaster, err := singleton.GetWithContext(ctx)
if err != nil {
return nil, err
}
return broadcaster.WithResourceKeys(ctx, resourceKeys...)
}
// Release releases the broadcaster.
func Release() {
if !singleton.Ready() {
return
}
singleton.Get().Close()
}

View File

@ -0,0 +1,13 @@
//go:build test
// +build test
package broadcast
import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func ResetBroadcaster() {
singleton = syncutil.NewFuture[broadcaster.Broadcaster]()
}

View File

@ -2,89 +2,180 @@ package broadcaster
import (
"context"
"fmt"
"sync"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// RecoverBroadcaster recovers the broadcaster from the recovery info.
func RecoverBroadcaster(ctx context.Context) (Broadcaster, error) {
tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx)
if err != nil {
return nil, err
}
return newBroadcastTaskManager(tasks), nil
}
// newBroadcastTaskManager creates a new broadcast task manager with recovery info.
func newBroadcastTaskManager(protos []*streamingpb.BroadcastTask) (*broadcastTaskManager, []*pendingBroadcastTask) {
// return the manager, the pending broadcast tasks and the pending ack callback tasks.
func newBroadcastTaskManager(protos []*streamingpb.BroadcastTask) *broadcastTaskManager {
logger := resource.Resource().Logger().With(log.FieldComponent("broadcaster"))
metrics := newBroadcasterMetrics()
rkLocker := newResourceKeyLocker(metrics)
ackScheduler := newAckCallbackScheduler(logger)
recoveryTasks := make([]*broadcastTask, 0, len(protos))
for _, proto := range protos {
t := newBroadcastTaskFromProto(proto, metrics)
t.SetLogger(logger.With(zap.Uint64("broadcastID", t.header.BroadcastID)))
t := newBroadcastTaskFromProto(proto, metrics, ackScheduler)
t.SetLogger(logger)
recoveryTasks = append(recoveryTasks, t)
}
rks := make(map[message.ResourceKey]uint64, len(recoveryTasks))
tasks := make(map[uint64]*broadcastTask, len(recoveryTasks))
pendingTasks := make([]*pendingBroadcastTask, 0, len(recoveryTasks))
pendingAckCallbackTasks := make([]*broadcastTask, 0, len(recoveryTasks))
tombstoneIDs := make([]uint64, 0, len(recoveryTasks))
for _, task := range recoveryTasks {
for rk := range task.header.ResourceKeys {
if oldTaskID, ok := rks[rk]; ok {
panic(fmt.Sprintf("unreachable: dirty recovery info in metastore, broadcast ids: [%d, %d]", oldTaskID, task.header.BroadcastID))
switch task.task.State {
case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK:
guards, err := rkLocker.FastLock(task.Header().ResourceKeys.Collect()...)
if err != nil {
panic(err)
}
rks[rk] = task.header.BroadcastID
metrics.IncomingResourceKey(rk.Domain)
task.WithResourceKeyLockGuards(guards)
if newPending := newPendingBroadcastTask(task); newPending != nil {
// if there's some pending messages that is not appended, it should be continued to be appended.
pendingTasks = append(pendingTasks, newPending)
} else {
// if there's no pending messages, it should be added to the pending ack callback tasks.
pendingAckCallbackTasks = append(pendingAckCallbackTasks, task)
}
tasks[task.header.BroadcastID] = task
if task.task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING {
// only the task is pending need to be reexecuted.
pendingTasks = append(pendingTasks, newPendingBroadcastTask(task))
case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED:
// The task is recovered from the remote cluster, so it doesn't hold the resource lock.
// but the task execution order should be protected by the order of broadcastID (by ackCallbackScheduler)
if isAllDone(task.task) {
pendingAckCallbackTasks = append(pendingAckCallbackTasks, task)
}
case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE:
tombstoneIDs = append(tombstoneIDs, task.Header().BroadcastID)
}
tasks[task.Header().BroadcastID] = task
}
m := &broadcastTaskManager{
Binder: log.Binder{},
cond: syncutil.NewContextCond(&sync.Mutex{}),
lifetime: typeutil.NewLifetime(),
mu: &sync.Mutex{},
tasks: tasks,
resourceKeys: rks,
resourceKeyLocker: rkLocker,
metrics: metrics,
broadcastScheduler: newBroadcasterScheduler(pendingTasks, logger),
ackScheduler: ackScheduler,
}
// add the pending ack callback tasks into the ack scheduler.
ackScheduler.Initialize(pendingAckCallbackTasks, tombstoneIDs, m)
m.SetLogger(logger)
return m, pendingTasks
return m
}
// broadcastTaskManager is the manager of the broadcast task.
type broadcastTaskManager struct {
log.Binder
cond *syncutil.ContextCond
lifetime *typeutil.Lifetime
mu *sync.Mutex
tasks map[uint64]*broadcastTask // map the broadcastID to the broadcastTaskState
resourceKeys map[message.ResourceKey]uint64 // map the resource key to the broadcastID
tombstoneTasks []uint64 // the broadcastID of the tombstone tasks
resourceKeyLocker *resourceKeyLocker
metrics *broadcasterMetrics
broadcastScheduler *broadcasterScheduler // the scheduler of the broadcast task
ackScheduler *ackCallbackScheduler // the scheduler of the ack task
}
// AddTask adds a new broadcast task into the manager.
func (bm *broadcastTaskManager) AddTask(ctx context.Context, msg message.BroadcastMutableMessage) (*pendingBroadcastTask, error) {
var err error
if msg, err = bm.assignID(ctx, msg); err != nil {
return nil, err
}
task, err := bm.addBroadcastTask(ctx, msg)
if err != nil {
return nil, err
}
return newPendingBroadcastTask(task), nil
}
// assignID assigns the broadcast id to the message.
func (bm *broadcastTaskManager) assignID(ctx context.Context, msg message.BroadcastMutableMessage) (message.BroadcastMutableMessage, error) {
// WithResourceKeys acquires the resource keys for the broadcast task.
func (bm *broadcastTaskManager) WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (BroadcastAPI, error) {
id, err := resource.Resource().IDAllocator().Allocate(ctx)
if err != nil {
return nil, errors.Wrapf(err, "allocate new id failed")
}
msg = msg.WithBroadcastID(id)
return msg, nil
resourceKeys = bm.appendSharedClusterRK(resourceKeys...)
guards, err := bm.resourceKeyLocker.Lock(resourceKeys...)
if err != nil {
return nil, err
}
if err := bm.checkClusterRole(ctx); err != nil {
// unlock the guards if the cluster role is not primary.
guards.Unlock()
return nil, err
}
return &broadcasterWithRK{
broadcaster: bm,
broadcastID: id,
guards: guards,
}, nil
}
// checkClusterRole checks if the cluster status is primary, otherwise return error.
func (bm *broadcastTaskManager) checkClusterRole(ctx context.Context) error {
// Check if the cluster status is primary, otherwise return error.
b, err := balance.GetWithContext(ctx)
if err != nil {
return err
}
if b.ReplicateRole() != replicateutil.RolePrimary {
return status.NewReplicateViolation("cluster is not primary, cannot do any DDL/DCL")
}
return nil
}
// appendSharedClusterRK appends the shared cluster resource key to the resource keys.
// shared cluster resource key is required for all broadcast messages.
func (bm *broadcastTaskManager) appendSharedClusterRK(resourceKeys ...message.ResourceKey) []message.ResourceKey {
for _, rk := range resourceKeys {
if rk.Domain == messagespb.ResourceDomain_ResourceDomainCluster {
return resourceKeys
}
}
return append(resourceKeys, message.NewSharedClusterResourceKey())
}
// broadcast broadcasts the message to all vchannels.
// it will block until the message is broadcasted to all vchannels
func (bm *broadcastTaskManager) broadcast(ctx context.Context, msg message.BroadcastMutableMessage, broadcastID uint64, guards *lockGuards) (*types.BroadcastAppendResult, error) {
if !bm.lifetime.Add(typeutil.LifetimeStateWorking) {
guards.Unlock()
return nil, status.NewOnShutdownError("broadcaster is closing")
}
defer bm.lifetime.Done()
// check if the message is valid to be broadcasted.
// TODO: the message check callback should not be an component of broadcaster,
// it should be removed after the import operation refactory.
if err := registry.CallMessageCheckCallback(ctx, msg); err != nil {
guards.Unlock()
return nil, err
}
task := bm.addBroadcastTask(msg, broadcastID, guards)
pendingTask := newPendingBroadcastTask(task)
// Add it into broadcast scheduler to broadcast the message into all vchannels.
return bm.broadcastScheduler.AddTask(ctx, pendingTask)
}
// LegacyAck is the legacy ack function for the broadcast task.
@ -105,72 +196,90 @@ func (bm *broadcastTaskManager) LegacyAck(ctx context.Context, broadcastID uint6
// Ack acknowledges the message at the specified vchannel.
func (bm *broadcastTaskManager) Ack(ctx context.Context, msg message.ImmutableMessage) error {
if err := registry.CallMessageAckCallback(ctx, msg); err != nil {
bm.Logger().Warn("message ack callback failed", log.FieldMessage(msg), zap.Error(err))
return err
if !bm.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("broadcaster is closing")
}
bm.Logger().Warn("message ack callback success", log.FieldMessage(msg))
defer bm.lifetime.Done()
broadcastID := msg.BroadcastHeader().BroadcastID
vchannel := msg.VChannel()
task, ok := bm.getBroadcastTaskByID(broadcastID)
t, ok := bm.getOrCreateBroadcastTask(msg)
if !ok {
bm.Logger().Warn("broadcast task not found, it may already acked, ignore the request", zap.Uint64("broadcastID", broadcastID), zap.String("vchannel", vchannel))
bm.Logger().Debug(
"task is tombstone, ignored the ack request",
zap.Uint64("broadcastID", msg.BroadcastHeader().BroadcastID),
zap.String("vchannel", msg.VChannel()))
return nil
}
if err := task.Ack(ctx, msg); err != nil {
return t.Ack(ctx, msg)
}
// DropTombstone drops the tombstone task from the manager.
func (bm *broadcastTaskManager) DropTombstone(ctx context.Context, broadcastID uint64) error {
if !bm.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("broadcaster is closing")
}
defer bm.lifetime.Done()
t, ok := bm.getBroadcastTaskByID(broadcastID)
if !ok {
bm.Logger().Debug("task is not found, ignored the drop tombstone request", zap.Uint64("broadcastID", broadcastID))
return nil
}
if err := t.DropTombstone(ctx); err != nil {
return err
}
if task.State() == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE {
bm.removeBroadcastTask(broadcastID)
}
return nil
}
// ReleaseResourceKeys releases the resource keys by the broadcastID.
func (bm *broadcastTaskManager) ReleaseResourceKeys(broadcastID uint64) {
bm.cond.LockAndBroadcast()
defer bm.cond.L.Unlock()
// Close closes the broadcast task manager.
func (bm *broadcastTaskManager) Close() {
bm.lifetime.SetState(typeutil.LifetimeStateStopped)
bm.lifetime.Wait()
bm.removeResourceKeys(broadcastID)
bm.broadcastScheduler.Close()
bm.ackScheduler.Close()
}
// addBroadcastTask adds the broadcast task into the manager.
func (bm *broadcastTaskManager) addBroadcastTask(ctx context.Context, msg message.BroadcastMutableMessage) (*broadcastTask, error) {
newIncomingTask := newBroadcastTaskFromBroadcastMessage(msg, bm.metrics)
header := newIncomingTask.Header()
newIncomingTask.SetLogger(bm.Logger().With(zap.Uint64("broadcastID", header.BroadcastID)))
func (bm *broadcastTaskManager) addBroadcastTask(msg message.BroadcastMutableMessage, broadcastID uint64, guards *lockGuards) *broadcastTask {
msg = msg.OverwriteBroadcastHeader(broadcastID, guards.ResourceKeys()...)
newIncomingTask := newBroadcastTaskFromBroadcastMessage(msg, bm.metrics, bm.ackScheduler)
newIncomingTask.SetLogger(bm.Logger())
newIncomingTask.WithResourceKeyLockGuards(guards)
bm.cond.L.Lock()
for bm.checkIfResourceKeyExist(header) {
if err := bm.cond.Wait(ctx); err != nil {
return nil, err
}
bm.mu.Lock()
bm.tasks[broadcastID] = newIncomingTask
bm.mu.Unlock()
return newIncomingTask
}
// setup the resource keys to make resource exclusive held.
for key := range header.ResourceKeys {
bm.resourceKeys[key] = header.BroadcastID
bm.metrics.IncomingResourceKey(key.Domain)
// getOrCreateBroadcastTask returns the task by the broadcastID
// return false if the task is tombstone.
// if the task is not found, it will create a new task.
func (bm *broadcastTaskManager) getOrCreateBroadcastTask(msg message.ImmutableMessage) (*broadcastTask, bool) {
bm.mu.Lock()
defer bm.mu.Unlock()
bh := msg.BroadcastHeader()
t, ok := bm.tasks[msg.BroadcastHeader().BroadcastID]
if ok {
return t, t.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE
}
bm.tasks[header.BroadcastID] = newIncomingTask
bm.cond.L.Unlock()
return newIncomingTask, nil
if msg.ReplicateHeader() == nil {
bm.Logger().Warn("try to recover task from the wal from non-replicate message, ignore it")
return nil, false
}
func (bm *broadcastTaskManager) checkIfResourceKeyExist(header *message.BroadcastHeader) bool {
for key := range header.ResourceKeys {
if _, ok := bm.resourceKeys[key]; ok {
return true
}
}
return false
newBroadcastTask := newBroadcastTaskFromImmutableMessage(msg, bm.metrics, bm.ackScheduler)
newBroadcastTask.SetLogger(bm.Logger())
bm.tasks[bh.BroadcastID] = newBroadcastTask
return newBroadcastTask, true
}
// getBroadcastTaskByID return the task by the broadcastID.
func (bm *broadcastTaskManager) getBroadcastTaskByID(broadcastID uint64) (*broadcastTask, bool) {
bm.cond.L.Lock()
defer bm.cond.L.Unlock()
bm.mu.Lock()
defer bm.mu.Unlock()
t, ok := bm.tasks[broadcastID]
return t, ok
@ -178,22 +287,8 @@ func (bm *broadcastTaskManager) getBroadcastTaskByID(broadcastID uint64) (*broad
// removeBroadcastTask removes the broadcast task by the broadcastID.
func (bm *broadcastTaskManager) removeBroadcastTask(broadcastID uint64) {
bm.cond.LockAndBroadcast()
defer bm.cond.L.Unlock()
bm.mu.Lock()
defer bm.mu.Unlock()
bm.removeResourceKeys(broadcastID)
delete(bm.tasks, broadcastID)
}
// removeResourceKeys removes the resource keys by the broadcastID.
func (bm *broadcastTaskManager) removeResourceKeys(broadcastID uint64) {
task, ok := bm.tasks[broadcastID]
if !ok {
return
}
// remove the related resource keys
for key := range task.header.ResourceKeys {
delete(bm.resourceKeys, key)
bm.metrics.GoneResourceKey(key.Domain)
}
}

View File

@ -7,11 +7,7 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/hardware"
@ -20,32 +16,25 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func RecoverBroadcaster(
ctx context.Context,
) (Broadcaster, error) {
tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx)
if err != nil {
return nil, err
}
manager, pendings := newBroadcastTaskManager(tasks)
b := &broadcasterImpl{
manager: manager,
lifetime: typeutil.NewLifetime(),
// newBroadcasterScheduler creates a new broadcaster scheduler.
func newBroadcasterScheduler(pendings []*pendingBroadcastTask, logger *log.MLogger) *broadcasterScheduler {
b := &broadcasterScheduler{
backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
pendings: pendings,
backoffs: typeutil.NewHeap[*pendingBroadcastTask](&pendingBroadcastTaskArray{}),
backoffChan: make(chan *pendingBroadcastTask),
pendingChan: make(chan *pendingBroadcastTask),
backoffChan: make(chan *pendingBroadcastTask),
workerChan: make(chan *pendingBroadcastTask),
}
b.SetLogger(logger)
go b.execute()
return b, nil
return b
}
// broadcasterImpl is the implementation of Broadcaster
type broadcasterImpl struct {
manager *broadcastTaskManager
lifetime *typeutil.Lifetime
// broadcasterScheduler is the implementation of Broadcaster
type broadcasterScheduler struct {
log.Binder
backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}]
pendings []*pendingBroadcastTask
backoffs typeutil.Heap[*pendingBroadcastTask]
@ -54,87 +43,33 @@ type broadcasterImpl struct {
workerChan chan *pendingBroadcastTask
}
// Broadcast broadcasts the message to all channels.
func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (result *types.BroadcastAppendResult, err error) {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return nil, status.NewOnShutdownError("broadcaster is closing")
}
defer func() {
b.lifetime.Done()
if err != nil {
b.Logger().Warn("broadcast message failed", zap.Error(err))
return
}
}()
// We need to check if the message is valid before adding it to the broadcaster.
// TODO: add resource key lock here to avoid state race condition.
// TODO: add all ddl to check operation here after ddl framework is ready.
if err := registry.CallMessageCheckCallback(ctx, msg); err != nil {
b.Logger().Warn("check message ack callback failed", zap.Error(err))
return nil, err
}
t, err := b.manager.AddTask(ctx, msg)
if err != nil {
return nil, err
}
func (b *broadcasterScheduler) AddTask(ctx context.Context, task *pendingBroadcastTask) (*types.BroadcastAppendResult, error) {
select {
case <-b.backgroundTaskNotifier.Context().Done():
// We can only check the background context but not the request context here.
// Because we want the new incoming task must be delivered to the background task queue
// otherwise the broadcaster is closing
return nil, status.NewOnShutdownError("broadcaster is closing")
case b.pendingChan <- t:
panic("unreachable: broadcaster is closing when adding new task")
case b.pendingChan <- task:
}
// Wait both request context and the background task context.
ctx, _ = contextutil.MergeContext(ctx, b.backgroundTaskNotifier.Context())
r, err := t.BlockUntilTaskDone(ctx)
// wait for all the vchannels acked.
result, err := task.BlockUntilAllAck(ctx)
if err != nil {
return nil, err
}
// wait for all the vchannels acked.
if err := t.BlockUntilAllAck(ctx); err != nil {
return nil, err
}
return r, nil
return result, nil
}
func (b *broadcasterImpl) LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("broadcaster is closing")
}
defer b.lifetime.Done()
return b.manager.LegacyAck(ctx, broadcastID, vchannel)
}
// Ack acknowledges the message at the specified vchannel.
func (b *broadcasterImpl) Ack(ctx context.Context, msg message.ImmutableMessage) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("broadcaster is closing")
}
defer b.lifetime.Done()
return b.manager.Ack(ctx, msg)
}
func (b *broadcasterImpl) Close() {
b.lifetime.SetState(typeutil.LifetimeStateStopped)
b.lifetime.Wait()
func (b *broadcasterScheduler) Close() {
b.backgroundTaskNotifier.Cancel()
b.backgroundTaskNotifier.BlockUntilFinish()
}
func (b *broadcasterImpl) Logger() *log.MLogger {
return b.manager.Logger()
}
// execute the broadcaster
func (b *broadcasterImpl) execute() {
func (b *broadcasterScheduler) execute() {
workers := int(float64(hardware.GetCPUNum()) * paramtable.Get().StreamingCfg.WALBroadcasterConcurrencyRatio.GetAsFloat())
if workers < 1 {
workers = 1
@ -162,7 +97,7 @@ func (b *broadcasterImpl) execute() {
b.dispatch()
}
func (b *broadcasterImpl) dispatch() {
func (b *broadcasterScheduler) dispatch() {
for {
var workerChan chan *pendingBroadcastTask
var nextTask *pendingBroadcastTask
@ -203,7 +138,7 @@ func (b *broadcasterImpl) dispatch() {
}
}
func (b *broadcasterImpl) worker(no int) {
func (b *broadcasterScheduler) worker(no int) {
logger := b.Logger().With(zap.Int("workerNo", no))
defer func() {
logger.Info("broadcaster worker exit")
@ -222,8 +157,6 @@ func (b *broadcasterImpl) worker(no int) {
case b.backoffChan <- task:
}
}
// All message of broadcast task is sent, release the resource keys to let other task with same resource keys to apply operation.
b.manager.ReleaseResourceKeys(task.Header().BroadcastID)
}
}
}

View File

@ -12,65 +12,125 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
// newBroadcastTaskFromProto creates a new broadcast task from the proto.
func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadcasterMetrics) *broadcastTask {
func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask {
m := metrics.NewBroadcastTask(proto.GetState())
msg := message.NewBroadcastMutableMessageBeforeAppend(proto.Message.Payload, proto.Message.Properties)
bh := msg.BroadcastHeader()
bt := &broadcastTask{
mu: sync.Mutex{},
header: bh,
msg: msg,
task: proto,
recoverPersisted: true, // the task is recovered from the recovery info, so it's persisted.
dirty: true, // the task is recovered from the recovery info, so it's persisted.
metrics: m,
ackCallbackScheduler: ackCallbackScheduler,
allAcked: make(chan struct{}),
}
if isAllDone(proto) {
if proto.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE {
close(bt.allAcked)
}
return bt
}
// newBroadcastTaskFromBroadcastMessage creates a new broadcast task from the broadcast message.
func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage, metrics *broadcasterMetrics) *broadcastTask {
func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask {
m := metrics.NewBroadcastTask(streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING)
header := msg.BroadcastHeader()
bt := &broadcastTask{
Binder: log.Binder{},
mu: sync.Mutex{},
header: header,
msg: msg,
task: &streamingpb.BroadcastTask{
Message: msg.IntoMessageProto(),
State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING,
AckedVchannelBitmap: make([]byte, len(header.VChannels)),
AckedCheckpoints: make([]*streamingpb.AckedCheckpoint, len(header.VChannels)),
},
recoverPersisted: false,
dirty: false,
metrics: m,
ackCallbackScheduler: ackCallbackScheduler,
allAcked: make(chan struct{}),
}
if isAllDone(bt.task) {
if bt.task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE {
close(bt.allAcked)
}
return bt
}
// newBroadcastTaskFromImmutableMessage creates a new broadcast task from the immutable message.
func newBroadcastTaskFromImmutableMessage(msg message.ImmutableMessage, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask {
broadcastMsg := msg.IntoBroadcastMutableMessage()
task := newBroadcastTaskFromBroadcastMessage(broadcastMsg, metrics, ackCallbackScheduler)
// if the task is created from the immutable message, it already has been broadcasted, so transfer its state into recovered.
task.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED
task.metrics.ToState(streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED)
return task
}
// broadcastTask is the state of the broadcast task.
type broadcastTask struct {
log.Binder
mu sync.Mutex
header *message.BroadcastHeader
msg message.BroadcastMutableMessage
task *streamingpb.BroadcastTask
recoverPersisted bool // a flag to indicate that the task has been persisted into the recovery info and can be recovered.
dirty bool // a flag to indicate that the task has been modified and needs to be saved into the recovery info.
metrics *taskMetricsGuard
allAcked chan struct{}
guards *lockGuards
ackCallbackScheduler *ackCallbackScheduler
}
// SetLogger sets the logger of the broadcast task.
func (b *broadcastTask) SetLogger(logger *log.MLogger) {
b.Binder.SetLogger(logger.With(log.FieldMessage(b.msg)))
}
// WithResourceKeyLockGuards sets the lock guards for the broadcast task.
func (b *broadcastTask) WithResourceKeyLockGuards(guards *lockGuards) {
b.mu.Lock()
defer b.mu.Unlock()
if b.guards != nil {
panic("broadcast task already has lock guards")
}
b.guards = guards
}
// BroadcastResult returns the broadcast result of the broadcast task.
func (b *broadcastTask) BroadcastResult() (message.BroadcastMutableMessage, map[string]*types.AppendResult) {
b.mu.Lock()
defer b.mu.Unlock()
vchannels := b.msg.BroadcastHeader().VChannels
result := make(map[string]*types.AppendResult, len(vchannels))
for idx, vchannel := range vchannels {
if b.task.AckedCheckpoints == nil {
// forward compatible with the old version.
result[vchannel] = &types.AppendResult{
MessageID: nil,
LastConfirmedMessageID: nil,
TimeTick: 0,
}
continue
}
cp := b.task.AckedCheckpoints[idx]
if cp == nil || cp.TimeTick == 0 {
panic("unreachable: BroadcastResult is called before the broadcast task is acked")
}
result[vchannel] = &types.AppendResult{
MessageID: message.MustUnmarshalMessageID(cp.MessageId),
LastConfirmedMessageID: message.MustUnmarshalMessageID(cp.LastConfirmedMessageId),
TimeTick: cp.TimeTick,
}
}
return b.msg, result
}
// Header returns the header of the broadcast task.
func (b *broadcastTask) Header() *message.BroadcastHeader {
// header is a immutable field, no need to lock.
return b.header
return b.msg.BroadcastHeader()
}
// State returns the State of the broadcast task.
@ -92,7 +152,7 @@ func (b *broadcastTask) PendingBroadcastMessages() []message.MutableMessage {
// filter out the vchannel that has been acked.
pendingMessages := make([]message.MutableMessage, 0, len(msgs))
for i, msg := range msgs {
if b.task.AckedVchannelBitmap[i] != 0 {
if b.task.AckedVchannelBitmap[i] != 0 || (b.task.AckedCheckpoints != nil && b.task.AckedCheckpoints[i] != nil) {
continue
}
pendingMessages = append(pendingMessages, msg)
@ -105,84 +165,113 @@ func (b *broadcastTask) InitializeRecovery(ctx context.Context) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.recoverPersisted {
return nil
}
if err := b.saveTask(ctx, b.task, b.Logger()); err != nil {
if err := b.saveTaskIfDirty(ctx, b.Logger()); err != nil {
return err
}
b.recoverPersisted = true
return nil
}
// GetImmutableMessageFromVChannel gets the immutable message from the vchannel.
// If the vchannel is already acked, it returns nil.
func (b *broadcastTask) GetImmutableMessageFromVChannel(vchannel string) message.ImmutableMessage {
b.mu.Lock()
defer b.mu.Unlock()
return b.getImmutableMessageFromVChannel(vchannel, nil)
}
func (b *broadcastTask) getImmutableMessageFromVChannel(vchannel string, result *types.AppendResult) message.ImmutableMessage {
msg := message.NewBroadcastMutableMessageBeforeAppend(b.task.Message.Payload, b.task.Message.Properties)
msgs := msg.SplitIntoMutableMessage()
for _, msg := range msgs {
if msg.VChannel() == vchannel {
// The legacy message don't have timetick, so we need to set it to 0.
return msg.WithTimeTick(0).IntoImmutableMessage(nil)
timetick := uint64(0)
var messageID message.MessageID
var lastConfirmedMessageID message.MessageID
if result != nil {
messageID = result.MessageID
timetick = result.TimeTick
lastConfirmedMessageID = result.LastConfirmedMessageID
}
// The legacy message don't have last confirmed message id/timetick/message id,
// so we just mock a unsafely message here.
if lastConfirmedMessageID == nil {
return msg.WithTimeTick(timetick).WithLastConfirmedUseMessageID().IntoImmutableMessage(messageID)
}
return msg.WithTimeTick(timetick).WithLastConfirmed(lastConfirmedMessageID).IntoImmutableMessage(messageID)
}
}
return nil
}
// Ack acknowledges the message at the specified vchannel.
func (b *broadcastTask) Ack(ctx context.Context, msg message.ImmutableMessage) error {
// return true if all the vchannels are acked at first time, false if not.
func (b *broadcastTask) Ack(ctx context.Context, msgs ...message.ImmutableMessage) (err error) {
b.mu.Lock()
defer b.mu.Unlock()
task, ok := b.copyAndSetVChannelAcked(msg.VChannel())
if !ok {
return nil
return b.ack(ctx, msgs...)
}
// We should always save the task after acked.
// Even if the task mark as done in memory.
// Because the task is set as done in memory before save the recovery info.
if err := b.saveTask(ctx, task, b.Logger().With(zap.String("ackVChannel", msg.VChannel()))); err != nil {
// ack acknowledges the message at the specified vchannel.
func (b *broadcastTask) ack(ctx context.Context, msgs ...message.ImmutableMessage) (err error) {
b.copyAndSetAckedCheckpoints(msgs...)
if !b.dirty {
return nil
}
if err := b.saveTaskIfDirty(ctx, b.Logger()); err != nil {
return err
}
b.task = task
if isAllDone(task) {
if isAllDone(b.task) {
b.ackCallbackScheduler.AddTask(b)
b.metrics.ObserveAckAll()
close(b.allAcked)
}
return nil
}
// BlockUntilAllAck blocks until all the vchannels are acked.
func (b *broadcastTask) BlockUntilAllAck(ctx context.Context) error {
func (b *broadcastTask) BlockUntilAllAck(ctx context.Context) (*types.BroadcastAppendResult, error) {
select {
case <-ctx.Done():
return ctx.Err()
return nil, ctx.Err()
case <-b.allAcked:
return nil
_, result := b.BroadcastResult()
return &types.BroadcastAppendResult{
BroadcastID: b.Header().BroadcastID,
AppendResults: result,
}, nil
}
}
// copyAndSetVChannelAcked copies the task and set the vchannel as acked.
// if the vchannel is already acked, it returns nil and false.
func (b *broadcastTask) copyAndSetVChannelAcked(vchannel string) (*streamingpb.BroadcastTask, bool) {
// copyAndSetAckedCheckpoints copies the task and set the acked checkpoints.
func (b *broadcastTask) copyAndSetAckedCheckpoints(msgs ...message.ImmutableMessage) {
task := proto.Clone(b.task).(*streamingpb.BroadcastTask)
for _, msg := range msgs {
vchannel := msg.VChannel()
idx, err := findIdxOfVChannel(vchannel, b.Header().VChannels)
if err != nil {
panic(err)
}
if task.AckedVchannelBitmap[idx] != 0 {
return nil, false
if len(task.AckedVchannelBitmap) == 0 {
task.AckedVchannelBitmap = make([]byte, len(b.Header().VChannels))
}
if len(task.AckedCheckpoints) == 0 {
task.AckedCheckpoints = make([]*streamingpb.AckedCheckpoint, len(b.Header().VChannels))
}
if cp := task.AckedCheckpoints[idx]; cp != nil && cp.TimeTick != 0 {
// after proto.Clone, the cp is always not nil, so we also need to check the time tick.
continue
}
// the ack result is dirty, so we need to set the dirty flag to true.
b.dirty = true
task.AckedVchannelBitmap[idx] = 1
if isAllDone(task) {
// All vchannels are acked, mark the task as done, even if there are still pending messages on working.
// The pending messages is repeated sent operation, can be ignored.
task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
task.AckedCheckpoints[idx] = &streamingpb.AckedCheckpoint{
MessageId: msg.MessageID().IntoProto(),
LastConfirmedMessageId: msg.LastConfirmedMessageID().IntoProto(),
TimeTick: msg.TimeTick(),
}
return task, true
}
// update current task state.
b.task = task
}
// findIdxOfVChannel finds the index of the vchannel in the broadcast task.
@ -195,33 +284,34 @@ func findIdxOfVChannel(vchannel string, vchannels []string) (int, error) {
return -1, errors.Errorf("unreachable: vchannel is %s not found in the broadcast task", vchannel)
}
// BroadcastDone marks the broadcast operation is done.
func (b *broadcastTask) BroadcastDone(ctx context.Context) error {
// FastAck trigger a fast ack operation when the broadcast operation is done.
func (b *broadcastTask) FastAck(ctx context.Context, broadcastResult map[string]*types.AppendResult) error {
// Broadcast operation is done.
b.metrics.ObserveBroadcastDone()
b.mu.Lock()
defer b.mu.Unlock()
task := b.copyAndMarkBroadcastDone()
if err := b.saveTask(ctx, task, b.Logger()); err != nil {
return err
// because we need to wait for the streamingnode to ack the message,
// however, if the message is already write into wal, the message is determined,
// so we can make a fast ack operation here to speed up the ack operation.
msgs := make([]message.ImmutableMessage, 0, len(broadcastResult))
for vchannel := range broadcastResult {
msgs = append(msgs, b.getImmutableMessageFromVChannel(vchannel, broadcastResult[vchannel]))
}
b.task = task
b.metrics.ObserveBroadcastDone()
return nil
return b.ack(ctx, msgs...)
}
// copyAndMarkBroadcastDone copies the task and mark the broadcast task as done.
// !!! The ack state of the task should not be removed, because the task is a lock-hint of resource key held by a broadcast operation.
// It can be removed only after the broadcast message is acked by all the vchannels.
func (b *broadcastTask) copyAndMarkBroadcastDone() *streamingpb.BroadcastTask {
task := proto.Clone(b.task).(*streamingpb.BroadcastTask)
if isAllDone(task) {
// If all vchannels are acked, mark the task as done.
task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
} else {
// There's no more pending message, mark the task as wait ack.
task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK
}
return task
// DropTombstone drops the tombstone of the broadcast task.
// It will remove the tombstone of the broadcast task in recovery storage.
// After the tombstone is dropped, the idempotency and deduplication can not be guaranteed.
func (b *broadcastTask) DropTombstone(ctx context.Context) error {
b.mu.Lock()
defer b.mu.Unlock()
b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
b.dirty = true
return b.saveTaskIfDirty(ctx, b.Logger())
}
// isAllDone check if all the vchannels are acked.
@ -243,14 +333,44 @@ func ackedCount(task *streamingpb.BroadcastTask) int {
return count
}
// saveTask saves the broadcast task recovery info.
func (b *broadcastTask) saveTask(ctx context.Context, task *streamingpb.BroadcastTask, logger *log.MLogger) error {
logger = logger.With(zap.String("state", task.State.String()), zap.Int("ackedVChannelCount", ackedCount(task)))
if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.header.BroadcastID, task); err != nil {
logger.Warn("save broadcast task failed", zap.Error(err))
// MarkAckCallbackDone marks the ack callback is done.
func (b *broadcastTask) MarkAckCallbackDone(ctx context.Context) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.task.State != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE {
b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE
close(b.allAcked)
b.dirty = true
}
if err := b.saveTaskIfDirty(ctx, b.Logger()); err != nil {
return err
}
logger.Info("save broadcast task done")
b.metrics.ToState(task.State)
if b.guards != nil {
// release the resource key lock if done.
// if the broadcast task is recovered from the remote cluster by replication,
// it doesn't hold the resource key lock, so skip it.
b.guards.Unlock()
}
return nil
}
// saveTaskIfDirty saves the broadcast task recovery info if the task is dirty.
func (b *broadcastTask) saveTaskIfDirty(ctx context.Context, logger *log.MLogger) error {
if !b.dirty {
return nil
}
b.dirty = false
logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", ackedCount(b.task)))
if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.Header().BroadcastID, b.task); err != nil {
logger.Warn("save broadcast task failed", zap.Error(err))
if ctx.Err() != nil {
panic("critical error: the save broadcast task is failed before the context is done")
}
return err
}
b.metrics.ToState(b.task.State)
logger.Info("save broadcast task done")
return nil
}

View File

@ -8,8 +8,10 @@ import (
)
type Broadcaster interface {
// Broadcast broadcasts the message to all channels.
Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)
// WithResourceKeys sets the resource keys of the broadcast operation.
// It will acquire locks of the resource keys and return the broadcast api.
// Once the broadcast api is returned, the Close() method of the broadcast api should be called to release the resource safely.
WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (BroadcastAPI, error)
// LegacyAck is the legacy ack interface for the 2.6.0 import message.
LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error
@ -21,6 +23,14 @@ type Broadcaster interface {
Close()
}
type BroadcastAPI interface {
// Broadcast broadcasts the message to all channels.
Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)
// Close releases the resource keys that broadcast api holds.
Close()
}
// AppendOperator is used to append messages, there's only two implement of this interface:
// 1. streaming.WAL()
// 2. old msgstream interface [deprecated]

View File

@ -15,6 +15,9 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types"
@ -26,6 +29,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -33,6 +37,20 @@ import (
func TestBroadcaster(t *testing.T) {
registry.ResetRegistration()
paramtable.Init()
paramtable.Get().StreamingCfg.WALBroadcasterTombstoneCheckInternal.SwapTempValue("10ms")
paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxCount.SwapTempValue("2")
paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxLifetime.SwapTempValue("20ms")
mb := mock_balancer.NewMockBalancer(t)
mb.EXPECT().ReplicateRole().Return(replicateutil.RolePrimary)
mb.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
time.Sleep(100 * time.Second)
return nil
})
balance.Register(mb)
registry.RegisterDropCollectionV1AckCallback(func(ctx context.Context, msg message.BroadcastResultDropCollectionMessageV1) error {
return nil
})
meta := mock_metastore.NewMockStreamingCoordCataLog(t)
meta.EXPECT().ListBroadcastTask(mock.Anything).
@ -57,17 +75,16 @@ func TestBroadcaster(t *testing.T) {
createNewBroadcastMsg([]string{"v1", "v2", "v3"},
message.NewCollectionNameResourceKey("c3"),
message.NewCollectionNameResourceKey("c4")).WithBroadcastID(7),
streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK,
streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED,
[]byte{0x00, 0x00, 0x00}),
}, nil
}).Times(1)
done := typeutil.NewConcurrentSet[uint64]()
meta.EXPECT().SaveBroadcastTask(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, broadcastID uint64, bt *streamingpb.BroadcastTask) error {
// may failure
if rand.Int31n(10) < 3 {
return errors.New("save task failed")
if ctx.Err() != nil {
return ctx.Err()
}
if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE {
if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE {
done.Insert(broadcastID)
}
return nil
@ -84,7 +101,7 @@ func TestBroadcaster(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, bc)
assert.Eventually(t, func() bool {
return appended.Load() == 9 && len(done.Collect()) == 6 // only one task is done,
return appended.Load() == 9 && len(done.Collect()) == 6
}, 30*time.Second, 10*time.Millisecond)
// only task 7 is not done.
@ -103,14 +120,12 @@ func TestBroadcaster(t *testing.T) {
// Test broadcast here.
broadcastWithSameRK := func() {
var result *types.BroadcastAppendResult
for {
var err error
result, err = bc.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c7")))
if err == nil {
break
}
}
b, err := bc.WithResourceKeys(context.Background(), message.NewCollectionNameResourceKey("c7"))
assert.NoError(t, err)
result, err = b.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c7")))
assert.Equal(t, len(result.AppendResults), 3)
assert.NoError(t, err)
}
go broadcastWithSameRK()
go broadcastWithSameRK()
@ -119,8 +134,19 @@ func TestBroadcaster(t *testing.T) {
return appended.Load() == 15 && len(done.Collect()) == 9
}, 30*time.Second, 10*time.Millisecond)
// Test close befor broadcast
broadcastAPI, err := bc.WithResourceKeys(context.Background(), message.NewExclusiveClusterResourceKey())
assert.NoError(t, err)
broadcastAPI.Close()
broadcastAPI, err = bc.WithResourceKeys(context.Background(), message.NewExclusiveClusterResourceKey())
assert.NoError(t, err)
broadcastAPI.Close()
bc.Close()
_, err = bc.Broadcast(context.Background(), nil)
broadcastAPI, err = bc.WithResourceKeys(context.Background())
assert.NoError(t, err)
_, err = broadcastAPI.Broadcast(context.Background(), nil)
assert.Error(t, err)
err = bc.Ack(context.Background(), mock_message.NewMockImmutableMessage(t))
assert.Error(t, err)
@ -128,13 +154,17 @@ func TestBroadcaster(t *testing.T) {
func ack(t *testing.T, broadcaster Broadcaster, broadcastID uint64, vchannel string) {
for {
msg := mock_message.NewMockImmutableMessage(t)
msg.EXPECT().VChannel().Return(vchannel)
msg.EXPECT().MessageTypeWithVersion().Return(message.MessageTypeTimeTickV1)
msg.EXPECT().BroadcastHeader().Return(&message.BroadcastHeader{
BroadcastID: broadcastID,
})
msg.EXPECT().MarshalLogObject(mock.Anything).Return(nil).Maybe()
msg := message.NewDropCollectionMessageBuilderV1().
WithHeader(&message.DropCollectionMessageHeader{}).
WithBody(&msgpb.DropCollectionRequest{}).
WithBroadcast([]string{vchannel}).
MustBuildBroadcast().
WithBroadcastID(broadcastID).
SplitIntoMutableMessage()[0].
WithTimeTick(100).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
IntoImmutableMessage(walimplstest.NewTestMessageID(1))
if err := broadcaster.Ack(context.Background(), msg); err == nil {
break
}
@ -215,6 +245,18 @@ func createNewWaitAckBroadcastTaskFromMessage(
bitmap []byte,
) *streamingpb.BroadcastTask {
pb := msg.IntoMessageProto()
acks := make([]*streamingpb.AckedCheckpoint, len(bitmap))
for i := 0; i < len(bitmap); i++ {
if bitmap[i] != 0 {
messageID := walimplstest.NewTestMessageID(int64(i))
lastConfirmedMessageID := walimplstest.NewTestMessageID(int64(i))
acks[i] = &streamingpb.AckedCheckpoint{
MessageId: messageID.IntoProto(),
LastConfirmedMessageId: lastConfirmedMessageID.IntoProto(),
TimeTick: 1,
}
}
}
return &streamingpb.BroadcastTask{
Message: &messagespb.Message{
Payload: pb.Payload,
@ -222,5 +264,6 @@ func createNewWaitAckBroadcastTaskFromMessage(
},
State: state,
AckedVchannelBitmap: bitmap,
AckedCheckpoints: acks,
}
}

View File

@ -0,0 +1,27 @@
package broadcaster
import (
"context"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
type broadcasterWithRK struct {
broadcaster *broadcastTaskManager
broadcastID uint64
guards *lockGuards
}
func (b *broadcasterWithRK) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
// consume the guards after the broadcast is called to avoid double unlock.
guards := b.guards
b.guards = nil
return b.broadcaster.broadcast(ctx, msg, b.broadcastID, guards)
}
func (b *broadcasterWithRK) Close() {
if b.guards != nil {
b.guards.Unlock()
}
}

View File

@ -10,22 +10,21 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var errBroadcastTaskIsNotDone = errors.New("broadcast task is not done")
// newPendingBroadcastTask creates a new pendingBroadcastTask.
func newPendingBroadcastTask(
task *broadcastTask,
) *pendingBroadcastTask {
func newPendingBroadcastTask(task *broadcastTask) *pendingBroadcastTask {
msgs := task.PendingBroadcastMessages()
if len(msgs) == 0 {
return nil
}
return &pendingBroadcastTask{
broadcastTask: task,
pendingMessages: msgs,
appendResult: make(map[string]*types.AppendResult, len(msgs)),
future: syncutil.NewFuture[*types.BroadcastAppendResult](),
BackoffWithInstant: typeutil.NewBackoffWithInstant(typeutil.BackoffTimerConfig{
Default: 10 * time.Second,
Backoff: typeutil.BackoffConfig{
@ -42,8 +41,6 @@ type pendingBroadcastTask struct {
*broadcastTask
pendingMessages []message.MutableMessage
appendResult map[string]*types.AppendResult
future *syncutil.Future[*types.BroadcastAppendResult]
metrics *taskMetricsGuard
*typeutil.BackoffWithInstant
}
@ -53,7 +50,6 @@ type pendingBroadcastTask struct {
func (b *pendingBroadcastTask) Execute(ctx context.Context) error {
if err := b.broadcastTask.InitializeRecovery(ctx); err != nil {
b.Logger().Warn("broadcast task initialize recovery failed", zap.Error(err))
b.UpdateInstantWithNextBackOff()
return err
}
@ -70,17 +66,12 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context) error {
b.appendResult[b.pendingMessages[idx].VChannel()] = resp.AppendResult
}
b.pendingMessages = newPendings
if len(newPendings) == 0 {
b.future.Set(&types.BroadcastAppendResult{
BroadcastID: b.header.BroadcastID,
AppendResults: b.appendResult,
})
}
b.Logger().Info("broadcast task make a new broadcast done", zap.Int("backoffRetryMessages", len(b.pendingMessages)))
}
if len(b.pendingMessages) == 0 {
if err := b.broadcastTask.BroadcastDone(ctx); err != nil {
b.UpdateInstantWithNextBackOff()
// trigger a fast ack operation when the broadcast operation is done.
if err := b.broadcastTask.FastAck(ctx, b.appendResult); err != nil {
b.Logger().Warn("broadcast task save task failed", zap.Error(err))
return err
}
return nil
@ -89,11 +80,6 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context) error {
return errBroadcastTaskIsNotDone
}
// BlockUntilTaskDone blocks until the task is done.
func (b *pendingBroadcastTask) BlockUntilTaskDone(ctx context.Context) (*types.BroadcastAppendResult, error) {
return b.future.GetWithContext(ctx)
}
// pendingBroadcastTaskArray is a heap of pendingBroadcastTask.
type pendingBroadcastTaskArray []*pendingBroadcastTask

View File

@ -13,8 +13,8 @@ import (
// MessageAckCallback is the callback function for the message type.
type (
MessageAckCallback[H proto.Message, B proto.Message] = func(ctx context.Context, params message.SpecializedImmutableMessage[H, B]) error
messageInnerAckCallback = func(ctx context.Context, msgs message.ImmutableMessage) error
MessageAckCallback[H proto.Message, B proto.Message] = func(ctx context.Context, result message.BroadcastResult[H, B]) error
messageInnerAckCallback = func(ctx context.Context, msg message.BroadcastMutableMessage, result map[string]*message.AppendResult) error
)
// messageAckCallbacks is the map of message type to the callback function.
@ -31,15 +31,18 @@ func registerMessageAckCallback[H proto.Message, B proto.Message](callback Messa
// only for test, the register callback should be called once and only once
return
}
future.Set(func(ctx context.Context, msgs message.ImmutableMessage) error {
specializedMsg := message.MustAsSpecializedImmutableMessage[H, B](msgs)
return callback(ctx, specializedMsg)
future.Set(func(ctx context.Context, msgs message.BroadcastMutableMessage, result map[string]*message.AppendResult) error {
return callback(ctx, message.BroadcastResult[H, B]{
Message: message.MustAsSpecializedBroadcastMessage[H, B](msgs),
Results: result,
})
})
}
// CallMessageAckCallback calls the callback function for the message type.
func CallMessageAckCallback(ctx context.Context, msg message.ImmutableMessage) error {
callbackFuture, ok := messageAckCallbacks[msg.MessageTypeWithVersion()]
func CallMessageAckCallback(ctx context.Context, msg message.BroadcastMutableMessage, result map[string]*message.AppendResult) error {
version := msg.MessageTypeWithVersion()
callbackFuture, ok := messageAckCallbacks[version]
if !ok {
// No callback need tobe called, return nil
return nil
@ -48,5 +51,5 @@ func CallMessageAckCallback(ctx context.Context, msg message.ImmutableMessage) e
if err != nil {
return errors.Wrap(err, "when waiting callback registered")
}
return callback(ctx, msg)
return callback(ctx, msg, result)
}

View File

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
)
func TestMessageCallbackRegistration(t *testing.T) {
@ -18,7 +18,7 @@ func TestMessageCallbackRegistration(t *testing.T) {
// Test registering a callback
called := false
callback := func(ctx context.Context, msg message.ImmutableDropPartitionMessageV1) error {
callback := func(ctx context.Context, msg message.BroadcastResultDropPartitionMessageV1) error {
called = true
return nil
}
@ -34,13 +34,17 @@ func TestMessageCallbackRegistration(t *testing.T) {
msg := message.NewDropPartitionMessageBuilderV1().
WithHeader(&message.DropPartitionMessageHeader{}).
WithBody(&message.DropPartitionRequest{}).
WithVChannel("v1").
MustBuildMutable().
WithTimeTick(1).
IntoImmutableMessage(rmq.NewRmqID(1))
WithBroadcast([]string{"v1"}).
MustBuildBroadcast()
// Call the callback
err := CallMessageAckCallback(context.Background(), msg)
err := CallMessageAckCallback(context.Background(), msg, map[string]*message.AppendResult{
"v1": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
})
assert.NoError(t, err)
assert.True(t, called)
@ -48,7 +52,7 @@ func TestMessageCallbackRegistration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
err = CallMessageAckCallback(ctx, msg)
err = CallMessageAckCallback(ctx, msg, nil)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.DeadlineExceeded))
}

View File

@ -0,0 +1,139 @@
package broadcaster
import (
"sort"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/lock"
)
// errFastLockFailed is the error for fast lock failed.
var errFastLockFailed = errors.New("fast lock failed")
// newResourceKeyLocker creates a new resource key locker.
func newResourceKeyLocker(metrics *broadcasterMetrics) *resourceKeyLocker {
return &resourceKeyLocker{
inner: lock.NewKeyLock[resourceLockKey](),
}
}
// newResourceLockKey creates a new resource lock key.
func newResourceLockKey(key message.ResourceKey) resourceLockKey {
return resourceLockKey{
Domain: key.Domain,
Key: key.Key,
}
}
// resourceLockKey is the key for the resource lock.
type resourceLockKey struct {
Domain messagespb.ResourceDomain
Key string
}
// resourceKeyLocker is the locker for the resource keys.
// It's a low performance implementation, but the broadcaster is only used at low frequency of ddl.
// So it's acceptable to use this implementation.
type resourceKeyLocker struct {
inner *lock.KeyLock[resourceLockKey]
}
// lockGuards is the guards for multiple resource keys.
type lockGuards struct {
guards []*lockGuard
}
// ResourceKeys returns the resource keys.
func (l *lockGuards) ResourceKeys() []message.ResourceKey {
keys := make([]message.ResourceKey, 0, len(l.guards))
for _, guard := range l.guards {
keys = append(keys, guard.key)
}
return keys
}
// append appends the guard to the guards.
func (l *lockGuards) append(guard *lockGuard) {
l.guards = append(l.guards, guard)
}
// Unlock unlocks the resource keys.
func (l *lockGuards) Unlock() {
// release the locks in reverse order to avoid deadlock.
for i := len(l.guards) - 1; i >= 0; i-- {
l.guards[i].Unlock()
}
l.guards = nil
}
// lockGuard is the guard for the resource key.
type lockGuard struct {
locker *resourceKeyLocker
key message.ResourceKey
}
// Unlock unlocks the resource key.
func (l *lockGuard) Unlock() {
l.locker.unlockWithKey(l.key)
}
// FastLock locks the resource keys without waiting.
// return error if the resource key is already locked.
func (r *resourceKeyLocker) FastLock(keys ...message.ResourceKey) (*lockGuards, error) {
sortResourceKeys(keys)
g := &lockGuards{}
for _, key := range keys {
var locked bool
if key.Shared {
locked = r.inner.TryRLock(newResourceLockKey(key))
} else {
locked = r.inner.TryLock(newResourceLockKey(key))
}
if locked {
g.append(&lockGuard{locker: r, key: key})
continue
}
g.Unlock()
return nil, errors.Wrapf(errFastLockFailed, "fast lock failed at resource key %s", key.String())
}
return g, nil
}
// Lock locks the resource keys.
func (r *resourceKeyLocker) Lock(keys ...message.ResourceKey) (*lockGuards, error) {
// lock the keys in order to avoid deadlock.
sortResourceKeys(keys)
g := &lockGuards{}
for _, key := range keys {
if key.Shared {
r.inner.RLock(newResourceLockKey(key))
} else {
r.inner.Lock(newResourceLockKey(key))
}
g.append(&lockGuard{locker: r, key: key})
}
return g, nil
}
// unlockWithKey unlocks the resource key.
func (r *resourceKeyLocker) unlockWithKey(key message.ResourceKey) {
if key.Shared {
r.inner.RUnlock(newResourceLockKey(key))
return
}
r.inner.Unlock(newResourceLockKey(key))
}
// sortResourceKeys sorts the resource keys.
func sortResourceKeys(keys []message.ResourceKey) {
sort.Slice(keys, func(i, j int) bool {
if keys[i].Domain != keys[j].Domain {
return keys[i].Domain < keys[j].Domain
}
return keys[i].Key < keys[j].Key
})
}

View File

@ -0,0 +1,147 @@
package broadcaster
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
func TestResourceKeyLocker(t *testing.T) {
t.Run("concurrent lock/unlock", func(t *testing.T) {
locker := newResourceKeyLocker(newBroadcasterMetrics())
const numGoroutines = 10
const numKeys = 5
const numIterations = 100
// Create a set of test keys
keys := make([]message.ResourceKey, numKeys*2)
for i := 0; i < numKeys; i++ {
keys[i] = message.NewExclusiveCollectionNameResourceKey("test", fmt.Sprintf("test_collection_%d", i))
keys[i+numKeys] = message.NewSharedDBNameResourceKey("test")
}
rand.Shuffle(len(keys), func(i, j int) {
keys[i], keys[j] = keys[j], keys[i]
})
// Start multiple goroutines trying to lock/unlock the same keys
done := make(chan bool)
for i := 0; i < numGoroutines; i++ {
go func(id uint64) {
for j := 0; j < numIterations; j++ {
// Try to lock random subset of keys
right := rand.Intn(numKeys)
left := 0
if right > 0 {
left = rand.Intn(right)
}
keysToLock := make([]message.ResourceKey, right-left)
for i := left; i < right; i++ {
keysToLock[i-left] = keys[i]
}
rand.Shuffle(len(keysToLock), func(i, j int) {
keysToLock[i], keysToLock[j] = keysToLock[j], keysToLock[i]
})
n := rand.Intn(10)
if n < 3 {
// Lock the keys
guards, err := locker.Lock(keysToLock...)
if err != nil {
t.Errorf("Failed to lock keys: %v", err)
return
}
// Hold lock briefly
time.Sleep(time.Millisecond)
// Unlock the keys
guards.Unlock()
} else {
guards, err := locker.Lock(keysToLock...)
if err == nil {
guards.Unlock()
}
}
}
done <- true
}(uint64(i))
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
})
t.Run("deadlock prevention", func(t *testing.T) {
locker := newResourceKeyLocker(newBroadcasterMetrics())
key1 := message.NewCollectionNameResourceKey("test_collection_1")
key2 := message.NewCollectionNameResourceKey("test_collection_2")
// Create two goroutines that try to lock resources in different orders
done := make(chan bool)
go func() {
for i := 0; i < 100; i++ {
// Lock key1 then key2
guards, err := locker.Lock(key1, key2)
if err != nil {
t.Errorf("Failed to lock keys in order 1->2: %v", err)
return
}
time.Sleep(time.Millisecond)
guards.Unlock()
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
// Lock key2 then key1
guards, err := locker.Lock(key2, key1)
if err != nil {
t.Errorf("Failed to lock keys in order 2->1: %v", err)
return
}
time.Sleep(time.Millisecond)
guards.Unlock()
}
done <- true
}()
// Wait for both goroutines with timeout
for i := 0; i < 2; i++ {
select {
case <-done:
// Goroutine completed successfully
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected - goroutines did not complete in time")
}
}
})
t.Run("fast lock", func(t *testing.T) {
locker := newResourceKeyLocker(newBroadcasterMetrics())
key := message.NewCollectionNameResourceKey("test_collection")
// First fast lock should succeed
guards1, err := locker.FastLock(key)
if err != nil {
t.Fatalf("First FastLock failed: %v", err)
}
// Second fast lock should fail
_, err = locker.FastLock(key)
if err == nil {
t.Fatal("Second FastLock should have failed")
}
// After unlock, fast lock should succeed again
guards1.Unlock()
guards2, err := locker.FastLock(key)
if err != nil {
t.Fatalf("FastLock after unlock failed: %v", err)
}
guards2.Unlock()
})
}

View File

@ -0,0 +1,124 @@
package broadcaster
import (
"sort"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// tombstoneItem is a tombstone item with expired time.
type tombstoneItem struct {
broadcastID uint64
createTime time.Time // the time when the tombstone is created, when recovery, the createTime will be reset to the current time, but it's ok.
}
// tombstoneScheduler is a scheduler for the tombstone.
type tombstoneScheduler struct {
log.Binder
notifier *syncutil.AsyncTaskNotifier[struct{}]
pending chan uint64
bm *broadcastTaskManager
tombstones []tombstoneItem
}
// newTombstoneScheduler creates a new tombstone scheduler.
func newTombstoneScheduler(logger *log.MLogger) *tombstoneScheduler {
ts := &tombstoneScheduler{
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
pending: make(chan uint64),
}
ts.SetLogger(logger)
return ts
}
// Initialize initializes the tombstone scheduler.
func (s *tombstoneScheduler) Initialize(bm *broadcastTaskManager, tombstoneBroadcastIDs []uint64) {
sort.Slice(tombstoneBroadcastIDs, func(i, j int) bool {
return tombstoneBroadcastIDs[i] < tombstoneBroadcastIDs[j]
})
s.bm = bm
s.tombstones = make([]tombstoneItem, 0, len(tombstoneBroadcastIDs))
for _, broadcastID := range tombstoneBroadcastIDs {
s.tombstones = append(s.tombstones, tombstoneItem{
broadcastID: broadcastID,
createTime: time.Now(),
})
}
go s.background()
}
// AddPending adds a pending tombstone to the scheduler.
func (s *tombstoneScheduler) AddPending(broadcastID uint64) {
select {
case <-s.notifier.Context().Done():
panic("unreachable: tombstone scheduler is closing when adding pending tombstone")
case s.pending <- broadcastID:
}
}
// Close closes the tombstone scheduler.
func (s *tombstoneScheduler) Close() {
s.notifier.Cancel()
s.notifier.BlockUntilFinish()
}
// background is the background goroutine of the tombstone scheduler.
func (s *tombstoneScheduler) background() {
defer func() {
s.notifier.Finish(struct{}{})
s.Logger().Info("tombstone scheduler background exit")
}()
s.Logger().Info("tombstone scheduler background start")
tombstoneGCInterval := paramtable.Get().StreamingCfg.WALBroadcasterTombstoneCheckInternal.GetAsDurationByParse()
ticker := time.NewTicker(tombstoneGCInterval)
defer ticker.Stop()
for {
s.triggerGCTombstone()
select {
case <-s.notifier.Context().Done():
return
case broadcastID := <-s.pending:
s.tombstones = append(s.tombstones, tombstoneItem{
broadcastID: broadcastID,
createTime: time.Now(),
})
case <-ticker.C:
}
}
}
// triggerGCTombstone triggers the garbage collection of the tombstone.
func (s *tombstoneScheduler) triggerGCTombstone() {
maxTombstoneLifetime := paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxLifetime.GetAsDurationByParse()
maxTombstoneCount := paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxCount.GetAsInt()
expiredTime := time.Now().Add(-maxTombstoneLifetime)
expiredOffset := 0
if len(s.tombstones) > maxTombstoneCount {
expiredOffset = len(s.tombstones) - maxTombstoneCount
}
s.Logger().Info("triggerGCTombstone",
zap.Int("tombstone count", len(s.tombstones)),
zap.Int("expired offset", expiredOffset),
zap.Time("expired time", expiredTime))
for idx, tombstone := range s.tombstones {
// drop tombstone until the expired time or until the expired offset.
if idx >= expiredOffset && tombstone.createTime.After(expiredTime) {
s.tombstones = s.tombstones[idx:]
return
}
if err := s.bm.DropTombstone(s.notifier.Context(), tombstone.broadcastID); err != nil {
s.Logger().Error("failed to drop tombstone", zap.Error(err))
s.tombstones = s.tombstones[idx:]
return
}
}
}

View File

@ -4,8 +4,6 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/metastore/kv/streamingcoord"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/streamingcoord/server/service"
"github.com/milvus-io/milvus/internal/types"
@ -52,14 +50,10 @@ func (s *ServerBuilder) Build() *Server {
resource.OptStreamingCatalog(streamingcoord.NewCataLog(s.metaKV)),
resource.OptMixCoordClient(s.mixCoordClient),
)
balancer := syncutil.NewFuture[balancer.Balancer]()
broadcaster := syncutil.NewFuture[broadcaster.Broadcaster]()
return &Server{
logger: resource.Resource().Logger().With(log.FieldComponent("server")),
session: s.session,
assignmentService: service.NewAssignmentService(balancer),
broadcastService: service.NewBroadcastService(broadcaster),
balancer: balancer,
broadcaster: broadcaster,
assignmentService: service.NewAssignmentService(),
broadcastService: service.NewBroadcastService(),
}
}

View File

@ -6,10 +6,11 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/streamingcoord/server/service"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -17,7 +18,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// Server is the streamingcoord server.
@ -30,10 +30,6 @@ type Server struct {
// service level variables.
assignmentService service.AssignmentService
broadcastService service.BroadcastService
// basic component variables can be used at service level.
balancer *syncutil.Future[balancer.Balancer]
broadcaster *syncutil.Future[broadcaster.Broadcaster]
}
// Init initializes the streamingcoord server.
@ -60,8 +56,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) {
s.logger.Warn("recover balancer failed", zap.Error(err))
return struct{}{}, err
}
s.balancer.Set(balancer)
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
balance.Register(balancer)
s.logger.Info("recover balancer done")
return struct{}{}, nil
}))
@ -74,7 +69,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) {
s.logger.Warn("recover broadcaster failed", zap.Error(err))
return struct{}{}, err
}
s.broadcaster.Set(broadcaster)
broadcast.Register(broadcaster)
s.logger.Info("recover broadcaster done")
return struct{}{}, nil
}))
@ -89,18 +84,10 @@ func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) {
// Close closes the streamingcoord server.
func (s *Server) Stop() {
if s.balancer.Ready() {
s.logger.Info("start close balancer...")
s.balancer.Get().Close()
} else {
s.logger.Info("balancer not ready, skip close")
}
if s.broadcaster.Ready() {
balance.Release()
s.logger.Info("start close broadcaster...")
s.broadcaster.Get().Close()
} else {
s.logger.Info("broadcaster not ready, skip close")
}
broadcast.Release()
s.logger.Info("release streamingcoord resource...")
resource.Release()
s.logger.Info("streamingcoord server stopped")

View File

@ -8,7 +8,7 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/service/discover"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -18,17 +18,13 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
var _ streamingpb.StreamingCoordAssignmentServiceServer = (*assignmentServiceImpl)(nil)
// NewAssignmentService returns a new assignment service.
func NewAssignmentService(
balancer *syncutil.Future[balancer.Balancer],
) streamingpb.StreamingCoordAssignmentServiceServer {
func NewAssignmentService() streamingpb.StreamingCoordAssignmentServiceServer {
assignmentService := &assignmentServiceImpl{
balancer: balancer,
listenerTotal: metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()),
}
// TODO: after recovering from wal, add it to here.
@ -44,7 +40,6 @@ type AssignmentService interface {
type assignmentServiceImpl struct {
streamingpb.UnimplementedStreamingCoordAssignmentServiceServer
balancer *syncutil.Future[balancer.Balancer]
listenerTotal prometheus.Gauge
}
@ -53,7 +48,7 @@ func (s *assignmentServiceImpl) AssignmentDiscover(server streamingpb.StreamingC
s.listenerTotal.Inc()
defer s.listenerTotal.Dec()
balancer, err := s.balancer.GetWithContext(server.Context())
balancer, err := balance.GetWithContext(server.Context())
if err != nil {
return err
}
@ -91,7 +86,7 @@ func (s *assignmentServiceImpl) UpdateReplicateConfiguration(ctx context.Context
// validateReplicateConfiguration validates the replicate configuration.
func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Context, config *commonpb.ReplicateConfiguration) (message.BroadcastMutableMessage, error) {
balancer, err := s.balancer.GetWithContext(ctx)
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return nil, err
}
@ -135,7 +130,7 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte
// AlterReplicateConfiguration puts the replicate configuration into the balancer.
// It's a callback function of the broadcast service.
func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error {
balancer, err := s.balancer.GetWithContext(ctx)
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return err
}
@ -144,7 +139,7 @@ func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context,
// UpdateWALBalancePolicy is used to update the WAL balance policy.
func (s *assignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) {
balancer, err := s.balancer.GetWithContext(ctx)
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return nil, err
}

View File

@ -3,10 +3,9 @@ package service
import (
"context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// BroadcastService is the interface of the broadcast service.
@ -15,30 +14,31 @@ type BroadcastService interface {
}
// NewBroadcastService creates a new broadcast service.
func NewBroadcastService(bc *syncutil.Future[broadcaster.Broadcaster]) BroadcastService {
return &broadcastServceImpl{
broadcaster: bc,
}
func NewBroadcastService() BroadcastService {
return &broadcastServceImpl{}
}
// broadcastServiceeeeImpl is the implementation of the broadcast service.
type broadcastServceImpl struct {
broadcaster *syncutil.Future[broadcaster.Broadcaster]
}
type broadcastServceImpl struct{}
// Broadcast broadcasts the message to all channels.
func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.BroadcastRequest) (*streamingpb.BroadcastResponse, error) {
broadcaster, err := s.broadcaster.GetWithContext(ctx)
msg := message.NewBroadcastMutableMessageBeforeAppend(req.Message.Payload, req.Message.Properties)
api, err := broadcast.StartBroadcastWithResourceKeys(ctx, msg.BroadcastHeader().ResourceKeys.Collect()...)
if err != nil {
return nil, err
}
results, err := broadcaster.Broadcast(ctx, message.NewBroadcastMutableMessageBeforeAppend(req.Message.Payload, req.Message.Properties))
results, err := api.Broadcast(ctx, msg)
if err != nil {
return nil, err
}
protoResult := make(map[string]*streamingpb.ProduceMessageResponseResult, len(results.AppendResults))
for vchannel, result := range results.AppendResults {
protoResult[vchannel] = result.IntoProto()
protoResult[vchannel] = &streamingpb.ProduceMessageResponseResult{
Id: result.MessageID.IntoProto(),
Timetick: result.TimeTick,
LastConfirmedId: result.LastConfirmedMessageID.IntoProto(),
}
}
return &streamingpb.BroadcastResponse{
BroadcastId: results.BroadcastID,
@ -48,7 +48,7 @@ func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.Br
// Ack acknowledges the message at the specified vchannel.
func (s *broadcastServceImpl) Ack(ctx context.Context, req *streamingpb.BroadcastAckRequest) (*streamingpb.BroadcastAckResponse, error) {
broadcaster, err := s.broadcaster.GetWithContext(ctx)
broadcaster, err := broadcast.GetWithContext(ctx)
if err != nil {
return nil, err
}

View File

@ -7,10 +7,12 @@ import (
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
@ -18,17 +20,24 @@ import (
func TestBroadcastService(t *testing.T) {
fb := syncutil.NewFuture[broadcaster.Broadcaster]()
mba := mock_broadcaster.NewMockBroadcastAPI(t)
mb := mock_broadcaster.NewMockBroadcaster(t)
fb.Set(mb)
mb.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil)
mba.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil)
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(mba, nil)
mb.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil)
mb.EXPECT().LegacyAck(mock.Anything, mock.Anything, mock.Anything).Return(nil)
service := NewBroadcastService(fb)
broadcast.Register(mb)
msg := message.NewCreateCollectionMessageBuilderV1().
WithHeader(&message.CreateCollectionMessageHeader{}).
WithBody(&msgpb.CreateCollectionRequest{}).
WithBroadcast([]string{"v1"}, message.NewCollectionNameResourceKey("r1")).
MustBuildBroadcast()
service := NewBroadcastService()
service.Broadcast(context.Background(), &streamingpb.BroadcastRequest{
Message: &messagespb.Message{
Payload: []byte("payload"),
Properties: map[string]string{"_bh": "1"},
},
Message: msg.IntoMessageProto(),
})
service.Ack(context.Background(), &streamingpb.BroadcastAckRequest{
BroadcastId: 1,

View File

@ -300,11 +300,16 @@ func (p *producerImpl) recvLoop() (err error) {
case *streamingpb.ProduceMessageResponse_Result:
msgID, err := message.UnmarshalMessageID(produceResp.Result.GetId())
if err != nil {
return err
return errors.Wrap(err, "failed to unmarshal message id")
}
lcMsgID, err := message.UnmarshalMessageID(produceResp.Result.GetLastConfirmedId())
if err != nil {
return errors.Wrap(err, "failed to unmarshal last confirmed message id")
}
result = produceResponse{
result: &types.AppendResult{
MessageID: msgID,
LastConfirmedMessageID: lcMsgID,
TimeTick: produceResp.Result.GetTimetick(),
TxnCtx: message.NewTxnContextFromProto(produceResp.Result.GetTxnContext()),
Extra: produceResp.Result.GetExtra(),

View File

@ -87,6 +87,7 @@ func TestProducer(t *testing.T) {
Response: &streamingpb.ProduceMessageResponse_Result{
Result: &streamingpb.ProduceMessageResponseResult{
Id: walimplstest.NewTestMessageID(1).IntoProto(),
LastConfirmedId: walimplstest.NewTestMessageID(1).IntoProto(),
},
},
},

View File

@ -202,6 +202,7 @@ func TestProduceServerRecvArm(t *testing.T) {
msgID := walimplstest.NewTestMessageID(1)
f(&wal.AppendResult{
MessageID: msgID,
LastConfirmedMessageID: msgID,
TimeTick: 100,
}, nil)
})

View File

@ -199,6 +199,7 @@ func (w *walAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage)
// unwrap the messageID if needed.
r := &wal.AppendResult{
MessageID: messageID,
LastConfirmedMessageID: extraAppendResult.LastConfirmedMessageID,
TimeTick: extraAppendResult.TimeTick,
TxnCtx: extraAppendResult.TxnCtx,
Extra: extra,

View File

@ -45,6 +45,7 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message
ackManager := impl.operator.AckManager()
var txnSession *txn.TxnSession
var immutableMsg message.ImmutableMessage
if msg.MessageType() != message.MessageTypeTimeTick {
// Allocate new timestamp acker for message.
var acker *ack.Acker
@ -69,7 +70,7 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message
return
}
acker.Ack(
ack.OptImmutableMessage(msg.IntoImmutableMessage(msgID)),
ack.OptImmutableMessage(immutableMsg),
ack.OptTxnSession(txnSession),
)
}()
@ -115,8 +116,10 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message
if txnSession != nil {
ctx = txn.WithTxnSession(ctx, txnSession)
}
msgID, err = impl.appendMsg(ctx, msg, append)
return
if immutableMsg, err = impl.appendMsg(ctx, msg, append); err != nil {
return nil, err
}
return immutableMsg.MessageID(), nil
}
// GracefulClose implements InterceptorWithGracefulClose.
@ -207,12 +210,14 @@ func (impl *timeTickAppendInterceptor) appendMsg(
ctx context.Context,
msg message.MutableMessage,
append func(context.Context, message.MutableMessage) (message.MessageID, error),
) (message.MessageID, error) {
) (message.ImmutableMessage, error) {
msgID, err := append(ctx, msg)
if err != nil {
return nil, err
}
utility.ReplaceAppendResultTimeTick(ctx, msg.TimeTick())
utility.ReplaceAppendResultTxnContext(ctx, msg.TxnContext())
return msgID, nil
immutableMsg := msg.IntoImmutableMessage(msgID)
utility.ReplaceAppendResultTimeTick(ctx, immutableMsg.TimeTick())
utility.ReplaceAppendResultLastConfirmedMessageID(ctx, immutableMsg.LastConfirmedMessageID())
utility.ReplaceAppendResultTxnContext(ctx, immutableMsg.TxnContext())
return immutableMsg, nil
}

View File

@ -81,6 +81,7 @@ func (m *AppendMetrics) IntoLogFields() []zap.Field {
fields = append(fields, zap.Error(m.err))
} else {
fields = append(fields, zap.String("messageID", m.result.MessageID.String()))
fields = append(fields, zap.String("lcMessageID", m.result.LastConfirmedMessageID.String()))
fields = append(fields, zap.Uint64("timetick", m.result.TimeTick))
if m.result.TxnCtx != nil {
fields = append(fields, zap.Int64("txnID", int64(m.result.TxnCtx.TxnID)))

View File

@ -25,6 +25,7 @@ type ExtraAppendResult struct {
TimeTick uint64
TxnCtx *message.TxnContext
Extra protoreflect.ProtoMessage
LastConfirmedMessageID message.MessageID
}
// NotPersistedHint is the hint of not persisted message.
@ -66,6 +67,12 @@ func ModifyAppendResultExtra[M protoreflect.ProtoMessage](ctx context.Context, m
result.(*ExtraAppendResult).Extra = new
}
// ReplaceAppendResultLastConfirmedMessageID set last confirmed message id to context
func ReplaceAppendResultLastConfirmedMessageID(ctx context.Context, lastConfirmedMessageID message.MessageID) {
result := ctx.Value(extraAppendResultValue)
result.(*ExtraAppendResult).LastConfirmedMessageID = lastConfirmedMessageID
}
// ReplaceAppendResultTimeTick set time tick to context
func ReplaceAppendResultTimeTick(ctx context.Context, timeTick uint64) {
result := ctx.Value(extraAppendResultValue)

View File

@ -83,6 +83,18 @@ func TestReplaceAppendResultTxnContext(t *testing.T) {
assert.Equal(t, retrievedResult.TxnCtx.TxnID, newTxnCtx.TxnID)
}
func TestReplaceAppendResultLastConfirmedMessageID(t *testing.T) {
ctx := context.Background()
result := &ExtraAppendResult{LastConfirmedMessageID: walimplstest.NewTestMessageID(1)}
ctx = WithExtraAppendResult(ctx, result)
newLastConfirmedMessageID := walimplstest.NewTestMessageID(2)
ReplaceAppendResultLastConfirmedMessageID(ctx, newLastConfirmedMessageID)
retrievedResult := ctx.Value(extraAppendResultValue).(*ExtraAppendResult)
assert.True(t, retrievedResult.LastConfirmedMessageID.EQ(newLastConfirmedMessageID))
}
func TestWithFlushFromOldArch(t *testing.T) {
ctx := context.Background()
assert.False(t, GetFlushFromOldArch(ctx))

View File

@ -3,6 +3,7 @@ package testutil
import (
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
registry2 "github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry"
)
@ -12,4 +13,5 @@ func ResetEnvironment() {
registry.ResetRegistration()
snmanager.ResetStreamingNodeManager()
registry2.ResetRegisterLocalWALManager()
broadcast.ResetBroadcaster()
}

View File

@ -0,0 +1,104 @@
package kv
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/kv/predicates"
"github.com/milvus-io/milvus/pkg/v2/log"
)
var _ MetaKv = (*ReliableWriteMetaKv)(nil)
// NewReliableWriteMetaKv returns a new ReliableWriteMetaKv if the kv is not a ReliableWriteMetaKv.
func NewReliableWriteMetaKv(kv MetaKv) MetaKv {
if _, ok := kv.(*ReliableWriteMetaKv); ok {
return kv
}
return &ReliableWriteMetaKv{
Binder: log.Binder{},
MetaKv: kv,
}
}
// ReliableWriteMetaKv is a wrapper of MetaKv that ensures the data is written reliably.
// It will retry the metawrite operation until the data is written successfully or the context is timeout.
// It's useful to promise the meta data is consistent in memory and underlying meta storage.
type ReliableWriteMetaKv struct {
log.Binder
MetaKv
}
func (kv *ReliableWriteMetaKv) Save(ctx context.Context, key, value string) error {
return kv.retryWithBackoff(ctx, func(ctx context.Context) error {
return kv.MetaKv.Save(ctx, key, value)
})
}
func (kv *ReliableWriteMetaKv) MultiSave(ctx context.Context, kvs map[string]string) error {
return kv.retryWithBackoff(ctx, func(ctx context.Context) error {
return kv.MetaKv.MultiSave(ctx, kvs)
})
}
func (kv *ReliableWriteMetaKv) Remove(ctx context.Context, key string) error {
return kv.retryWithBackoff(ctx, func(ctx context.Context) error {
return kv.MetaKv.Remove(ctx, key)
})
}
func (kv *ReliableWriteMetaKv) MultiRemove(ctx context.Context, keys []string) error {
return kv.retryWithBackoff(ctx, func(ctx context.Context) error {
return kv.MetaKv.MultiRemove(ctx, keys)
})
}
func (kv *ReliableWriteMetaKv) MultiSaveAndRemove(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error {
return kv.retryWithBackoff(ctx, func(ctx context.Context) error {
return kv.MetaKv.MultiSaveAndRemove(ctx, saves, removals, preds...)
})
}
func (kv *ReliableWriteMetaKv) MultiSaveAndRemoveWithPrefix(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error {
return kv.retryWithBackoff(ctx, func(ctx context.Context) error {
return kv.MetaKv.MultiSaveAndRemoveWithPrefix(ctx, saves, removals, preds...)
})
}
func (kv *ReliableWriteMetaKv) CompareVersionAndSwap(ctx context.Context, key string, version int64, target string) (bool, error) {
var result bool
err := kv.retryWithBackoff(ctx, func(ctx context.Context) error {
var err error
result, err = kv.MetaKv.CompareVersionAndSwap(ctx, key, version, target)
return err
})
return result, err
}
// retryWithBackoff retries the function with backoff.
func (kv *ReliableWriteMetaKv) retryWithBackoff(ctx context.Context, fn func(ctx context.Context) error) error {
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = 10 * time.Millisecond
backoff.MaxInterval = 1 * time.Second
backoff.MaxElapsedTime = 0
backoff.Reset()
for {
err := fn(ctx)
if err == nil {
return nil
}
if ctx.Err() != nil {
return ctx.Err()
}
nextInterval := backoff.NextBackOff()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(nextInterval):
kv.Logger().Warn("failed to persist operation, wait for retry...", zap.Duration("nextRetryInterval", nextInterval), zap.Error(err))
}
}
}

View File

@ -0,0 +1,127 @@
package kv
import (
"context"
"sync"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/v2/kv/predicates"
"github.com/milvus-io/milvus/pkg/v2/mocks/mock_kv"
)
func TestReliableWriteMetaKv(t *testing.T) {
kv := mock_kv.NewMockMetaKv(t)
fail := atomic.NewBool(true)
kv.EXPECT().Save(context.TODO(), mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s1, s2 string) error {
if !fail.Load() {
return nil
}
return errors.New("test")
})
kv.EXPECT().MultiSave(context.TODO(), mock.Anything).RunAndReturn(func(ctx context.Context, kvs map[string]string) error {
if !fail.Load() {
return nil
}
return errors.New("test")
})
kv.EXPECT().Remove(context.TODO(), mock.Anything).RunAndReturn(func(ctx context.Context, key string) error {
if !fail.Load() {
return nil
}
return errors.New("test")
})
kv.EXPECT().MultiRemove(context.TODO(), mock.Anything).RunAndReturn(func(ctx context.Context, keys []string) error {
if !fail.Load() {
return nil
}
return errors.New("test")
})
kv.EXPECT().MultiSaveAndRemove(context.TODO(), mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error {
if !fail.Load() {
return nil
}
return errors.New("test")
})
kv.EXPECT().MultiSaveAndRemoveWithPrefix(context.TODO(), mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error {
if !fail.Load() {
return nil
}
return errors.New("test")
})
kv.EXPECT().CompareVersionAndSwap(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key string, version int64, target string) (bool, error) {
if !fail.Load() {
return false, nil
}
return false, errors.New("test")
})
rkv := NewReliableWriteMetaKv(kv)
wg := sync.WaitGroup{}
wg.Add(7)
success := atomic.NewInt32(0)
go func() {
defer wg.Done()
err := rkv.Save(context.TODO(), "test", "test")
if err == nil {
success.Add(1)
}
}()
go func() {
defer wg.Done()
err := rkv.MultiSave(context.TODO(), map[string]string{"test": "test"})
if err == nil {
success.Add(1)
}
}()
go func() {
defer wg.Done()
err := rkv.Remove(context.TODO(), "test")
if err == nil {
success.Add(1)
}
}()
go func() {
defer wg.Done()
err := rkv.MultiRemove(context.TODO(), []string{"test"})
if err == nil {
success.Add(1)
}
}()
go func() {
defer wg.Done()
err := rkv.MultiSaveAndRemove(context.TODO(), map[string]string{"test": "test"}, []string{"test"})
if err == nil {
success.Add(1)
}
}()
go func() {
defer wg.Done()
err := rkv.MultiSaveAndRemoveWithPrefix(context.TODO(), map[string]string{"test": "test"}, []string{"test"})
if err == nil {
success.Add(1)
}
}()
go func() {
defer wg.Done()
_, err := rkv.CompareVersionAndSwap(context.TODO(), "test", 0, "test")
if err == nil {
success.Add(1)
}
}()
time.Sleep(1 * time.Second)
fail.Store(false)
wg.Wait()
assert.Equal(t, int32(7), success.Load())
fail.Store(true)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := rkv.CompareVersionAndSwap(ctx, "test", 0, "test")
assert.ErrorIs(t, err, context.DeadlineExceeded)
}

View File

@ -44,6 +44,17 @@ func (br *BroadcastResult[H, B]) GetControlChannelResult() *AppendResult {
return nil
}
// GetVChannelsWithoutControlChannel returns the vchannels without control channel.
func (br *BroadcastResult[H, B]) GetVChannelsWithoutControlChannel() []string {
vchannels := make([]string, 0, len(br.Results))
for vchannel := range br.Results {
if !funcutil.IsControlChannel(vchannel) {
vchannels = append(vchannels, vchannel)
}
}
return vchannels
}
// AppendResult is the result of append operation.
type AppendResult struct {
MessageID MessageID

View File

@ -0,0 +1,23 @@
package message
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
)
func TestBroadcastResult(t *testing.T) {
r := BroadcastResult[*CreateDatabaseMessageHeader, *CreateDatabaseMessageBody]{
Message: nil,
Results: map[string]*AppendResult{
"v1": {},
"v2": {},
"abc" + funcutil.ControlChannelSuffix: {},
},
}
assert.ElementsMatch(t, []string{"v1", "v2"}, r.GetVChannelsWithoutControlChannel())
assert.NotNil(t, r.GetControlChannelResult())
}

View File

@ -24,6 +24,10 @@ type AppendResult struct {
// MessageID is generated by underlying walimpls.
MessageID message.MessageID
// LastConfirmedMessageID is the last confirmed message id.
// From these message id, the reader can read all the messages which timetick is greater than the TimeTick in response.
LastConfirmedMessageID message.MessageID
// TimeTick is the time tick of the message.
// Set by timetick interceptor.
TimeTick uint64
@ -51,6 +55,7 @@ func (r *AppendResult) IntoProto() *streamingpb.ProduceMessageResponseResult {
Timetick: r.TimeTick,
TxnContext: r.TxnCtx.IntoProto(),
Extra: r.Extra,
LastConfirmedId: r.LastConfirmedMessageID.IntoProto(),
}
}

View File

@ -45,6 +45,7 @@ func TestAppendResult_IntoProto(t *testing.T) {
MessageID: msgID,
TimeTick: 12345,
TxnCtx: &message.TxnContext{TxnID: 1},
LastConfirmedMessageID: msgID,
}
protoResult := result.IntoProto()
@ -52,6 +53,7 @@ func TestAppendResult_IntoProto(t *testing.T) {
assert.Equal(t, "1", protoResult.Id.Id)
assert.Equal(t, uint64(12345), protoResult.Timetick)
assert.Equal(t, int64(1), protoResult.TxnContext.TxnId)
assert.Equal(t, "1", protoResult.LastConfirmedId.Id)
}
func TestAppendResponses_MaxTimeTick(t *testing.T) {

View File

@ -109,6 +109,47 @@ func (k *KeyLock[K]) Lock(key K) {
}
}
func (k *KeyLock[K]) TryLock(key K) bool {
k.keyLocksMutex.Lock()
// update the key map
if keyLock, ok := k.refLocks[key]; ok {
keyLock.ref()
k.keyLocksMutex.Unlock()
locked := keyLock.mutex.TryLock()
if !locked {
k.keyLocksMutex.Lock()
keyLock.unref()
if keyLock.refCounter == 0 {
_ = refLockPoolPool.ReturnObject(ctx, keyLock)
delete(k.refLocks, key)
}
k.keyLocksMutex.Unlock()
}
return locked
} else {
obj, err := refLockPoolPool.BorrowObject(ctx)
if err != nil {
log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err))
k.keyLocksMutex.Unlock()
return false
}
newKLock := obj.(*RefLock)
// newKLock := newRefLock()
locked := newKLock.mutex.TryLock()
if !locked {
_ = refLockPoolPool.ReturnObject(ctx, newKLock)
k.keyLocksMutex.Unlock()
return false
}
k.refLocks[key] = newKLock
newKLock.ref()
k.keyLocksMutex.Unlock()
return true
}
}
func (k *KeyLock[K]) Unlock(lockedKey K) {
k.keyLocksMutex.Lock()
defer k.keyLocksMutex.Unlock()
@ -151,6 +192,47 @@ func (k *KeyLock[K]) RLock(key K) {
}
}
func (k *KeyLock[K]) TryRLock(key K) bool {
k.keyLocksMutex.Lock()
// update the key map
if keyLock, ok := k.refLocks[key]; ok {
keyLock.ref()
k.keyLocksMutex.Unlock()
locked := keyLock.mutex.TryRLock()
if !locked {
k.keyLocksMutex.Lock()
keyLock.unref()
if keyLock.refCounter == 0 {
_ = refLockPoolPool.ReturnObject(ctx, keyLock)
delete(k.refLocks, key)
}
k.keyLocksMutex.Unlock()
}
return locked
} else {
obj, err := refLockPoolPool.BorrowObject(ctx)
if err != nil {
log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err))
k.keyLocksMutex.Unlock()
return false
}
newKLock := obj.(*RefLock)
// newKLock := newRefLock()
locked := newKLock.mutex.TryRLock()
if !locked {
_ = refLockPoolPool.ReturnObject(ctx, newKLock)
k.keyLocksMutex.Unlock()
return false
}
k.refLocks[key] = newKLock
newKLock.ref()
k.keyLocksMutex.Unlock()
return true
}
}
func (k *KeyLock[K]) RUnlock(lockedKey K) {
k.keyLocksMutex.Lock()
defer k.keyLocksMutex.Unlock()

View File

@ -82,3 +82,50 @@ func TestNewKeyLock(t *testing.T) {
keyLock.keyLocksMutex.Unlock()
assert.Equal(t, 0, keyLen)
}
func TestKeyLockTryLock(t *testing.T) {
keyLock := NewKeyLock[string]()
ok := keyLock.TryLock("a")
assert.True(t, ok)
ok = keyLock.TryLock("b")
assert.True(t, ok)
ok = keyLock.TryLock("a")
assert.False(t, ok)
ok = keyLock.TryLock("b")
assert.False(t, ok)
ok = keyLock.TryRLock("a")
assert.False(t, ok)
ok = keyLock.TryRLock("b")
assert.False(t, ok)
assert.Equal(t, 2, keyLock.size())
keyLock.Unlock("a")
keyLock.Unlock("b")
assert.Zero(t, keyLock.size())
ok = keyLock.TryRLock("a")
assert.True(t, ok)
ok = keyLock.TryRLock("b")
assert.True(t, ok)
ok = keyLock.TryLock("a")
assert.False(t, ok)
ok = keyLock.TryLock("b")
assert.False(t, ok)
ok = keyLock.TryRLock("a")
assert.True(t, ok)
ok = keyLock.TryRLock("b")
assert.True(t, ok)
assert.Equal(t, 2, keyLock.size())
keyLock.RUnlock("a")
keyLock.RUnlock("b")
assert.Equal(t, 2, keyLock.size())
keyLock.RUnlock("a")
keyLock.RUnlock("b")
assert.Equal(t, 0, keyLock.size())
}

View File

@ -6148,6 +6148,9 @@ type streamingConfig struct {
// broadcaster
WALBroadcasterConcurrencyRatio ParamItem `refreshable:"false"`
WALBroadcasterTombstoneCheckInternal ParamItem `refreshable:"true"`
WALBroadcasterTombstoneMaxCount ParamItem `refreshable:"true"`
WALBroadcasterTombstoneMaxLifetime ParamItem `refreshable:"true"`
// txn
TxnDefaultKeepaliveTimeout ParamItem `refreshable:"true"`
@ -6327,6 +6330,39 @@ it also determine the depth of depth first search method that is used to find th
}
p.WALBroadcasterConcurrencyRatio.Init(base.mgr)
p.WALBroadcasterTombstoneCheckInternal = ParamItem{
Key: "streaming.walBroadcaster.tombstone.checkInternal",
Version: "2.6.0",
Doc: `The interval of garbage collection of tombstone, 5m by default.
Tombstone is used to reject duplicate submissions of DDL messages,
too few tombstones may lead to ABA issues in the state of milvus cluster.`,
DefaultValue: "5m",
Export: false,
}
p.WALBroadcasterTombstoneCheckInternal.Init(base.mgr)
p.WALBroadcasterTombstoneMaxCount = ParamItem{
Key: "streaming.walBroadcaster.tombstone.maxCount",
Version: "2.6.0",
Doc: `The max count of tombstone, 256 by default.
Tombstone is used to reject duplicate submissions of DDL messages,
too few tombstones may lead to ABA issues in the state of milvus cluster.`,
DefaultValue: "256",
Export: false,
}
p.WALBroadcasterTombstoneMaxCount.Init(base.mgr)
p.WALBroadcasterTombstoneMaxLifetime = ParamItem{
Key: "streaming.walBroadcaster.tombstone.maxLifetime",
Version: "2.6.0",
Doc: `The max lifetime of tombstone, 30m by default.
Tombstone is used to reject duplicate submissions of DDL messages,
too few tombstones may lead to ABA issues in the state of milvus cluster.`,
DefaultValue: "30m",
Export: false,
}
p.WALBroadcasterTombstoneMaxLifetime.Init(base.mgr)
// txn
p.TxnDefaultKeepaliveTimeout = ParamItem{
Key: "streaming.txn.defaultKeepaliveTimeout",

View File

@ -665,6 +665,9 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, 3, params.StreamingCfg.WALBalancerPolicyVChannelFairRebalanceMaxStep.GetAsInt())
assert.Equal(t, 30*time.Second, params.StreamingCfg.WALBalancerOperationTimeout.GetAsDurationByParse())
assert.Equal(t, 1.0, params.StreamingCfg.WALBroadcasterConcurrencyRatio.GetAsFloat())
assert.Equal(t, 5*time.Minute, params.StreamingCfg.WALBroadcasterTombstoneCheckInternal.GetAsDurationByParse())
assert.Equal(t, 256, params.StreamingCfg.WALBroadcasterTombstoneMaxCount.GetAsInt())
assert.Equal(t, 30*time.Minute, params.StreamingCfg.WALBroadcasterTombstoneMaxLifetime.GetAsDurationByParse())
assert.Equal(t, 10*time.Second, params.StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse())
assert.Equal(t, 30*time.Second, params.StreamingCfg.WALWriteAheadBufferKeepalive.GetAsDurationByParse())
assert.Equal(t, int64(64*1024*1024), params.StreamingCfg.WALWriteAheadBufferCapacity.GetAsSize())