diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index 8af207d675..23b687b757 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -20,6 +20,7 @@ packages: github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster: interfaces: Broadcaster: + BroadcastAPI: AppendOperator: Watcher: github.com/milvus-io/milvus/internal/streamingcoord/client: diff --git a/internal/coordinator/snmanager/streaming_node_manager.go b/internal/coordinator/snmanager/streaming_node_manager.go index 60c3dad50c..74200f2dd0 100644 --- a/internal/coordinator/snmanager/streaming_node_manager.go +++ b/internal/coordinator/snmanager/streaming_node_manager.go @@ -8,6 +8,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" @@ -23,7 +24,6 @@ var ErrStreamingServiceNotReady = errors.New("streaming service is not ready, ma func newStreamingNodeManager() *StreamingNodeManager { snm := &StreamingNodeManager{ notifier: syncutil.NewAsyncTaskNotifier[struct{}](), - balancer: syncutil.NewFuture[balancer.Balancer](), cond: syncutil.NewContextCond(&sync.Mutex{}), latestAssignments: make(map[string]types.PChannelInfoAssigned), nodeChangedNotifier: syncutil.NewVersionedNotifier(), @@ -63,9 +63,7 @@ func (s *StreamingReadyNotifier) IsReady() bool { // StreamingNodeManager is a manager for manage the querynode that embedded into streaming node. // StreamingNodeManager is exclusive with ResourceManager. type StreamingNodeManager struct { - notifier *syncutil.AsyncTaskNotifier[struct{}] - balancer *syncutil.Future[balancer.Balancer] - // The coord is merged after 2.6, so we don't need to make distribution safe. + notifier *syncutil.AsyncTaskNotifier[struct{}] cond *syncutil.ContextCond latestAssignments map[string]types.PChannelInfoAssigned // The latest assignments info got from streaming coord balance module. nodeChangedNotifier *syncutil.VersionedNotifier // used to notify that node in streaming node manager has been changed. @@ -73,14 +71,18 @@ type StreamingNodeManager struct { // GetBalancer returns the balancer of the streaming node manager. func (s *StreamingNodeManager) GetBalancer() balancer.Balancer { - return s.balancer.Get() + b, err := balance.GetWithContext(context.Background()) + if err != nil { + panic(err) + } + return b } // GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located. // Return -1 and error if the vchannel is not found or context is canceled. func (s *StreamingNodeManager) GetLatestWALLocated(ctx context.Context, vchannel string) (int64, error) { pchannel := funcutil.ToPhysicalChannel(vchannel) - balancer, err := s.balancer.GetWithContext(ctx) + balancer, err := balance.GetWithContext(ctx) if err != nil { return -1, err } @@ -107,7 +109,7 @@ func (s *StreamingNodeManager) CheckIfStreamingServiceReady(ctx context.Context) // RegisterStreamingEnabledNotifier registers a notifier into the balancer. func (s *StreamingNodeManager) RegisterStreamingEnabledListener(ctx context.Context, notifier *StreamingReadyNotifier) error { - balancer, err := s.balancer.GetWithContext(ctx) + balancer, err := balance.GetWithContext(ctx) if err != nil { return err } @@ -134,7 +136,7 @@ func (s *StreamingNodeManager) GetWALLocated(vChannel string) int64 { // GetStreamingQueryNodeIDs returns the server ids of the streaming query nodes. func (s *StreamingNodeManager) GetStreamingQueryNodeIDs() typeutil.UniqueSet { - balancer, err := s.balancer.GetWithContext(context.Background()) + balancer, err := balance.GetWithContext(context.Background()) if err != nil { panic(err) } @@ -154,15 +156,10 @@ func (s *StreamingNodeManager) ListenNodeChanged() *syncutil.VersionedListener { return s.nodeChangedNotifier.Listen(syncutil.VersionedListenAtEarliest) } -// SetBalancerReady set the balancer ready for the streaming node manager from streamingcoord initialization. -func (s *StreamingNodeManager) SetBalancerReady(b balancer.Balancer) { - s.balancer.Set(b) -} - func (s *StreamingNodeManager) execute() (err error) { defer s.notifier.Finish(struct{}{}) - b, err := s.balancer.GetWithContext(s.notifier.Context()) + b, err := balance.GetWithContext(s.notifier.Context()) if err != nil { return errors.Wrap(err, "failed to wait balancer ready") } @@ -182,3 +179,8 @@ func (s *StreamingNodeManager) execute() (err error) { } } } + +func (s *StreamingNodeManager) Close() { + s.notifier.Cancel() + s.notifier.BlockUntilFinish() +} diff --git a/internal/coordinator/snmanager/streaming_node_manager_test.go b/internal/coordinator/snmanager/streaming_node_manager_test.go index e9a75a90bf..557bd6f5d4 100644 --- a/internal/coordinator/snmanager/streaming_node_manager_test.go +++ b/internal/coordinator/snmanager/streaming_node_manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" @@ -20,17 +21,18 @@ type pChannelInfoAssigned struct { } func TestStreamingNodeManager(t *testing.T) { + StaticStreamingNodeManager.Close() m := newStreamingNodeManager() b := mock_balancer.NewMockBalancer(t) ch := make(chan pChannelInfoAssigned, 1) b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil) - b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Run( - func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) { + b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error { for { select { case <-ctx.Done(): - return + return ctx.Err() case p := <-ch: cb(balancer.WatchChannelAssignmentsCallbackParam{ Version: p.version, @@ -41,7 +43,7 @@ func TestStreamingNodeManager(t *testing.T) { } }) b.EXPECT().RegisterStreamingEnabledNotifier(mock.Anything).Return() - m.SetBalancerReady(b) + balance.Register(b) streamingNodes := m.GetStreamingQueryNodeIDs() assert.Empty(t, streamingNodes) diff --git a/internal/coordinator/snmanager/test_utility.go b/internal/coordinator/snmanager/test_utility.go index d2db679747..3a27caee7e 100644 --- a/internal/coordinator/snmanager/test_utility.go +++ b/internal/coordinator/snmanager/test_utility.go @@ -11,10 +11,13 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" ) func ResetStreamingNodeManager() { + StaticStreamingNodeManager.Close() + balance.ResetBalancer() StaticStreamingNodeManager = newStreamingNodeManager() } @@ -26,5 +29,5 @@ func ResetDoNothingStreamingNodeManager(t *testing.T) { return ctx.Err() }).Maybe() b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil).Maybe() - StaticStreamingNodeManager.SetBalancerReady(b) + balance.Register(b) } diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 4ab7c19cc3..5eb252b2cb 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -353,17 +353,24 @@ func (s *Server) initDataCoord() error { // initMessageCallback initializes the message callback. // TODO: we should build a ddl framework to handle the message ack callback for ddl messages func (s *Server) initMessageCallback() { - registry.RegisterDropPartitionV1AckCallback(func(ctx context.Context, msg message.ImmutableDropPartitionMessageV1) error { - return s.NotifyDropPartition(ctx, msg.VChannel(), []int64{msg.Header().PartitionId}) + registry.RegisterDropPartitionV1AckCallback(func(ctx context.Context, result message.BroadcastResultDropPartitionMessageV1) error { + partitionID := result.Message.Header().PartitionId + for _, vchannel := range result.GetVChannelsWithoutControlChannel() { + if err := s.NotifyDropPartition(ctx, vchannel, []int64{partitionID}); err != nil { + return err + } + } + return nil }) - registry.RegisterImportV1AckCallback(func(ctx context.Context, msg message.ImmutableImportMessageV1) error { - body := msg.MustBody() + registry.RegisterImportV1AckCallback(func(ctx context.Context, result message.BroadcastResultImportMessageV1) error { + body := result.Message.MustBody() + vchannels := result.GetVChannelsWithoutControlChannel() importResp, err := s.ImportV2(ctx, &internalpb.ImportRequestInternal{ CollectionID: body.GetCollectionID(), CollectionName: body.GetCollectionName(), PartitionIDs: body.GetPartitionIDs(), - ChannelNames: []string{msg.VChannel()}, + ChannelNames: vchannels, Schema: body.GetSchema(), Files: lo.Map(body.GetFiles(), func(file *msgpb.ImportFile, _ int) *internalpb.ImportFile { return &internalpb.ImportFile{ diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 60a4019044..226b5c5857 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -61,7 +61,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/workerpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" @@ -2887,8 +2887,7 @@ func TestServer_InitMessageCallback(t *testing.T) { server.initMessageCallback() // Test DropPartition message callback - dropPartitionMsg, err := message.NewDropPartitionMessageBuilderV1(). - WithVChannel("test_channel"). + dropPartitionMsg := message.NewDropPartitionMessageBuilderV1(). WithHeader(&message.DropPartitionMessageHeader{ CollectionId: 1, PartitionId: 1, @@ -2898,9 +2897,15 @@ func TestServer_InitMessageCallback(t *testing.T) { MsgType: commonpb.MsgType_DropPartition, }, }). - BuildMutable() - assert.NoError(t, err) - err = registry.CallMessageAckCallback(ctx, dropPartitionMsg.IntoImmutableMessage(rmq.NewRmqID(1))) + WithBroadcast([]string{"test_channel"}, message.NewImportJobIDResourceKey(1)). + MustBuildBroadcast() + err := registry.CallMessageAckCallback(ctx, dropPartitionMsg, map[string]*message.AppendResult{ + "test_channel": { + MessageID: walimplstest.NewTestMessageID(1), + LastConfirmedMessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + }, + }) assert.Error(t, err) // server not healthy // Test Import message check callback @@ -2918,16 +2923,22 @@ func TestServer_InitMessageCallback(t *testing.T) { assert.NoError(t, err) // Test Import message ack callback - importMsg, err := message.NewImportMessageBuilderV1(). - WithVChannel("test_channel"). + importMsg := message.NewImportMessageBuilderV1(). WithHeader(&message.ImportMessageHeader{}). WithBody(&msgpb.ImportMsg{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Import, }, }). - BuildMutable() - assert.NoError(t, err) - err = registry.CallMessageAckCallback(ctx, importMsg.IntoImmutableMessage(rmq.NewRmqID(1))) + WithBroadcast([]string{"test_channel"}, resourceKey). + MustBuildBroadcast() + err = registry.CallMessageAckCallback(ctx, importMsg, map[string]*message.AppendResult{ + "test_channel": { + MessageID: walimplstest.NewTestMessageID(1), + LastConfirmedMessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + }, + }, + ) assert.Error(t, err) // server not healthy } diff --git a/internal/distributed/streaming/balancer_test.go b/internal/distributed/streaming/balancer_test.go index e0038ae992..ac4a8074db 100644 --- a/internal/distributed/streaming/balancer_test.go +++ b/internal/distributed/streaming/balancer_test.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -54,7 +55,7 @@ func TestBalancer(t *testing.T) { }) snmanager.ResetStreamingNodeManager() - snmanager.StaticStreamingNodeManager.SetBalancerReady(sbalancer) + balance.Register(sbalancer) balancer := balancerImpl{ walAccesserImpl: &walAccesserImpl{}, diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go index e38198b0ae..749defd3f0 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -36,7 +36,7 @@ import ( // │   └── cluster-2-pchannel-2 func NewCataLog(metaKV kv.MetaKv) metastore.StreamingCoordCataLog { return &catalog{ - metaKV: metaKV, + metaKV: kv.NewReliableWriteMetaKv(metaKV), } } diff --git a/internal/metastore/kv/streamingcoord/kv_catalog_test.go b/internal/metastore/kv/streamingcoord/kv_catalog_test.go index 9c4bcd136f..f72cd945cd 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog_test.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog_test.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/v2/mocks/mock_kv" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/v2/util/merr" ) func TestCatalog(t *testing.T) { @@ -128,18 +129,6 @@ func TestCatalog(t *testing.T) { tasks, err = catalog.ListBroadcastTask(context.Background()) assert.Error(t, err) assert.Nil(t, tasks) - - kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Unset() - kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(errors.New("save error")) - kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Unset() - kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("save error")) - err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{{ - Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1}, - Node: &streamingpb.StreamingNodeInfo{ServerId: 1}, - }}) - assert.Error(t, err) - err = catalog.SaveBroadcastTask(context.Background(), 1, &streamingpb.BroadcastTask{}) - assert.Error(t, err) } func TestCatalog_ReplicationCatalog(t *testing.T) { @@ -255,4 +244,11 @@ func TestCatalog_ReplicationCatalog(t *testing.T) { assert.Equal(t, infos[0].GetSourceChannelName(), "source-channel-2") assert.Equal(t, infos[0].GetTargetChannelName(), "target-channel-2") assert.Equal(t, infos[0].GetTargetCluster().GetClusterId(), "target-cluster") + + kv.EXPECT().Load(mock.Anything, mock.Anything).Unset() + kv.EXPECT().Load(mock.Anything, mock.Anything).Return("", merr.ErrIoKeyNotFound) + + cfg, err = catalog.GetReplicateConfiguration(context.Background()) + assert.NoError(t, err) + assert.Nil(t, cfg) } diff --git a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go index 91669b10ee..4e0029881c 100644 --- a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go +++ b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go @@ -11,6 +11,8 @@ import ( mock "github.com/stretchr/testify/mock" + replicateutil "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" + streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" syncutil "github.com/milvus-io/milvus/pkg/v2/util/syncutil" @@ -315,6 +317,51 @@ func (_c *MockBalancer_RegisterStreamingEnabledNotifier_Call) RunAndReturn(run f return _c } +// ReplicateRole provides a mock function with no fields +func (_m *MockBalancer) ReplicateRole() replicateutil.Role { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ReplicateRole") + } + + var r0 replicateutil.Role + if rf, ok := ret.Get(0).(func() replicateutil.Role); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(replicateutil.Role) + } + + return r0 +} + +// MockBalancer_ReplicateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReplicateRole' +type MockBalancer_ReplicateRole_Call struct { + *mock.Call +} + +// ReplicateRole is a helper method to define mock.On call +func (_e *MockBalancer_Expecter) ReplicateRole() *MockBalancer_ReplicateRole_Call { + return &MockBalancer_ReplicateRole_Call{Call: _e.mock.On("ReplicateRole")} +} + +func (_c *MockBalancer_ReplicateRole_Call) Run(run func()) *MockBalancer_ReplicateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBalancer_ReplicateRole_Call) Return(_a0 replicateutil.Role) *MockBalancer_ReplicateRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBalancer_ReplicateRole_Call) RunAndReturn(run func() replicateutil.Role) *MockBalancer_ReplicateRole_Call { + _c.Call.Return(run) + return _c +} + // Trigger provides a mock function with given fields: ctx func (_m *MockBalancer) Trigger(ctx context.Context) error { ret := _m.Called(ctx) diff --git a/internal/mocks/streamingcoord/server/mock_broadcaster/mock_BroadcastAPI.go b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_BroadcastAPI.go new file mode 100644 index 0000000000..f9e90c37d4 --- /dev/null +++ b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_BroadcastAPI.go @@ -0,0 +1,130 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mock_broadcaster + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" +) + +// MockBroadcastAPI is an autogenerated mock type for the BroadcastAPI type +type MockBroadcastAPI struct { + mock.Mock +} + +type MockBroadcastAPI_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBroadcastAPI) EXPECT() *MockBroadcastAPI_Expecter { + return &MockBroadcastAPI_Expecter{mock: &_m.Mock} +} + +// Broadcast provides a mock function with given fields: ctx, msg +func (_m *MockBroadcastAPI) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for Broadcast") + } + + var r0 *types.BroadcastAppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BroadcastAppendResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroadcastAPI_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast' +type MockBroadcastAPI_Broadcast_Call struct { + *mock.Call +} + +// Broadcast is a helper method to define mock.On call +// - ctx context.Context +// - msg message.BroadcastMutableMessage +func (_e *MockBroadcastAPI_Expecter) Broadcast(ctx interface{}, msg interface{}) *MockBroadcastAPI_Broadcast_Call { + return &MockBroadcastAPI_Broadcast_Call{Call: _e.mock.On("Broadcast", ctx, msg)} +} + +func (_c *MockBroadcastAPI_Broadcast_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockBroadcastAPI_Broadcast_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage)) + }) + return _c +} + +func (_c *MockBroadcastAPI_Broadcast_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockBroadcastAPI_Broadcast_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroadcastAPI_Broadcast_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockBroadcastAPI_Broadcast_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with no fields +func (_m *MockBroadcastAPI) Close() { + _m.Called() +} + +// MockBroadcastAPI_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockBroadcastAPI_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockBroadcastAPI_Expecter) Close() *MockBroadcastAPI_Close_Call { + return &MockBroadcastAPI_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockBroadcastAPI_Close_Call) Run(run func()) *MockBroadcastAPI_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastAPI_Close_Call) Return() *MockBroadcastAPI_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockBroadcastAPI_Close_Call) RunAndReturn(run func()) *MockBroadcastAPI_Close_Call { + _c.Run(run) + return _c +} + +// NewMockBroadcastAPI creates a new instance of MockBroadcastAPI. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBroadcastAPI(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBroadcastAPI { + mock := &MockBroadcastAPI{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingcoord/server/mock_broadcaster/mock_Broadcaster.go b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_Broadcaster.go index ac5a2a186a..fc370f2c72 100644 --- a/internal/mocks/streamingcoord/server/mock_broadcaster/mock_Broadcaster.go +++ b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_Broadcaster.go @@ -5,10 +5,11 @@ package mock_broadcaster import ( context "context" - message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - mock "github.com/stretchr/testify/mock" + broadcaster "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" - types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + + mock "github.com/stretchr/testify/mock" ) // MockBroadcaster is an autogenerated mock type for the Broadcaster type @@ -71,65 +72,6 @@ func (_c *MockBroadcaster_Ack_Call) RunAndReturn(run func(context.Context, messa return _c } -// Broadcast provides a mock function with given fields: ctx, msg -func (_m *MockBroadcaster) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { - ret := _m.Called(ctx, msg) - - if len(ret) == 0 { - panic("no return value specified for Broadcast") - } - - var r0 *types.BroadcastAppendResult - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok { - return rf(ctx, msg) - } - if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok { - r0 = rf(ctx, msg) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*types.BroadcastAppendResult) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok { - r1 = rf(ctx, msg) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockBroadcaster_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast' -type MockBroadcaster_Broadcast_Call struct { - *mock.Call -} - -// Broadcast is a helper method to define mock.On call -// - ctx context.Context -// - msg message.BroadcastMutableMessage -func (_e *MockBroadcaster_Expecter) Broadcast(ctx interface{}, msg interface{}) *MockBroadcaster_Broadcast_Call { - return &MockBroadcaster_Broadcast_Call{Call: _e.mock.On("Broadcast", ctx, msg)} -} - -func (_c *MockBroadcaster_Broadcast_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockBroadcaster_Broadcast_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage)) - }) - return _c -} - -func (_c *MockBroadcaster_Broadcast_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockBroadcaster_Broadcast_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockBroadcaster_Broadcast_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockBroadcaster_Broadcast_Call { - _c.Call.Return(run) - return _c -} - // Close provides a mock function with no fields func (_m *MockBroadcaster) Close() { _m.Called() @@ -210,6 +152,79 @@ func (_c *MockBroadcaster_LegacyAck_Call) RunAndReturn(run func(context.Context, return _c } +// WithResourceKeys provides a mock function with given fields: ctx, resourceKeys +func (_m *MockBroadcaster) WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (broadcaster.BroadcastAPI, error) { + _va := make([]interface{}, len(resourceKeys)) + for _i := range resourceKeys { + _va[_i] = resourceKeys[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for WithResourceKeys") + } + + var r0 broadcaster.BroadcastAPI + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...message.ResourceKey) (broadcaster.BroadcastAPI, error)); ok { + return rf(ctx, resourceKeys...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...message.ResourceKey) broadcaster.BroadcastAPI); ok { + r0 = rf(ctx, resourceKeys...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(broadcaster.BroadcastAPI) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ...message.ResourceKey) error); ok { + r1 = rf(ctx, resourceKeys...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroadcaster_WithResourceKeys_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithResourceKeys' +type MockBroadcaster_WithResourceKeys_Call struct { + *mock.Call +} + +// WithResourceKeys is a helper method to define mock.On call +// - ctx context.Context +// - resourceKeys ...message.ResourceKey +func (_e *MockBroadcaster_Expecter) WithResourceKeys(ctx interface{}, resourceKeys ...interface{}) *MockBroadcaster_WithResourceKeys_Call { + return &MockBroadcaster_WithResourceKeys_Call{Call: _e.mock.On("WithResourceKeys", + append([]interface{}{ctx}, resourceKeys...)...)} +} + +func (_c *MockBroadcaster_WithResourceKeys_Call) Run(run func(ctx context.Context, resourceKeys ...message.ResourceKey)) *MockBroadcaster_WithResourceKeys_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.ResourceKey, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(message.ResourceKey) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockBroadcaster_WithResourceKeys_Call) Return(_a0 broadcaster.BroadcastAPI, _a1 error) *MockBroadcaster_WithResourceKeys_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroadcaster_WithResourceKeys_Call) RunAndReturn(run func(context.Context, ...message.ResourceKey) (broadcaster.BroadcastAPI, error)) *MockBroadcaster_WithResourceKeys_Call { + _c.Call.Return(run) + return _c +} + // NewMockBroadcaster creates a new instance of MockBroadcaster. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockBroadcaster(t interface { diff --git a/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go b/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go index ac5f85a656..2d8135c849 100644 --- a/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go +++ b/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go @@ -7,11 +7,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" @@ -50,7 +50,7 @@ func TestAssignChannelToWALLocatedFirst(t *testing.T) { <-ctx.Done() return context.Cause(ctx) }) - snmanager.StaticStreamingNodeManager.SetBalancerReady(b) + balance.Register(b) channels := []*meta.DmChannel{ {VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel_v1"}}, diff --git a/internal/querycoordv2/observers/replica_observer_test.go b/internal/querycoordv2/observers/replica_observer_test.go index 2d0f9dfd84..d7fe88321f 100644 --- a/internal/querycoordv2/observers/replica_observer_test.go +++ b/internal/querycoordv2/observers/replica_observer_test.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/kv" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" @@ -242,7 +243,7 @@ func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() { return pchans[0], nil } }) - snmanager.StaticStreamingNodeManager.SetBalancerReady(b) + balance.Register(b) suite.observer = NewReplicaObserver(suite.meta, suite.distMgr) suite.observer.Start() diff --git a/internal/rootcoord/garbage_collector_test.go b/internal/rootcoord/garbage_collector_test.go index 9ac7dea09f..92a0068338 100644 --- a/internal/rootcoord/garbage_collector_test.go +++ b/internal/rootcoord/garbage_collector_test.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" mocktso "github.com/milvus-io/milvus/internal/tso/mocks" "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/common" @@ -551,14 +552,15 @@ func TestGcPartitionData(t *testing.T) { snmanager.ResetStreamingNodeManager() b := mock_balancer.NewMockBalancer(t) - b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Run( - func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) { + b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error { <-ctx.Done() + return ctx.Err() }) b.EXPECT().RegisterStreamingEnabledNotifier(mock.Anything).Run(func(notifier *syncutil.AsyncTaskNotifier[struct{}]) { notifier.Cancel() }) - snmanager.StaticStreamingNodeManager.SetBalancerReady(b) + balance.Register(b) wal := mock_streaming.NewMockWALAccesser(t) broadcast := mock_streaming.NewMockBroadcast(t) diff --git a/internal/streamingcoord/server/balancer/balance/singleton.go b/internal/streamingcoord/server/balancer/balance/singleton.go new file mode 100644 index 0000000000..2012dcee2d --- /dev/null +++ b/internal/streamingcoord/server/balancer/balance/singleton.go @@ -0,0 +1,25 @@ +package balance + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +var singleton = syncutil.NewFuture[balancer.Balancer]() + +func Register(balancer balancer.Balancer) { + singleton.Set(balancer) +} + +func GetWithContext(ctx context.Context) (balancer.Balancer, error) { + return singleton.GetWithContext(ctx) +} + +func Release() { + if !singleton.Ready() { + return + } + singleton.Get().Close() +} diff --git a/internal/streamingcoord/server/balancer/balance/test_utility.go b/internal/streamingcoord/server/balancer/balance/test_utility.go new file mode 100644 index 0000000000..87f6ac4470 --- /dev/null +++ b/internal/streamingcoord/server/balancer/balance/test_utility.go @@ -0,0 +1,13 @@ +//go:build test +// +build test + +package balance + +import ( + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +func ResetBalancer() { + singleton = syncutil.NewFuture[balancer.Balancer]() +} diff --git a/internal/streamingcoord/server/balancer/balancer.go b/internal/streamingcoord/server/balancer/balancer.go index 71dc65531f..0bce72b3c9 100644 --- a/internal/streamingcoord/server/balancer/balancer.go +++ b/internal/streamingcoord/server/balancer/balancer.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil" ) @@ -36,6 +37,9 @@ type Balancer interface { // UpdateBalancePolicy update the balance policy. UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) + // ReplicateRole returns the replicate role of the balancer. + ReplicateRole() replicateutil.Role + // RegisterStreamingEnabledNotifier registers a notifier into the balancer. // If the error is returned, the balancer is closed. // Otherwise, the following rules are applied: diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index 667fda12fc..85125abf00 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -19,6 +19,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -93,6 +94,11 @@ func (b *balancerImpl) GetLatestChannelAssignment() (*WatchChannelAssignmentsCal return b.channelMetaManager.GetLatestChannelAssignment() } +// ReplicateRole returns the replicate role of the balancer. +func (b *balancerImpl) ReplicateRole() replicateutil.Role { + return b.channelMetaManager.ReplicateRole() +} + // GetAllStreamingNodes fetches all streaming node info. func (b *balancerImpl) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) { return resource.Resource().StreamingNodeManagerClient().GetAllStreamingNodes(ctx) diff --git a/internal/streamingcoord/server/balancer/channel/manager.go b/internal/streamingcoord/server/balancer/channel/manager.go index 73893cd9a5..bd1f19e272 100644 --- a/internal/streamingcoord/server/balancer/channel/manager.go +++ b/internal/streamingcoord/server/balancer/channel/manager.go @@ -171,6 +171,17 @@ func (cm *ChannelManager) IsStreamingEnabledOnce() bool { return cm.streamingVersion != nil } +// ReplicateRole returns the replicate role of the channel manager. +func (cm *ChannelManager) ReplicateRole() replicateutil.Role { + cm.cond.L.Lock() + defer cm.cond.L.Unlock() + + if cm.replicateConfig == nil { + return replicateutil.RolePrimary + } + return cm.replicateConfig.GetCurrentCluster().Role() +} + // TriggerWatchUpdate triggers the watch update. // Because current watch must see new incoming streaming node right away, // so a watch updating trigger will be called if there's new incoming streaming node. diff --git a/internal/streamingcoord/server/broadcaster/ack_callback_scheduler.go b/internal/streamingcoord/server/broadcaster/ack_callback_scheduler.go new file mode 100644 index 0000000000..08a5c5a567 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/ack_callback_scheduler.go @@ -0,0 +1,194 @@ +package broadcaster + +import ( + "context" + "sort" + "time" + + "github.com/cenkalti/backoff/v4" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +// newAckCallbackScheduler creates a new ack callback scheduler. +func newAckCallbackScheduler(logger *log.MLogger) *ackCallbackScheduler { + s := &ackCallbackScheduler{ + notifier: syncutil.NewAsyncTaskNotifier[struct{}](), + pending: make(chan *broadcastTask, 16), + triggerChan: make(chan struct{}, 1), + rkLocker: newResourceKeyLocker(newBroadcasterMetrics()), + tombstoneScheduler: newTombstoneScheduler(logger), + } + s.SetLogger(logger) + return s +} + +type ackCallbackScheduler struct { + log.Binder + + notifier *syncutil.AsyncTaskNotifier[struct{}] + pending chan *broadcastTask + triggerChan chan struct{} + tombstoneScheduler *tombstoneScheduler + pendingAckedTasks []*broadcastTask // should already sorted by the broadcastID + // For the task that hold the conflicted resource-key (which is protected by the resource-key lock), + // broadcastID is always increasing, + // the task which broadcastID is smaller happens before the task which broadcastID is larger. + // Meanwhile the timetick order of any vchannel of those two tasks are same with the order of broadcastID, + // so the smaller broadcastID task is always acked before the larger broadcastID task. + // so we can exeucte the tasks by the order of the broadcastID to promise the ack order is same with wal order. + rkLocker *resourceKeyLocker // it is used to lock the resource-key of ack operation. + // it is not same instance with the resourceKeyLocker in the broadcastTaskManager. + // because it is just used to check if the resource-key is locked when acked. + // For primary milvus cluster, it makes no sense, because the execution order is already protected by the broadcastTaskManager. + // But for secondary milvus cluster, it is necessary to use this rkLocker to protect the resource-key when acked to avoid the execution order broken. +} + +// Initialize initializes the ack scheduler with a list of broadcast tasks. +func (s *ackCallbackScheduler) Initialize(tasks []*broadcastTask, tombstoneIDs []uint64, bm *broadcastTaskManager) { + // when initializing, the tasks in recovery info may be out of order, so we need to sort them by the broadcastID. + sortByBroadcastID(tasks) + s.tombstoneScheduler.Initialize(bm, tombstoneIDs) + s.pendingAckedTasks = tasks + go s.background() +} + +// AddTask adds a new broadcast task into the ack scheduler. +func (s *ackCallbackScheduler) AddTask(task *broadcastTask) { + select { + case <-s.notifier.Context().Done(): + panic("unreachable: ack scheduler is closing when adding new task") + case s.pending <- task: + } +} + +// Close closes the ack scheduler. +func (s *ackCallbackScheduler) Close() { + s.notifier.Cancel() + s.notifier.BlockUntilFinish() + + // close the tombstone scheduler after the ack scheduler is closed. + s.tombstoneScheduler.Close() +} + +// background is the background task of the ack scheduler. +func (s *ackCallbackScheduler) background() { + defer func() { + s.notifier.Finish(struct{}{}) + s.Logger().Info("ack scheduler background exit") + }() + s.Logger().Info("ack scheduler background start") + + for { + s.triggerAckCallback() + select { + case <-s.notifier.Context().Done(): + return + case task := <-s.pending: + s.addBroadcastTask(task) + case <-s.triggerChan: + } + } +} + +// addBroadcastTask adds a broadcast task into the pending acked tasks. +func (s *ackCallbackScheduler) addBroadcastTask(task *broadcastTask) error { + s.pendingAckedTasks = append(s.pendingAckedTasks, task) + sortByBroadcastID(s.pendingAckedTasks) // It's a redundant operation, + // once at runtime, the tasks are coming with the order of the broadcastID if they have the conflict resource-key. + return nil +} + +// triggerAckCallback triggers the ack callback. +func (s *ackCallbackScheduler) triggerAckCallback() { + pendingTasks := make([]*broadcastTask, 0, len(s.pendingAckedTasks)) + for _, task := range s.pendingAckedTasks { + if task.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING && + task.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK && + task.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED { + s.Logger().Info("task cannot be acked, skip the ack callback", zap.Uint64("broadcastID", task.Header().BroadcastID)) + continue + } + g, err := s.rkLocker.FastLock(task.Header().ResourceKeys.Collect()...) + if err != nil { + s.Logger().Warn("lock is occupied, delay the ack callback", zap.Uint64("broadcastID", task.Header().BroadcastID), zap.Error(err)) + pendingTasks = append(pendingTasks, task) + continue + } + // Execute the ack callback in background. + go s.doAckCallback(task, g) + } + s.pendingAckedTasks = pendingTasks +} + +// doAckCallback executes the ack callback. +func (s *ackCallbackScheduler) doAckCallback(bt *broadcastTask, g *lockGuards) (err error) { + defer func() { + g.Unlock() + s.triggerChan <- struct{}{} + if err == nil { + s.Logger().Info("execute ack callback done", zap.Uint64("broadcastID", bt.Header().BroadcastID)) + } else { + s.Logger().Warn("execute ack callback failed", zap.Uint64("broadcastID", bt.Header().BroadcastID), zap.Error(err)) + } + }() + s.Logger().Info("start to execute ack callback", zap.Uint64("broadcastID", bt.Header().BroadcastID)) + + msg, result := bt.BroadcastResult() + makeMap := make(map[string]*message.AppendResult, len(result)) + for vchannel, result := range result { + makeMap[vchannel] = &message.AppendResult{ + MessageID: result.MessageID, + LastConfirmedMessageID: result.LastConfirmedMessageID, + TimeTick: result.TimeTick, + } + } + + // call the ack callback until done. + if err := s.callMessageAckCallbackUntilDone(s.notifier.Context(), msg, makeMap); err != nil { + return err + } + if err := bt.MarkAckCallbackDone(s.notifier.Context()); err != nil { + // The catalog is reliable to write, so we can mark the ack callback done without retrying. + return err + } + s.tombstoneScheduler.AddPending(bt.Header().BroadcastID) + return nil +} + +// callMessageAckCallbackUntilDone calls the message ack callback until done. +func (s *ackCallbackScheduler) callMessageAckCallbackUntilDone(ctx context.Context, msg message.BroadcastMutableMessage, result map[string]*message.AppendResult) error { + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = 10 * time.Millisecond + backoff.MaxInterval = 10 * time.Second + backoff.MaxElapsedTime = 0 + backoff.Reset() + + for { + err := registry.CallMessageAckCallback(ctx, msg, result) + if err == nil { + return nil + } + nextInterval := backoff.NextBackOff() + s.Logger().Warn("failed to call message ack callback, wait for retry...", + log.FieldMessage(msg), + zap.Duration("nextInterval", nextInterval), + zap.Error(err)) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(nextInterval): + } + } +} + +func sortByBroadcastID(tasks []*broadcastTask) { + sort.Slice(tasks, func(i, j int) bool { + return tasks[i].Header().BroadcastID < tasks[j].Header().BroadcastID + }) +} diff --git a/internal/streamingcoord/server/broadcaster/broadcast/singleton.go b/internal/streamingcoord/server/broadcaster/broadcast/singleton.go new file mode 100644 index 0000000000..660f829e1a --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcast/singleton.go @@ -0,0 +1,38 @@ +package broadcast + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +var singleton = syncutil.NewFuture[broadcaster.Broadcaster]() + +// Register registers the broadcaster. +func Register(broadcaster broadcaster.Broadcaster) { + singleton.Set(broadcaster) +} + +// GetWithContext gets the broadcaster with context. +func GetWithContext(ctx context.Context) (broadcaster.Broadcaster, error) { + return singleton.GetWithContext(ctx) +} + +// StartBroadcastWithResourceKeys starts a broadcast with resource keys. +func StartBroadcastWithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (broadcaster.BroadcastAPI, error) { + broadcaster, err := singleton.GetWithContext(ctx) + if err != nil { + return nil, err + } + return broadcaster.WithResourceKeys(ctx, resourceKeys...) +} + +// Release releases the broadcaster. +func Release() { + if !singleton.Ready() { + return + } + singleton.Get().Close() +} diff --git a/internal/streamingcoord/server/broadcaster/broadcast/test_utility.go b/internal/streamingcoord/server/broadcaster/broadcast/test_utility.go new file mode 100644 index 0000000000..13b4d2fab5 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcast/test_utility.go @@ -0,0 +1,13 @@ +//go:build test +// +build test + +package broadcast + +import ( + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +func ResetBroadcaster() { + singleton = syncutil.NewFuture[broadcaster.Broadcaster]() +} diff --git a/internal/streamingcoord/server/broadcaster/broadcast_manager.go b/internal/streamingcoord/server/broadcaster/broadcast_manager.go index 5df4468d38..d0df87086c 100644 --- a/internal/streamingcoord/server/broadcaster/broadcast_manager.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_manager.go @@ -2,89 +2,180 @@ package broadcaster import ( "context" - "fmt" "sync" "github.com/cockroachdb/errors" "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/proto/messagespb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/util/syncutil" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) +// RecoverBroadcaster recovers the broadcaster from the recovery info. +func RecoverBroadcaster(ctx context.Context) (Broadcaster, error) { + tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx) + if err != nil { + return nil, err + } + return newBroadcastTaskManager(tasks), nil +} + // newBroadcastTaskManager creates a new broadcast task manager with recovery info. -func newBroadcastTaskManager(protos []*streamingpb.BroadcastTask) (*broadcastTaskManager, []*pendingBroadcastTask) { +// return the manager, the pending broadcast tasks and the pending ack callback tasks. +func newBroadcastTaskManager(protos []*streamingpb.BroadcastTask) *broadcastTaskManager { logger := resource.Resource().Logger().With(log.FieldComponent("broadcaster")) metrics := newBroadcasterMetrics() + rkLocker := newResourceKeyLocker(metrics) + ackScheduler := newAckCallbackScheduler(logger) recoveryTasks := make([]*broadcastTask, 0, len(protos)) for _, proto := range protos { - t := newBroadcastTaskFromProto(proto, metrics) - t.SetLogger(logger.With(zap.Uint64("broadcastID", t.header.BroadcastID))) + t := newBroadcastTaskFromProto(proto, metrics, ackScheduler) + t.SetLogger(logger) recoveryTasks = append(recoveryTasks, t) } - rks := make(map[message.ResourceKey]uint64, len(recoveryTasks)) tasks := make(map[uint64]*broadcastTask, len(recoveryTasks)) pendingTasks := make([]*pendingBroadcastTask, 0, len(recoveryTasks)) + pendingAckCallbackTasks := make([]*broadcastTask, 0, len(recoveryTasks)) + tombstoneIDs := make([]uint64, 0, len(recoveryTasks)) for _, task := range recoveryTasks { - for rk := range task.header.ResourceKeys { - if oldTaskID, ok := rks[rk]; ok { - panic(fmt.Sprintf("unreachable: dirty recovery info in metastore, broadcast ids: [%d, %d]", oldTaskID, task.header.BroadcastID)) + switch task.task.State { + case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK: + guards, err := rkLocker.FastLock(task.Header().ResourceKeys.Collect()...) + if err != nil { + panic(err) } - rks[rk] = task.header.BroadcastID - metrics.IncomingResourceKey(rk.Domain) - } - tasks[task.header.BroadcastID] = task - if task.task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING { - // only the task is pending need to be reexecuted. - pendingTasks = append(pendingTasks, newPendingBroadcastTask(task)) + task.WithResourceKeyLockGuards(guards) + + if newPending := newPendingBroadcastTask(task); newPending != nil { + // if there's some pending messages that is not appended, it should be continued to be appended. + pendingTasks = append(pendingTasks, newPending) + } else { + // if there's no pending messages, it should be added to the pending ack callback tasks. + pendingAckCallbackTasks = append(pendingAckCallbackTasks, task) + } + case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED: + // The task is recovered from the remote cluster, so it doesn't hold the resource lock. + // but the task execution order should be protected by the order of broadcastID (by ackCallbackScheduler) + if isAllDone(task.task) { + pendingAckCallbackTasks = append(pendingAckCallbackTasks, task) + } + case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE: + tombstoneIDs = append(tombstoneIDs, task.Header().BroadcastID) } + tasks[task.Header().BroadcastID] = task } + m := &broadcastTaskManager{ - Binder: log.Binder{}, - cond: syncutil.NewContextCond(&sync.Mutex{}), - tasks: tasks, - resourceKeys: rks, - metrics: metrics, + lifetime: typeutil.NewLifetime(), + mu: &sync.Mutex{}, + tasks: tasks, + resourceKeyLocker: rkLocker, + metrics: metrics, + broadcastScheduler: newBroadcasterScheduler(pendingTasks, logger), + ackScheduler: ackScheduler, } + + // add the pending ack callback tasks into the ack scheduler. + ackScheduler.Initialize(pendingAckCallbackTasks, tombstoneIDs, m) m.SetLogger(logger) - return m, pendingTasks + return m } // broadcastTaskManager is the manager of the broadcast task. type broadcastTaskManager struct { log.Binder - cond *syncutil.ContextCond - tasks map[uint64]*broadcastTask // map the broadcastID to the broadcastTaskState - resourceKeys map[message.ResourceKey]uint64 // map the resource key to the broadcastID - metrics *broadcasterMetrics + + lifetime *typeutil.Lifetime + mu *sync.Mutex + tasks map[uint64]*broadcastTask // map the broadcastID to the broadcastTaskState + tombstoneTasks []uint64 // the broadcastID of the tombstone tasks + resourceKeyLocker *resourceKeyLocker + metrics *broadcasterMetrics + broadcastScheduler *broadcasterScheduler // the scheduler of the broadcast task + ackScheduler *ackCallbackScheduler // the scheduler of the ack task } -// AddTask adds a new broadcast task into the manager. -func (bm *broadcastTaskManager) AddTask(ctx context.Context, msg message.BroadcastMutableMessage) (*pendingBroadcastTask, error) { - var err error - if msg, err = bm.assignID(ctx, msg); err != nil { - return nil, err - } - task, err := bm.addBroadcastTask(ctx, msg) - if err != nil { - return nil, err - } - return newPendingBroadcastTask(task), nil -} - -// assignID assigns the broadcast id to the message. -func (bm *broadcastTaskManager) assignID(ctx context.Context, msg message.BroadcastMutableMessage) (message.BroadcastMutableMessage, error) { +// WithResourceKeys acquires the resource keys for the broadcast task. +func (bm *broadcastTaskManager) WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (BroadcastAPI, error) { id, err := resource.Resource().IDAllocator().Allocate(ctx) if err != nil { return nil, errors.Wrapf(err, "allocate new id failed") } - msg = msg.WithBroadcastID(id) - return msg, nil + + resourceKeys = bm.appendSharedClusterRK(resourceKeys...) + guards, err := bm.resourceKeyLocker.Lock(resourceKeys...) + if err != nil { + return nil, err + } + + if err := bm.checkClusterRole(ctx); err != nil { + // unlock the guards if the cluster role is not primary. + guards.Unlock() + return nil, err + } + return &broadcasterWithRK{ + broadcaster: bm, + broadcastID: id, + guards: guards, + }, nil +} + +// checkClusterRole checks if the cluster status is primary, otherwise return error. +func (bm *broadcastTaskManager) checkClusterRole(ctx context.Context) error { + // Check if the cluster status is primary, otherwise return error. + b, err := balance.GetWithContext(ctx) + if err != nil { + return err + } + if b.ReplicateRole() != replicateutil.RolePrimary { + return status.NewReplicateViolation("cluster is not primary, cannot do any DDL/DCL") + } + return nil +} + +// appendSharedClusterRK appends the shared cluster resource key to the resource keys. +// shared cluster resource key is required for all broadcast messages. +func (bm *broadcastTaskManager) appendSharedClusterRK(resourceKeys ...message.ResourceKey) []message.ResourceKey { + for _, rk := range resourceKeys { + if rk.Domain == messagespb.ResourceDomain_ResourceDomainCluster { + return resourceKeys + } + } + return append(resourceKeys, message.NewSharedClusterResourceKey()) +} + +// broadcast broadcasts the message to all vchannels. +// it will block until the message is broadcasted to all vchannels +func (bm *broadcastTaskManager) broadcast(ctx context.Context, msg message.BroadcastMutableMessage, broadcastID uint64, guards *lockGuards) (*types.BroadcastAppendResult, error) { + if !bm.lifetime.Add(typeutil.LifetimeStateWorking) { + guards.Unlock() + return nil, status.NewOnShutdownError("broadcaster is closing") + } + defer bm.lifetime.Done() + + // check if the message is valid to be broadcasted. + // TODO: the message check callback should not be an component of broadcaster, + // it should be removed after the import operation refactory. + if err := registry.CallMessageCheckCallback(ctx, msg); err != nil { + guards.Unlock() + return nil, err + } + + task := bm.addBroadcastTask(msg, broadcastID, guards) + pendingTask := newPendingBroadcastTask(task) + + // Add it into broadcast scheduler to broadcast the message into all vchannels. + return bm.broadcastScheduler.AddTask(ctx, pendingTask) } // LegacyAck is the legacy ack function for the broadcast task. @@ -105,72 +196,90 @@ func (bm *broadcastTaskManager) LegacyAck(ctx context.Context, broadcastID uint6 // Ack acknowledges the message at the specified vchannel. func (bm *broadcastTaskManager) Ack(ctx context.Context, msg message.ImmutableMessage) error { - if err := registry.CallMessageAckCallback(ctx, msg); err != nil { - bm.Logger().Warn("message ack callback failed", log.FieldMessage(msg), zap.Error(err)) - return err + if !bm.lifetime.Add(typeutil.LifetimeStateWorking) { + return status.NewOnShutdownError("broadcaster is closing") } - bm.Logger().Warn("message ack callback success", log.FieldMessage(msg)) + defer bm.lifetime.Done() - broadcastID := msg.BroadcastHeader().BroadcastID - vchannel := msg.VChannel() - task, ok := bm.getBroadcastTaskByID(broadcastID) + t, ok := bm.getOrCreateBroadcastTask(msg) if !ok { - bm.Logger().Warn("broadcast task not found, it may already acked, ignore the request", zap.Uint64("broadcastID", broadcastID), zap.String("vchannel", vchannel)) + bm.Logger().Debug( + "task is tombstone, ignored the ack request", + zap.Uint64("broadcastID", msg.BroadcastHeader().BroadcastID), + zap.String("vchannel", msg.VChannel())) return nil } - if err := task.Ack(ctx, msg); err != nil { + return t.Ack(ctx, msg) +} + +// DropTombstone drops the tombstone task from the manager. +func (bm *broadcastTaskManager) DropTombstone(ctx context.Context, broadcastID uint64) error { + if !bm.lifetime.Add(typeutil.LifetimeStateWorking) { + return status.NewOnShutdownError("broadcaster is closing") + } + defer bm.lifetime.Done() + + t, ok := bm.getBroadcastTaskByID(broadcastID) + if !ok { + bm.Logger().Debug("task is not found, ignored the drop tombstone request", zap.Uint64("broadcastID", broadcastID)) + return nil + } + if err := t.DropTombstone(ctx); err != nil { return err } - if task.State() == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE { - bm.removeBroadcastTask(broadcastID) - } + bm.removeBroadcastTask(broadcastID) return nil } -// ReleaseResourceKeys releases the resource keys by the broadcastID. -func (bm *broadcastTaskManager) ReleaseResourceKeys(broadcastID uint64) { - bm.cond.LockAndBroadcast() - defer bm.cond.L.Unlock() +// Close closes the broadcast task manager. +func (bm *broadcastTaskManager) Close() { + bm.lifetime.SetState(typeutil.LifetimeStateStopped) + bm.lifetime.Wait() - bm.removeResourceKeys(broadcastID) + bm.broadcastScheduler.Close() + bm.ackScheduler.Close() } // addBroadcastTask adds the broadcast task into the manager. -func (bm *broadcastTaskManager) addBroadcastTask(ctx context.Context, msg message.BroadcastMutableMessage) (*broadcastTask, error) { - newIncomingTask := newBroadcastTaskFromBroadcastMessage(msg, bm.metrics) - header := newIncomingTask.Header() - newIncomingTask.SetLogger(bm.Logger().With(zap.Uint64("broadcastID", header.BroadcastID))) +func (bm *broadcastTaskManager) addBroadcastTask(msg message.BroadcastMutableMessage, broadcastID uint64, guards *lockGuards) *broadcastTask { + msg = msg.OverwriteBroadcastHeader(broadcastID, guards.ResourceKeys()...) + newIncomingTask := newBroadcastTaskFromBroadcastMessage(msg, bm.metrics, bm.ackScheduler) + newIncomingTask.SetLogger(bm.Logger()) + newIncomingTask.WithResourceKeyLockGuards(guards) - bm.cond.L.Lock() - for bm.checkIfResourceKeyExist(header) { - if err := bm.cond.Wait(ctx); err != nil { - return nil, err - } - } - - // setup the resource keys to make resource exclusive held. - for key := range header.ResourceKeys { - bm.resourceKeys[key] = header.BroadcastID - bm.metrics.IncomingResourceKey(key.Domain) - } - bm.tasks[header.BroadcastID] = newIncomingTask - bm.cond.L.Unlock() - return newIncomingTask, nil + bm.mu.Lock() + bm.tasks[broadcastID] = newIncomingTask + bm.mu.Unlock() + return newIncomingTask } -func (bm *broadcastTaskManager) checkIfResourceKeyExist(header *message.BroadcastHeader) bool { - for key := range header.ResourceKeys { - if _, ok := bm.resourceKeys[key]; ok { - return true - } +// getOrCreateBroadcastTask returns the task by the broadcastID +// return false if the task is tombstone. +// if the task is not found, it will create a new task. +func (bm *broadcastTaskManager) getOrCreateBroadcastTask(msg message.ImmutableMessage) (*broadcastTask, bool) { + bm.mu.Lock() + defer bm.mu.Unlock() + + bh := msg.BroadcastHeader() + t, ok := bm.tasks[msg.BroadcastHeader().BroadcastID] + if ok { + return t, t.State() != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE } - return false + if msg.ReplicateHeader() == nil { + bm.Logger().Warn("try to recover task from the wal from non-replicate message, ignore it") + return nil, false + } + + newBroadcastTask := newBroadcastTaskFromImmutableMessage(msg, bm.metrics, bm.ackScheduler) + newBroadcastTask.SetLogger(bm.Logger()) + bm.tasks[bh.BroadcastID] = newBroadcastTask + return newBroadcastTask, true } // getBroadcastTaskByID return the task by the broadcastID. func (bm *broadcastTaskManager) getBroadcastTaskByID(broadcastID uint64) (*broadcastTask, bool) { - bm.cond.L.Lock() - defer bm.cond.L.Unlock() + bm.mu.Lock() + defer bm.mu.Unlock() t, ok := bm.tasks[broadcastID] return t, ok @@ -178,22 +287,8 @@ func (bm *broadcastTaskManager) getBroadcastTaskByID(broadcastID uint64) (*broad // removeBroadcastTask removes the broadcast task by the broadcastID. func (bm *broadcastTaskManager) removeBroadcastTask(broadcastID uint64) { - bm.cond.LockAndBroadcast() - defer bm.cond.L.Unlock() + bm.mu.Lock() + defer bm.mu.Unlock() - bm.removeResourceKeys(broadcastID) delete(bm.tasks, broadcastID) } - -// removeResourceKeys removes the resource keys by the broadcastID. -func (bm *broadcastTaskManager) removeResourceKeys(broadcastID uint64) { - task, ok := bm.tasks[broadcastID] - if !ok { - return - } - // remove the related resource keys - for key := range task.header.ResourceKeys { - delete(bm.resourceKeys, key) - bm.metrics.GoneResourceKey(key.Domain) - } -} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_impl.go b/internal/streamingcoord/server/broadcaster/broadcast_scheduler.go similarity index 57% rename from internal/streamingcoord/server/broadcaster/broadcaster_impl.go rename to internal/streamingcoord/server/broadcaster/broadcast_scheduler.go index 0afea390be..9a9bff016a 100644 --- a/internal/streamingcoord/server/broadcaster/broadcaster_impl.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_scheduler.go @@ -7,11 +7,7 @@ import ( "go.uber.org/zap" - "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" - "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" - "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/hardware" @@ -20,32 +16,25 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) -func RecoverBroadcaster( - ctx context.Context, -) (Broadcaster, error) { - tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx) - if err != nil { - return nil, err - } - manager, pendings := newBroadcastTaskManager(tasks) - b := &broadcasterImpl{ - manager: manager, - lifetime: typeutil.NewLifetime(), +// newBroadcasterScheduler creates a new broadcaster scheduler. +func newBroadcasterScheduler(pendings []*pendingBroadcastTask, logger *log.MLogger) *broadcasterScheduler { + b := &broadcasterScheduler{ backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), pendings: pendings, backoffs: typeutil.NewHeap[*pendingBroadcastTask](&pendingBroadcastTaskArray{}), - backoffChan: make(chan *pendingBroadcastTask), pendingChan: make(chan *pendingBroadcastTask), + backoffChan: make(chan *pendingBroadcastTask), workerChan: make(chan *pendingBroadcastTask), } + b.SetLogger(logger) go b.execute() - return b, nil + return b } -// broadcasterImpl is the implementation of Broadcaster -type broadcasterImpl struct { - manager *broadcastTaskManager - lifetime *typeutil.Lifetime +// broadcasterScheduler is the implementation of Broadcaster +type broadcasterScheduler struct { + log.Binder + backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] pendings []*pendingBroadcastTask backoffs typeutil.Heap[*pendingBroadcastTask] @@ -54,87 +43,33 @@ type broadcasterImpl struct { workerChan chan *pendingBroadcastTask } -// Broadcast broadcasts the message to all channels. -func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (result *types.BroadcastAppendResult, err error) { - if !b.lifetime.Add(typeutil.LifetimeStateWorking) { - return nil, status.NewOnShutdownError("broadcaster is closing") - } - defer func() { - b.lifetime.Done() - if err != nil { - b.Logger().Warn("broadcast message failed", zap.Error(err)) - return - } - }() - - // We need to check if the message is valid before adding it to the broadcaster. - // TODO: add resource key lock here to avoid state race condition. - // TODO: add all ddl to check operation here after ddl framework is ready. - if err := registry.CallMessageCheckCallback(ctx, msg); err != nil { - b.Logger().Warn("check message ack callback failed", zap.Error(err)) - return nil, err - } - - t, err := b.manager.AddTask(ctx, msg) - if err != nil { - return nil, err - } +func (b *broadcasterScheduler) AddTask(ctx context.Context, task *pendingBroadcastTask) (*types.BroadcastAppendResult, error) { select { case <-b.backgroundTaskNotifier.Context().Done(): // We can only check the background context but not the request context here. // Because we want the new incoming task must be delivered to the background task queue // otherwise the broadcaster is closing - return nil, status.NewOnShutdownError("broadcaster is closing") - case b.pendingChan <- t: + panic("unreachable: broadcaster is closing when adding new task") + case b.pendingChan <- task: } // Wait both request context and the background task context. ctx, _ = contextutil.MergeContext(ctx, b.backgroundTaskNotifier.Context()) - r, err := t.BlockUntilTaskDone(ctx) + // wait for all the vchannels acked. + result, err := task.BlockUntilAllAck(ctx) if err != nil { return nil, err } - - // wait for all the vchannels acked. - if err := t.BlockUntilAllAck(ctx); err != nil { - return nil, err - } - return r, nil + return result, nil } -func (b *broadcasterImpl) LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error { - if !b.lifetime.Add(typeutil.LifetimeStateWorking) { - return status.NewOnShutdownError("broadcaster is closing") - } - defer b.lifetime.Done() - - return b.manager.LegacyAck(ctx, broadcastID, vchannel) -} - -// Ack acknowledges the message at the specified vchannel. -func (b *broadcasterImpl) Ack(ctx context.Context, msg message.ImmutableMessage) error { - if !b.lifetime.Add(typeutil.LifetimeStateWorking) { - return status.NewOnShutdownError("broadcaster is closing") - } - defer b.lifetime.Done() - - return b.manager.Ack(ctx, msg) -} - -func (b *broadcasterImpl) Close() { - b.lifetime.SetState(typeutil.LifetimeStateStopped) - b.lifetime.Wait() - +func (b *broadcasterScheduler) Close() { b.backgroundTaskNotifier.Cancel() b.backgroundTaskNotifier.BlockUntilFinish() } -func (b *broadcasterImpl) Logger() *log.MLogger { - return b.manager.Logger() -} - // execute the broadcaster -func (b *broadcasterImpl) execute() { +func (b *broadcasterScheduler) execute() { workers := int(float64(hardware.GetCPUNum()) * paramtable.Get().StreamingCfg.WALBroadcasterConcurrencyRatio.GetAsFloat()) if workers < 1 { workers = 1 @@ -162,7 +97,7 @@ func (b *broadcasterImpl) execute() { b.dispatch() } -func (b *broadcasterImpl) dispatch() { +func (b *broadcasterScheduler) dispatch() { for { var workerChan chan *pendingBroadcastTask var nextTask *pendingBroadcastTask @@ -203,7 +138,7 @@ func (b *broadcasterImpl) dispatch() { } } -func (b *broadcasterImpl) worker(no int) { +func (b *broadcasterScheduler) worker(no int) { logger := b.Logger().With(zap.Int("workerNo", no)) defer func() { logger.Info("broadcaster worker exit") @@ -222,8 +157,6 @@ func (b *broadcasterImpl) worker(no int) { case b.backoffChan <- task: } } - // All message of broadcast task is sent, release the resource keys to let other task with same resource keys to apply operation. - b.manager.ReleaseResourceKeys(task.Header().BroadcastID) } } } diff --git a/internal/streamingcoord/server/broadcaster/broadcast_task.go b/internal/streamingcoord/server/broadcaster/broadcast_task.go index dd271a49a4..dc50c80f81 100644 --- a/internal/streamingcoord/server/broadcaster/broadcast_task.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_task.go @@ -12,65 +12,125 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" ) // newBroadcastTaskFromProto creates a new broadcast task from the proto. -func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadcasterMetrics) *broadcastTask { +func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask { m := metrics.NewBroadcastTask(proto.GetState()) msg := message.NewBroadcastMutableMessageBeforeAppend(proto.Message.Payload, proto.Message.Properties) - bh := msg.BroadcastHeader() bt := &broadcastTask{ - mu: sync.Mutex{}, - header: bh, - task: proto, - recoverPersisted: true, // the task is recovered from the recovery info, so it's persisted. - metrics: m, - allAcked: make(chan struct{}), + mu: sync.Mutex{}, + msg: msg, + task: proto, + dirty: true, // the task is recovered from the recovery info, so it's persisted. + metrics: m, + ackCallbackScheduler: ackCallbackScheduler, + allAcked: make(chan struct{}), } - if isAllDone(proto) { + if proto.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE { close(bt.allAcked) } return bt } // newBroadcastTaskFromBroadcastMessage creates a new broadcast task from the broadcast message. -func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage, metrics *broadcasterMetrics) *broadcastTask { +func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask { m := metrics.NewBroadcastTask(streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING) header := msg.BroadcastHeader() bt := &broadcastTask{ Binder: log.Binder{}, mu: sync.Mutex{}, - header: header, + msg: msg, task: &streamingpb.BroadcastTask{ Message: msg.IntoMessageProto(), State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, AckedVchannelBitmap: make([]byte, len(header.VChannels)), + AckedCheckpoints: make([]*streamingpb.AckedCheckpoint, len(header.VChannels)), }, - recoverPersisted: false, - metrics: m, - allAcked: make(chan struct{}), + dirty: false, + metrics: m, + ackCallbackScheduler: ackCallbackScheduler, + allAcked: make(chan struct{}), } - if isAllDone(bt.task) { + if bt.task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE { close(bt.allAcked) } return bt } +// newBroadcastTaskFromImmutableMessage creates a new broadcast task from the immutable message. +func newBroadcastTaskFromImmutableMessage(msg message.ImmutableMessage, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask { + broadcastMsg := msg.IntoBroadcastMutableMessage() + task := newBroadcastTaskFromBroadcastMessage(broadcastMsg, metrics, ackCallbackScheduler) + // if the task is created from the immutable message, it already has been broadcasted, so transfer its state into recovered. + task.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED + task.metrics.ToState(streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED) + return task +} + // broadcastTask is the state of the broadcast task. type broadcastTask struct { log.Binder - mu sync.Mutex - header *message.BroadcastHeader - task *streamingpb.BroadcastTask - recoverPersisted bool // a flag to indicate that the task has been persisted into the recovery info and can be recovered. - metrics *taskMetricsGuard - allAcked chan struct{} + mu sync.Mutex + msg message.BroadcastMutableMessage + task *streamingpb.BroadcastTask + dirty bool // a flag to indicate that the task has been modified and needs to be saved into the recovery info. + metrics *taskMetricsGuard + allAcked chan struct{} + guards *lockGuards + ackCallbackScheduler *ackCallbackScheduler +} + +// SetLogger sets the logger of the broadcast task. +func (b *broadcastTask) SetLogger(logger *log.MLogger) { + b.Binder.SetLogger(logger.With(log.FieldMessage(b.msg))) +} + +// WithResourceKeyLockGuards sets the lock guards for the broadcast task. +func (b *broadcastTask) WithResourceKeyLockGuards(guards *lockGuards) { + b.mu.Lock() + defer b.mu.Unlock() + if b.guards != nil { + panic("broadcast task already has lock guards") + } + b.guards = guards +} + +// BroadcastResult returns the broadcast result of the broadcast task. +func (b *broadcastTask) BroadcastResult() (message.BroadcastMutableMessage, map[string]*types.AppendResult) { + b.mu.Lock() + defer b.mu.Unlock() + + vchannels := b.msg.BroadcastHeader().VChannels + result := make(map[string]*types.AppendResult, len(vchannels)) + for idx, vchannel := range vchannels { + if b.task.AckedCheckpoints == nil { + // forward compatible with the old version. + result[vchannel] = &types.AppendResult{ + MessageID: nil, + LastConfirmedMessageID: nil, + TimeTick: 0, + } + continue + } + cp := b.task.AckedCheckpoints[idx] + if cp == nil || cp.TimeTick == 0 { + panic("unreachable: BroadcastResult is called before the broadcast task is acked") + } + result[vchannel] = &types.AppendResult{ + MessageID: message.MustUnmarshalMessageID(cp.MessageId), + LastConfirmedMessageID: message.MustUnmarshalMessageID(cp.LastConfirmedMessageId), + TimeTick: cp.TimeTick, + } + } + return b.msg, result } // Header returns the header of the broadcast task. func (b *broadcastTask) Header() *message.BroadcastHeader { // header is a immutable field, no need to lock. - return b.header + return b.msg.BroadcastHeader() } // State returns the State of the broadcast task. @@ -92,7 +152,7 @@ func (b *broadcastTask) PendingBroadcastMessages() []message.MutableMessage { // filter out the vchannel that has been acked. pendingMessages := make([]message.MutableMessage, 0, len(msgs)) for i, msg := range msgs { - if b.task.AckedVchannelBitmap[i] != 0 { + if b.task.AckedVchannelBitmap[i] != 0 || (b.task.AckedCheckpoints != nil && b.task.AckedCheckpoints[i] != nil) { continue } pendingMessages = append(pendingMessages, msg) @@ -105,84 +165,113 @@ func (b *broadcastTask) InitializeRecovery(ctx context.Context) error { b.mu.Lock() defer b.mu.Unlock() - if b.recoverPersisted { - return nil - } - if err := b.saveTask(ctx, b.task, b.Logger()); err != nil { + if err := b.saveTaskIfDirty(ctx, b.Logger()); err != nil { return err } - b.recoverPersisted = true return nil } // GetImmutableMessageFromVChannel gets the immutable message from the vchannel. -// If the vchannel is already acked, it returns nil. func (b *broadcastTask) GetImmutableMessageFromVChannel(vchannel string) message.ImmutableMessage { b.mu.Lock() defer b.mu.Unlock() + return b.getImmutableMessageFromVChannel(vchannel, nil) +} + +func (b *broadcastTask) getImmutableMessageFromVChannel(vchannel string, result *types.AppendResult) message.ImmutableMessage { msg := message.NewBroadcastMutableMessageBeforeAppend(b.task.Message.Payload, b.task.Message.Properties) msgs := msg.SplitIntoMutableMessage() for _, msg := range msgs { if msg.VChannel() == vchannel { - // The legacy message don't have timetick, so we need to set it to 0. - return msg.WithTimeTick(0).IntoImmutableMessage(nil) + timetick := uint64(0) + var messageID message.MessageID + var lastConfirmedMessageID message.MessageID + if result != nil { + messageID = result.MessageID + timetick = result.TimeTick + lastConfirmedMessageID = result.LastConfirmedMessageID + } + // The legacy message don't have last confirmed message id/timetick/message id, + // so we just mock a unsafely message here. + if lastConfirmedMessageID == nil { + return msg.WithTimeTick(timetick).WithLastConfirmedUseMessageID().IntoImmutableMessage(messageID) + } + return msg.WithTimeTick(timetick).WithLastConfirmed(lastConfirmedMessageID).IntoImmutableMessage(messageID) } } return nil } // Ack acknowledges the message at the specified vchannel. -func (b *broadcastTask) Ack(ctx context.Context, msg message.ImmutableMessage) error { +// return true if all the vchannels are acked at first time, false if not. +func (b *broadcastTask) Ack(ctx context.Context, msgs ...message.ImmutableMessage) (err error) { b.mu.Lock() defer b.mu.Unlock() - task, ok := b.copyAndSetVChannelAcked(msg.VChannel()) - if !ok { + + return b.ack(ctx, msgs...) +} + +// ack acknowledges the message at the specified vchannel. +func (b *broadcastTask) ack(ctx context.Context, msgs ...message.ImmutableMessage) (err error) { + b.copyAndSetAckedCheckpoints(msgs...) + if !b.dirty { return nil } - - // We should always save the task after acked. - // Even if the task mark as done in memory. - // Because the task is set as done in memory before save the recovery info. - if err := b.saveTask(ctx, task, b.Logger().With(zap.String("ackVChannel", msg.VChannel()))); err != nil { + if err := b.saveTaskIfDirty(ctx, b.Logger()); err != nil { return err } - b.task = task - if isAllDone(task) { + if isAllDone(b.task) { + b.ackCallbackScheduler.AddTask(b) b.metrics.ObserveAckAll() - close(b.allAcked) } return nil } // BlockUntilAllAck blocks until all the vchannels are acked. -func (b *broadcastTask) BlockUntilAllAck(ctx context.Context) error { +func (b *broadcastTask) BlockUntilAllAck(ctx context.Context) (*types.BroadcastAppendResult, error) { select { case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() case <-b.allAcked: - return nil + _, result := b.BroadcastResult() + return &types.BroadcastAppendResult{ + BroadcastID: b.Header().BroadcastID, + AppendResults: result, + }, nil } } -// copyAndSetVChannelAcked copies the task and set the vchannel as acked. -// if the vchannel is already acked, it returns nil and false. -func (b *broadcastTask) copyAndSetVChannelAcked(vchannel string) (*streamingpb.BroadcastTask, bool) { +// copyAndSetAckedCheckpoints copies the task and set the acked checkpoints. +func (b *broadcastTask) copyAndSetAckedCheckpoints(msgs ...message.ImmutableMessage) { task := proto.Clone(b.task).(*streamingpb.BroadcastTask) - idx, err := findIdxOfVChannel(vchannel, b.Header().VChannels) - if err != nil { - panic(err) + for _, msg := range msgs { + vchannel := msg.VChannel() + idx, err := findIdxOfVChannel(vchannel, b.Header().VChannels) + if err != nil { + panic(err) + } + if len(task.AckedVchannelBitmap) == 0 { + task.AckedVchannelBitmap = make([]byte, len(b.Header().VChannels)) + } + if len(task.AckedCheckpoints) == 0 { + task.AckedCheckpoints = make([]*streamingpb.AckedCheckpoint, len(b.Header().VChannels)) + } + if cp := task.AckedCheckpoints[idx]; cp != nil && cp.TimeTick != 0 { + // after proto.Clone, the cp is always not nil, so we also need to check the time tick. + continue + } + // the ack result is dirty, so we need to set the dirty flag to true. + b.dirty = true + task.AckedVchannelBitmap[idx] = 1 + task.AckedCheckpoints[idx] = &streamingpb.AckedCheckpoint{ + MessageId: msg.MessageID().IntoProto(), + LastConfirmedMessageId: msg.LastConfirmedMessageID().IntoProto(), + TimeTick: msg.TimeTick(), + } } - if task.AckedVchannelBitmap[idx] != 0 { - return nil, false - } - task.AckedVchannelBitmap[idx] = 1 - if isAllDone(task) { - // All vchannels are acked, mark the task as done, even if there are still pending messages on working. - // The pending messages is repeated sent operation, can be ignored. - task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE - } - return task, true + // update current task state. + b.task = task } // findIdxOfVChannel finds the index of the vchannel in the broadcast task. @@ -195,33 +284,34 @@ func findIdxOfVChannel(vchannel string, vchannels []string) (int, error) { return -1, errors.Errorf("unreachable: vchannel is %s not found in the broadcast task", vchannel) } -// BroadcastDone marks the broadcast operation is done. -func (b *broadcastTask) BroadcastDone(ctx context.Context) error { +// FastAck trigger a fast ack operation when the broadcast operation is done. +func (b *broadcastTask) FastAck(ctx context.Context, broadcastResult map[string]*types.AppendResult) error { + // Broadcast operation is done. + b.metrics.ObserveBroadcastDone() + b.mu.Lock() defer b.mu.Unlock() - task := b.copyAndMarkBroadcastDone() - if err := b.saveTask(ctx, task, b.Logger()); err != nil { - return err + // because we need to wait for the streamingnode to ack the message, + // however, if the message is already write into wal, the message is determined, + // so we can make a fast ack operation here to speed up the ack operation. + msgs := make([]message.ImmutableMessage, 0, len(broadcastResult)) + for vchannel := range broadcastResult { + msgs = append(msgs, b.getImmutableMessageFromVChannel(vchannel, broadcastResult[vchannel])) } - b.task = task - b.metrics.ObserveBroadcastDone() - return nil + return b.ack(ctx, msgs...) } -// copyAndMarkBroadcastDone copies the task and mark the broadcast task as done. -// !!! The ack state of the task should not be removed, because the task is a lock-hint of resource key held by a broadcast operation. -// It can be removed only after the broadcast message is acked by all the vchannels. -func (b *broadcastTask) copyAndMarkBroadcastDone() *streamingpb.BroadcastTask { - task := proto.Clone(b.task).(*streamingpb.BroadcastTask) - if isAllDone(task) { - // If all vchannels are acked, mark the task as done. - task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE - } else { - // There's no more pending message, mark the task as wait ack. - task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK - } - return task +// DropTombstone drops the tombstone of the broadcast task. +// It will remove the tombstone of the broadcast task in recovery storage. +// After the tombstone is dropped, the idempotency and deduplication can not be guaranteed. +func (b *broadcastTask) DropTombstone(ctx context.Context) error { + b.mu.Lock() + defer b.mu.Unlock() + + b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE + b.dirty = true + return b.saveTaskIfDirty(ctx, b.Logger()) } // isAllDone check if all the vchannels are acked. @@ -243,14 +333,44 @@ func ackedCount(task *streamingpb.BroadcastTask) int { return count } -// saveTask saves the broadcast task recovery info. -func (b *broadcastTask) saveTask(ctx context.Context, task *streamingpb.BroadcastTask, logger *log.MLogger) error { - logger = logger.With(zap.String("state", task.State.String()), zap.Int("ackedVChannelCount", ackedCount(task))) - if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.header.BroadcastID, task); err != nil { - logger.Warn("save broadcast task failed", zap.Error(err)) +// MarkAckCallbackDone marks the ack callback is done. +func (b *broadcastTask) MarkAckCallbackDone(ctx context.Context) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.task.State != streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE { + b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE + close(b.allAcked) + b.dirty = true + } + + if err := b.saveTaskIfDirty(ctx, b.Logger()); err != nil { return err } - logger.Info("save broadcast task done") - b.metrics.ToState(task.State) + + if b.guards != nil { + // release the resource key lock if done. + // if the broadcast task is recovered from the remote cluster by replication, + // it doesn't hold the resource key lock, so skip it. + b.guards.Unlock() + } + return nil +} + +// saveTaskIfDirty saves the broadcast task recovery info if the task is dirty. +func (b *broadcastTask) saveTaskIfDirty(ctx context.Context, logger *log.MLogger) error { + if !b.dirty { + return nil + } + b.dirty = false + logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", ackedCount(b.task))) + if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.Header().BroadcastID, b.task); err != nil { + logger.Warn("save broadcast task failed", zap.Error(err)) + if ctx.Err() != nil { + panic("critical error: the save broadcast task is failed before the context is done") + } + return err + } + b.metrics.ToState(b.task.State) + logger.Info("save broadcast task done") return nil } diff --git a/internal/streamingcoord/server/broadcaster/broadcaster.go b/internal/streamingcoord/server/broadcaster/broadcaster.go index 85385cf8a5..92e2440269 100644 --- a/internal/streamingcoord/server/broadcaster/broadcaster.go +++ b/internal/streamingcoord/server/broadcaster/broadcaster.go @@ -8,8 +8,10 @@ import ( ) type Broadcaster interface { - // Broadcast broadcasts the message to all channels. - Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) + // WithResourceKeys sets the resource keys of the broadcast operation. + // It will acquire locks of the resource keys and return the broadcast api. + // Once the broadcast api is returned, the Close() method of the broadcast api should be called to release the resource safely. + WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (BroadcastAPI, error) // LegacyAck is the legacy ack interface for the 2.6.0 import message. LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error @@ -21,6 +23,14 @@ type Broadcaster interface { Close() } +type BroadcastAPI interface { + // Broadcast broadcasts the message to all channels. + Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) + + // Close releases the resource keys that broadcast api holds. + Close() +} + // AppendOperator is used to append messages, there's only two implement of this interface: // 1. streaming.WAL() // 2. old msgstream interface [deprecated] diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_test.go b/internal/streamingcoord/server/broadcaster/broadcaster_test.go index 448f56e872..88f5231092 100644 --- a/internal/streamingcoord/server/broadcaster/broadcaster_test.go +++ b/internal/streamingcoord/server/broadcaster/broadcaster_test.go @@ -15,6 +15,9 @@ import ( "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming" "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" internaltypes "github.com/milvus-io/milvus/internal/types" @@ -26,6 +29,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -33,6 +37,20 @@ import ( func TestBroadcaster(t *testing.T) { registry.ResetRegistration() paramtable.Init() + paramtable.Get().StreamingCfg.WALBroadcasterTombstoneCheckInternal.SwapTempValue("10ms") + paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxCount.SwapTempValue("2") + paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxLifetime.SwapTempValue("20ms") + + mb := mock_balancer.NewMockBalancer(t) + mb.EXPECT().ReplicateRole().Return(replicateutil.RolePrimary) + mb.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error { + time.Sleep(100 * time.Second) + return nil + }) + balance.Register(mb) + registry.RegisterDropCollectionV1AckCallback(func(ctx context.Context, msg message.BroadcastResultDropCollectionMessageV1) error { + return nil + }) meta := mock_metastore.NewMockStreamingCoordCataLog(t) meta.EXPECT().ListBroadcastTask(mock.Anything). @@ -57,17 +75,16 @@ func TestBroadcaster(t *testing.T) { createNewBroadcastMsg([]string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c3"), message.NewCollectionNameResourceKey("c4")).WithBroadcastID(7), - streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK, + streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED, []byte{0x00, 0x00, 0x00}), }, nil }).Times(1) done := typeutil.NewConcurrentSet[uint64]() meta.EXPECT().SaveBroadcastTask(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, broadcastID uint64, bt *streamingpb.BroadcastTask) error { - // may failure - if rand.Int31n(10) < 3 { - return errors.New("save task failed") + if ctx.Err() != nil { + return ctx.Err() } - if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE { + if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE { done.Insert(broadcastID) } return nil @@ -84,7 +101,7 @@ func TestBroadcaster(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, bc) assert.Eventually(t, func() bool { - return appended.Load() == 9 && len(done.Collect()) == 6 // only one task is done, + return appended.Load() == 9 && len(done.Collect()) == 6 }, 30*time.Second, 10*time.Millisecond) // only task 7 is not done. @@ -103,14 +120,12 @@ func TestBroadcaster(t *testing.T) { // Test broadcast here. broadcastWithSameRK := func() { var result *types.BroadcastAppendResult - for { - var err error - result, err = bc.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c7"))) - if err == nil { - break - } - } + var err error + b, err := bc.WithResourceKeys(context.Background(), message.NewCollectionNameResourceKey("c7")) + assert.NoError(t, err) + result, err = b.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c7"))) assert.Equal(t, len(result.AppendResults), 3) + assert.NoError(t, err) } go broadcastWithSameRK() go broadcastWithSameRK() @@ -119,8 +134,19 @@ func TestBroadcaster(t *testing.T) { return appended.Load() == 15 && len(done.Collect()) == 9 }, 30*time.Second, 10*time.Millisecond) + // Test close befor broadcast + broadcastAPI, err := bc.WithResourceKeys(context.Background(), message.NewExclusiveClusterResourceKey()) + assert.NoError(t, err) + broadcastAPI.Close() + + broadcastAPI, err = bc.WithResourceKeys(context.Background(), message.NewExclusiveClusterResourceKey()) + assert.NoError(t, err) + broadcastAPI.Close() + bc.Close() - _, err = bc.Broadcast(context.Background(), nil) + broadcastAPI, err = bc.WithResourceKeys(context.Background()) + assert.NoError(t, err) + _, err = broadcastAPI.Broadcast(context.Background(), nil) assert.Error(t, err) err = bc.Ack(context.Background(), mock_message.NewMockImmutableMessage(t)) assert.Error(t, err) @@ -128,13 +154,17 @@ func TestBroadcaster(t *testing.T) { func ack(t *testing.T, broadcaster Broadcaster, broadcastID uint64, vchannel string) { for { - msg := mock_message.NewMockImmutableMessage(t) - msg.EXPECT().VChannel().Return(vchannel) - msg.EXPECT().MessageTypeWithVersion().Return(message.MessageTypeTimeTickV1) - msg.EXPECT().BroadcastHeader().Return(&message.BroadcastHeader{ - BroadcastID: broadcastID, - }) - msg.EXPECT().MarshalLogObject(mock.Anything).Return(nil).Maybe() + msg := message.NewDropCollectionMessageBuilderV1(). + WithHeader(&message.DropCollectionMessageHeader{}). + WithBody(&msgpb.DropCollectionRequest{}). + WithBroadcast([]string{vchannel}). + MustBuildBroadcast(). + WithBroadcastID(broadcastID). + SplitIntoMutableMessage()[0]. + WithTimeTick(100). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)) + if err := broadcaster.Ack(context.Background(), msg); err == nil { break } @@ -215,6 +245,18 @@ func createNewWaitAckBroadcastTaskFromMessage( bitmap []byte, ) *streamingpb.BroadcastTask { pb := msg.IntoMessageProto() + acks := make([]*streamingpb.AckedCheckpoint, len(bitmap)) + for i := 0; i < len(bitmap); i++ { + if bitmap[i] != 0 { + messageID := walimplstest.NewTestMessageID(int64(i)) + lastConfirmedMessageID := walimplstest.NewTestMessageID(int64(i)) + acks[i] = &streamingpb.AckedCheckpoint{ + MessageId: messageID.IntoProto(), + LastConfirmedMessageId: lastConfirmedMessageID.IntoProto(), + TimeTick: 1, + } + } + } return &streamingpb.BroadcastTask{ Message: &messagespb.Message{ Payload: pb.Payload, @@ -222,5 +264,6 @@ func createNewWaitAckBroadcastTaskFromMessage( }, State: state, AckedVchannelBitmap: bitmap, + AckedCheckpoints: acks, } } diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_with_rk.go b/internal/streamingcoord/server/broadcaster/broadcaster_with_rk.go new file mode 100644 index 0000000000..934bcfb5b1 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster_with_rk.go @@ -0,0 +1,27 @@ +package broadcaster + +import ( + "context" + + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" +) + +type broadcasterWithRK struct { + broadcaster *broadcastTaskManager + broadcastID uint64 + guards *lockGuards +} + +func (b *broadcasterWithRK) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + // consume the guards after the broadcast is called to avoid double unlock. + guards := b.guards + b.guards = nil + return b.broadcaster.broadcast(ctx, msg, b.broadcastID, guards) +} + +func (b *broadcasterWithRK) Close() { + if b.guards != nil { + b.guards.Unlock() + } +} diff --git a/internal/streamingcoord/server/broadcaster/task.go b/internal/streamingcoord/server/broadcaster/pending_broadcast_task.go similarity index 82% rename from internal/streamingcoord/server/broadcaster/task.go rename to internal/streamingcoord/server/broadcaster/pending_broadcast_task.go index b831247f3a..356e2207a1 100644 --- a/internal/streamingcoord/server/broadcaster/task.go +++ b/internal/streamingcoord/server/broadcaster/pending_broadcast_task.go @@ -10,22 +10,21 @@ import ( "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" - "github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) var errBroadcastTaskIsNotDone = errors.New("broadcast task is not done") // newPendingBroadcastTask creates a new pendingBroadcastTask. -func newPendingBroadcastTask( - task *broadcastTask, -) *pendingBroadcastTask { +func newPendingBroadcastTask(task *broadcastTask) *pendingBroadcastTask { msgs := task.PendingBroadcastMessages() + if len(msgs) == 0 { + return nil + } return &pendingBroadcastTask{ broadcastTask: task, pendingMessages: msgs, appendResult: make(map[string]*types.AppendResult, len(msgs)), - future: syncutil.NewFuture[*types.BroadcastAppendResult](), BackoffWithInstant: typeutil.NewBackoffWithInstant(typeutil.BackoffTimerConfig{ Default: 10 * time.Second, Backoff: typeutil.BackoffConfig{ @@ -42,8 +41,6 @@ type pendingBroadcastTask struct { *broadcastTask pendingMessages []message.MutableMessage appendResult map[string]*types.AppendResult - future *syncutil.Future[*types.BroadcastAppendResult] - metrics *taskMetricsGuard *typeutil.BackoffWithInstant } @@ -53,7 +50,6 @@ type pendingBroadcastTask struct { func (b *pendingBroadcastTask) Execute(ctx context.Context) error { if err := b.broadcastTask.InitializeRecovery(ctx); err != nil { b.Logger().Warn("broadcast task initialize recovery failed", zap.Error(err)) - b.UpdateInstantWithNextBackOff() return err } @@ -70,17 +66,12 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context) error { b.appendResult[b.pendingMessages[idx].VChannel()] = resp.AppendResult } b.pendingMessages = newPendings - if len(newPendings) == 0 { - b.future.Set(&types.BroadcastAppendResult{ - BroadcastID: b.header.BroadcastID, - AppendResults: b.appendResult, - }) - } b.Logger().Info("broadcast task make a new broadcast done", zap.Int("backoffRetryMessages", len(b.pendingMessages))) } if len(b.pendingMessages) == 0 { - if err := b.broadcastTask.BroadcastDone(ctx); err != nil { - b.UpdateInstantWithNextBackOff() + // trigger a fast ack operation when the broadcast operation is done. + if err := b.broadcastTask.FastAck(ctx, b.appendResult); err != nil { + b.Logger().Warn("broadcast task save task failed", zap.Error(err)) return err } return nil @@ -89,11 +80,6 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context) error { return errBroadcastTaskIsNotDone } -// BlockUntilTaskDone blocks until the task is done. -func (b *pendingBroadcastTask) BlockUntilTaskDone(ctx context.Context) (*types.BroadcastAppendResult, error) { - return b.future.GetWithContext(ctx) -} - // pendingBroadcastTaskArray is a heap of pendingBroadcastTask. type pendingBroadcastTaskArray []*pendingBroadcastTask diff --git a/internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go index 449ebae025..a6ffe834c5 100644 --- a/internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go +++ b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go @@ -13,8 +13,8 @@ import ( // MessageAckCallback is the callback function for the message type. type ( - MessageAckCallback[H proto.Message, B proto.Message] = func(ctx context.Context, params message.SpecializedImmutableMessage[H, B]) error - messageInnerAckCallback = func(ctx context.Context, msgs message.ImmutableMessage) error + MessageAckCallback[H proto.Message, B proto.Message] = func(ctx context.Context, result message.BroadcastResult[H, B]) error + messageInnerAckCallback = func(ctx context.Context, msg message.BroadcastMutableMessage, result map[string]*message.AppendResult) error ) // messageAckCallbacks is the map of message type to the callback function. @@ -31,15 +31,18 @@ func registerMessageAckCallback[H proto.Message, B proto.Message](callback Messa // only for test, the register callback should be called once and only once return } - future.Set(func(ctx context.Context, msgs message.ImmutableMessage) error { - specializedMsg := message.MustAsSpecializedImmutableMessage[H, B](msgs) - return callback(ctx, specializedMsg) + future.Set(func(ctx context.Context, msgs message.BroadcastMutableMessage, result map[string]*message.AppendResult) error { + return callback(ctx, message.BroadcastResult[H, B]{ + Message: message.MustAsSpecializedBroadcastMessage[H, B](msgs), + Results: result, + }) }) } // CallMessageAckCallback calls the callback function for the message type. -func CallMessageAckCallback(ctx context.Context, msg message.ImmutableMessage) error { - callbackFuture, ok := messageAckCallbacks[msg.MessageTypeWithVersion()] +func CallMessageAckCallback(ctx context.Context, msg message.BroadcastMutableMessage, result map[string]*message.AppendResult) error { + version := msg.MessageTypeWithVersion() + callbackFuture, ok := messageAckCallbacks[version] if !ok { // No callback need tobe called, return nil return nil @@ -48,5 +51,5 @@ func CallMessageAckCallback(ctx context.Context, msg message.ImmutableMessage) e if err != nil { return errors.Wrap(err, "when waiting callback registered") } - return callback(ctx, msg) + return callback(ctx, msg, result) } diff --git a/internal/streamingcoord/server/broadcaster/registry/ack_message_callback_test.go b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback_test.go index 892d33b375..41c0313cf6 100644 --- a/internal/streamingcoord/server/broadcaster/registry/ack_message_callback_test.go +++ b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" ) func TestMessageCallbackRegistration(t *testing.T) { @@ -18,7 +18,7 @@ func TestMessageCallbackRegistration(t *testing.T) { // Test registering a callback called := false - callback := func(ctx context.Context, msg message.ImmutableDropPartitionMessageV1) error { + callback := func(ctx context.Context, msg message.BroadcastResultDropPartitionMessageV1) error { called = true return nil } @@ -34,13 +34,17 @@ func TestMessageCallbackRegistration(t *testing.T) { msg := message.NewDropPartitionMessageBuilderV1(). WithHeader(&message.DropPartitionMessageHeader{}). WithBody(&message.DropPartitionRequest{}). - WithVChannel("v1"). - MustBuildMutable(). - WithTimeTick(1). - IntoImmutableMessage(rmq.NewRmqID(1)) + WithBroadcast([]string{"v1"}). + MustBuildBroadcast() // Call the callback - err := CallMessageAckCallback(context.Background(), msg) + err := CallMessageAckCallback(context.Background(), msg, map[string]*message.AppendResult{ + "v1": { + MessageID: walimplstest.NewTestMessageID(1), + LastConfirmedMessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + }, + }) assert.NoError(t, err) assert.True(t, called) @@ -48,7 +52,7 @@ func TestMessageCallbackRegistration(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - err = CallMessageAckCallback(ctx, msg) + err = CallMessageAckCallback(ctx, msg, nil) assert.Error(t, err) assert.True(t, errors.Is(err, context.DeadlineExceeded)) } diff --git a/internal/streamingcoord/server/broadcaster/resource_key_locker.go b/internal/streamingcoord/server/broadcaster/resource_key_locker.go new file mode 100644 index 0000000000..0bab3dc405 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/resource_key_locker.go @@ -0,0 +1,139 @@ +package broadcaster + +import ( + "sort" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/pkg/v2/proto/messagespb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/lock" +) + +// errFastLockFailed is the error for fast lock failed. +var errFastLockFailed = errors.New("fast lock failed") + +// newResourceKeyLocker creates a new resource key locker. +func newResourceKeyLocker(metrics *broadcasterMetrics) *resourceKeyLocker { + return &resourceKeyLocker{ + inner: lock.NewKeyLock[resourceLockKey](), + } +} + +// newResourceLockKey creates a new resource lock key. +func newResourceLockKey(key message.ResourceKey) resourceLockKey { + return resourceLockKey{ + Domain: key.Domain, + Key: key.Key, + } +} + +// resourceLockKey is the key for the resource lock. +type resourceLockKey struct { + Domain messagespb.ResourceDomain + Key string +} + +// resourceKeyLocker is the locker for the resource keys. +// It's a low performance implementation, but the broadcaster is only used at low frequency of ddl. +// So it's acceptable to use this implementation. +type resourceKeyLocker struct { + inner *lock.KeyLock[resourceLockKey] +} + +// lockGuards is the guards for multiple resource keys. +type lockGuards struct { + guards []*lockGuard +} + +// ResourceKeys returns the resource keys. +func (l *lockGuards) ResourceKeys() []message.ResourceKey { + keys := make([]message.ResourceKey, 0, len(l.guards)) + for _, guard := range l.guards { + keys = append(keys, guard.key) + } + return keys +} + +// append appends the guard to the guards. +func (l *lockGuards) append(guard *lockGuard) { + l.guards = append(l.guards, guard) +} + +// Unlock unlocks the resource keys. +func (l *lockGuards) Unlock() { + // release the locks in reverse order to avoid deadlock. + for i := len(l.guards) - 1; i >= 0; i-- { + l.guards[i].Unlock() + } + l.guards = nil +} + +// lockGuard is the guard for the resource key. +type lockGuard struct { + locker *resourceKeyLocker + key message.ResourceKey +} + +// Unlock unlocks the resource key. +func (l *lockGuard) Unlock() { + l.locker.unlockWithKey(l.key) +} + +// FastLock locks the resource keys without waiting. +// return error if the resource key is already locked. +func (r *resourceKeyLocker) FastLock(keys ...message.ResourceKey) (*lockGuards, error) { + sortResourceKeys(keys) + + g := &lockGuards{} + for _, key := range keys { + var locked bool + if key.Shared { + locked = r.inner.TryRLock(newResourceLockKey(key)) + } else { + locked = r.inner.TryLock(newResourceLockKey(key)) + } + if locked { + g.append(&lockGuard{locker: r, key: key}) + continue + } + g.Unlock() + return nil, errors.Wrapf(errFastLockFailed, "fast lock failed at resource key %s", key.String()) + } + return g, nil +} + +// Lock locks the resource keys. +func (r *resourceKeyLocker) Lock(keys ...message.ResourceKey) (*lockGuards, error) { + // lock the keys in order to avoid deadlock. + sortResourceKeys(keys) + g := &lockGuards{} + for _, key := range keys { + if key.Shared { + r.inner.RLock(newResourceLockKey(key)) + } else { + r.inner.Lock(newResourceLockKey(key)) + } + g.append(&lockGuard{locker: r, key: key}) + } + return g, nil +} + +// unlockWithKey unlocks the resource key. +func (r *resourceKeyLocker) unlockWithKey(key message.ResourceKey) { + if key.Shared { + r.inner.RUnlock(newResourceLockKey(key)) + return + } + r.inner.Unlock(newResourceLockKey(key)) +} + +// sortResourceKeys sorts the resource keys. +func sortResourceKeys(keys []message.ResourceKey) { + sort.Slice(keys, func(i, j int) bool { + if keys[i].Domain != keys[j].Domain { + return keys[i].Domain < keys[j].Domain + } + return keys[i].Key < keys[j].Key + }) +} diff --git a/internal/streamingcoord/server/broadcaster/resource_key_locker_test.go b/internal/streamingcoord/server/broadcaster/resource_key_locker_test.go new file mode 100644 index 0000000000..ed0079cdb7 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/resource_key_locker_test.go @@ -0,0 +1,147 @@ +package broadcaster + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" +) + +func TestResourceKeyLocker(t *testing.T) { + t.Run("concurrent lock/unlock", func(t *testing.T) { + locker := newResourceKeyLocker(newBroadcasterMetrics()) + const numGoroutines = 10 + const numKeys = 5 + const numIterations = 100 + + // Create a set of test keys + keys := make([]message.ResourceKey, numKeys*2) + for i := 0; i < numKeys; i++ { + keys[i] = message.NewExclusiveCollectionNameResourceKey("test", fmt.Sprintf("test_collection_%d", i)) + keys[i+numKeys] = message.NewSharedDBNameResourceKey("test") + } + rand.Shuffle(len(keys), func(i, j int) { + keys[i], keys[j] = keys[j], keys[i] + }) + + // Start multiple goroutines trying to lock/unlock the same keys + done := make(chan bool) + for i := 0; i < numGoroutines; i++ { + go func(id uint64) { + for j := 0; j < numIterations; j++ { + // Try to lock random subset of keys + right := rand.Intn(numKeys) + left := 0 + if right > 0 { + left = rand.Intn(right) + } + keysToLock := make([]message.ResourceKey, right-left) + for i := left; i < right; i++ { + keysToLock[i-left] = keys[i] + } + rand.Shuffle(len(keysToLock), func(i, j int) { + keysToLock[i], keysToLock[j] = keysToLock[j], keysToLock[i] + }) + n := rand.Intn(10) + if n < 3 { + // Lock the keys + guards, err := locker.Lock(keysToLock...) + if err != nil { + t.Errorf("Failed to lock keys: %v", err) + return + } + // Hold lock briefly + time.Sleep(time.Millisecond) + + // Unlock the keys + guards.Unlock() + } else { + guards, err := locker.Lock(keysToLock...) + if err == nil { + guards.Unlock() + } + } + } + done <- true + }(uint64(i)) + } + + // Wait for all goroutines to complete + for i := 0; i < numGoroutines; i++ { + <-done + } + }) + + t.Run("deadlock prevention", func(t *testing.T) { + locker := newResourceKeyLocker(newBroadcasterMetrics()) + key1 := message.NewCollectionNameResourceKey("test_collection_1") + key2 := message.NewCollectionNameResourceKey("test_collection_2") + + // Create two goroutines that try to lock resources in different orders + done := make(chan bool) + go func() { + for i := 0; i < 100; i++ { + // Lock key1 then key2 + guards, err := locker.Lock(key1, key2) + if err != nil { + t.Errorf("Failed to lock keys in order 1->2: %v", err) + return + } + time.Sleep(time.Millisecond) + guards.Unlock() + } + done <- true + }() + + go func() { + for i := 0; i < 100; i++ { + // Lock key2 then key1 + guards, err := locker.Lock(key2, key1) + if err != nil { + t.Errorf("Failed to lock keys in order 2->1: %v", err) + return + } + time.Sleep(time.Millisecond) + guards.Unlock() + } + done <- true + }() + + // Wait for both goroutines with timeout + for i := 0; i < 2; i++ { + select { + case <-done: + // Goroutine completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Deadlock detected - goroutines did not complete in time") + } + } + }) + + t.Run("fast lock", func(t *testing.T) { + locker := newResourceKeyLocker(newBroadcasterMetrics()) + key := message.NewCollectionNameResourceKey("test_collection") + + // First fast lock should succeed + guards1, err := locker.FastLock(key) + if err != nil { + t.Fatalf("First FastLock failed: %v", err) + } + + // Second fast lock should fail + _, err = locker.FastLock(key) + if err == nil { + t.Fatal("Second FastLock should have failed") + } + + // After unlock, fast lock should succeed again + guards1.Unlock() + guards2, err := locker.FastLock(key) + if err != nil { + t.Fatalf("FastLock after unlock failed: %v", err) + } + guards2.Unlock() + }) +} diff --git a/internal/streamingcoord/server/broadcaster/tombstone_scheduler.go b/internal/streamingcoord/server/broadcaster/tombstone_scheduler.go new file mode 100644 index 0000000000..efa42aea12 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/tombstone_scheduler.go @@ -0,0 +1,124 @@ +package broadcaster + +import ( + "sort" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +// tombstoneItem is a tombstone item with expired time. +type tombstoneItem struct { + broadcastID uint64 + createTime time.Time // the time when the tombstone is created, when recovery, the createTime will be reset to the current time, but it's ok. +} + +// tombstoneScheduler is a scheduler for the tombstone. +type tombstoneScheduler struct { + log.Binder + + notifier *syncutil.AsyncTaskNotifier[struct{}] + pending chan uint64 + bm *broadcastTaskManager + tombstones []tombstoneItem +} + +// newTombstoneScheduler creates a new tombstone scheduler. +func newTombstoneScheduler(logger *log.MLogger) *tombstoneScheduler { + ts := &tombstoneScheduler{ + notifier: syncutil.NewAsyncTaskNotifier[struct{}](), + pending: make(chan uint64), + } + ts.SetLogger(logger) + return ts +} + +// Initialize initializes the tombstone scheduler. +func (s *tombstoneScheduler) Initialize(bm *broadcastTaskManager, tombstoneBroadcastIDs []uint64) { + sort.Slice(tombstoneBroadcastIDs, func(i, j int) bool { + return tombstoneBroadcastIDs[i] < tombstoneBroadcastIDs[j] + }) + s.bm = bm + s.tombstones = make([]tombstoneItem, 0, len(tombstoneBroadcastIDs)) + for _, broadcastID := range tombstoneBroadcastIDs { + s.tombstones = append(s.tombstones, tombstoneItem{ + broadcastID: broadcastID, + createTime: time.Now(), + }) + } + go s.background() +} + +// AddPending adds a pending tombstone to the scheduler. +func (s *tombstoneScheduler) AddPending(broadcastID uint64) { + select { + case <-s.notifier.Context().Done(): + panic("unreachable: tombstone scheduler is closing when adding pending tombstone") + case s.pending <- broadcastID: + } +} + +// Close closes the tombstone scheduler. +func (s *tombstoneScheduler) Close() { + s.notifier.Cancel() + s.notifier.BlockUntilFinish() +} + +// background is the background goroutine of the tombstone scheduler. +func (s *tombstoneScheduler) background() { + defer func() { + s.notifier.Finish(struct{}{}) + s.Logger().Info("tombstone scheduler background exit") + }() + s.Logger().Info("tombstone scheduler background start") + + tombstoneGCInterval := paramtable.Get().StreamingCfg.WALBroadcasterTombstoneCheckInternal.GetAsDurationByParse() + ticker := time.NewTicker(tombstoneGCInterval) + defer ticker.Stop() + + for { + s.triggerGCTombstone() + select { + case <-s.notifier.Context().Done(): + return + case broadcastID := <-s.pending: + s.tombstones = append(s.tombstones, tombstoneItem{ + broadcastID: broadcastID, + createTime: time.Now(), + }) + case <-ticker.C: + } + } +} + +// triggerGCTombstone triggers the garbage collection of the tombstone. +func (s *tombstoneScheduler) triggerGCTombstone() { + maxTombstoneLifetime := paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxLifetime.GetAsDurationByParse() + maxTombstoneCount := paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxCount.GetAsInt() + + expiredTime := time.Now().Add(-maxTombstoneLifetime) + expiredOffset := 0 + if len(s.tombstones) > maxTombstoneCount { + expiredOffset = len(s.tombstones) - maxTombstoneCount + } + s.Logger().Info("triggerGCTombstone", + zap.Int("tombstone count", len(s.tombstones)), + zap.Int("expired offset", expiredOffset), + zap.Time("expired time", expiredTime)) + for idx, tombstone := range s.tombstones { + // drop tombstone until the expired time or until the expired offset. + if idx >= expiredOffset && tombstone.createTime.After(expiredTime) { + s.tombstones = s.tombstones[idx:] + return + } + if err := s.bm.DropTombstone(s.notifier.Context(), tombstone.broadcastID); err != nil { + s.Logger().Error("failed to drop tombstone", zap.Error(err)) + s.tombstones = s.tombstones[idx:] + return + } + } +} diff --git a/internal/streamingcoord/server/builder.go b/internal/streamingcoord/server/builder.go index 78b11f545e..d782e6b739 100644 --- a/internal/streamingcoord/server/builder.go +++ b/internal/streamingcoord/server/builder.go @@ -4,8 +4,6 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/metastore/kv/streamingcoord" - "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" - "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/types" @@ -52,14 +50,10 @@ func (s *ServerBuilder) Build() *Server { resource.OptStreamingCatalog(streamingcoord.NewCataLog(s.metaKV)), resource.OptMixCoordClient(s.mixCoordClient), ) - balancer := syncutil.NewFuture[balancer.Balancer]() - broadcaster := syncutil.NewFuture[broadcaster.Broadcaster]() return &Server{ logger: resource.Resource().Logger().With(log.FieldComponent("server")), session: s.session, - assignmentService: service.NewAssignmentService(balancer), - broadcastService: service.NewBroadcastService(broadcaster), - balancer: balancer, - broadcaster: broadcaster, + assignmentService: service.NewAssignmentService(), + broadcastService: service.NewBroadcastService(), } } diff --git a/internal/streamingcoord/server/server.go b/internal/streamingcoord/server/server.go index 32884f5d79..6640c291ad 100644 --- a/internal/streamingcoord/server/server.go +++ b/internal/streamingcoord/server/server.go @@ -6,10 +6,11 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" _ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -17,7 +18,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/util/conc" - "github.com/milvus-io/milvus/pkg/v2/util/syncutil" ) // Server is the streamingcoord server. @@ -30,10 +30,6 @@ type Server struct { // service level variables. assignmentService service.AssignmentService broadcastService service.BroadcastService - - // basic component variables can be used at service level. - balancer *syncutil.Future[balancer.Balancer] - broadcaster *syncutil.Future[broadcaster.Broadcaster] } // Init initializes the streamingcoord server. @@ -60,8 +56,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) { s.logger.Warn("recover balancer failed", zap.Error(err)) return struct{}{}, err } - s.balancer.Set(balancer) - snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer) + balance.Register(balancer) s.logger.Info("recover balancer done") return struct{}{}, nil })) @@ -74,7 +69,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) { s.logger.Warn("recover broadcaster failed", zap.Error(err)) return struct{}{}, err } - s.broadcaster.Set(broadcaster) + broadcast.Register(broadcaster) s.logger.Info("recover broadcaster done") return struct{}{}, nil })) @@ -89,18 +84,10 @@ func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) { // Close closes the streamingcoord server. func (s *Server) Stop() { - if s.balancer.Ready() { - s.logger.Info("start close balancer...") - s.balancer.Get().Close() - } else { - s.logger.Info("balancer not ready, skip close") - } - if s.broadcaster.Ready() { - s.logger.Info("start close broadcaster...") - s.broadcaster.Get().Close() - } else { - s.logger.Info("broadcaster not ready, skip close") - } + s.logger.Info("start close balancer...") + balance.Release() + s.logger.Info("start close broadcaster...") + broadcast.Release() s.logger.Info("release streamingcoord resource...") resource.Release() s.logger.Info("streamingcoord server stopped") diff --git a/internal/streamingcoord/server/service/assignment.go b/internal/streamingcoord/server/service/assignment.go index 5fd33f5c27..61571d3a3f 100644 --- a/internal/streamingcoord/server/service/assignment.go +++ b/internal/streamingcoord/server/service/assignment.go @@ -8,7 +8,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel" "github.com/milvus-io/milvus/internal/streamingcoord/server/service/discover" "github.com/milvus-io/milvus/pkg/v2/log" @@ -18,17 +18,13 @@ import ( "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" - "github.com/milvus-io/milvus/pkg/v2/util/syncutil" ) var _ streamingpb.StreamingCoordAssignmentServiceServer = (*assignmentServiceImpl)(nil) // NewAssignmentService returns a new assignment service. -func NewAssignmentService( - balancer *syncutil.Future[balancer.Balancer], -) streamingpb.StreamingCoordAssignmentServiceServer { +func NewAssignmentService() streamingpb.StreamingCoordAssignmentServiceServer { assignmentService := &assignmentServiceImpl{ - balancer: balancer, listenerTotal: metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()), } // TODO: after recovering from wal, add it to here. @@ -44,7 +40,6 @@ type AssignmentService interface { type assignmentServiceImpl struct { streamingpb.UnimplementedStreamingCoordAssignmentServiceServer - balancer *syncutil.Future[balancer.Balancer] listenerTotal prometheus.Gauge } @@ -53,7 +48,7 @@ func (s *assignmentServiceImpl) AssignmentDiscover(server streamingpb.StreamingC s.listenerTotal.Inc() defer s.listenerTotal.Dec() - balancer, err := s.balancer.GetWithContext(server.Context()) + balancer, err := balance.GetWithContext(server.Context()) if err != nil { return err } @@ -91,7 +86,7 @@ func (s *assignmentServiceImpl) UpdateReplicateConfiguration(ctx context.Context // validateReplicateConfiguration validates the replicate configuration. func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Context, config *commonpb.ReplicateConfiguration) (message.BroadcastMutableMessage, error) { - balancer, err := s.balancer.GetWithContext(ctx) + balancer, err := balance.GetWithContext(ctx) if err != nil { return nil, err } @@ -135,7 +130,7 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte // AlterReplicateConfiguration puts the replicate configuration into the balancer. // It's a callback function of the broadcast service. func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error { - balancer, err := s.balancer.GetWithContext(ctx) + balancer, err := balance.GetWithContext(ctx) if err != nil { return err } @@ -144,7 +139,7 @@ func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context, // UpdateWALBalancePolicy is used to update the WAL balance policy. func (s *assignmentServiceImpl) UpdateWALBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) { - balancer, err := s.balancer.GetWithContext(ctx) + balancer, err := balance.GetWithContext(ctx) if err != nil { return nil, err } diff --git a/internal/streamingcoord/server/service/broadcast.go b/internal/streamingcoord/server/service/broadcast.go index f552692a66..2b75fbd03e 100644 --- a/internal/streamingcoord/server/service/broadcast.go +++ b/internal/streamingcoord/server/service/broadcast.go @@ -3,10 +3,9 @@ package service import ( "context" - "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/util/syncutil" ) // BroadcastService is the interface of the broadcast service. @@ -15,30 +14,31 @@ type BroadcastService interface { } // NewBroadcastService creates a new broadcast service. -func NewBroadcastService(bc *syncutil.Future[broadcaster.Broadcaster]) BroadcastService { - return &broadcastServceImpl{ - broadcaster: bc, - } +func NewBroadcastService() BroadcastService { + return &broadcastServceImpl{} } // broadcastServiceeeeImpl is the implementation of the broadcast service. -type broadcastServceImpl struct { - broadcaster *syncutil.Future[broadcaster.Broadcaster] -} +type broadcastServceImpl struct{} // Broadcast broadcasts the message to all channels. func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.BroadcastRequest) (*streamingpb.BroadcastResponse, error) { - broadcaster, err := s.broadcaster.GetWithContext(ctx) + msg := message.NewBroadcastMutableMessageBeforeAppend(req.Message.Payload, req.Message.Properties) + api, err := broadcast.StartBroadcastWithResourceKeys(ctx, msg.BroadcastHeader().ResourceKeys.Collect()...) if err != nil { return nil, err } - results, err := broadcaster.Broadcast(ctx, message.NewBroadcastMutableMessageBeforeAppend(req.Message.Payload, req.Message.Properties)) + results, err := api.Broadcast(ctx, msg) if err != nil { return nil, err } protoResult := make(map[string]*streamingpb.ProduceMessageResponseResult, len(results.AppendResults)) for vchannel, result := range results.AppendResults { - protoResult[vchannel] = result.IntoProto() + protoResult[vchannel] = &streamingpb.ProduceMessageResponseResult{ + Id: result.MessageID.IntoProto(), + Timetick: result.TimeTick, + LastConfirmedId: result.LastConfirmedMessageID.IntoProto(), + } } return &streamingpb.BroadcastResponse{ BroadcastId: results.BroadcastID, @@ -48,7 +48,7 @@ func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.Br // Ack acknowledges the message at the specified vchannel. func (s *broadcastServceImpl) Ack(ctx context.Context, req *streamingpb.BroadcastAckRequest) (*streamingpb.BroadcastAckResponse, error) { - broadcaster, err := s.broadcaster.GetWithContext(ctx) + broadcaster, err := broadcast.GetWithContext(ctx) if err != nil { return nil, err } diff --git a/internal/streamingcoord/server/service/broadcast_test.go b/internal/streamingcoord/server/service/broadcast_test.go index 407172880b..32e6b9bbdf 100644 --- a/internal/streamingcoord/server/service/broadcast_test.go +++ b/internal/streamingcoord/server/service/broadcast_test.go @@ -7,10 +7,12 @@ import ( "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" - "github.com/milvus-io/milvus/pkg/v2/proto/messagespb" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/v2/util/syncutil" @@ -18,17 +20,24 @@ import ( func TestBroadcastService(t *testing.T) { fb := syncutil.NewFuture[broadcaster.Broadcaster]() + mba := mock_broadcaster.NewMockBroadcastAPI(t) mb := mock_broadcaster.NewMockBroadcaster(t) fb.Set(mb) - mb.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil) + mba.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil) + mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(mba, nil) mb.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil) mb.EXPECT().LegacyAck(mock.Anything, mock.Anything, mock.Anything).Return(nil) - service := NewBroadcastService(fb) + broadcast.Register(mb) + + msg := message.NewCreateCollectionMessageBuilderV1(). + WithHeader(&message.CreateCollectionMessageHeader{}). + WithBody(&msgpb.CreateCollectionRequest{}). + WithBroadcast([]string{"v1"}, message.NewCollectionNameResourceKey("r1")). + MustBuildBroadcast() + + service := NewBroadcastService() service.Broadcast(context.Background(), &streamingpb.BroadcastRequest{ - Message: &messagespb.Message{ - Payload: []byte("payload"), - Properties: map[string]string{"_bh": "1"}, - }, + Message: msg.IntoMessageProto(), }) service.Ack(context.Background(), &streamingpb.BroadcastAckRequest{ BroadcastId: 1, diff --git a/internal/streamingnode/client/handler/producer/producer_impl.go b/internal/streamingnode/client/handler/producer/producer_impl.go index cb64a9bbf4..fcd6cbe632 100644 --- a/internal/streamingnode/client/handler/producer/producer_impl.go +++ b/internal/streamingnode/client/handler/producer/producer_impl.go @@ -300,14 +300,19 @@ func (p *producerImpl) recvLoop() (err error) { case *streamingpb.ProduceMessageResponse_Result: msgID, err := message.UnmarshalMessageID(produceResp.Result.GetId()) if err != nil { - return err + return errors.Wrap(err, "failed to unmarshal message id") + } + lcMsgID, err := message.UnmarshalMessageID(produceResp.Result.GetLastConfirmedId()) + if err != nil { + return errors.Wrap(err, "failed to unmarshal last confirmed message id") } result = produceResponse{ result: &types.AppendResult{ - MessageID: msgID, - TimeTick: produceResp.Result.GetTimetick(), - TxnCtx: message.NewTxnContextFromProto(produceResp.Result.GetTxnContext()), - Extra: produceResp.Result.GetExtra(), + MessageID: msgID, + LastConfirmedMessageID: lcMsgID, + TimeTick: produceResp.Result.GetTimetick(), + TxnCtx: message.NewTxnContextFromProto(produceResp.Result.GetTxnContext()), + Extra: produceResp.Result.GetExtra(), }, } case *streamingpb.ProduceMessageResponse_Error: diff --git a/internal/streamingnode/client/handler/producer/producer_test.go b/internal/streamingnode/client/handler/producer/producer_test.go index 9867d4e8f6..d1fd585ed3 100644 --- a/internal/streamingnode/client/handler/producer/producer_test.go +++ b/internal/streamingnode/client/handler/producer/producer_test.go @@ -86,7 +86,8 @@ func TestProducer(t *testing.T) { RequestId: 2, Response: &streamingpb.ProduceMessageResponse_Result{ Result: &streamingpb.ProduceMessageResponseResult{ - Id: walimplstest.NewTestMessageID(1).IntoProto(), + Id: walimplstest.NewTestMessageID(1).IntoProto(), + LastConfirmedId: walimplstest.NewTestMessageID(1).IntoProto(), }, }, }, diff --git a/internal/streamingnode/server/service/handler/producer/produce_server_test.go b/internal/streamingnode/server/service/handler/producer/produce_server_test.go index 4af6e76b49..b84cabb49a 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server_test.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server_test.go @@ -201,8 +201,9 @@ func TestProduceServerRecvArm(t *testing.T) { l.EXPECT().AppendAsync(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, mm message.MutableMessage, f func(*wal.AppendResult, error)) { msgID := walimplstest.NewTestMessageID(1) f(&wal.AppendResult{ - MessageID: msgID, - TimeTick: 100, + MessageID: msgID, + LastConfirmedMessageID: msgID, + TimeTick: 100, }, nil) }) l.EXPECT().IsAvailable().Return(true) diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index 35f8379429..438fe09624 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -198,10 +198,11 @@ func (w *walAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) // unwrap the messageID if needed. r := &wal.AppendResult{ - MessageID: messageID, - TimeTick: extraAppendResult.TimeTick, - TxnCtx: extraAppendResult.TxnCtx, - Extra: extra, + MessageID: messageID, + LastConfirmedMessageID: extraAppendResult.LastConfirmedMessageID, + TimeTick: extraAppendResult.TimeTick, + TxnCtx: extraAppendResult.TxnCtx, + Extra: extra, } appendMetrics.Done(r, nil) return r, nil diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go index 268b5f2612..1762eece80 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go @@ -45,6 +45,7 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message ackManager := impl.operator.AckManager() var txnSession *txn.TxnSession + var immutableMsg message.ImmutableMessage if msg.MessageType() != message.MessageTypeTimeTick { // Allocate new timestamp acker for message. var acker *ack.Acker @@ -69,7 +70,7 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message return } acker.Ack( - ack.OptImmutableMessage(msg.IntoImmutableMessage(msgID)), + ack.OptImmutableMessage(immutableMsg), ack.OptTxnSession(txnSession), ) }() @@ -115,8 +116,10 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message if txnSession != nil { ctx = txn.WithTxnSession(ctx, txnSession) } - msgID, err = impl.appendMsg(ctx, msg, append) - return + if immutableMsg, err = impl.appendMsg(ctx, msg, append); err != nil { + return nil, err + } + return immutableMsg.MessageID(), nil } // GracefulClose implements InterceptorWithGracefulClose. @@ -207,12 +210,14 @@ func (impl *timeTickAppendInterceptor) appendMsg( ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error), -) (message.MessageID, error) { +) (message.ImmutableMessage, error) { msgID, err := append(ctx, msg) if err != nil { return nil, err } - utility.ReplaceAppendResultTimeTick(ctx, msg.TimeTick()) - utility.ReplaceAppendResultTxnContext(ctx, msg.TxnContext()) - return msgID, nil + immutableMsg := msg.IntoImmutableMessage(msgID) + utility.ReplaceAppendResultTimeTick(ctx, immutableMsg.TimeTick()) + utility.ReplaceAppendResultLastConfirmedMessageID(ctx, immutableMsg.LastConfirmedMessageID()) + utility.ReplaceAppendResultTxnContext(ctx, immutableMsg.TxnContext()) + return immutableMsg, nil } diff --git a/internal/streamingnode/server/wal/metricsutil/append.go b/internal/streamingnode/server/wal/metricsutil/append.go index ffb1241f9a..24b647e770 100644 --- a/internal/streamingnode/server/wal/metricsutil/append.go +++ b/internal/streamingnode/server/wal/metricsutil/append.go @@ -81,6 +81,7 @@ func (m *AppendMetrics) IntoLogFields() []zap.Field { fields = append(fields, zap.Error(m.err)) } else { fields = append(fields, zap.String("messageID", m.result.MessageID.String())) + fields = append(fields, zap.String("lcMessageID", m.result.LastConfirmedMessageID.String())) fields = append(fields, zap.Uint64("timetick", m.result.TimeTick)) if m.result.TxnCtx != nil { fields = append(fields, zap.Int64("txnID", int64(m.result.TxnCtx.TxnID))) diff --git a/internal/streamingnode/server/wal/utility/context.go b/internal/streamingnode/server/wal/utility/context.go index 76db912548..1d46fb3369 100644 --- a/internal/streamingnode/server/wal/utility/context.go +++ b/internal/streamingnode/server/wal/utility/context.go @@ -22,9 +22,10 @@ var ( // ExtraAppendResult is the extra append result. type ExtraAppendResult struct { - TimeTick uint64 - TxnCtx *message.TxnContext - Extra protoreflect.ProtoMessage + TimeTick uint64 + TxnCtx *message.TxnContext + Extra protoreflect.ProtoMessage + LastConfirmedMessageID message.MessageID } // NotPersistedHint is the hint of not persisted message. @@ -66,6 +67,12 @@ func ModifyAppendResultExtra[M protoreflect.ProtoMessage](ctx context.Context, m result.(*ExtraAppendResult).Extra = new } +// ReplaceAppendResultLastConfirmedMessageID set last confirmed message id to context +func ReplaceAppendResultLastConfirmedMessageID(ctx context.Context, lastConfirmedMessageID message.MessageID) { + result := ctx.Value(extraAppendResultValue) + result.(*ExtraAppendResult).LastConfirmedMessageID = lastConfirmedMessageID +} + // ReplaceAppendResultTimeTick set time tick to context func ReplaceAppendResultTimeTick(ctx context.Context, timeTick uint64) { result := ctx.Value(extraAppendResultValue) diff --git a/internal/streamingnode/server/wal/utility/context_test.go b/internal/streamingnode/server/wal/utility/context_test.go index f518b05d4a..28a1e4762f 100644 --- a/internal/streamingnode/server/wal/utility/context_test.go +++ b/internal/streamingnode/server/wal/utility/context_test.go @@ -83,6 +83,18 @@ func TestReplaceAppendResultTxnContext(t *testing.T) { assert.Equal(t, retrievedResult.TxnCtx.TxnID, newTxnCtx.TxnID) } +func TestReplaceAppendResultLastConfirmedMessageID(t *testing.T) { + ctx := context.Background() + result := &ExtraAppendResult{LastConfirmedMessageID: walimplstest.NewTestMessageID(1)} + ctx = WithExtraAppendResult(ctx, result) + + newLastConfirmedMessageID := walimplstest.NewTestMessageID(2) + ReplaceAppendResultLastConfirmedMessageID(ctx, newLastConfirmedMessageID) + + retrievedResult := ctx.Value(extraAppendResultValue).(*ExtraAppendResult) + assert.True(t, retrievedResult.LastConfirmedMessageID.EQ(newLastConfirmedMessageID)) +} + func TestWithFlushFromOldArch(t *testing.T) { ctx := context.Background() assert.False(t, GetFlushFromOldArch(ctx)) diff --git a/internal/util/testutil/reset_env.go b/internal/util/testutil/reset_env.go index 37aed2d772..87418c50bf 100644 --- a/internal/util/testutil/reset_env.go +++ b/internal/util/testutil/reset_env.go @@ -3,6 +3,7 @@ package testutil import ( "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast" "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" registry2 "github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry" ) @@ -12,4 +13,5 @@ func ResetEnvironment() { registry.ResetRegistration() snmanager.ResetStreamingNodeManager() registry2.ResetRegisterLocalWALManager() + broadcast.ResetBroadcaster() } diff --git a/pkg/kv/reliable_write_meta_kv.go b/pkg/kv/reliable_write_meta_kv.go new file mode 100644 index 0000000000..5d8bda726b --- /dev/null +++ b/pkg/kv/reliable_write_meta_kv.go @@ -0,0 +1,104 @@ +package kv + +import ( + "context" + "time" + + "github.com/cenkalti/backoff/v4" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/v2/kv/predicates" + "github.com/milvus-io/milvus/pkg/v2/log" +) + +var _ MetaKv = (*ReliableWriteMetaKv)(nil) + +// NewReliableWriteMetaKv returns a new ReliableWriteMetaKv if the kv is not a ReliableWriteMetaKv. +func NewReliableWriteMetaKv(kv MetaKv) MetaKv { + if _, ok := kv.(*ReliableWriteMetaKv); ok { + return kv + } + return &ReliableWriteMetaKv{ + Binder: log.Binder{}, + MetaKv: kv, + } +} + +// ReliableWriteMetaKv is a wrapper of MetaKv that ensures the data is written reliably. +// It will retry the metawrite operation until the data is written successfully or the context is timeout. +// It's useful to promise the meta data is consistent in memory and underlying meta storage. +type ReliableWriteMetaKv struct { + log.Binder + MetaKv +} + +func (kv *ReliableWriteMetaKv) Save(ctx context.Context, key, value string) error { + return kv.retryWithBackoff(ctx, func(ctx context.Context) error { + return kv.MetaKv.Save(ctx, key, value) + }) +} + +func (kv *ReliableWriteMetaKv) MultiSave(ctx context.Context, kvs map[string]string) error { + return kv.retryWithBackoff(ctx, func(ctx context.Context) error { + return kv.MetaKv.MultiSave(ctx, kvs) + }) +} + +func (kv *ReliableWriteMetaKv) Remove(ctx context.Context, key string) error { + return kv.retryWithBackoff(ctx, func(ctx context.Context) error { + return kv.MetaKv.Remove(ctx, key) + }) +} + +func (kv *ReliableWriteMetaKv) MultiRemove(ctx context.Context, keys []string) error { + return kv.retryWithBackoff(ctx, func(ctx context.Context) error { + return kv.MetaKv.MultiRemove(ctx, keys) + }) +} + +func (kv *ReliableWriteMetaKv) MultiSaveAndRemove(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error { + return kv.retryWithBackoff(ctx, func(ctx context.Context) error { + return kv.MetaKv.MultiSaveAndRemove(ctx, saves, removals, preds...) + }) +} + +func (kv *ReliableWriteMetaKv) MultiSaveAndRemoveWithPrefix(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error { + return kv.retryWithBackoff(ctx, func(ctx context.Context) error { + return kv.MetaKv.MultiSaveAndRemoveWithPrefix(ctx, saves, removals, preds...) + }) +} + +func (kv *ReliableWriteMetaKv) CompareVersionAndSwap(ctx context.Context, key string, version int64, target string) (bool, error) { + var result bool + err := kv.retryWithBackoff(ctx, func(ctx context.Context) error { + var err error + result, err = kv.MetaKv.CompareVersionAndSwap(ctx, key, version, target) + return err + }) + return result, err +} + +// retryWithBackoff retries the function with backoff. +func (kv *ReliableWriteMetaKv) retryWithBackoff(ctx context.Context, fn func(ctx context.Context) error) error { + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = 10 * time.Millisecond + backoff.MaxInterval = 1 * time.Second + backoff.MaxElapsedTime = 0 + backoff.Reset() + for { + err := fn(ctx) + if err == nil { + return nil + } + if ctx.Err() != nil { + return ctx.Err() + } + nextInterval := backoff.NextBackOff() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(nextInterval): + kv.Logger().Warn("failed to persist operation, wait for retry...", zap.Duration("nextRetryInterval", nextInterval), zap.Error(err)) + } + } +} diff --git a/pkg/kv/reliable_write_meta_kv_test.go b/pkg/kv/reliable_write_meta_kv_test.go new file mode 100644 index 0000000000..c843453a8d --- /dev/null +++ b/pkg/kv/reliable_write_meta_kv_test.go @@ -0,0 +1,127 @@ +package kv + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/atomic" + + "github.com/milvus-io/milvus/pkg/v2/kv/predicates" + "github.com/milvus-io/milvus/pkg/v2/mocks/mock_kv" +) + +func TestReliableWriteMetaKv(t *testing.T) { + kv := mock_kv.NewMockMetaKv(t) + fail := atomic.NewBool(true) + kv.EXPECT().Save(context.TODO(), mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s1, s2 string) error { + if !fail.Load() { + return nil + } + return errors.New("test") + }) + kv.EXPECT().MultiSave(context.TODO(), mock.Anything).RunAndReturn(func(ctx context.Context, kvs map[string]string) error { + if !fail.Load() { + return nil + } + return errors.New("test") + }) + kv.EXPECT().Remove(context.TODO(), mock.Anything).RunAndReturn(func(ctx context.Context, key string) error { + if !fail.Load() { + return nil + } + return errors.New("test") + }) + kv.EXPECT().MultiRemove(context.TODO(), mock.Anything).RunAndReturn(func(ctx context.Context, keys []string) error { + if !fail.Load() { + return nil + } + return errors.New("test") + }) + kv.EXPECT().MultiSaveAndRemove(context.TODO(), mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error { + if !fail.Load() { + return nil + } + return errors.New("test") + }) + kv.EXPECT().MultiSaveAndRemoveWithPrefix(context.TODO(), mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, saves map[string]string, removals []string, preds ...predicates.Predicate) error { + if !fail.Load() { + return nil + } + return errors.New("test") + }) + kv.EXPECT().CompareVersionAndSwap(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key string, version int64, target string) (bool, error) { + if !fail.Load() { + return false, nil + } + return false, errors.New("test") + }) + rkv := NewReliableWriteMetaKv(kv) + wg := sync.WaitGroup{} + wg.Add(7) + success := atomic.NewInt32(0) + go func() { + defer wg.Done() + err := rkv.Save(context.TODO(), "test", "test") + if err == nil { + success.Add(1) + } + }() + go func() { + defer wg.Done() + err := rkv.MultiSave(context.TODO(), map[string]string{"test": "test"}) + if err == nil { + success.Add(1) + } + }() + go func() { + defer wg.Done() + err := rkv.Remove(context.TODO(), "test") + if err == nil { + success.Add(1) + } + }() + go func() { + defer wg.Done() + err := rkv.MultiRemove(context.TODO(), []string{"test"}) + if err == nil { + success.Add(1) + } + }() + go func() { + defer wg.Done() + err := rkv.MultiSaveAndRemove(context.TODO(), map[string]string{"test": "test"}, []string{"test"}) + if err == nil { + success.Add(1) + } + }() + go func() { + defer wg.Done() + err := rkv.MultiSaveAndRemoveWithPrefix(context.TODO(), map[string]string{"test": "test"}, []string{"test"}) + if err == nil { + success.Add(1) + } + }() + go func() { + defer wg.Done() + _, err := rkv.CompareVersionAndSwap(context.TODO(), "test", 0, "test") + if err == nil { + success.Add(1) + } + }() + time.Sleep(1 * time.Second) + fail.Store(false) + wg.Wait() + assert.Equal(t, int32(7), success.Load()) + + fail.Store(true) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := rkv.CompareVersionAndSwap(ctx, "test", 0, "test") + assert.ErrorIs(t, err, context.DeadlineExceeded) +} diff --git a/pkg/streaming/util/message/broadcast.go b/pkg/streaming/util/message/broadcast.go index 99d94597ec..2e3f1c4eac 100644 --- a/pkg/streaming/util/message/broadcast.go +++ b/pkg/streaming/util/message/broadcast.go @@ -44,6 +44,17 @@ func (br *BroadcastResult[H, B]) GetControlChannelResult() *AppendResult { return nil } +// GetVChannelsWithoutControlChannel returns the vchannels without control channel. +func (br *BroadcastResult[H, B]) GetVChannelsWithoutControlChannel() []string { + vchannels := make([]string, 0, len(br.Results)) + for vchannel := range br.Results { + if !funcutil.IsControlChannel(vchannel) { + vchannels = append(vchannels, vchannel) + } + } + return vchannels +} + // AppendResult is the result of append operation. type AppendResult struct { MessageID MessageID diff --git a/pkg/streaming/util/message/broadcast_test.go b/pkg/streaming/util/message/broadcast_test.go new file mode 100644 index 0000000000..6a6226f65d --- /dev/null +++ b/pkg/streaming/util/message/broadcast_test.go @@ -0,0 +1,23 @@ +package message + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" +) + +func TestBroadcastResult(t *testing.T) { + r := BroadcastResult[*CreateDatabaseMessageHeader, *CreateDatabaseMessageBody]{ + Message: nil, + Results: map[string]*AppendResult{ + "v1": {}, + "v2": {}, + "abc" + funcutil.ControlChannelSuffix: {}, + }, + } + + assert.ElementsMatch(t, []string{"v1", "v2"}, r.GetVChannelsWithoutControlChannel()) + assert.NotNil(t, r.GetControlChannelResult()) +} diff --git a/pkg/streaming/util/types/responses.go b/pkg/streaming/util/types/responses.go index d1a33e81e9..ea4ab5c1ca 100644 --- a/pkg/streaming/util/types/responses.go +++ b/pkg/streaming/util/types/responses.go @@ -24,6 +24,10 @@ type AppendResult struct { // MessageID is generated by underlying walimpls. MessageID message.MessageID + // LastConfirmedMessageID is the last confirmed message id. + // From these message id, the reader can read all the messages which timetick is greater than the TimeTick in response. + LastConfirmedMessageID message.MessageID + // TimeTick is the time tick of the message. // Set by timetick interceptor. TimeTick uint64 @@ -47,10 +51,11 @@ func (r *AppendResult) GetExtra(m proto.Message) error { // IntoProto converts the append result to proto. func (r *AppendResult) IntoProto() *streamingpb.ProduceMessageResponseResult { return &streamingpb.ProduceMessageResponseResult{ - Id: r.MessageID.IntoProto(), - Timetick: r.TimeTick, - TxnContext: r.TxnCtx.IntoProto(), - Extra: r.Extra, + Id: r.MessageID.IntoProto(), + Timetick: r.TimeTick, + TxnContext: r.TxnCtx.IntoProto(), + Extra: r.Extra, + LastConfirmedId: r.LastConfirmedMessageID.IntoProto(), } } diff --git a/pkg/streaming/util/types/responses_test.go b/pkg/streaming/util/types/responses_test.go index e6698b16c8..72c7a0b72c 100644 --- a/pkg/streaming/util/types/responses_test.go +++ b/pkg/streaming/util/types/responses_test.go @@ -42,9 +42,10 @@ func TestAppendResult_IntoProto(t *testing.T) { msgID := mock_message.NewMockMessageID(t) msgID.EXPECT().IntoProto().Return(&commonpb.MessageID{WALName: commonpb.WALName(message.WALNameTest), Id: "1"}) result := &AppendResult{ - MessageID: msgID, - TimeTick: 12345, - TxnCtx: &message.TxnContext{TxnID: 1}, + MessageID: msgID, + TimeTick: 12345, + TxnCtx: &message.TxnContext{TxnID: 1}, + LastConfirmedMessageID: msgID, } protoResult := result.IntoProto() @@ -52,6 +53,7 @@ func TestAppendResult_IntoProto(t *testing.T) { assert.Equal(t, "1", protoResult.Id.Id) assert.Equal(t, uint64(12345), protoResult.Timetick) assert.Equal(t, int64(1), protoResult.TxnContext.TxnId) + assert.Equal(t, "1", protoResult.LastConfirmedId.Id) } func TestAppendResponses_MaxTimeTick(t *testing.T) { diff --git a/pkg/util/lock/key_lock.go b/pkg/util/lock/key_lock.go index 699c327730..a4a26dc579 100644 --- a/pkg/util/lock/key_lock.go +++ b/pkg/util/lock/key_lock.go @@ -109,6 +109,47 @@ func (k *KeyLock[K]) Lock(key K) { } } +func (k *KeyLock[K]) TryLock(key K) bool { + k.keyLocksMutex.Lock() + // update the key map + if keyLock, ok := k.refLocks[key]; ok { + keyLock.ref() + + k.keyLocksMutex.Unlock() + locked := keyLock.mutex.TryLock() + if !locked { + k.keyLocksMutex.Lock() + keyLock.unref() + if keyLock.refCounter == 0 { + _ = refLockPoolPool.ReturnObject(ctx, keyLock) + delete(k.refLocks, key) + } + k.keyLocksMutex.Unlock() + } + return locked + } else { + obj, err := refLockPoolPool.BorrowObject(ctx) + if err != nil { + log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err)) + k.keyLocksMutex.Unlock() + return false + } + newKLock := obj.(*RefLock) + // newKLock := newRefLock() + locked := newKLock.mutex.TryLock() + if !locked { + _ = refLockPoolPool.ReturnObject(ctx, newKLock) + k.keyLocksMutex.Unlock() + return false + } + k.refLocks[key] = newKLock + newKLock.ref() + + k.keyLocksMutex.Unlock() + return true + } +} + func (k *KeyLock[K]) Unlock(lockedKey K) { k.keyLocksMutex.Lock() defer k.keyLocksMutex.Unlock() @@ -151,6 +192,47 @@ func (k *KeyLock[K]) RLock(key K) { } } +func (k *KeyLock[K]) TryRLock(key K) bool { + k.keyLocksMutex.Lock() + // update the key map + if keyLock, ok := k.refLocks[key]; ok { + keyLock.ref() + + k.keyLocksMutex.Unlock() + locked := keyLock.mutex.TryRLock() + if !locked { + k.keyLocksMutex.Lock() + keyLock.unref() + if keyLock.refCounter == 0 { + _ = refLockPoolPool.ReturnObject(ctx, keyLock) + delete(k.refLocks, key) + } + k.keyLocksMutex.Unlock() + } + return locked + } else { + obj, err := refLockPoolPool.BorrowObject(ctx) + if err != nil { + log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err)) + k.keyLocksMutex.Unlock() + return false + } + newKLock := obj.(*RefLock) + // newKLock := newRefLock() + locked := newKLock.mutex.TryRLock() + if !locked { + _ = refLockPoolPool.ReturnObject(ctx, newKLock) + k.keyLocksMutex.Unlock() + return false + } + k.refLocks[key] = newKLock + newKLock.ref() + + k.keyLocksMutex.Unlock() + return true + } +} + func (k *KeyLock[K]) RUnlock(lockedKey K) { k.keyLocksMutex.Lock() defer k.keyLocksMutex.Unlock() diff --git a/pkg/util/lock/key_lock_test.go b/pkg/util/lock/key_lock_test.go index 755c23a232..4b2d726da2 100644 --- a/pkg/util/lock/key_lock_test.go +++ b/pkg/util/lock/key_lock_test.go @@ -82,3 +82,50 @@ func TestNewKeyLock(t *testing.T) { keyLock.keyLocksMutex.Unlock() assert.Equal(t, 0, keyLen) } + +func TestKeyLockTryLock(t *testing.T) { + keyLock := NewKeyLock[string]() + ok := keyLock.TryLock("a") + assert.True(t, ok) + ok = keyLock.TryLock("b") + assert.True(t, ok) + + ok = keyLock.TryLock("a") + assert.False(t, ok) + ok = keyLock.TryLock("b") + assert.False(t, ok) + + ok = keyLock.TryRLock("a") + assert.False(t, ok) + ok = keyLock.TryRLock("b") + assert.False(t, ok) + + assert.Equal(t, 2, keyLock.size()) + keyLock.Unlock("a") + keyLock.Unlock("b") + assert.Zero(t, keyLock.size()) + + ok = keyLock.TryRLock("a") + assert.True(t, ok) + ok = keyLock.TryRLock("b") + assert.True(t, ok) + + ok = keyLock.TryLock("a") + assert.False(t, ok) + ok = keyLock.TryLock("b") + assert.False(t, ok) + + ok = keyLock.TryRLock("a") + assert.True(t, ok) + ok = keyLock.TryRLock("b") + assert.True(t, ok) + + assert.Equal(t, 2, keyLock.size()) + keyLock.RUnlock("a") + keyLock.RUnlock("b") + assert.Equal(t, 2, keyLock.size()) + + keyLock.RUnlock("a") + keyLock.RUnlock("b") + assert.Equal(t, 0, keyLock.size()) +} diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 54be41fd3f..b3bce3ca03 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -6147,7 +6147,10 @@ type streamingConfig struct { WALBalancerPolicyVChannelFairRebalanceMaxStep ParamItem `refreshable:"true"` // broadcaster - WALBroadcasterConcurrencyRatio ParamItem `refreshable:"false"` + WALBroadcasterConcurrencyRatio ParamItem `refreshable:"false"` + WALBroadcasterTombstoneCheckInternal ParamItem `refreshable:"true"` + WALBroadcasterTombstoneMaxCount ParamItem `refreshable:"true"` + WALBroadcasterTombstoneMaxLifetime ParamItem `refreshable:"true"` // txn TxnDefaultKeepaliveTimeout ParamItem `refreshable:"true"` @@ -6327,6 +6330,39 @@ it also determine the depth of depth first search method that is used to find th } p.WALBroadcasterConcurrencyRatio.Init(base.mgr) + p.WALBroadcasterTombstoneCheckInternal = ParamItem{ + Key: "streaming.walBroadcaster.tombstone.checkInternal", + Version: "2.6.0", + Doc: `The interval of garbage collection of tombstone, 5m by default. +Tombstone is used to reject duplicate submissions of DDL messages, +too few tombstones may lead to ABA issues in the state of milvus cluster.`, + DefaultValue: "5m", + Export: false, + } + p.WALBroadcasterTombstoneCheckInternal.Init(base.mgr) + + p.WALBroadcasterTombstoneMaxCount = ParamItem{ + Key: "streaming.walBroadcaster.tombstone.maxCount", + Version: "2.6.0", + Doc: `The max count of tombstone, 256 by default. +Tombstone is used to reject duplicate submissions of DDL messages, +too few tombstones may lead to ABA issues in the state of milvus cluster.`, + DefaultValue: "256", + Export: false, + } + p.WALBroadcasterTombstoneMaxCount.Init(base.mgr) + + p.WALBroadcasterTombstoneMaxLifetime = ParamItem{ + Key: "streaming.walBroadcaster.tombstone.maxLifetime", + Version: "2.6.0", + Doc: `The max lifetime of tombstone, 30m by default. +Tombstone is used to reject duplicate submissions of DDL messages, +too few tombstones may lead to ABA issues in the state of milvus cluster.`, + DefaultValue: "30m", + Export: false, + } + p.WALBroadcasterTombstoneMaxLifetime.Init(base.mgr) + // txn p.TxnDefaultKeepaliveTimeout = ParamItem{ Key: "streaming.txn.defaultKeepaliveTimeout", diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index f47bdc5e52..431824377f 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -665,6 +665,9 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 3, params.StreamingCfg.WALBalancerPolicyVChannelFairRebalanceMaxStep.GetAsInt()) assert.Equal(t, 30*time.Second, params.StreamingCfg.WALBalancerOperationTimeout.GetAsDurationByParse()) assert.Equal(t, 1.0, params.StreamingCfg.WALBroadcasterConcurrencyRatio.GetAsFloat()) + assert.Equal(t, 5*time.Minute, params.StreamingCfg.WALBroadcasterTombstoneCheckInternal.GetAsDurationByParse()) + assert.Equal(t, 256, params.StreamingCfg.WALBroadcasterTombstoneMaxCount.GetAsInt()) + assert.Equal(t, 30*time.Minute, params.StreamingCfg.WALBroadcasterTombstoneMaxLifetime.GetAsDurationByParse()) assert.Equal(t, 10*time.Second, params.StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse()) assert.Equal(t, 30*time.Second, params.StreamingCfg.WALWriteAheadBufferKeepalive.GetAsDurationByParse()) assert.Equal(t, int64(64*1024*1024), params.StreamingCfg.WALWriteAheadBufferCapacity.GetAsSize())