diff --git a/go.mod b/go.mod index bc5f8c1207..c0c887c1ef 100644 --- a/go.mod +++ b/go.mod @@ -86,6 +86,7 @@ require ( google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v3 v3.0.1 mosn.io/holmes v1.0.2 + mosn.io/pkg v0.0.0-20211217101631-d914102d1baf ) require ( @@ -151,6 +152,8 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dubbogo/getty v1.3.4 // indirect + github.com/dubbogo/gost v1.11.16 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/dvsekhvalnov/jose2go v1.6.0 // indirect github.com/ebitengine/purego v0.8.1 // indirect @@ -194,6 +197,7 @@ require ( github.com/ianlancetaylor/cgosymbolizer v0.0.0-20221217025313-27d3c9f66b6a // indirect github.com/jonboulle/clockwork v0.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/k0kubun/pp v3.0.1+incompatible // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -298,7 +302,6 @@ require ( k8s.io/klog/v2 v2.130.1 // indirect k8s.io/utils v0.0.0-20250321185631-1f6e0b77f77e // indirect mosn.io/api v0.0.0-20210204052134-5b9a826795fd // indirect - mosn.io/pkg v0.0.0-20211217101631-d914102d1baf // indirect sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect sigs.k8s.io/yaml v1.4.0 // indirect diff --git a/go.sum b/go.sum index 7365902c09..5ddf646ca4 100644 --- a/go.sum +++ b/go.sum @@ -303,9 +303,11 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dubbogo/getty v1.3.4 h1:5TvH213pnSIKYzY7IK8TT/r6yr5uPTB/U6YNLT+GsU0= github.com/dubbogo/getty v1.3.4/go.mod h1:36f+gH/ekaqcDWKbxNBQk9b9HXcGtaI6YHxp4YTntX8= github.com/dubbogo/go-zookeeper v1.0.3/go.mod h1:fn6n2CAEer3novYgk9ULLwAjuV8/g4DdC2ENwRb6E+c= github.com/dubbogo/gost v1.5.2/go.mod h1:pPTjVyoJan3aPxBPNUX0ADkXjPibLo+/Ib0/fADXSG8= +github.com/dubbogo/gost v1.11.16 h1:fvOw8aKQ0BuUYuD+MaXAYFvT7tg2l7WAS5SL5gZJpFs= github.com/dubbogo/gost v1.11.16/go.mod h1:vIcP9rqz2KsXHPjsAwIUtfJIJjppQLQDcYaZTy/61jI= github.com/dubbogo/jsonparser v1.0.1/go.mod h1:tYAtpctvSP/tWw4MeelsowSPgXQRVHHWbqL6ynps8jU= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -657,6 +659,7 @@ github.com/juju/cmd v0.0.0-20171107070456-e74f39857ca0/go.mod h1:yWJQHl73rdSX4DH github.com/juju/collections v0.0.0-20200605021417-0d0ec82b7271/go.mod h1:5XgO71dV1JClcOJE+4dzdn4HrI5LiyKd7PlVG6eZYhY= github.com/juju/errors v0.0.0-20150916125642-1b5e39b83d18/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= github.com/juju/errors v0.0.0-20190930114154-d42613fe1ab9/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= +github.com/juju/errors v0.0.0-20200330140219-3fe23663418f h1:MCOvExGLpaSIzLYB4iQXEHP4jYVU6vmzLNQPdMVrxnM= github.com/juju/errors v0.0.0-20200330140219-3fe23663418f/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= github.com/juju/httpprof v0.0.0-20141217160036-14bf14c30767/go.mod h1:+MaLYz4PumRkkyHYeXJ2G5g5cIW0sli2bOfpmbaMV/g= @@ -681,7 +684,9 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= +github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40= github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg= github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8= github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE= diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index e6e9d1838c..8af207d675 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -57,6 +57,10 @@ packages: InterceptorWithReady: InterceptorWithMetrics: InterceptorBuilder: + github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates: + interfaces: + ReplicatesManager: + ReplicateAcker: github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards: interfaces: ShardManager: diff --git a/internal/cdc/controller/controllerimpl/controller_impl.go b/internal/cdc/controller/controllerimpl/controller_impl.go index b6225c9f1f..75abd9003c 100644 --- a/internal/cdc/controller/controllerimpl/controller_impl.go +++ b/internal/cdc/controller/controllerimpl/controller_impl.go @@ -64,20 +64,24 @@ func (c *controller) Start() { func (c *controller) Stop() { c.stopOnce.Do(func() { log.Ctx(c.ctx).Info("CDC controller stopping...") - // TODO: sheep, gracefully stop the replicators close(c.stopChan) c.wg.Wait() + resource.Resource().ReplicateManagerClient().Close() log.Ctx(c.ctx).Info("CDC controller stopped") }) } func (c *controller) run() { - replicatePChannels, err := resource.Resource().ReplicationCatalog().ListReplicatePChannels(c.ctx) + targetReplicatePChannels, err := resource.Resource().ReplicationCatalog().ListReplicatePChannels(c.ctx) if err != nil { log.Ctx(c.ctx).Error("failed to get replicate pchannels", zap.Error(err)) return } - for _, replicatePChannel := range replicatePChannels { + // create replicators for all replicate pchannels + for _, replicatePChannel := range targetReplicatePChannels { resource.Resource().ReplicateManagerClient().CreateReplicator(replicatePChannel) } + + // remove out of target replicators + resource.Resource().ReplicateManagerClient().RemoveOutOfTargetReplicators(targetReplicatePChannels) } diff --git a/internal/cdc/controller/controllerimpl/controller_impl_test.go b/internal/cdc/controller/controllerimpl/controller_impl_test.go index c60873dd7c..7945f2c7cc 100644 --- a/internal/cdc/controller/controllerimpl/controller_impl_test.go +++ b/internal/cdc/controller/controllerimpl/controller_impl_test.go @@ -30,6 +30,7 @@ import ( func TestController_StartAndStop(t *testing.T) { mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t) + mockReplicateManagerClient.EXPECT().Close().Return() resource.InitForTest(t, resource.OptReplicateManagerClient(mockReplicateManagerClient), ) @@ -45,6 +46,7 @@ func TestController_StartAndStop(t *testing.T) { func TestController_Run(t *testing.T) { mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t) + mockReplicateManagerClient.EXPECT().Close().Return() replicatePChannels := []*streamingpb.ReplicatePChannelMeta{ { @@ -55,6 +57,7 @@ func TestController_Run(t *testing.T) { mockReplicationCatalog := mock_metastore.NewMockReplicationCatalog(t) mockReplicationCatalog.EXPECT().ListReplicatePChannels(mock.Anything).Return(replicatePChannels, nil) mockReplicateManagerClient.EXPECT().CreateReplicator(replicatePChannels[0]).Return() + mockReplicateManagerClient.EXPECT().RemoveOutOfTargetReplicators(replicatePChannels).Return() resource.InitForTest(t, resource.OptReplicateManagerClient(mockReplicateManagerClient), resource.OptReplicationCatalog(mockReplicationCatalog), @@ -68,6 +71,7 @@ func TestController_Run(t *testing.T) { func TestController_RunError(t *testing.T) { mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t) + mockReplicateManagerClient.EXPECT().Close().Return() mockReplicationCatalog := mock_metastore.NewMockReplicationCatalog(t) mockReplicationCatalog.EXPECT().ListReplicatePChannels(mock.Anything).Return(nil, assert.AnError) diff --git a/internal/cdc/replication/mock_replicate_manager_client.go b/internal/cdc/replication/mock_replicate_manager_client.go index 7943136ce7..88053815ff 100644 --- a/internal/cdc/replication/mock_replicate_manager_client.go +++ b/internal/cdc/replication/mock_replicate_manager_client.go @@ -20,6 +20,38 @@ func (_m *MockReplicateManagerClient) EXPECT() *MockReplicateManagerClient_Expec return &MockReplicateManagerClient_Expecter{mock: &_m.Mock} } +// Close provides a mock function with no fields +func (_m *MockReplicateManagerClient) Close() { + _m.Called() +} + +// MockReplicateManagerClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockReplicateManagerClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockReplicateManagerClient_Expecter) Close() *MockReplicateManagerClient_Close_Call { + return &MockReplicateManagerClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockReplicateManagerClient_Close_Call) Run(run func()) *MockReplicateManagerClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockReplicateManagerClient_Close_Call) Return() *MockReplicateManagerClient_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockReplicateManagerClient_Close_Call) RunAndReturn(run func()) *MockReplicateManagerClient_Close_Call { + _c.Run(run) + return _c +} + // CreateReplicator provides a mock function with given fields: replicateInfo func (_m *MockReplicateManagerClient) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) { _m.Called(replicateInfo) @@ -53,6 +85,39 @@ func (_c *MockReplicateManagerClient_CreateReplicator_Call) RunAndReturn(run fun return _c } +// RemoveOutOfTargetReplicators provides a mock function with given fields: targetReplicatePChannels +func (_m *MockReplicateManagerClient) RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta) { + _m.Called(targetReplicatePChannels) +} + +// MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveOutOfTargetReplicators' +type MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call struct { + *mock.Call +} + +// RemoveOutOfTargetReplicators is a helper method to define mock.On call +// - targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta +func (_e *MockReplicateManagerClient_Expecter) RemoveOutOfTargetReplicators(targetReplicatePChannels interface{}) *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call { + return &MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call{Call: _e.mock.On("RemoveOutOfTargetReplicators", targetReplicatePChannels)} +} + +func (_c *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call) Run(run func(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta)) *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]*streamingpb.ReplicatePChannelMeta)) + }) + return _c +} + +func (_c *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call) Return() *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call { + _c.Call.Return() + return _c +} + +func (_c *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call) RunAndReturn(run func([]*streamingpb.ReplicatePChannelMeta)) *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call { + _c.Run(run) + return _c +} + // NewMockReplicateManagerClient creates a new instance of MockReplicateManagerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockReplicateManagerClient(t interface { diff --git a/internal/cdc/replication/replicate_manager_client.go b/internal/cdc/replication/replicate_manager_client.go index 7b069abb6a..6b40f64e31 100644 --- a/internal/cdc/replication/replicate_manager_client.go +++ b/internal/cdc/replication/replicate_manager_client.go @@ -22,4 +22,10 @@ import "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" type ReplicateManagerClient interface { // CreateReplicator creates a new replicator for the replicate pchannel. CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) + + // RemoveOutOfTargetReplicators removes replicators that are not in the target replicate pchannels. + RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta) + + // Close closes the replicate manager client. + Close() } diff --git a/internal/cdc/replication/replicatemanager/channel_replicator.go b/internal/cdc/replication/replicatemanager/channel_replicator.go index 23a31816d1..4a0387d799 100644 --- a/internal/cdc/replication/replicatemanager/channel_replicator.go +++ b/internal/cdc/replication/replicatemanager/channel_replicator.go @@ -27,13 +27,13 @@ import ( "github.com/milvus-io/milvus/internal/cdc/replication/replicatestream" "github.com/milvus-io/milvus/internal/cdc/resource" "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/v2/streaming/util/options" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -108,23 +108,15 @@ func (r *channelReplicator) replicateLoop() error { zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), ) - startFrom, err := r.getReplicateStartMessageID() + cp, err := r.getReplicateCheckpoint() if err != nil { return err } ch := make(adaptor.ChanMessageHandler, scannerHandlerChanSize) - var deliverPolicy options.DeliverPolicy - if startFrom == nil { - // No checkpoint found, seek from the earliest position - deliverPolicy = options.DeliverPolicyAll() - } else { - // Seek from the checkpoint - deliverPolicy = options.DeliverPolicyStartFrom(startFrom) - } scanner := streaming.WAL().Read(r.ctx, streaming.ReadOption{ PChannel: r.replicateInfo.GetSourceChannelName(), - DeliverPolicy: deliverPolicy, - DeliverFilters: []options.DeliverFilter{}, + DeliverPolicy: options.DeliverPolicyStartFrom(cp.MessageID), + DeliverFilters: []options.DeliverFilter{options.DeliverFilterTimeTickGT(cp.TimeTick)}, MessageHandler: ch, }) defer scanner.Close() @@ -132,7 +124,7 @@ func (r *channelReplicator) replicateLoop() error { rsc := r.createRscFunc(r.ctx, r.replicateInfo) defer rsc.Close() - logger.Info("start replicate channel loop", zap.Any("startFrom", startFrom)) + logger.Info("start replicate channel loop", zap.Any("startFrom", cp)) for { select { @@ -142,7 +134,9 @@ func (r *channelReplicator) replicateLoop() error { case msg := <-ch: // TODO: Should be done at streamingnode. if msg.MessageType().IsSelfControlled() { - logger.Debug("skip self-controlled message", log.FieldMessage(msg)) + if msg.MessageType() != message.MessageTypeTimeTick { + logger.Debug("skip self-controlled message", log.FieldMessage(msg)) + } continue } err := rsc.Replicate(msg) @@ -150,18 +144,11 @@ func (r *channelReplicator) replicateLoop() error { panic(fmt.Sprintf("replicate message failed due to unrecoverable error: %v", err)) } logger.Debug("replicate message success", log.FieldMessage(msg)) - if msg.MessageType() == message.MessageTypeAlterReplicateConfig { - roleChanged := r.handlePutReplicateConfigMessage(msg) - if roleChanged { - // Role changed, return and stop replicate. - return nil - } - } } } } -func (r *channelReplicator) getReplicateStartMessageID() (message.MessageID, error) { +func (r *channelReplicator) getReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) { logger := log.With( zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), @@ -189,41 +176,20 @@ func (r *channelReplicator) getReplicateStartMessageID() (message.MessageID, err } } if checkpoint == nil || checkpoint.MessageId == nil { - logger.Info("channel not found in replicate info, will start from the beginning") - return nil, nil + initializedCheckpoint := utility.NewReplicateCheckpointFromProto(r.replicateInfo.InitializedCheckpoint) + logger.Info("channel not found in replicate info, will start from the beginning", + zap.Stringer("messageID", initializedCheckpoint.MessageID), + zap.Uint64("timeTick", initializedCheckpoint.TimeTick), + ) + return initializedCheckpoint, nil } - startFrom := message.MustUnmarshalMessageID(checkpoint.GetMessageId()) + cp := utility.NewReplicateCheckpointFromProto(checkpoint) logger.Info("replicate messages from position", - zap.Any("checkpoint", checkpoint), - zap.Any("startFromMessageID", startFrom), + zap.Stringer("messageID", cp.MessageID), + zap.Uint64("timeTick", cp.TimeTick), ) - return startFrom, nil -} - -func (r *channelReplicator) handlePutReplicateConfigMessage(msg message.ImmutableMessage) (roleChanged bool) { - logger := log.With( - zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), - zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), - ) - logger.Info("handle PutReplicateConfigMessage", log.FieldMessage(msg)) - prcMsg := message.MustAsImmutableAlterReplicateConfigMessageV2(msg) - replicateConfig := prcMsg.Header().ReplicateConfiguration - currentClusterID := paramtable.Get().CommonCfg.ClusterPrefix.GetValue() - currentCluster := replicateutil.MustNewConfigHelper(currentClusterID, replicateConfig).GetCurrentCluster() - if currentCluster.Role() == replicateutil.RolePrimary { - logger.Info("primary cluster, skip handle PutReplicateConfigMessage") - return false - } - // Current cluster role changed, not primary cluster, - // we need to remove the replicate pchannel. - err := resource.Resource().ReplicationCatalog().RemoveReplicatePChannel(r.ctx, - r.replicateInfo.GetSourceChannelName(), r.replicateInfo.GetTargetChannelName()) - if err != nil { - panic(fmt.Sprintf("failed to remove replicate pchannel: %v", err)) - } - logger.Info("handle PutReplicateConfigMessage done, replicate pchannel removed") - return true + return cp, nil } func (r *channelReplicator) StopReplicate() { diff --git a/internal/cdc/replication/replicatemanager/replicate_manager.go b/internal/cdc/replication/replicatemanager/replicate_manager.go index 69da7d8d3e..226dcf6ba5 100644 --- a/internal/cdc/replication/replicatemanager/replicate_manager.go +++ b/internal/cdc/replication/replicatemanager/replicate_manager.go @@ -18,8 +18,10 @@ package replicatemanager import ( "context" + "fmt" "strings" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/v2/log" @@ -32,16 +34,22 @@ type replicateManager struct { ctx context.Context // replicators is a map of replicate pchannel name to ChannelReplicator. - replicators map[string]Replicator + replicators map[string]Replicator + replicatorPChannels map[string]*streamingpb.ReplicatePChannelMeta } func NewReplicateManager() *replicateManager { return &replicateManager{ - ctx: context.Background(), - replicators: make(map[string]Replicator), + ctx: context.Background(), + replicators: make(map[string]Replicator), + replicatorPChannels: make(map[string]*streamingpb.ReplicatePChannelMeta), } } +func bindReplicatorKey(replicateInfo *streamingpb.ReplicatePChannelMeta) string { + return fmt.Sprintf("%s_%s", replicateInfo.GetSourceChannelName(), replicateInfo.GetTargetChannelName()) +} + func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) { logger := log.With( zap.String("sourceChannel", replicateInfo.GetSourceChannelName()), @@ -52,13 +60,36 @@ func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.Replicate // current cluster is not source cluster, skip create replicator return } - _, ok := r.replicators[replicateInfo.GetSourceChannelName()] + replicatorKey := bindReplicatorKey(replicateInfo) + _, ok := r.replicators[replicatorKey] if ok { logger.Debug("replicator already exists, skip create replicator") return } replicator := NewChannelReplicator(replicateInfo) replicator.StartReplicate() - r.replicators[replicateInfo.GetSourceChannelName()] = replicator + r.replicators[replicatorKey] = replicator + r.replicatorPChannels[replicatorKey] = replicateInfo logger.Info("created replicator for replicate pchannel") } + +func (r *replicateManager) RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta) { + targets := lo.KeyBy(targetReplicatePChannels, bindReplicatorKey) + for replicatorKey, replicator := range r.replicators { + if pchannelMeta, ok := targets[replicatorKey]; !ok { + replicator.StopReplicate() + delete(r.replicators, replicatorKey) + delete(r.replicatorPChannels, replicatorKey) + log.Info("removed replicator due to out of target", + zap.String("sourceChannel", pchannelMeta.GetSourceChannelName()), + zap.String("targetChannel", pchannelMeta.GetTargetChannelName()), + ) + } + } +} + +func (r *replicateManager) Close() { + for _, replicator := range r.replicators { + replicator.StopReplicate() + } +} diff --git a/internal/cdc/replication/replicatemanager/replicate_manager_test.go b/internal/cdc/replication/replicatemanager/replicate_manager_test.go index 21ceffe231..8157ab6e49 100644 --- a/internal/cdc/replication/replicatemanager/replicate_manager_test.go +++ b/internal/cdc/replication/replicatemanager/replicate_manager_test.go @@ -60,7 +60,7 @@ func TestReplicateManager_CreateReplicator(t *testing.T) { // Verify replicator was created assert.Equal(t, 1, len(manager.replicators)) - replicator, exists := manager.replicators["test-source-channel-1"] + replicator, exists := manager.replicators["test-source-channel-1_test-target-channel-1"] assert.True(t, exists) assert.NotNil(t, replicator) @@ -77,12 +77,12 @@ func TestReplicateManager_CreateReplicator(t *testing.T) { // Verify second replicator was created assert.Equal(t, 2, len(manager.replicators)) - replicator2, exists := manager.replicators["test-source-channel-2"] + replicator2, exists := manager.replicators["test-source-channel-2_test-target-channel-2"] assert.True(t, exists) assert.NotNil(t, replicator2) // Verify first replicator still exists - replicator1, exists := manager.replicators["test-source-channel-1"] + replicator1, exists := manager.replicators["test-source-channel-1_test-target-channel-1"] assert.True(t, exists) assert.NotNil(t, replicator1) } diff --git a/internal/cdc/replication/replicatestream/metrics.go b/internal/cdc/replication/replicatestream/metrics.go index 7f3946a189..e52d5a97a8 100644 --- a/internal/cdc/replication/replicatestream/metrics.go +++ b/internal/cdc/replication/replicatestream/metrics.go @@ -17,10 +17,13 @@ package replicatestream import ( + "time" + "github.com/milvus-io/milvus/pkg/v2/metrics" streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util/timerecord" + "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -83,6 +86,14 @@ func (m *replicateMetrics) OnConfirmed(msg message.ImmutableMessage) { m.replicateInfo.GetSourceChannelName(), m.replicateInfo.GetTargetChannelName(), ).Observe(float64(replicateDuration.Milliseconds())) + + now := time.Now() + confirmedTime := tsoutil.PhysicalTime(msg.TimeTick()) + lag := now.Sub(confirmedTime) + metrics.CDCReplicateLag.WithLabelValues( + m.replicateInfo.GetSourceChannelName(), + m.replicateInfo.GetTargetChannelName(), + ).Set(float64(lag.Milliseconds())) } func (m *replicateMetrics) OnConnect() { @@ -93,29 +104,29 @@ func (m *replicateMetrics) OnConnect() { } func (m *replicateMetrics) OnDisconnect() { - clusterID := m.replicateInfo.GetTargetCluster().GetClusterId() + targetClusterID := m.replicateInfo.GetTargetCluster().GetClusterId() metrics.CDCStreamRPCConnections.WithLabelValues( - clusterID, + targetClusterID, metrics.CDCStatusConnected, ).Dec() metrics.CDCStreamRPCConnections.WithLabelValues( - clusterID, + targetClusterID, metrics.CDCStatusDisconnected, ).Inc() } func (m *replicateMetrics) OnReconnect() { - clusterID := m.replicateInfo.GetTargetCluster().GetClusterId() + targetClusterID := m.replicateInfo.GetTargetCluster().GetClusterId() metrics.CDCStreamRPCConnections.WithLabelValues( - clusterID, + targetClusterID, metrics.CDCStatusDisconnected, ).Dec() metrics.CDCStreamRPCConnections.WithLabelValues( - clusterID, + targetClusterID, metrics.CDCStatusConnected, ).Inc() metrics.CDCStreamRPCReconnectTimes.WithLabelValues( - clusterID, + targetClusterID, ).Inc() } diff --git a/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go b/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go index 1c6a203a46..60ada98061 100644 --- a/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go +++ b/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go @@ -18,6 +18,7 @@ package replicatestream import ( "context" + "fmt" "sync" "time" @@ -32,6 +33,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" ) const pendingMessageQueueLength = 128 @@ -86,13 +88,14 @@ func (r *replicateStreamClient) startInternal() { backoff.MaxElapsedTime = 0 backoff.Reset() - disconnect := func(stopCh chan struct{}, err error) { + disconnect := func(stopCh chan struct{}, err error) (reconnect bool) { r.metrics.OnDisconnect() close(stopCh) r.client.CloseSend() r.wg.Wait() time.Sleep(backoff.NextBackOff()) log.Warn("restart replicate stream client", zap.Error(err)) + return err != nil } for { @@ -131,9 +134,15 @@ func (r *replicateStreamClient) startInternal() { r.wg.Wait() return case err := <-sendErrCh: - disconnect(stopCh, err) + reconnect := disconnect(stopCh, err) + if !reconnect { + return + } case err := <-recvErrCh: - disconnect(stopCh, err) + reconnect := disconnect(stopCh, err) + if !reconnect { + return + } } } } @@ -280,6 +289,13 @@ func (r *replicateStreamClient) recvLoop(stopCh <-chan struct{}) error { if lastConfirmedMessageInfo != nil { messages := r.pendingMessages.CleanupConfirmedMessages(lastConfirmedMessageInfo.GetConfirmedTimeTick()) for _, msg := range messages { + if msg.MessageType() == message.MessageTypeAlterReplicateConfig { + roleChanged := r.handleAlterReplicateConfigMessage(msg) + if roleChanged { + // Role changed, return and stop replicate. + return nil + } + } r.metrics.OnConfirmed(msg) } } @@ -287,6 +303,32 @@ func (r *replicateStreamClient) recvLoop(stopCh <-chan struct{}) error { } } +func (r *replicateStreamClient) handleAlterReplicateConfigMessage(msg message.ImmutableMessage) (roleChanged bool) { + logger := log.With( + zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), + zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), + ) + logger.Info("handle AlterReplicateConfigMessage", log.FieldMessage(msg)) + prcMsg := message.MustAsImmutableAlterReplicateConfigMessageV2(msg) + replicateConfig := prcMsg.Header().ReplicateConfiguration + currentClusterID := paramtable.Get().CommonCfg.ClusterPrefix.GetValue() + currentCluster := replicateutil.MustNewConfigHelper(currentClusterID, replicateConfig).GetCurrentCluster() + _, err := currentCluster.GetTargetChannel(r.replicateInfo.GetSourceChannelName(), + r.replicateInfo.GetTargetCluster().GetClusterId()) + if err != nil { + // Cannot find the target channel, it means that the `current->target` topology edge is removed, + // so we need to remove the replicate pchannel and stop replicate. + err := resource.Resource().ReplicationCatalog().RemoveReplicatePChannel(r.ctx, r.replicateInfo) + if err != nil { + panic(fmt.Sprintf("failed to remove replicate pchannel: %v", err)) + } + logger.Info("handle AlterReplicateConfigMessage done, replicate pchannel removed") + return true + } + logger.Info("target channel found, skip handle AlterReplicateConfigMessage") + return false +} + func (r *replicateStreamClient) Close() { r.cancel() r.wg.Wait() diff --git a/internal/distributed/streaming/replicate_service.go b/internal/distributed/streaming/replicate_service.go index 0b63ffcad3..e454803d1b 100644 --- a/internal/distributed/streaming/replicate_service.go +++ b/internal/distributed/streaming/replicate_service.go @@ -119,6 +119,16 @@ func (s replicateService) overwriteReplicateMessage(ctx context.Context, msg mes return nil, err } } + + if funcutil.IsControlChannel(msg.VChannel()) { + assignments, err := s.streamingCoordClient.Assignment().GetLatestAssignments(ctx) + if err != nil { + return nil, err + } + if !strings.HasPrefix(msg.VChannel(), assignments.PChannelOfCChannel()) { + return nil, status.NewReplicateViolation("invalid control channel %s, expected pchannel %s", msg.VChannel(), assignments.PChannelOfCChannel()) + } + } return msg, nil } diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 157980cccf..742025b288 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -212,7 +212,7 @@ type QueryCoordCatalog interface { type ReplicationCatalog interface { // RemoveReplicatePChannel removes the replicate pchannel from metastore. // Remove the task of CDC replication task of current cluster, should be called when a CDC replication task is finished. - RemoveReplicatePChannel(ctx context.Context, sourceChannelName, targetChannelName string) error + RemoveReplicatePChannel(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta) error // ListReplicatePChannels lists all replicate pchannels from metastore. // every ReplicatePChannelMeta is a task of CDC replication task of current cluster which is a source cluster in replication topology. diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go index ac409a64d6..2ac153bd48 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -216,8 +216,8 @@ func (c *catalog) GetReplicateConfiguration(ctx context.Context) (*streamingpb.R return config, nil } -func (c *catalog) RemoveReplicatePChannel(ctx context.Context, targetClusterID, sourceChannelName string) error { - key := buildReplicatePChannelPath(targetClusterID, sourceChannelName) +func (c *catalog) RemoveReplicatePChannel(ctx context.Context, task *streamingpb.ReplicatePChannelMeta) error { + key := buildReplicatePChannelPath(task.GetTargetCluster().GetClusterId(), task.GetSourceChannelName()) return c.metaKV.Remove(ctx, key) } diff --git a/internal/metastore/kv/streamingcoord/kv_catalog_test.go b/internal/metastore/kv/streamingcoord/kv_catalog_test.go index c11771db2d..9c4bcd136f 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog_test.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog_test.go @@ -242,7 +242,11 @@ func TestCatalog_ReplicationCatalog(t *testing.T) { assert.Equal(t, infos[1].GetTargetChannelName(), "target-channel-2") assert.Equal(t, infos[1].GetTargetCluster().GetClusterId(), "target-cluster") - err = catalog.RemoveReplicatePChannel(context.Background(), "target-cluster", "source-channel-1") + err = catalog.RemoveReplicatePChannel(context.Background(), &streamingpb.ReplicatePChannelMeta{ + SourceChannelName: "source-channel-1", + TargetChannelName: "target-channel-1", + TargetCluster: &commonpb.MilvusCluster{ClusterId: "target-cluster"}, + }) assert.NoError(t, err) infos, err = catalog.ListReplicatePChannels(context.Background()) diff --git a/internal/mocks/mock_metastore/mock_ReplicationCatalog.go b/internal/mocks/mock_metastore/mock_ReplicationCatalog.go index b0740e5c64..f26a51961f 100644 --- a/internal/mocks/mock_metastore/mock_ReplicationCatalog.go +++ b/internal/mocks/mock_metastore/mock_ReplicationCatalog.go @@ -81,17 +81,17 @@ func (_c *MockReplicationCatalog_ListReplicatePChannels_Call) RunAndReturn(run f return _c } -// RemoveReplicatePChannel provides a mock function with given fields: ctx, sourceChannelName, targetChannelName -func (_m *MockReplicationCatalog) RemoveReplicatePChannel(ctx context.Context, sourceChannelName string, targetChannelName string) error { - ret := _m.Called(ctx, sourceChannelName, targetChannelName) +// RemoveReplicatePChannel provides a mock function with given fields: ctx, meta +func (_m *MockReplicationCatalog) RemoveReplicatePChannel(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta) error { + ret := _m.Called(ctx, meta) if len(ret) == 0 { panic("no return value specified for RemoveReplicatePChannel") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, sourceChannelName, targetChannelName) + if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.ReplicatePChannelMeta) error); ok { + r0 = rf(ctx, meta) } else { r0 = ret.Error(0) } @@ -106,15 +106,14 @@ type MockReplicationCatalog_RemoveReplicatePChannel_Call struct { // RemoveReplicatePChannel is a helper method to define mock.On call // - ctx context.Context -// - sourceChannelName string -// - targetChannelName string -func (_e *MockReplicationCatalog_Expecter) RemoveReplicatePChannel(ctx interface{}, sourceChannelName interface{}, targetChannelName interface{}) *MockReplicationCatalog_RemoveReplicatePChannel_Call { - return &MockReplicationCatalog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, sourceChannelName, targetChannelName)} +// - meta *streamingpb.ReplicatePChannelMeta +func (_e *MockReplicationCatalog_Expecter) RemoveReplicatePChannel(ctx interface{}, meta interface{}) *MockReplicationCatalog_RemoveReplicatePChannel_Call { + return &MockReplicationCatalog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, meta)} } -func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, sourceChannelName string, targetChannelName string)) *MockReplicationCatalog_RemoveReplicatePChannel_Call { +func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta)) *MockReplicationCatalog_RemoveReplicatePChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(*streamingpb.ReplicatePChannelMeta)) }) return _c } @@ -124,7 +123,7 @@ func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) Return(_a0 error) return _c } -func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, string, string) error) *MockReplicationCatalog_RemoveReplicatePChannel_Call { +func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, *streamingpb.ReplicatePChannelMeta) error) *MockReplicationCatalog_RemoveReplicatePChannel_Call { _c.Call.Return(run) return _c } diff --git a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go index 3ded0842a1..b80e0d789d 100644 --- a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go +++ b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go @@ -371,17 +371,17 @@ func (_c *MockStreamingCoordCataLog_ListReplicatePChannels_Call) RunAndReturn(ru return _c } -// RemoveReplicatePChannel provides a mock function with given fields: ctx, sourceChannelName, targetChannelName -func (_m *MockStreamingCoordCataLog) RemoveReplicatePChannel(ctx context.Context, sourceChannelName string, targetChannelName string) error { - ret := _m.Called(ctx, sourceChannelName, targetChannelName) +// RemoveReplicatePChannel provides a mock function with given fields: ctx, meta +func (_m *MockStreamingCoordCataLog) RemoveReplicatePChannel(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta) error { + ret := _m.Called(ctx, meta) if len(ret) == 0 { panic("no return value specified for RemoveReplicatePChannel") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, sourceChannelName, targetChannelName) + if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.ReplicatePChannelMeta) error); ok { + r0 = rf(ctx, meta) } else { r0 = ret.Error(0) } @@ -396,15 +396,14 @@ type MockStreamingCoordCataLog_RemoveReplicatePChannel_Call struct { // RemoveReplicatePChannel is a helper method to define mock.On call // - ctx context.Context -// - sourceChannelName string -// - targetChannelName string -func (_e *MockStreamingCoordCataLog_Expecter) RemoveReplicatePChannel(ctx interface{}, sourceChannelName interface{}, targetChannelName interface{}) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { - return &MockStreamingCoordCataLog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, sourceChannelName, targetChannelName)} +// - meta *streamingpb.ReplicatePChannelMeta +func (_e *MockStreamingCoordCataLog_Expecter) RemoveReplicatePChannel(ctx interface{}, meta interface{}) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { + return &MockStreamingCoordCataLog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, meta)} } -func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, sourceChannelName string, targetChannelName string)) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { +func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta)) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(*streamingpb.ReplicatePChannelMeta)) }) return _c } @@ -414,7 +413,7 @@ func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) Return(_a0 err return _c } -func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, string, string) error) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { +func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, *streamingpb.ReplicatePChannelMeta) error) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { _c.Call.Return(run) return _c } diff --git a/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go b/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go index 6a74cf209a..39f6ecc1f3 100644 --- a/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go +++ b/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go @@ -301,6 +301,63 @@ func (_c *MockWAL_GetLatestMVCCTimestamp_Call) RunAndReturn(run func(context.Con return _c } +// GetReplicateCheckpoint provides a mock function with no fields +func (_m *MockWAL) GetReplicateCheckpoint() (*wal.ReplicateCheckpoint, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetReplicateCheckpoint") + } + + var r0 *wal.ReplicateCheckpoint + var r1 error + if rf, ok := ret.Get(0).(func() (*wal.ReplicateCheckpoint, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *wal.ReplicateCheckpoint); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*wal.ReplicateCheckpoint) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockWAL_GetReplicateCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicateCheckpoint' +type MockWAL_GetReplicateCheckpoint_Call struct { + *mock.Call +} + +// GetReplicateCheckpoint is a helper method to define mock.On call +func (_e *MockWAL_Expecter) GetReplicateCheckpoint() *MockWAL_GetReplicateCheckpoint_Call { + return &MockWAL_GetReplicateCheckpoint_Call{Call: _e.mock.On("GetReplicateCheckpoint")} +} + +func (_c *MockWAL_GetReplicateCheckpoint_Call) Run(run func()) *MockWAL_GetReplicateCheckpoint_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWAL_GetReplicateCheckpoint_Call) Return(_a0 *wal.ReplicateCheckpoint, _a1 error) *MockWAL_GetReplicateCheckpoint_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockWAL_GetReplicateCheckpoint_Call) RunAndReturn(run func() (*wal.ReplicateCheckpoint, error)) *MockWAL_GetReplicateCheckpoint_Call { + _c.Call.Return(run) + return _c +} + // IsAvailable provides a mock function with no fields func (_m *MockWAL) IsAvailable() bool { ret := _m.Called() diff --git a/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates/mock_ReplicateAcker.go b/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates/mock_ReplicateAcker.go new file mode 100644 index 0000000000..79625a104e --- /dev/null +++ b/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates/mock_ReplicateAcker.go @@ -0,0 +1,65 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mock_replicates + +import mock "github.com/stretchr/testify/mock" + +// MockReplicateAcker is an autogenerated mock type for the ReplicateAcker type +type MockReplicateAcker struct { + mock.Mock +} + +type MockReplicateAcker_Expecter struct { + mock *mock.Mock +} + +func (_m *MockReplicateAcker) EXPECT() *MockReplicateAcker_Expecter { + return &MockReplicateAcker_Expecter{mock: &_m.Mock} +} + +// Ack provides a mock function with given fields: err +func (_m *MockReplicateAcker) Ack(err error) { + _m.Called(err) +} + +// MockReplicateAcker_Ack_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ack' +type MockReplicateAcker_Ack_Call struct { + *mock.Call +} + +// Ack is a helper method to define mock.On call +// - err error +func (_e *MockReplicateAcker_Expecter) Ack(err interface{}) *MockReplicateAcker_Ack_Call { + return &MockReplicateAcker_Ack_Call{Call: _e.mock.On("Ack", err)} +} + +func (_c *MockReplicateAcker_Ack_Call) Run(run func(err error)) *MockReplicateAcker_Ack_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(error)) + }) + return _c +} + +func (_c *MockReplicateAcker_Ack_Call) Return() *MockReplicateAcker_Ack_Call { + _c.Call.Return() + return _c +} + +func (_c *MockReplicateAcker_Ack_Call) RunAndReturn(run func(error)) *MockReplicateAcker_Ack_Call { + _c.Run(run) + return _c +} + +// NewMockReplicateAcker creates a new instance of MockReplicateAcker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockReplicateAcker(t interface { + mock.TestingT + Cleanup(func()) +}) *MockReplicateAcker { + mock := &MockReplicateAcker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates/mock_ReplicatesManager.go b/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates/mock_ReplicatesManager.go new file mode 100644 index 0000000000..1a8a2969de --- /dev/null +++ b/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates/mock_ReplicatesManager.go @@ -0,0 +1,252 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mock_replicates + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + replicates "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates" + + replicateutil "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" + + utility "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" +) + +// MockReplicatesManager is an autogenerated mock type for the ReplicatesManager type +type MockReplicatesManager struct { + mock.Mock +} + +type MockReplicatesManager_Expecter struct { + mock *mock.Mock +} + +func (_m *MockReplicatesManager) EXPECT() *MockReplicatesManager_Expecter { + return &MockReplicatesManager_Expecter{mock: &_m.Mock} +} + +// BeginReplicateMessage provides a mock function with given fields: ctx, msg +func (_m *MockReplicatesManager) BeginReplicateMessage(ctx context.Context, msg message.MutableMessage) (replicates.ReplicateAcker, error) { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for BeginReplicateMessage") + } + + var r0 replicates.ReplicateAcker + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) (replicates.ReplicateAcker, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) replicates.ReplicateAcker); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(replicates.ReplicateAcker) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockReplicatesManager_BeginReplicateMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeginReplicateMessage' +type MockReplicatesManager_BeginReplicateMessage_Call struct { + *mock.Call +} + +// BeginReplicateMessage is a helper method to define mock.On call +// - ctx context.Context +// - msg message.MutableMessage +func (_e *MockReplicatesManager_Expecter) BeginReplicateMessage(ctx interface{}, msg interface{}) *MockReplicatesManager_BeginReplicateMessage_Call { + return &MockReplicatesManager_BeginReplicateMessage_Call{Call: _e.mock.On("BeginReplicateMessage", ctx, msg)} +} + +func (_c *MockReplicatesManager_BeginReplicateMessage_Call) Run(run func(ctx context.Context, msg message.MutableMessage)) *MockReplicatesManager_BeginReplicateMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.MutableMessage)) + }) + return _c +} + +func (_c *MockReplicatesManager_BeginReplicateMessage_Call) Return(_a0 replicates.ReplicateAcker, _a1 error) *MockReplicatesManager_BeginReplicateMessage_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockReplicatesManager_BeginReplicateMessage_Call) RunAndReturn(run func(context.Context, message.MutableMessage) (replicates.ReplicateAcker, error)) *MockReplicatesManager_BeginReplicateMessage_Call { + _c.Call.Return(run) + return _c +} + +// GetReplicateCheckpoint provides a mock function with no fields +func (_m *MockReplicatesManager) GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetReplicateCheckpoint") + } + + var r0 *utility.ReplicateCheckpoint + var r1 error + if rf, ok := ret.Get(0).(func() (*utility.ReplicateCheckpoint, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *utility.ReplicateCheckpoint); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*utility.ReplicateCheckpoint) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockReplicatesManager_GetReplicateCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicateCheckpoint' +type MockReplicatesManager_GetReplicateCheckpoint_Call struct { + *mock.Call +} + +// GetReplicateCheckpoint is a helper method to define mock.On call +func (_e *MockReplicatesManager_Expecter) GetReplicateCheckpoint() *MockReplicatesManager_GetReplicateCheckpoint_Call { + return &MockReplicatesManager_GetReplicateCheckpoint_Call{Call: _e.mock.On("GetReplicateCheckpoint")} +} + +func (_c *MockReplicatesManager_GetReplicateCheckpoint_Call) Run(run func()) *MockReplicatesManager_GetReplicateCheckpoint_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockReplicatesManager_GetReplicateCheckpoint_Call) Return(_a0 *utility.ReplicateCheckpoint, _a1 error) *MockReplicatesManager_GetReplicateCheckpoint_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockReplicatesManager_GetReplicateCheckpoint_Call) RunAndReturn(run func() (*utility.ReplicateCheckpoint, error)) *MockReplicatesManager_GetReplicateCheckpoint_Call { + _c.Call.Return(run) + return _c +} + +// Role provides a mock function with no fields +func (_m *MockReplicatesManager) Role() replicateutil.Role { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Role") + } + + var r0 replicateutil.Role + if rf, ok := ret.Get(0).(func() replicateutil.Role); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(replicateutil.Role) + } + + return r0 +} + +// MockReplicatesManager_Role_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Role' +type MockReplicatesManager_Role_Call struct { + *mock.Call +} + +// Role is a helper method to define mock.On call +func (_e *MockReplicatesManager_Expecter) Role() *MockReplicatesManager_Role_Call { + return &MockReplicatesManager_Role_Call{Call: _e.mock.On("Role")} +} + +func (_c *MockReplicatesManager_Role_Call) Run(run func()) *MockReplicatesManager_Role_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockReplicatesManager_Role_Call) Return(_a0 replicateutil.Role) *MockReplicatesManager_Role_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockReplicatesManager_Role_Call) RunAndReturn(run func() replicateutil.Role) *MockReplicatesManager_Role_Call { + _c.Call.Return(run) + return _c +} + +// SwitchReplicateMode provides a mock function with given fields: ctx, msg +func (_m *MockReplicatesManager) SwitchReplicateMode(ctx context.Context, msg message.MutableAlterReplicateConfigMessageV2) error { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for SwitchReplicateMode") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, message.MutableAlterReplicateConfigMessageV2) error); ok { + r0 = rf(ctx, msg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockReplicatesManager_SwitchReplicateMode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SwitchReplicateMode' +type MockReplicatesManager_SwitchReplicateMode_Call struct { + *mock.Call +} + +// SwitchReplicateMode is a helper method to define mock.On call +// - ctx context.Context +// - msg message.MutableAlterReplicateConfigMessageV2 +func (_e *MockReplicatesManager_Expecter) SwitchReplicateMode(ctx interface{}, msg interface{}) *MockReplicatesManager_SwitchReplicateMode_Call { + return &MockReplicatesManager_SwitchReplicateMode_Call{Call: _e.mock.On("SwitchReplicateMode", ctx, msg)} +} + +func (_c *MockReplicatesManager_SwitchReplicateMode_Call) Run(run func(ctx context.Context, msg message.MutableAlterReplicateConfigMessageV2)) *MockReplicatesManager_SwitchReplicateMode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.MutableAlterReplicateConfigMessageV2)) + }) + return _c +} + +func (_c *MockReplicatesManager_SwitchReplicateMode_Call) Return(_a0 error) *MockReplicatesManager_SwitchReplicateMode_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockReplicatesManager_SwitchReplicateMode_Call) RunAndReturn(run func(context.Context, message.MutableAlterReplicateConfigMessageV2) error) *MockReplicatesManager_SwitchReplicateMode_Call { + _c.Call.Return(run) + return _c +} + +// NewMockReplicatesManager creates a new instance of MockReplicatesManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockReplicatesManager(t interface { + mock.TestingT + Cleanup(func()) +}, +) *MockReplicatesManager { + mock := &MockReplicatesManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proxy/replicate/replicate_stream_server.go b/internal/proxy/replicate/replicate_stream_server.go index adc3b37ea9..b646c24002 100644 --- a/internal/proxy/replicate/replicate_stream_server.go +++ b/internal/proxy/replicate/replicate_stream_server.go @@ -34,7 +34,7 @@ func CreateReplicateServer(streamServer milvuspb.MilvusService_CreateReplicateSt type ReplicateStreamServer struct { clusterID string streamServer milvuspb.MilvusService_CreateReplicateStreamServer - replicateRespCh chan *milvuspb.ReplicateResponse // All processing messages result should sent from theses channel. + replicateRespCh chan *milvuspb.ReplicateResponse wg sync.WaitGroup } @@ -111,7 +111,6 @@ func (p *ReplicateStreamServer) recvLoop() (err error) { // handleReplicateMessage handles the replicate message request. func (p *ReplicateStreamServer) handleReplicateMessage(req *milvuspb.ReplicateRequest_ReplicateMessage) error { - // TODO: sheep, update metrics. p.wg.Add(1) defer p.wg.Done() reqMsg := req.ReplicateMessage.GetMessage() diff --git a/internal/streamingcoord/server/service/assignment.go b/internal/streamingcoord/server/service/assignment.go index c40db68a0d..c067238c4a 100644 --- a/internal/streamingcoord/server/service/assignment.go +++ b/internal/streamingcoord/server/service/assignment.go @@ -32,7 +32,7 @@ func NewAssignmentService( listenerTotal: metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()), } // TODO: after recovering from wal, add it to here. - // registry.RegisterPutReplicateConfigV2AckCallback(assignmentService.putReplicateConfiguration) + // registry.RegisterAlterReplicateConfigV2AckCallback(assignmentService.AlterReplicateConfiguration) return assignmentService } @@ -83,7 +83,7 @@ func (s *assignmentServiceImpl) UpdateReplicateConfiguration(ctx context.Context } // TODO: After recovering from wal, remove the operation here. - if err := s.putReplicateConfiguration(ctx, mockMessages...); err != nil { + if err := s.AlterReplicateConfiguration(ctx, mockMessages...); err != nil { return nil, err } return &streamingpb.UpdateReplicateConfigurationResponse{}, nil @@ -130,9 +130,9 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte return b, nil } -// putReplicateConfiguration puts the replicate configuration into the balancer. +// AlterReplicateConfiguration puts the replicate configuration into the balancer. // It's a callback function of the broadcast service. -func (s *assignmentServiceImpl) putReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error { +func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error { balancer, err := s.balancer.GetWithContext(ctx) if err != nil { return err diff --git a/internal/streamingnode/client/handler/handler_client_impl.go b/internal/streamingnode/client/handler/handler_client_impl.go index ed031648d5..4786e834c9 100644 --- a/internal/streamingnode/client/handler/handler_client_impl.go +++ b/internal/streamingnode/client/handler/handler_client_impl.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" @@ -65,16 +66,41 @@ func (hc *handlerClientImpl) GetLatestMVCCTimestampIfLocal(ctx context.Context, return w.GetLatestMVCCTimestamp(ctx, vchannel) } -// GetReplicateCheckpoint returns the WAL checkpoint that will be used to create scanner. -func (hc *handlerClientImpl) GetReplicateCheckpoint(ctx context.Context, channelName string) (*wal.ReplicateCheckpoint, error) { +// GetReplicateCheckpoint gets the replicate checkpoint of the wal. +func (hc *handlerClientImpl) GetReplicateCheckpoint(ctx context.Context, pchannel string) (*wal.ReplicateCheckpoint, error) { if !hc.lifetime.Add(typeutil.LifetimeStateWorking) { return nil, ErrClientClosed } defer hc.lifetime.Done() - return nil, nil - - // TODO: sheep, implement it. + logger := log.With(zap.String("pchannel", pchannel), zap.String("handler", "replicate checkpoint")) + cp, err := hc.createHandlerAfterStreamingNodeReady(ctx, logger, pchannel, func(ctx context.Context, assign *types.PChannelInfoAssigned) (any, error) { + if assign.Channel.AccessMode != types.AccessModeRW { + return nil, errors.New("replicate checkpoint can only be read for RW channel") + } + localWAL, err := registry.GetLocalAvailableWAL(assign.Channel) + if err == nil { + return localWAL.GetReplicateCheckpoint() + } + if !shouldUseRemoteWAL(err) { + return nil, err + } + handlerService, err := hc.service.GetService(ctx) + if err != nil { + return nil, err + } + resp, err := handlerService.GetReplicateCheckpoint(ctx, &streamingpb.GetReplicateCheckpointRequest{ + Pchannel: types.NewProtoFromPChannelInfo(assign.Channel), + }) + if err != nil { + return nil, err + } + return utility.NewReplicateCheckpointFromProto(resp.Checkpoint), nil + }) + if err != nil { + return nil, err + } + return cp.(*wal.ReplicateCheckpoint), nil } // GetWALMetricsIfLocal gets the metrics of the local wal. diff --git a/internal/streamingnode/client/handler/handler_client_test.go b/internal/streamingnode/client/handler/handler_client_test.go index 397046ff27..ab8f6fce66 100644 --- a/internal/streamingnode/client/handler/handler_client_test.go +++ b/internal/streamingnode/client/handler/handler_client_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_assignment" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_consumer" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_producer" @@ -34,6 +35,14 @@ func TestHandlerClient(t *testing.T) { service := mock_lazygrpc.NewMockService[streamingpb.StreamingNodeHandlerServiceClient](t) handlerServiceClient := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t) + handlerServiceClient.EXPECT().GetReplicateCheckpoint(mock.Anything, mock.Anything).Return(&streamingpb.GetReplicateCheckpointResponse{ + Checkpoint: &commonpb.ReplicateCheckpoint{ + ClusterId: "pchannel", + Pchannel: "pchannel", + MessageId: nil, + TimeTick: 0, + }, + }, nil) service.EXPECT().GetService(mock.Anything).Return(handlerServiceClient, nil) rb := mock_resolver.NewMockBuilder(t) rb.EXPECT().Close().Run(func() {}) @@ -91,6 +100,10 @@ func TestHandlerClient(t *testing.T) { producer2.Close() producer3.Close() + rcp, err := handler.GetReplicateCheckpoint(ctx, "pchannel") + assert.NoError(t, err) + assert.NotNil(t, rcp) + handler.GetLatestMVCCTimestampIfLocal(ctx, "pchannel") producer4, err := handler.CreateProducer(ctx, &ProducerOptions{PChannel: "pchannel"}) assert.NoError(t, err) diff --git a/internal/streamingnode/server/flusher/flusherimpl/wal_flusher.go b/internal/streamingnode/server/flusher/flusherimpl/wal_flusher.go index 7cd78b9c7f..7a47d8cf5b 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/wal_flusher.go +++ b/internal/streamingnode/server/flusher/flusherimpl/wal_flusher.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" @@ -160,7 +161,7 @@ func (impl *WALFlusherImpl) buildFlusherComponents(ctx context.Context, l wal.WA impl.RecoveryStorage.UpdateFlusherCheckpoint(mp.ChannelName, &recovery.WALCheckpoint{ MessageID: messageID, TimeTick: mp.Timestamp, - Magic: recovery.RecoveryMagicStreamingInitialized, + Magic: utility.RecoveryMagicStreamingInitialized, }) }) go cpUpdater.Start() diff --git a/internal/streamingnode/server/service/handler.go b/internal/streamingnode/server/service/handler.go index 53e3626e43..e971ad4222 100644 --- a/internal/streamingnode/server/service/handler.go +++ b/internal/streamingnode/server/service/handler.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/service/handler/producer" "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" ) var _ HandlerService = (*handlerServiceImpl)(nil) @@ -27,12 +28,22 @@ type HandlerService = streamingpb.StreamingNodeHandlerServiceServer // 2. wait wal handling result and transform it into grpc response (convert error into grpc error) // 3. send response to client. type handlerServiceImpl struct { + streamingpb.UnimplementedStreamingNodeHandlerServiceServer + walManager walmanager.Manager } -// GetReplicateCheckpoint returns the WAL checkpoint that will be used to create scanner +// GetReplicateCheckpoint returns the replicate checkpoint of the wal. func (hs *handlerServiceImpl) GetReplicateCheckpoint(ctx context.Context, req *streamingpb.GetReplicateCheckpointRequest) (*streamingpb.GetReplicateCheckpointResponse, error) { - panic("not implemented") // TODO: sheep, implement it. + wal, err := hs.walManager.GetAvailableWAL(types.NewPChannelInfoFromProto(req.GetPchannel())) + if err != nil { + return nil, err + } + cp, err := wal.GetReplicateCheckpoint() + if err != nil { + return nil, err + } + return &streamingpb.GetReplicateCheckpointResponse{Checkpoint: cp.IntoProto()}, nil } // Produce creates a new producer for the channel on this log node. diff --git a/internal/streamingnode/server/wal/adaptor/opener.go b/internal/streamingnode/server/wal/adaptor/opener.go index 1f1849dfc3..eb264caea3 100644 --- a/internal/streamingnode/server/wal/adaptor/opener.go +++ b/internal/streamingnode/server/wal/adaptor/opener.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery" @@ -17,6 +18,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -107,6 +109,15 @@ func (o *openerAdaptorImpl) openRWWAL(ctx context.Context, l walimpls.WALImpls, InitialRecoverSnapshot: snapshot, TxnManager: param.TxnManager, }) + if param.ReplicateManager, err = replicates.RecoverReplicateManager( + &replicates.ReplicateManagerRecoverParam{ + ChannelInfo: param.ChannelInfo, + CurrentClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), + InitialRecoverSnapshot: snapshot, + }, + ); err != nil { + return nil, err + } // start the flusher to flush and generate recovery info. var flusher *flusherimpl.WALFlusherImpl diff --git a/internal/streamingnode/server/wal/adaptor/ro_wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/ro_wal_adaptor.go index 4fbd727226..d1a5e19431 100644 --- a/internal/streamingnode/server/wal/adaptor/ro_wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/ro_wal_adaptor.go @@ -50,6 +50,10 @@ func (w *roWALAdaptorImpl) GetLatestMVCCTimestamp(ctx context.Context, vchannel panic("we cannot acquire lastest mvcc timestamp from a read only wal") } +func (w *roWALAdaptorImpl) GetReplicateCheckpoint() (*wal.ReplicateCheckpoint, error) { + panic("we cannot get replicate checkpoint from a read only wal") +} + // Append writes a record to the log. func (w *roWALAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) { panic("we cannot append message into a read only wal") diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index 3c68501175..35f8379429 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -121,6 +121,16 @@ func (w *walAdaptorImpl) GetLatestMVCCTimestamp(ctx context.Context, vchannel st return currentMVCC.Timetick, nil } +// GetReplicateCheckpoint returns the replicate checkpoint of the wal. +func (w *walAdaptorImpl) GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) { + if !w.lifetime.Add(typeutil.LifetimeStateWorking) { + return nil, status.NewOnShutdownError("wal is on shutdown") + } + defer w.lifetime.Done() + + return w.param.ReplicateManager.GetReplicateCheckpoint() +} + // Append writes a record to the log. func (w *walAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) { if !w.lifetime.Add(typeutil.LifetimeStateWorking) { diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index b3444cf638..02aa63a398 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/lock" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/redo" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" @@ -57,6 +58,7 @@ func TestWAL(t *testing.T) { b := registry.MustGetBuilder(message.WALNameTest, redo.NewInterceptorBuilder(), lock.NewInterceptorBuilder(), + replicate.NewInterceptorBuilder(), timetick.NewInterceptorBuilder(), shard.NewInterceptorBuilder(), ) @@ -181,6 +183,10 @@ func (f *testOneWALFramework) Run() { } func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, rwWAL wal.WAL, roWAL wal.ROWAL) { + cp, err := rwWAL.GetReplicateCheckpoint() + assert.True(f.t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(f.t, cp) + f.testSendCreateCollection(ctx, rwWAL) defer f.testSendDropCollection(ctx, rwWAL) diff --git a/internal/streamingnode/server/wal/interceptors/interceptor.go b/internal/streamingnode/server/wal/interceptors/interceptor.go index d76b52bfc7..846d4f6a71 100644 --- a/internal/streamingnode/server/wal/interceptors/interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/interceptor.go @@ -4,6 +4,7 @@ import ( "context" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/mvcc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" @@ -22,13 +23,14 @@ type ( // InterceptorBuildParam is the parameter to build a interceptor. type InterceptorBuildParam struct { ChannelInfo types.PChannelInfo - WAL *syncutil.Future[wal.WAL] // The wal final object, can be used after interceptor is ready. - LastTimeTickMessage message.ImmutableMessage // The last time tick message in wal. - WriteAheadBuffer *wab.WriteAheadBuffer // The write ahead buffer for the wal, used to erase the subscription of underlying wal. - MVCCManager *mvcc.MVCCManager // The MVCC manager for the wal, can be used to get the latest mvcc timetick. - InitialRecoverSnapshot *recovery.RecoverySnapshot // The initial recover snapshot for the wal, used to recover the wal state. - TxnManager *txn.TxnManager // The transaction manager for the wal, used to manage the transactions. - ShardManager shards.ShardManager // The shard manager for the wal, used to manage the shards, segment assignment, partition. + WAL *syncutil.Future[wal.WAL] // The wal final object, can be used after interceptor is ready. + LastTimeTickMessage message.ImmutableMessage // The last time tick message in wal. + WriteAheadBuffer *wab.WriteAheadBuffer // The write ahead buffer for the wal, used to erase the subscription of underlying wal. + MVCCManager *mvcc.MVCCManager // The MVCC manager for the wal, can be used to get the latest mvcc timetick. + InitialRecoverSnapshot *recovery.RecoverySnapshot // The initial recover snapshot for the wal, used to recover the wal state. + TxnManager *txn.TxnManager // The transaction manager for the wal, used to manage the transactions. + ShardManager shards.ShardManager // The shard manager for the wal, used to manage the shards, segment assignment, partition. + ReplicateManager replicates.ReplicateManager // The replicates manager for the wal, used to manage the replicates. } // Clear release the resources in the interceptor build param. diff --git a/internal/streamingnode/server/wal/interceptors/lock/builder.go b/internal/streamingnode/server/wal/interceptors/lock/builder.go index 7eb53bdcda..15c45ed31b 100644 --- a/internal/streamingnode/server/wal/interceptors/lock/builder.go +++ b/internal/streamingnode/server/wal/interceptors/lock/builder.go @@ -16,6 +16,7 @@ type interceptorBuilder struct{} // Build creates a new redo interceptor. func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor { return &lockAppendInterceptor{ + channel: param.ChannelInfo, vchannelLocker: lock.NewKeyLock[string](), txnManager: param.TxnManager, } diff --git a/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go b/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go index d8f0c63b84..ab2a11b688 100644 --- a/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go @@ -2,14 +2,18 @@ package lock import ( "context" + "sync" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/lock" ) type lockAppendInterceptor struct { + channel types.PChannelInfo + glock sync.RWMutex // glock is a wal level lock, it will acquire a highest level lock for wal. vchannelLocker *lock.KeyLock[string] txnManager *txn.TxnManager } @@ -26,23 +30,34 @@ func (r *lockAppendInterceptor) acquireLockGuard(_ context.Context, msg message. // Acquire the write lock for the vchannel. vchannel := msg.VChannel() if msg.MessageType().IsExclusiveRequired() { - r.vchannelLocker.Lock(vchannel) - return func() { - // For exclusive messages, we need to fail all transactions at the vchannel. - // Otherwise, the transaction message may cross the exclusive message. - // e.g. an exclusive message like `ManualFlush` happens, it will flush all the growing segment. - // But the transaction insert message that use those segments may not be committed, - // if we allow it to be committed, a insert message can be seen after the manual flush message, lead to the wrong wal message order. - // So we need to fail all transactions at the vchannel, it will be retried at client side with new txn. - // - // the append operation of exclusive message should be low rate, so it's acceptable to fail all transactions at the vchannel. - r.txnManager.FailTxnAtVChannel(vchannel) - r.vchannelLocker.Unlock(vchannel) + if vchannel == "" || vchannel == r.channel.Name { + r.glock.Lock() + return func() { + // fail all transactions at all vchannels. + r.txnManager.FailTxnAtVChannel("") + r.glock.Unlock() + } + } else { + r.vchannelLocker.Lock(vchannel) + return func() { + // For exclusive messages, we need to fail all transactions at the vchannel. + // Otherwise, the transaction message may cross the exclusive message. + // e.g. an exclusive message like `ManualFlush` happens, it will flush all the growing segment. + // But the transaction insert message that use those segments may not be committed, + // if we allow it to be committed, a insert message can be seen after the manual flush message, lead to the wrong wal message order. + // So we need to fail all transactions at the vchannel, it will be retried at client side with new txn. + // + // the append operation of exclusive message should be low rate, so it's acceptable to fail all transactions at the vchannel. + r.txnManager.FailTxnAtVChannel(vchannel) + r.vchannelLocker.Unlock(vchannel) + } } } + r.glock.RLock() r.vchannelLocker.RLock(vchannel) return func() { r.vchannelLocker.RUnlock(vchannel) + r.glock.RUnlock() } } diff --git a/internal/streamingnode/server/wal/interceptors/replicate/builder.go b/internal/streamingnode/server/wal/interceptors/replicate/builder.go new file mode 100644 index 0000000000..e77be520ac --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/builder.go @@ -0,0 +1,15 @@ +package replicate + +import "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + +func NewInterceptorBuilder() interceptors.InterceptorBuilder { + return &interceptorBuilder{} +} + +type interceptorBuilder struct{} + +func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor { + return &replicateInterceptor{ + replicateManager: param.ReplicateManager, + } +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicate_interceptor.go b/internal/streamingnode/server/wal/interceptors/replicate/replicate_interceptor.go new file mode 100644 index 0000000000..d563393472 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicate_interceptor.go @@ -0,0 +1,51 @@ +package replicate + +import ( + "context" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" +) + +const interceptorName = "replicate" + +type replicateInterceptor struct { + replicateManager replicates.ReplicateManager +} + +func (impl *replicateInterceptor) Name() string { + return interceptorName +} + +func (impl *replicateInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (msgID message.MessageID, err error) { + if msg.MessageType() == message.MessageTypeAlterReplicateConfig { + // A AlterReplicateConfig message is protected by wal level lock, so it's safe to switch replicate mode. + // switch replicate mode if the message is put replicate config. + alterReplicateConfig := message.MustAsMutableAlterReplicateConfigMessageV2(msg) + if err := impl.replicateManager.SwitchReplicateMode(ctx, alterReplicateConfig); err != nil { + return nil, err + } + return appendOp(ctx, msg) + } + + // Begin to replicate the message. + acker, err := impl.replicateManager.BeginReplicateMessage(ctx, msg) + if errors.Is(err, replicates.ErrNotHandledByReplicateManager) { + // the message is not handled by replicate manager, write it into wal directly. + return appendOp(ctx, msg) + } + if err != nil { + return nil, err + } + + defer func() { + acker.Ack(err) + }() + return appendOp(ctx, msg) +} + +func (impl *replicateInterceptor) Close() { +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicate_interceptor_test.go b/internal/streamingnode/server/wal/interceptors/replicate/replicate_interceptor_test.go new file mode 100644 index 0000000000..42d29bd54e --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicate_interceptor_test.go @@ -0,0 +1,71 @@ +package replicate + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" +) + +func TestReplicateInterceptor(t *testing.T) { + manager := mock_replicates.NewMockReplicatesManager(t) + acker := mock_replicates.NewMockReplicateAcker(t) + manager.EXPECT().SwitchReplicateMode(mock.Anything, mock.Anything).Return(nil) + manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Return(acker, nil) + acker.EXPECT().Ack(mock.Anything).Return() + + interceptor := NewInterceptorBuilder().Build(&interceptors.InterceptorBuildParam{ + ReplicateManager: manager, + }) + mutableMsg := message.NewAlterReplicateConfigMessageBuilderV2(). + WithHeader(&message.AlterReplicateConfigMessageHeader{}). + WithBody(&message.AlterReplicateConfigMessageBody{}). + WithAllVChannel(). + MustBuildMutable() + + msgID, err := interceptor.DoAppend(context.Background(), mutableMsg, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + return walimplstest.NewTestMessageID(1), nil + }) + assert.NoError(t, err) + assert.NotNil(t, msgID) + + mutableMsg2 := message.NewCreateDatabaseMessageBuilderV2(). + WithHeader(&message.CreateDatabaseMessageHeader{}). + WithBody(&message.CreateDatabaseMessageBody{}). + WithVChannel("test"). + MustBuildMutable() + + msgID2, err := interceptor.DoAppend(context.Background(), mutableMsg2, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + return walimplstest.NewTestMessageID(2), nil + }) + assert.NoError(t, err) + assert.NotNil(t, msgID2) + + manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Unset() + manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Return(nil, replicates.ErrNotHandledByReplicateManager) + + msgID3, err := interceptor.DoAppend(context.Background(), mutableMsg2, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + return walimplstest.NewTestMessageID(3), nil + }) + assert.NoError(t, err) + assert.NotNil(t, msgID3) + + manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Unset() + manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Return(nil, errors.New("test")) + + msgID4, err := interceptor.DoAppend(context.Background(), mutableMsg2, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + return walimplstest.NewTestMessageID(4), nil + }) + assert.Error(t, err) + assert.Nil(t, msgID4) + + interceptor.Close() +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicates/impl.go b/internal/streamingnode/server/wal/interceptors/replicate/replicates/impl.go new file mode 100644 index 0000000000..2e775c653c --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicates/impl.go @@ -0,0 +1,240 @@ +package replicates + +import ( + "context" + "sync" + + "github.com/cockroachdb/errors" + "google.golang.org/protobuf/encoding/protojson" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" +) + +// ErrNotHandledByReplicateManager is a special error to indicate that the message should not be handled by the replicate manager. +var ErrNotHandledByReplicateManager = errors.New("not handled by replicate manager") + +// ReplicateManagerRecoverParam is the parameter for recovering the replicate manager. +type ReplicateManagerRecoverParam struct { + ChannelInfo types.PChannelInfo + CurrentClusterID string + InitialRecoverSnapshot *recovery.RecoverySnapshot // the initial recover snapshot of the replicate manager. +} + +// RecoverReplicateManager recovers the replicate manager from the initial recover snapshot. +// It will recover the replicate manager from the initial recover snapshot. +// If the wal is on replicating mode, it will recover the replicate state. +func RecoverReplicateManager(param *ReplicateManagerRecoverParam) (ReplicateManager, error) { + replicateConfigHelper, err := replicateutil.NewConfigHelper(param.CurrentClusterID, param.InitialRecoverSnapshot.Checkpoint.ReplicateConfig) + if err != nil { + return nil, newReplicateViolationErrorForConfig(param.InitialRecoverSnapshot.Checkpoint.ReplicateConfig, err) + } + rm := &replicatesManagerImpl{ + mu: sync.Mutex{}, + currentClusterID: param.CurrentClusterID, + pchannel: param.ChannelInfo, + replicateConfigHelper: replicateConfigHelper, + } + if !rm.isPrimaryRole() { + // if current cluster is not the primary role, + // recover the secondary state for it. + if rm.secondaryState, err = recoverSecondaryState(param); err != nil { + return nil, err + } + } + return rm, nil +} + +// replicatesManagerImpl is the implementation of the replicates manager. +type replicatesManagerImpl struct { + mu sync.Mutex + pchannel types.PChannelInfo + currentClusterID string + replicateConfigHelper *replicateutil.ConfigHelper + secondaryState *secondaryState // if the current cluster is not the primary role, it will have secondaryState. +} + +// SwitchReplicateMode switches the replicates manager between replicating mode and non-replicating mode. +func (impl *replicatesManagerImpl) SwitchReplicateMode(_ context.Context, msg message.MutableAlterReplicateConfigMessageV2) error { + impl.mu.Lock() + defer impl.mu.Unlock() + + newCfg := msg.Header().ReplicateConfiguration + newGraph, err := replicateutil.NewConfigHelper(impl.currentClusterID, newCfg) + if err != nil { + return newReplicateViolationErrorForConfig(newCfg, err) + } + incomingCurrentClusterConfig := newGraph.GetCurrentCluster() + switch incomingCurrentClusterConfig.Role() { + case replicateutil.RolePrimary: + // drop the replicating state if the current cluster is switched to primary. + impl.secondaryState = nil + case replicateutil.RoleSecondary: + if impl.isPrimaryRole() || impl.secondaryState.SourceClusterID() != incomingCurrentClusterConfig.SourceCluster().GetClusterId() { + // Only update the replicating state when the current cluster switch from primary to secondary, + // or the source cluster is changed. + impl.secondaryState = newSecondaryState( + incomingCurrentClusterConfig.SourceCluster().GetClusterId(), + incomingCurrentClusterConfig.MustGetSourceChannel(impl.pchannel.Name), + ) + } + } + impl.replicateConfigHelper = newGraph + return nil +} + +func (impl *replicatesManagerImpl) BeginReplicateMessage(ctx context.Context, msg message.MutableMessage) (g ReplicateAcker, err error) { + rh := msg.ReplicateHeader() + // some message type like timetick, create segment, flush are generated by wal itself. + // it should never be handled by the replicates manager. + if msg.MessageType().IsSelfControlled() { + if rh != nil { + return nil, status.NewIgnoreOperation("wal self-controlled message cannot be replicated") + } + return nil, ErrNotHandledByReplicateManager + } + + impl.mu.Lock() + defer func() { + if err != nil { + impl.mu.Unlock() + } + }() + + switch impl.getRole() { + case replicateutil.RolePrimary: + if rh != nil { + return nil, status.NewReplicateViolation("replicate message cannot be received in primary role") + } + return nil, ErrNotHandledByReplicateManager + case replicateutil.RoleSecondary: + if rh == nil { + return nil, status.NewReplicateViolation("non-replicate message cannot be received in secondary role") + } + return impl.beginReplicateMessage(ctx, msg) + default: + panic("unreachable: invalid role") + } +} + +// GetReplicateCheckpoint gets the replicate checkpoint. +func (impl *replicatesManagerImpl) GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) { + impl.mu.Lock() + defer impl.mu.Unlock() + + if impl.isPrimaryRole() { + return nil, status.NewReplicateViolation("wal is not a secondary cluster in replicating topology") + } + return impl.secondaryState.GetCheckpoint(), nil +} + +// beginReplicateMessage begins the replicate message operation. +func (impl *replicatesManagerImpl) beginReplicateMessage(ctx context.Context, msg message.MutableMessage) (ReplicateAcker, error) { + rh := msg.ReplicateHeader() + if rh.ClusterID != impl.secondaryState.SourceClusterID() { + return nil, status.NewReplicateViolation("cluster id mismatch, current: %s, expected: %s", rh.ClusterID, impl.secondaryState.SourceClusterID()) + } + + // if the incoming message's time tick is less than the checkpoint's time tick, + // it means that the message has been written to the wal, so it can be ignored. + // txn message will share same time tick, so we only filter with <, it will be deduplicated by the txnHelper. + isTxnBody := msg.TxnContext() != nil && msg.MessageType() != message.MessageTypeBeginTxn + if (isTxnBody && rh.TimeTick < impl.secondaryState.GetCheckpoint().TimeTick) || (!isTxnBody && rh.TimeTick <= impl.secondaryState.GetCheckpoint().TimeTick) { + return nil, status.NewIgnoreOperation("message is too old, message_id: %s, time_tick: %d, txn: %t, current time tick: %d", + rh.MessageID, rh.TimeTick, isTxnBody, impl.secondaryState.GetCheckpoint().TimeTick) + } + + if msg.TxnContext() != nil { + return impl.startReplicateTxnMessage(ctx, msg, rh) + } + return impl.startReplicateNonTxnMessage(ctx, msg, rh) +} + +// startReplicateTxnMessage starts the replicate txn message operation. +func (impl *replicatesManagerImpl) startReplicateTxnMessage(_ context.Context, msg message.MutableMessage, rh *message.ReplicateHeader) (ReplicateAcker, error) { + txn := msg.TxnContext() + switch msg.MessageType() { + case message.MessageTypeBeginTxn: + if err := impl.secondaryState.StartBegin(txn, rh); err != nil { + return nil, err + } + return replicateAckerImpl(func(err error) { + if err == nil { + impl.secondaryState.BeginDone(txn) + } + impl.mu.Unlock() + }), nil + case message.MessageTypeCommitTxn: + if err := impl.secondaryState.StartCommit(txn); err != nil { + return nil, err + } + // only update the checkpoint when the txn is committed. + return replicateAckerImpl(func(err error) { + if err == nil { + impl.secondaryState.CommitDone(txn) + impl.secondaryState.PushForwardCheckpoint(rh.TimeTick, rh.LastConfirmedMessageID) + } + impl.mu.Unlock() + }), nil + case message.MessageTypeRollbackTxn: + panic("unreachable: rollback txn message should never be replicated when wal is on replicating mode") + default: + if err := impl.secondaryState.AddNewMessage(txn, rh); err != nil { + return nil, err + } + return replicateAckerImpl(func(err error) { + if err == nil { + impl.secondaryState.AddNewMessageDone(rh) + } + impl.mu.Unlock() + }), nil + } +} + +// startReplicateNonTxnMessage starts the replicate non-txn message operation. +func (impl *replicatesManagerImpl) startReplicateNonTxnMessage(_ context.Context, _ message.MutableMessage, rh *message.ReplicateHeader) (ReplicateAcker, error) { + if impl.secondaryState.CurrentTxn() != nil { + return nil, status.NewReplicateViolation( + "txn is in progress, so the incoming message must be txn message, current txn: %d", + impl.secondaryState.CurrentTxn().TxnID, + ) + } + return replicateAckerImpl(func(err error) { + if err == nil { + impl.secondaryState.PushForwardCheckpoint(rh.TimeTick, rh.LastConfirmedMessageID) + } + impl.mu.Unlock() + }), nil +} + +// Role returns the role of the current cluster in the replicate topology. +func (impl *replicatesManagerImpl) Role() replicateutil.Role { + impl.mu.Lock() + defer impl.mu.Unlock() + + return impl.getRole() +} + +// getRole returns the role of the current cluster in the replicate topology. +func (impl *replicatesManagerImpl) getRole() replicateutil.Role { + if impl.replicateConfigHelper == nil { + return replicateutil.RolePrimary + } + return impl.replicateConfigHelper.MustGetCluster(impl.currentClusterID).Role() +} + +// isPrimaryRole checks if the current cluster is the primary role. +func (impl *replicatesManagerImpl) isPrimaryRole() bool { + return impl.getRole() == replicateutil.RolePrimary +} + +// newReplicateViolationErrorForConfig creates a new replicate violation error for the given configuration and error. +func newReplicateViolationErrorForConfig(cfg *commonpb.ReplicateConfiguration, err error) error { + bytes, _ := protojson.Marshal(cfg) + return status.NewReplicateViolation("when greating replciate graph, %s, %s", string(bytes), err.Error()) +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicates/manager.go b/internal/streamingnode/server/wal/interceptors/replicate/replicates/manager.go new file mode 100644 index 0000000000..595e26a52c --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicates/manager.go @@ -0,0 +1,48 @@ +package replicates + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" +) + +type replicateAckerImpl func(err error) + +func (r replicateAckerImpl) Ack(err error) { + r(err) +} + +// ReplicateAcker is a guard for replicate message. +type ReplicateAcker interface { + // Ack acknowledges the replicate message operation is done. + // It will push forward the in-memory checkpoint if the err is nil. + Ack(err error) +} + +// ReplicateManager manages the replicate operation on one wal. +// There are two states: +// 1. primary: wal will only receive the non-replicate message. +// 2. secondary: wal will only receive the replicate message. +type ReplicateManager interface { + // Role returns the role of the replicate manager. + Role() replicateutil.Role + + // SwitchReplicateMode switches the replicate mode. + // following cases will happens: + // 1. primary->secondary: will transit into replicating mode, the message without replicate header will be rejected. + // 2. primary->primary: nothing happens, + // 3. secondary->primary: will transit into non-replicating mode, the secondary replica state (remote cluster replicating checkpoint...) will be dropped. + // 4. secondary->secondary with the source cluster is changed: the previous remote cluster replicating checkpoint will be dropped. + // 5. secondary->secondary without the source cluster is changed: nothing happens. + SwitchReplicateMode(ctx context.Context, msg message.MutableAlterReplicateConfigMessageV2) error + + // BeginReplicateMessage begins the replicate one-replicated-message operation. + // ReplicateAcker's Ack method should be called if returned without error. + BeginReplicateMessage(ctx context.Context, msg message.MutableMessage) (ReplicateAcker, error) + + // GetReplicateCheckpoint gets current replicate checkpoint. + // return ReplicateViolationError if the replicate mode is not replicating. + GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicates/manager_test.go b/internal/streamingnode/server/wal/interceptors/replicate/replicates/manager_test.go new file mode 100644 index 0000000000..1e99fafa2d --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicates/manager_test.go @@ -0,0 +1,495 @@ +package replicates + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/metricsutil" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" +) + +func TestNonReplicateManager(t *testing.T) { + rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{ + ChannelInfo: types.PChannelInfo{ + Name: "test1-rootcoord-dml_0", + Term: 1, + }, + CurrentClusterID: "test1", + InitialRecoverSnapshot: &recovery.RecoverySnapshot{ + Checkpoint: &utility.WALCheckpoint{ + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + ReplicateCheckpoint: nil, + ReplicateConfig: nil, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + + testSwitchReplicateMode(t, rm, "test1", "test2") + testMessageOnPrimary(t, rm) + testMessageOnSecondary(t, rm) +} + +func TestPrimaryReplicateManager(t *testing.T) { + rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{ + ChannelInfo: types.PChannelInfo{ + Name: "test1-rootcoord-dml_0", + Term: 1, + }, + CurrentClusterID: "test1", + InitialRecoverSnapshot: &recovery.RecoverySnapshot{ + Checkpoint: &utility.WALCheckpoint{ + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + ReplicateCheckpoint: nil, + ReplicateConfig: newReplicateConfiguration("test1", "test2"), + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + + testSwitchReplicateMode(t, rm, "test1", "test2") + testMessageOnPrimary(t, rm) + testMessageOnSecondary(t, rm) +} + +func TestSecondaryReplicateManager(t *testing.T) { + rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{ + ChannelInfo: types.PChannelInfo{ + Name: "test1-rootcoord-dml_0", + Term: 1, + }, + CurrentClusterID: "test1", + InitialRecoverSnapshot: &recovery.RecoverySnapshot{ + Checkpoint: &utility.WALCheckpoint{ + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + ReplicateCheckpoint: &utility.ReplicateCheckpoint{ + ClusterID: "test2", + PChannel: "test2-rootcoord-dml_0", + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + }, + ReplicateConfig: newReplicateConfiguration("test2", "test1"), + }, + TxnBuffer: utility.NewTxnBuffer(log.With(), metricsutil.NewScanMetrics(types.PChannelInfo{}).NewScannerMetrics()), + }, + }) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + + testSwitchReplicateMode(t, rm, "test1", "test2") + testMessageOnPrimary(t, rm) + testMessageOnSecondary(t, rm) +} + +func TestSecondaryReplicateManagerWithTxn(t *testing.T) { + txnBuffer := utility.NewTxnBuffer(log.With(), metricsutil.NewScanMetrics(types.PChannelInfo{}).NewScannerMetrics()) + txnMsgs := newReplicateTxnMessage("test1", "test2", 2) + + for _, msg := range txnMsgs[0:3] { + immutableMsg := msg.WithTimeTick(3).IntoImmutableMessage(walimplstest.NewTestMessageID(1)) + txnBuffer.HandleImmutableMessages([]message.ImmutableMessage{immutableMsg}, msg.TimeTick()) + } + + rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{ + ChannelInfo: types.PChannelInfo{ + Name: "test1-rootcoord-dml_0", + Term: 1, + }, + CurrentClusterID: "test1", + InitialRecoverSnapshot: &recovery.RecoverySnapshot{ + Checkpoint: &utility.WALCheckpoint{ + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + ReplicateCheckpoint: &utility.ReplicateCheckpoint{ + ClusterID: "test2", + PChannel: "test2-rootcoord-dml_0", + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 1, + }, + ReplicateConfig: newReplicateConfiguration("test2", "test1"), + }, + TxnBuffer: txnBuffer, + }, + }) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + + committed := false + for _, msg := range newReplicateTxnMessage("test1", "test2", 2) { + g, err := rm.BeginReplicateMessage(context.Background(), msg) + if msg.MessageType() == message.MessageTypeCommitTxn && !committed { + assert.NoError(t, err) + assert.NotNil(t, g) + g.Ack(nil) + committed = true + } else { + assert.True(t, status.AsStreamingError(err).IsIgnoredOperation()) + assert.Nil(t, g) + } + } +} + +func testSwitchReplicateMode(t *testing.T, rm ReplicateManager, primaryClusterID, secondaryClusterID string) { + ctx := context.Background() + + // switch to primary + err := rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + cp, err := rm.GetReplicateCheckpoint() + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, cp) + + // idempotent switch to primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + cp, err = rm.GetReplicateCheckpoint() + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, cp) + + // switch to secondary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, secondaryClusterID) + assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // idempotent switch to secondary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, secondaryClusterID) + assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // switch back to primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + cp, err = rm.GetReplicateCheckpoint() + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, cp) + + // idempotent switch back to primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + cp, err = rm.GetReplicateCheckpoint() + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, cp) + + // switch back to secondary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, secondaryClusterID) + assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // idempotent switch back to secondary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, secondaryClusterID) + assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // add a new cluster and switch to primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID, "test3")) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + cp, err = rm.GetReplicateCheckpoint() + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, cp) + + // idempotent add a new cluster and switch to primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID, "test3")) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RolePrimary) + cp, err = rm.GetReplicateCheckpoint() + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, cp) + + // add a new cluster and switch to secondary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID, "test3")) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, secondaryClusterID) + assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // idempotent add a new cluster and switch to secondary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID, "test3")) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, secondaryClusterID) + assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // switch the primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage("test3", primaryClusterID, secondaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, "test3") + assert.Equal(t, cp.PChannel, "test3-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) + + // idempotent switch the primary + err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage("test3", primaryClusterID, secondaryClusterID)) + assert.NoError(t, err) + assert.Equal(t, rm.Role(), replicateutil.RoleSecondary) + cp, err = rm.GetReplicateCheckpoint() + assert.NoError(t, err) + assert.Equal(t, cp.ClusterID, "test3") + assert.Equal(t, cp.PChannel, "test3-rootcoord-dml_0") + assert.Nil(t, cp.MessageID) + assert.Equal(t, cp.TimeTick, uint64(0)) +} + +func testMessageOnPrimary(t *testing.T, rm ReplicateManager) { + // switch to primary + err := rm.SwitchReplicateMode(context.Background(), newAlterReplicateConfigMessage("test1", "test2")) + assert.NoError(t, err) + + // Test self-controlled message + g, err := rm.BeginReplicateMessage(context.Background(), message.NewCreateSegmentMessageBuilderV2(). + WithHeader(&message.CreateSegmentMessageHeader{}). + WithBody(&message.CreateSegmentMessageBody{}). + WithVChannel("test1-rootcoord-dml_0"). + MustBuildMutable()) + assert.ErrorIs(t, err, ErrNotHandledByReplicateManager) + assert.Nil(t, g) + + // Test non-replicate message + msg := newNonReplicateMessage("test1") + g, err = rm.BeginReplicateMessage(context.Background(), msg) + assert.ErrorIs(t, err, ErrNotHandledByReplicateManager) + assert.Nil(t, g) + + // Test replicate message + msg = newReplicateMessage("test1", "test2") + g, err = rm.BeginReplicateMessage(context.Background(), msg) + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, g) +} + +func testMessageOnSecondary(t *testing.T, rm ReplicateManager) { + // switch to secondary + err := rm.SwitchReplicateMode(context.Background(), newAlterReplicateConfigMessage("test2", "test1")) + assert.NoError(t, err) + + // Test wrong cluster replicates + msg := newReplicateMessage("test1", "test3") + g, err := rm.BeginReplicateMessage(context.Background(), msg) + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, g) + + // Test self-controlled message + g, err = rm.BeginReplicateMessage(context.Background(), message.NewCreateSegmentMessageBuilderV2(). + WithHeader(&message.CreateSegmentMessageHeader{}). + WithBody(&message.CreateSegmentMessageBody{}). + WithVChannel("test1-rootcoord-dml_0"). + MustBuildMutable()) + assert.ErrorIs(t, err, ErrNotHandledByReplicateManager) + assert.Nil(t, g) + + // Test non-replicate message + msg = newNonReplicateMessage("test1") + g, err = rm.BeginReplicateMessage(context.Background(), msg) + assert.True(t, status.AsStreamingError(err).IsReplicateViolation()) + assert.Nil(t, g) + + // Test replicate message + msg = newReplicateMessage("test1", "test2") + g, err = rm.BeginReplicateMessage(context.Background(), msg) + assert.NoError(t, err) + assert.NotNil(t, g) + g.Ack(nil) + + // Test replicate message + msg = newReplicateMessage("test1", "test2") + g, err = rm.BeginReplicateMessage(context.Background(), msg) + assert.True(t, status.AsStreamingError(err).IsIgnoredOperation()) + assert.Nil(t, g) + + for idx, msg := range newReplicateTxnMessage("test1", "test2", 2) { + g, err = rm.BeginReplicateMessage(context.Background(), msg) + if idx%2 == 0 { + assert.NoError(t, err) + assert.NotNil(t, g) + g.Ack(nil) + } else { + assert.True(t, status.AsStreamingError(err).IsIgnoredOperation()) + assert.Nil(t, g) + } + } + + msg = newReplicateMessage("test1", "test2", 2) + g, err = rm.BeginReplicateMessage(context.Background(), msg) + assert.True(t, status.AsStreamingError(err).IsIgnoredOperation()) + assert.Nil(t, g) + + g, err = rm.BeginReplicateMessage(context.Background(), newReplicateTxnMessage("test1", "test2", 2)[0]) + assert.True(t, status.AsStreamingError(err).IsIgnoredOperation()) + assert.Nil(t, g) +} + +// newReplicateConfiguration creates a valid replicate configuration for testing +func newReplicateConfiguration(primaryClusterID string, secondaryClusterID ...string) *commonpb.ReplicateConfiguration { + clusters := []*commonpb.MilvusCluster{ + {ClusterId: primaryClusterID, Pchannels: []string{primaryClusterID + "-rootcoord-dml_0", primaryClusterID + "-rootcoord-dml_1"}}, + } + crossClusterTopology := []*commonpb.CrossClusterTopology{} + for _, secondaryClusterID := range secondaryClusterID { + clusters = append(clusters, &commonpb.MilvusCluster{ClusterId: secondaryClusterID, Pchannels: []string{secondaryClusterID + "-rootcoord-dml_0", secondaryClusterID + "-rootcoord-dml_1"}}) + crossClusterTopology = append(crossClusterTopology, &commonpb.CrossClusterTopology{SourceClusterId: primaryClusterID, TargetClusterId: secondaryClusterID}) + } + return &commonpb.ReplicateConfiguration{ + Clusters: clusters, + CrossClusterTopology: crossClusterTopology, + } +} + +func newAlterReplicateConfigMessage(primaryClusterID string, secondaryClusterID ...string) message.MutableAlterReplicateConfigMessageV2 { + return message.MustAsMutableAlterReplicateConfigMessageV2(message.NewAlterReplicateConfigMessageBuilderV2(). + WithHeader(&message.AlterReplicateConfigMessageHeader{ + ReplicateConfiguration: newReplicateConfiguration(primaryClusterID, secondaryClusterID...), + }). + WithBody(&message.AlterReplicateConfigMessageBody{}). + WithVChannel(primaryClusterID + "-rootcoord-dml_0"). + MustBuildMutable()) +} + +func newNonReplicateMessage(clusterID string) message.MutableMessage { + return message.NewCreateDatabaseMessageBuilderV2(). + WithHeader(&message.CreateDatabaseMessageHeader{}). + WithBody(&message.CreateDatabaseMessageBody{}). + WithVChannel(clusterID + "-rootcoord-dml_0"). + MustBuildMutable() +} + +func newReplicateMessage(clusterID string, sourceClusterID string, timetick ...uint64) message.MutableMessage { + tt := uint64(1) + if len(timetick) > 0 { + tt = timetick[0] + } + msg := message.NewCreateDatabaseMessageBuilderV2(). + WithHeader(&message.CreateDatabaseMessageHeader{}). + WithBody(&message.CreateDatabaseMessageBody{}). + WithVChannel(sourceClusterID + "-rootcoord-dml_0"). + MustBuildMutable(). + WithTimeTick(tt). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)) + + replicateMsg := message.NewReplicateMessage( + sourceClusterID, + msg.IntoImmutableMessageProto(), + ) + replicateMsg.OverwriteReplicateVChannel( + clusterID + "-rootcoord-dml_0", + ) + return replicateMsg +} + +func newImmutableTxnMessage(clusterID string, timetick ...uint64) []message.ImmutableMessage { + tt := uint64(1) + if len(timetick) > 0 { + tt = timetick[0] + } + immutables := []message.ImmutableMessage{ + message.NewBeginTxnMessageBuilderV2(). + WithHeader(&message.BeginTxnMessageHeader{}). + WithBody(&message.BeginTxnMessageBody{}). + WithVChannel(clusterID + "-rootcoord-dml_0"). + MustBuildMutable(). + WithTxnContext(message.TxnContext{ + TxnID: message.TxnID(1), + Keepalive: message.TxnKeepaliveInfinite, + }). + WithTimeTick(tt). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)), + message.NewCreateDatabaseMessageBuilderV2(). + WithHeader(&message.CreateDatabaseMessageHeader{}). + WithBody(&message.CreateDatabaseMessageBody{}). + WithVChannel(clusterID + "-rootcoord-dml_0"). + MustBuildMutable(). + WithTxnContext(message.TxnContext{ + TxnID: message.TxnID(1), + Keepalive: message.TxnKeepaliveInfinite, + }). + WithTimeTick(tt). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)), + message.NewCommitTxnMessageBuilderV2(). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + WithVChannel(clusterID + "-rootcoord-dml_0"). + MustBuildMutable(). + WithTxnContext(message.TxnContext{ + TxnID: message.TxnID(1), + Keepalive: message.TxnKeepaliveInfinite, + }). + WithTimeTick(tt). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)), + } + return immutables +} + +func newReplicateTxnMessage(clusterID string, sourceClusterID string, timetick ...uint64) []message.MutableMessage { + immutables := newImmutableTxnMessage(sourceClusterID, timetick...) + replicateMsgs := []message.MutableMessage{} + for _, immutable := range immutables { + replicateMsg := message.NewReplicateMessage( + sourceClusterID, + immutable.IntoImmutableMessageProto(), + ) + replicateMsg.OverwriteReplicateVChannel( + clusterID + "-rootcoord-dml_0", + ) + replicateMsgs = append(replicateMsgs, replicateMsg) + // test the idempotency + replicateMsgs = append(replicateMsgs, replicateMsg) + } + return replicateMsgs +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicates/secondary_state.go b/internal/streamingnode/server/wal/interceptors/replicate/replicates/secondary_state.go new file mode 100644 index 0000000000..283725e68b --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicates/secondary_state.go @@ -0,0 +1,76 @@ +package replicates + +import ( + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" +) + +// newSecondaryState creates a new secondary state. +func newSecondaryState(sourceClusterID string, sourcePChannel string) *secondaryState { + return &secondaryState{ + checkpoint: &utility.ReplicateCheckpoint{ + ClusterID: sourceClusterID, + PChannel: sourcePChannel, + MessageID: nil, + TimeTick: 0, + }, + replicateTxnHelper: newReplicateTxnHelper(), + } +} + +// recoverSecondaryState recovers the secondary state from the recover param. +func recoverSecondaryState(param *ReplicateManagerRecoverParam) (*secondaryState, error) { + txnHelper := newReplicateTxnHelper() + sourceClusterID := param.InitialRecoverSnapshot.Checkpoint.ReplicateCheckpoint.ClusterID + // recover the txn helper. + uncommittedTxnBuilders := param.InitialRecoverSnapshot.TxnBuffer.GetUncommittedMessageBuilder() + for _, builder := range uncommittedTxnBuilders { + begin, body := builder.Messages() + replicateHeader := begin.ReplicateHeader() + // filter out the txn builders that are replicated from other cluster or not replicated. + if replicateHeader == nil || replicateHeader.ClusterID != sourceClusterID { + continue + } + // there will be only one uncommitted txn builder. + if err := txnHelper.StartBegin(begin.TxnContext(), begin.ReplicateHeader()); err != nil { + return nil, err + } + txnHelper.BeginDone(begin.TxnContext()) + for _, msg := range body { + if err := txnHelper.AddNewMessage(msg.TxnContext(), msg.ReplicateHeader()); err != nil { + return nil, err + } + txnHelper.AddNewMessageDone(msg.ReplicateHeader()) + } + } + return &secondaryState{ + checkpoint: param.InitialRecoverSnapshot.Checkpoint.ReplicateCheckpoint, + replicateTxnHelper: txnHelper, + }, nil +} + +// secondaryState describes the state of the secondary role. +type secondaryState struct { + checkpoint *utility.ReplicateCheckpoint + *replicateTxnHelper // if not nil, the txn replicating operation is in progress. +} + +// SourceClusterID returns the source cluster id of the secondary state. +func (s *secondaryState) SourceClusterID() string { + return s.checkpoint.ClusterID +} + +// GetCheckpoint returns the checkpoint of the secondary state. +func (s *secondaryState) GetCheckpoint() *utility.ReplicateCheckpoint { + return s.checkpoint +} + +// PushForwardCheckpoint pushes forward the checkpoint. +func (s *secondaryState) PushForwardCheckpoint(timetick uint64, lastConfirmedMessageID message.MessageID) error { + if timetick <= s.checkpoint.TimeTick { + return nil + } + s.checkpoint.TimeTick = timetick + s.checkpoint.MessageID = lastConfirmedMessageID + return nil +} diff --git a/internal/streamingnode/server/wal/interceptors/replicate/replicates/txn.go b/internal/streamingnode/server/wal/interceptors/replicate/replicates/txn.go new file mode 100644 index 0000000000..2ce4cc641b --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/replicate/replicates/txn.go @@ -0,0 +1,76 @@ +package replicates + +import ( + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +// newReplicateTxnHelper creates a new replicate txn helper. +func newReplicateTxnHelper() *replicateTxnHelper { + return &replicateTxnHelper{ + currentTxn: nil, + messageIDs: typeutil.NewSet[string](), + } +} + +// replicateTxnHelper is a helper for replicating a txn. +// It is used to handle and deduplicate the txn messages. +type replicateTxnHelper struct { + currentTxn *message.TxnContext + messageIDs typeutil.Set[string] +} + +// CurrentTxn returns the current txn context. +func (s *replicateTxnHelper) CurrentTxn() *message.TxnContext { + return s.currentTxn +} + +func (s *replicateTxnHelper) StartBegin(txn *message.TxnContext, replicateHeader *message.ReplicateHeader) error { + if s.currentTxn != nil { + if s.currentTxn.TxnID == txn.TxnID { + return status.NewIgnoreOperation("txn message is already in progress, txnID: %d", s.currentTxn.TxnID) + } + return status.NewReplicateViolation("begin txn violation, txnID: %d, incoming: %d", s.currentTxn.TxnID, txn.TxnID) + } + return nil +} + +func (s *replicateTxnHelper) BeginDone(txn *message.TxnContext) { + s.currentTxn = txn + s.messageIDs = typeutil.NewSet[string]() +} + +func (s *replicateTxnHelper) AddNewMessage(txn *message.TxnContext, replicateHeader *message.ReplicateHeader) error { + if s.currentTxn == nil { + return status.NewReplicateViolation("add new txn message without new txn, incoming: %d", s.currentTxn.TxnID, txn.TxnID) + } + if s.currentTxn.TxnID != txn.TxnID { + return status.NewReplicateViolation("add new txn message with different txn, current: %d, incoming: %d", s.currentTxn.TxnID, txn.TxnID) + } + if s.messageIDs.Contain(replicateHeader.MessageID.Marshal()) { + return status.NewIgnoreOperation("txn message is already in progress, txnID: %d, messageID: %d", s.currentTxn.TxnID, replicateHeader.MessageID) + } + return nil +} + +func (s *replicateTxnHelper) AddNewMessageDone(replicateHeader *message.ReplicateHeader) { + s.messageIDs.Insert(replicateHeader.MessageID.Marshal()) +} + +func (s *replicateTxnHelper) StartCommit(txn *message.TxnContext) error { + if s.currentTxn == nil { + return status.NewIgnoreOperation("commit txn without txn, maybe already committed, txnID: %d", txn.TxnID) + } + if s.currentTxn.TxnID != txn.TxnID { + return status.NewReplicateViolation("commit txn with different txn, current: %d, incoming: %d", s.currentTxn.TxnID, txn.TxnID) + } + s.currentTxn = nil + s.messageIDs = nil + return nil +} + +func (s *replicateTxnHelper) CommitDone(txn *message.TxnContext) { + s.currentTxn = nil + s.messageIDs = nil +} diff --git a/internal/streamingnode/server/wal/interceptors/shard/shards/segment_flush_worker.go b/internal/streamingnode/server/wal/interceptors/shard/shards/segment_flush_worker.go index f958bdd39f..93edcf22a1 100644 --- a/internal/streamingnode/server/wal/interceptors/shard/shards/segment_flush_worker.go +++ b/internal/streamingnode/server/wal/interceptors/shard/shards/segment_flush_worker.go @@ -76,7 +76,7 @@ func (w *segmentFlushWorker) do() { } nextInterval := backoff.NextBackOff() - w.Logger().Info("failed to allocate new growing segment, retrying", zap.Duration("nextInterval", nextInterval), zap.Error(err)) + w.Logger().Info("failed to flush new growing segment, retrying", zap.Duration("nextInterval", nextInterval), zap.Error(err)) select { case <-w.ctx.Done(): w.Logger().Info("flush segment canceled", zap.Error(w.ctx.Err())) diff --git a/internal/streamingnode/server/wal/interceptors/txn/session_test.go b/internal/streamingnode/server/wal/interceptors/txn/session_test.go index 01f95257bb..526259a6e0 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/session_test.go +++ b/internal/streamingnode/server/wal/interceptors/txn/session_test.go @@ -16,6 +16,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" ) @@ -211,6 +212,32 @@ func TestWithContext(t *testing.T) { assert.NotNil(t, session) } +func TestManagerFromReplcateMessage(t *testing.T) { + resource.InitForTest(t) + manager := NewTxnManager(types.PChannelInfo{Name: "test"}, nil) + immutableMsg := message.NewBeginTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.BeginTxnMessageHeader{ + KeepaliveMilliseconds: 10 * time.Millisecond.Milliseconds(), + }). + WithBody(&message.BeginTxnMessageBody{}). + MustBuildMutable(). + WithTimeTick(1). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + WithTxnContext(message.TxnContext{ + TxnID: 18, + Keepalive: 10 * time.Millisecond, + }). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)) + replicateMsg := message.NewReplicateMessage("test2", immutableMsg.IntoImmutableMessageProto()).WithTimeTick(2) + + session, err := manager.BeginNewTxn(context.Background(), message.MustAsMutableBeginTxnMessageV2(replicateMsg)) + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Equal(t, message.TxnID(18), session.TxnContext().TxnID) + assert.Equal(t, message.TxnKeepaliveInfinite, session.TxnContext().Keepalive) +} + func newBeginTxnMessage(timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 { return newBeginTxnMessageWithVChannel("v1", timetick, keepalive) } diff --git a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go index 7c48b8794c..4ca2648801 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go +++ b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go @@ -79,16 +79,8 @@ func (m *TxnManager) RecoverDone() <-chan struct{} { func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*TxnSession, error) { timetick := msg.TimeTick() vchannel := msg.VChannel() - keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond - if keepalive == 0 { - // If keepalive is 0, the txn set the keepalive with default keepalive. - keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse() - } - if keepalive < 1*time.Millisecond { - return nil, status.NewInvaildArgument("keepalive must be greater than 1ms") - } - id, err := resource.Resource().IDAllocator().Allocate(ctx) + txnCtx, err := m.buildTxnContext(ctx, msg) if err != nil { return nil, err } @@ -100,23 +92,49 @@ func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTx if m.closed != nil { return nil, status.NewTransactionExpired("manager closed") } - txnCtx := message.TxnContext{ - TxnID: message.TxnID(id), - Keepalive: keepalive, - } - session := newTxnSession(vchannel, txnCtx, timetick, m.metrics.BeginTxn()) + session := newTxnSession(vchannel, *txnCtx, timetick, m.metrics.BeginTxn()) m.sessions[session.TxnContext().TxnID] = session return session, nil } +// buildTxnContext builds the txn context from the message. +func (m *TxnManager) buildTxnContext(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*message.TxnContext, error) { + if msg.ReplicateHeader() != nil { + // reuse the txn context if replicated. + // If the message is replicated, it should never be expired, so we set the keepalive to infinite. + return &message.TxnContext{ + TxnID: msg.TxnContext().TxnID, + Keepalive: message.TxnKeepaliveInfinite, + }, nil + } + + keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond + if keepalive == 0 { + // If keepalive is 0, the txn set the keepalive with default keepalive. + keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse() + } + if keepalive < 1*time.Millisecond { + return nil, status.NewInvaildArgument("keepalive must be greater than 1ms") + } + id, err := resource.Resource().IDAllocator().Allocate(ctx) + if err != nil { + return nil, err + } + return &message.TxnContext{ + TxnID: message.TxnID(id), + Keepalive: keepalive, + }, nil +} + // FailTxnAtVChannel fails all transactions at the specified vchannel. +// If the vchannel is empty, it will fail all transactions. func (m *TxnManager) FailTxnAtVChannel(vchannel string) { // avoid the txn to be committed. m.mu.Lock() defer m.mu.Unlock() ids := make([]int64, 0, len(m.sessions)) for id, session := range m.sessions { - if session.VChannel() == vchannel { + if vchannel == "" || session.VChannel() == vchannel { session.Cleanup() delete(m.sessions, id) delete(m.recoveredSessions, id) diff --git a/internal/streamingnode/server/wal/recovery/recovery_persisted.go b/internal/streamingnode/server/wal/recovery/recovery_persisted.go index 1df0a14678..6b7a4ef342 100644 --- a/internal/streamingnode/server/wal/recovery/recovery_persisted.go +++ b/internal/streamingnode/server/wal/recovery/recovery_persisted.go @@ -158,7 +158,7 @@ func (r *recoveryStorageImpl) initializeRecoverInfo(ctx context.Context, channel checkpoint := &streamingpb.WALCheckpoint{ MessageId: untilMessage.LastConfirmedMessageID().IntoProto(), TimeTick: untilMessage.TimeTick(), - RecoveryMagic: RecoveryMagicStreamingInitialized, + RecoveryMagic: utility.RecoveryMagicStreamingInitialized, } if err := resource.Resource().StreamingNodeCatalog().SaveConsumeCheckpoint(ctx, channelInfo.Name, checkpoint); err != nil { return nil, errors.Wrap(err, "failed to save checkpoint to catalog") diff --git a/internal/streamingnode/server/wal/recovery/recovery_persisted_test.go b/internal/streamingnode/server/wal/recovery/recovery_persisted_test.go index 64289c943c..ad75278ee8 100644 --- a/internal/streamingnode/server/wal/recovery/recovery_persisted_test.go +++ b/internal/streamingnode/server/wal/recovery/recovery_persisted_test.go @@ -16,6 +16,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks/mock_metastore" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb" @@ -49,7 +50,7 @@ func TestInitRecoveryInfoFromMeta(t *testing.T) { &streamingpb.WALCheckpoint{ MessageId: rmq.NewRmqID(1).IntoProto(), TimeTick: 1, - RecoveryMagic: RecoveryMagicStreamingInitialized, + RecoveryMagic: utility.RecoveryMagicStreamingInitialized, }, nil) resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog)) channel := types.PChannelInfo{Name: "test_channel"} @@ -60,7 +61,7 @@ func TestInitRecoveryInfoFromMeta(t *testing.T) { err := rs.recoverRecoveryInfoFromMeta(context.Background(), channel, lastConfirmed.IntoImmutableMessage(rmq.NewRmqID(1))) assert.NoError(t, err) assert.NotNil(t, rs.checkpoint) - assert.Equal(t, RecoveryMagicStreamingInitialized, rs.checkpoint.Magic) + assert.Equal(t, utility.RecoveryMagicStreamingInitialized, rs.checkpoint.Magic) assert.True(t, rs.checkpoint.MessageID.EQ(rmq.NewRmqID(1))) } diff --git a/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go b/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go index 7b9403a194..0e66dea85d 100644 --- a/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go +++ b/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go @@ -11,12 +11,14 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/replicateutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil" ) @@ -69,6 +71,7 @@ func newRecoveryStorage(channel types.PChannelInfo) *recoveryStorageImpl { backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), cfg: cfg, mu: sync.Mutex{}, + currentClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), channel: channel, dirtyCounter: 0, persistNotifier: make(chan struct{}, 1), @@ -84,6 +87,7 @@ type recoveryStorageImpl struct { backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] cfg *config mu sync.Mutex + currentClusterID string channel types.PChannelInfo segments map[int64]*segmentRecoveryInfo vchannels map[string]*vchannelRecoveryInfo @@ -225,8 +229,7 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) { } r.handleMessage(msg) - r.checkpoint.TimeTick = msg.TimeTick() - r.checkpoint.MessageID = msg.LastConfirmedMessageID() + r.updateCheckpoint(msg) r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick) if !msg.IsPersisted() { @@ -239,6 +242,52 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) { } } +// updateCheckpoint updates the checkpoint of the recovery storage. +func (r *recoveryStorageImpl) updateCheckpoint(msg message.ImmutableMessage) { + if msg.MessageType() == message.MessageTypeAlterReplicateConfig { + cfg := message.MustAsImmutableAlterReplicateConfigMessageV2(msg) + r.checkpoint.ReplicateConfig = cfg.Header().ReplicateConfiguration + clusterRole := replicateutil.MustNewConfigHelper(r.currentClusterID, cfg.Header().ReplicateConfiguration).GetCurrentCluster() + switch clusterRole.Role() { + case replicateutil.RolePrimary: + r.checkpoint.ReplicateCheckpoint = nil + case replicateutil.RoleSecondary: + // Update the replicate checkpoint if the cluster role is secondary. + sourceClusterID := clusterRole.SourceCluster().GetClusterId() + sourcePChannel := clusterRole.MustGetSourceChannel(r.channel.Name) + if r.checkpoint.ReplicateCheckpoint == nil || r.checkpoint.ReplicateCheckpoint.ClusterID != sourceClusterID { + r.checkpoint.ReplicateCheckpoint = &utility.ReplicateCheckpoint{ + ClusterID: sourceClusterID, + PChannel: sourcePChannel, + MessageID: nil, + TimeTick: 0, + } + } + } + } + r.checkpoint.MessageID = msg.LastConfirmedMessageID() + r.checkpoint.TimeTick = msg.TimeTick() + + // update the replicate checkpoint. + replicateHeader := msg.ReplicateHeader() + if replicateHeader == nil { + return + } + if r.checkpoint.ReplicateCheckpoint == nil { + r.detectInconsistency(msg, "replicate checkpoint is nil when incoming replicate message") + return + } + if replicateHeader.ClusterID != r.checkpoint.ReplicateCheckpoint.ClusterID { + r.detectInconsistency(msg, + "replicate header cluster id mismatch", + zap.String("expected", r.checkpoint.ReplicateCheckpoint.ClusterID), + zap.String("actual", replicateHeader.ClusterID)) + return + } + r.checkpoint.ReplicateCheckpoint.MessageID = replicateHeader.LastConfirmedMessageID + r.checkpoint.ReplicateCheckpoint.TimeTick = replicateHeader.TimeTick +} + // The incoming message id is always sorted with timetick. func (r *recoveryStorageImpl) handleMessage(msg message.ImmutableMessage) { if msg.VChannel() != "" && msg.MessageType() != message.MessageTypeCreateCollection && diff --git a/internal/streamingnode/server/wal/recovery/replicate_checkpoint_test.go b/internal/streamingnode/server/wal/recovery/replicate_checkpoint_test.go new file mode 100644 index 0000000000..db6a21bfe0 --- /dev/null +++ b/internal/streamingnode/server/wal/recovery/replicate_checkpoint_test.go @@ -0,0 +1,105 @@ +package recovery + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest" +) + +func TestUpdateCheckpoint(t *testing.T) { + rs := &recoveryStorageImpl{ + currentClusterID: "test1", + channel: types.PChannelInfo{Name: "test1-rootcoord-dml_0"}, + checkpoint: &WALCheckpoint{}, + metrics: newRecoveryStorageMetrics(types.PChannelInfo{Name: "test1-rootcoord-dml_0"}), + } + + rs.updateCheckpoint(newAlterReplicateConfigMessage("test1", []string{"test2"}, 1, walimplstest.NewTestMessageID(1))) + assert.Nil(t, rs.checkpoint.ReplicateCheckpoint) + assert.Equal(t, rs.checkpoint.MessageID, walimplstest.NewTestMessageID(1)) + assert.Equal(t, rs.checkpoint.TimeTick, uint64(1)) + rs.updateCheckpoint(newAlterReplicateConfigMessage("test2", []string{"test1"}, 1, walimplstest.NewTestMessageID(1))) + assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint) + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test2") + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test2-rootcoord-dml_0") + assert.Nil(t, rs.checkpoint.ReplicateCheckpoint.MessageID) + assert.Zero(t, rs.checkpoint.ReplicateCheckpoint.TimeTick) + + replicateMsg := message.NewReplicateMessage("test3", message.NewCreateDatabaseMessageBuilderV2(). + WithHeader(&message.CreateDatabaseMessageHeader{}). + WithBody(&message.CreateDatabaseMessageBody{}). + WithVChannel("test3-rootcoord-dml_0"). + MustBuildMutable(). + WithTimeTick(3). + WithLastConfirmed(walimplstest.NewTestMessageID(10)). + IntoImmutableMessage(walimplstest.NewTestMessageID(20)).IntoImmutableMessageProto()) + replicateMsg.OverwriteReplicateVChannel("test1-rootcoord-dml_0") + immutableReplicateMsg := replicateMsg.WithTimeTick(4). + WithLastConfirmed(walimplstest.NewTestMessageID(11)). + IntoImmutableMessage(walimplstest.NewTestMessageID(22)) + rs.updateCheckpoint(immutableReplicateMsg) + + // update with wrong clusterID. + rs.updateCheckpoint(immutableReplicateMsg) + assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint) + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test2") + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test2-rootcoord-dml_0") + assert.Nil(t, rs.checkpoint.ReplicateCheckpoint.MessageID) + assert.Zero(t, rs.checkpoint.ReplicateCheckpoint.TimeTick) + + rs.updateCheckpoint(newAlterReplicateConfigMessage("test3", []string{"test2", "test1"}, 1, walimplstest.NewTestMessageID(1))) + assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint) + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test3") + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test3-rootcoord-dml_0") + assert.Nil(t, rs.checkpoint.ReplicateCheckpoint.MessageID) + assert.Zero(t, rs.checkpoint.ReplicateCheckpoint.TimeTick) + + // update with right clusterID. + rs.updateCheckpoint(immutableReplicateMsg) + assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint) + assert.Equal(t, rs.checkpoint.MessageID, walimplstest.NewTestMessageID(11)) + assert.Equal(t, rs.checkpoint.TimeTick, uint64(4)) + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test3") + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test3-rootcoord-dml_0") + assert.True(t, rs.checkpoint.ReplicateCheckpoint.MessageID.EQ(walimplstest.NewTestMessageID(10))) + assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.TimeTick, uint64(3)) + + rs.updateCheckpoint(newAlterReplicateConfigMessage("test1", []string{"test2"}, 1, walimplstest.NewTestMessageID(1))) + assert.Nil(t, rs.checkpoint.ReplicateCheckpoint) + rs.updateCheckpoint(immutableReplicateMsg) +} + +// newAlterReplicateConfigMessage creates a new alter replicate config message. +func newAlterReplicateConfigMessage(primaryClusterID string, secondaryClusterID []string, timetick uint64, messageID message.MessageID) message.ImmutableMessage { + return message.NewAlterReplicateConfigMessageBuilderV2(). + WithHeader(&message.AlterReplicateConfigMessageHeader{ + ReplicateConfiguration: newReplicateConfiguration(primaryClusterID, secondaryClusterID...), + }). + WithBody(&message.AlterReplicateConfigMessageBody{}). + WithVChannel("test1-rootcoord-dml_0"). + MustBuildMutable(). + WithTimeTick(timetick). + WithLastConfirmed(messageID). + IntoImmutableMessage(walimplstest.NewTestMessageID(10086)) +} + +// newReplicateConfiguration creates a valid replicate configuration for testing +func newReplicateConfiguration(primaryClusterID string, secondaryClusterID ...string) *commonpb.ReplicateConfiguration { + clusters := []*commonpb.MilvusCluster{ + {ClusterId: primaryClusterID, Pchannels: []string{primaryClusterID + "-rootcoord-dml_0", primaryClusterID + "-rootcoord-dml_1"}}, + } + crossClusterTopology := []*commonpb.CrossClusterTopology{} + for _, secondaryClusterID := range secondaryClusterID { + clusters = append(clusters, &commonpb.MilvusCluster{ClusterId: secondaryClusterID, Pchannels: []string{secondaryClusterID + "-rootcoord-dml_0", secondaryClusterID + "-rootcoord-dml_1"}}) + crossClusterTopology = append(crossClusterTopology, &commonpb.CrossClusterTopology{SourceClusterId: primaryClusterID, TargetClusterId: secondaryClusterID}) + } + return &commonpb.ReplicateConfiguration{ + Clusters: clusters, + CrossClusterTopology: crossClusterTopology, + } +} diff --git a/internal/streamingnode/server/wal/recovery/wal_truncator_test.go b/internal/streamingnode/server/wal/recovery/wal_truncator_test.go index 0cd81a5b44..e01155d6c8 100644 --- a/internal/streamingnode/server/wal/recovery/wal_truncator_test.go +++ b/internal/streamingnode/server/wal/recovery/wal_truncator_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/v2/mocks/streaming/mock_walimpls" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" @@ -23,7 +24,7 @@ func TestTruncator(t *testing.T) { truncator := newSamplingTruncator(&WALCheckpoint{ MessageID: rmq.NewRmqID(1), TimeTick: 1, - Magic: RecoveryMagicStreamingInitialized, + Magic: utility.RecoveryMagicStreamingInitialized, }, w, newRecoveryStorageMetrics(types.PChannelInfo{Name: "test", Term: 1})) for i := 0; i < 20; i++ { @@ -32,7 +33,7 @@ func TestTruncator(t *testing.T) { truncator.SampleCheckpoint(&WALCheckpoint{ MessageID: rmq.NewRmqID(int64(i)), TimeTick: tsoutil.ComposeTSByTime(time.Now(), 0), - Magic: RecoveryMagicStreamingInitialized, + Magic: utility.RecoveryMagicStreamingInitialized, }) } } diff --git a/internal/streamingnode/server/wal/utility/checkpoint.go b/internal/streamingnode/server/wal/utility/checkpoint.go index 39c0fce419..aa2b523ea9 100644 --- a/internal/streamingnode/server/wal/utility/checkpoint.go +++ b/internal/streamingnode/server/wal/utility/checkpoint.go @@ -13,30 +13,21 @@ const ( // NewWALCheckpointFromProto creates a new WALCheckpoint from a protobuf message. func NewWALCheckpointFromProto(cp *streamingpb.WALCheckpoint) *WALCheckpoint { - wcp := &WALCheckpoint{ - MessageID: message.MustUnmarshalMessageID(cp.MessageId), - TimeTick: cp.TimeTick, - Magic: cp.RecoveryMagic, - ReplicateConfig: cp.ReplicateConfig, + if cp == nil { + return nil } - if cp.ReplicateCheckpoint != nil { - var messageID message.MessageID - if cp.ReplicateCheckpoint.MessageId != nil { - messageID = message.MustUnmarshalMessageID(cp.ReplicateCheckpoint.MessageId) - } - wcp.ReplicateCheckpoint = &ReplicateCheckpoint{ - ClusterID: cp.ReplicateCheckpoint.ClusterId, - PChannel: cp.ReplicateCheckpoint.Pchannel, - MessageID: messageID, - TimeTick: cp.ReplicateCheckpoint.TimeTick, - } + return &WALCheckpoint{ + MessageID: message.MustUnmarshalMessageID(cp.MessageId), + TimeTick: cp.TimeTick, + Magic: cp.RecoveryMagic, + ReplicateConfig: cp.ReplicateConfig, + ReplicateCheckpoint: NewReplicateCheckpointFromProto(cp.ReplicateCheckpoint), } - return wcp } // WALCheckpoint represents a consume checkpoint in the Write-Ahead Log (WAL). type WALCheckpoint struct { - MessageID message.MessageID + MessageID message.MessageID // should always be not nil. TimeTick uint64 Magic int64 ReplicateCheckpoint *ReplicateCheckpoint @@ -45,15 +36,16 @@ type WALCheckpoint struct { // IntoProto converts the WALCheckpoint to a protobuf message. func (c *WALCheckpoint) IntoProto() *streamingpb.WALCheckpoint { - cp := &streamingpb.WALCheckpoint{ - MessageId: c.MessageID.IntoProto(), - TimeTick: c.TimeTick, - RecoveryMagic: c.Magic, + if c == nil { + return nil } - if c.ReplicateCheckpoint != nil { - cp.ReplicateCheckpoint = c.ReplicateCheckpoint.IntoProto() + return &streamingpb.WALCheckpoint{ + MessageId: message.MustMarshalMessageID(c.MessageID), + TimeTick: c.TimeTick, + RecoveryMagic: c.Magic, + ReplicateConfig: c.ReplicateConfig, + ReplicateCheckpoint: c.ReplicateCheckpoint.IntoProto(), } - return cp } // Clone creates a new WALCheckpoint with the same values as the original. @@ -62,16 +54,20 @@ func (c *WALCheckpoint) Clone() *WALCheckpoint { MessageID: c.MessageID, TimeTick: c.TimeTick, Magic: c.Magic, + ReplicateConfig: c.ReplicateConfig, ReplicateCheckpoint: c.ReplicateCheckpoint.Clone(), } } // NewReplicateCheckpointFromProto creates a new ReplicateCheckpoint from a protobuf message. func NewReplicateCheckpointFromProto(cp *commonpb.ReplicateCheckpoint) *ReplicateCheckpoint { + if cp == nil { + return nil + } return &ReplicateCheckpoint{ + MessageID: message.MustUnmarshalMessageID(cp.MessageId), ClusterID: cp.ClusterId, PChannel: cp.Pchannel, - MessageID: message.MustUnmarshalMessageID(cp.MessageId), TimeTick: cp.TimeTick, } } @@ -81,7 +77,7 @@ func NewReplicateCheckpointFromProto(cp *commonpb.ReplicateCheckpoint) *Replicat type ReplicateCheckpoint struct { ClusterID string // the cluster id of the source cluster. PChannel string // the pchannel of the source cluster. - MessageID message.MessageID // the last confirmed message id of the last replicated message. + MessageID message.MessageID // the last confirmed message id of the last replicated message, may be nil when initializing. TimeTick uint64 // the time tick of the last replicated message. } @@ -90,14 +86,10 @@ func (c *ReplicateCheckpoint) IntoProto() *commonpb.ReplicateCheckpoint { if c == nil { return nil } - var messageID *commonpb.MessageID - if c.MessageID != nil { - messageID = c.MessageID.IntoProto() - } return &commonpb.ReplicateCheckpoint{ ClusterId: c.ClusterID, Pchannel: c.PChannel, - MessageId: messageID, + MessageId: message.MustMarshalMessageID(c.MessageID), TimeTick: c.TimeTick, } } diff --git a/internal/streamingnode/server/wal/utility/checkpoint_test.go b/internal/streamingnode/server/wal/utility/checkpoint_test.go index b52c3c76bf..dffd72866a 100644 --- a/internal/streamingnode/server/wal/utility/checkpoint_test.go +++ b/internal/streamingnode/server/wal/utility/checkpoint_test.go @@ -11,6 +11,9 @@ import ( ) func TestNewWALCheckpointFromProto(t *testing.T) { + assert.Nil(t, NewWALCheckpointFromProto(nil)) + assert.Nil(t, NewWALCheckpointFromProto(nil).IntoProto()) + messageID := rmq.NewRmqID(1) timeTick := uint64(12345) recoveryMagic := int64(1) @@ -59,4 +62,25 @@ func TestNewWALCheckpointFromProto(t *testing.T) { assert.Equal(t, uint64(123456), newCheckpoint.ReplicateCheckpoint.TimeTick) assert.True(t, rmq.NewRmqID(2).EQ(newCheckpoint.ReplicateCheckpoint.MessageID)) assert.NotNil(t, newCheckpoint.ReplicateConfig) + + proto = newCheckpoint.IntoProto() + checkpoint2 = NewWALCheckpointFromProto(proto) + assert.True(t, messageID.EQ(checkpoint2.MessageID)) + assert.Equal(t, timeTick, checkpoint2.TimeTick) + assert.Equal(t, recoveryMagic, checkpoint2.Magic) + assert.Equal(t, "by-dev", checkpoint2.ReplicateCheckpoint.ClusterID) + assert.Equal(t, "p1", checkpoint2.ReplicateCheckpoint.PChannel) + assert.Equal(t, uint64(123456), checkpoint2.ReplicateCheckpoint.TimeTick) + assert.True(t, rmq.NewRmqID(2).EQ(checkpoint2.ReplicateCheckpoint.MessageID)) + assert.NotNil(t, checkpoint2.ReplicateConfig) + + checkpoint2 = newCheckpoint.Clone() + assert.True(t, messageID.EQ(checkpoint2.MessageID)) + assert.Equal(t, timeTick, checkpoint2.TimeTick) + assert.Equal(t, recoveryMagic, checkpoint2.Magic) + assert.Equal(t, "by-dev", checkpoint2.ReplicateCheckpoint.ClusterID) + assert.Equal(t, "p1", checkpoint2.ReplicateCheckpoint.PChannel) + assert.Equal(t, uint64(123456), checkpoint2.ReplicateCheckpoint.TimeTick) + assert.True(t, rmq.NewRmqID(2).EQ(checkpoint2.ReplicateCheckpoint.MessageID)) + assert.NotNil(t, checkpoint2.ReplicateConfig) } diff --git a/internal/streamingnode/server/wal/wal.go b/internal/streamingnode/server/wal/wal.go index 6153ca85a3..a5bb59e6c4 100644 --- a/internal/streamingnode/server/wal/wal.go +++ b/internal/streamingnode/server/wal/wal.go @@ -21,6 +21,13 @@ type WAL interface { // GetLatestMVCCTimestamp get the latest mvcc timestamp of the wal at vchannel. GetLatestMVCCTimestamp(ctx context.Context, vchannel string) (uint64, error) + // GetReplicateCheckpoint returns the replicate checkpoint of the wal. + // If the wal is not on replicating mode, it will return ReplicateViolationError. + // If the wal is on replicating mode, it will return the replicate checkpoint of the wal. + // If the wal is initialized into replica mode, not replicate any message, + // the message id of the replicate checkpoint will be 0. + GetReplicateCheckpoint() (*ReplicateCheckpoint, error) + // Append writes a record to the log. Append(ctx context.Context, msg message.MutableMessage) (*AppendResult, error) diff --git a/internal/streamingnode/server/walmanager/manager_impl.go b/internal/streamingnode/server/walmanager/manager_impl.go index 7d84bbf579..cae359ca8e 100644 --- a/internal/streamingnode/server/walmanager/manager_impl.go +++ b/internal/streamingnode/server/walmanager/manager_impl.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/lock" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/redo" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" @@ -28,6 +29,7 @@ func OpenManager() (Manager, error) { opener, err := registry.MustGetBuilder(walName, redo.NewInterceptorBuilder(), lock.NewInterceptorBuilder(), + replicate.NewInterceptorBuilder(), timetick.NewInterceptorBuilder(), shard.NewInterceptorBuilder(), ).Build() diff --git a/internal/util/streamingutil/status/streaming_error.go b/internal/util/streamingutil/status/streaming_error.go index 43997005e3..4fa407b6bb 100644 --- a/internal/util/streamingutil/status/streaming_error.go +++ b/internal/util/streamingutil/status/streaming_error.go @@ -56,10 +56,15 @@ func (e *StreamingError) IsSkippedOperation() bool { // Stop resuming retry and report to user. func (e *StreamingError) IsUnrecoverable() bool { return e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNRECOVERABLE || - e.Code == streamingpb.StreamingCode_STREAMING_CODE_REPLICATE_VIOLATION || + e.IsReplicateViolation() || e.IsTxnUnavilable() } +// IsReplicateViolation returns true if the error is caused by replicate violation. +func (e *StreamingError) IsReplicateViolation() bool { + return e.Code == streamingpb.StreamingCode_STREAMING_CODE_REPLICATE_VIOLATION +} + // IsTxnUnavilable returns true if the transaction is unavailable. func (e *StreamingError) IsTxnUnavilable() bool { return e.Code == streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED || diff --git a/pkg/metrics/cdc_metrics.go b/pkg/metrics/cdc_metrics.go index c34316f0c2..fa6da60d53 100644 --- a/pkg/metrics/cdc_metrics.go +++ b/pkg/metrics/cdc_metrics.go @@ -74,7 +74,7 @@ var CDCReplicateEndToEndLatency = prometheus.NewHistogramVec( Namespace: milvusNamespace, Subsystem: typeutil.CDCRole, Name: CDCMetricReplicateEndToEndLatency, - Help: "End-to-end latency from a single message being read from Source WAL to being written to Target WAL and receiving an ack", + Help: "End-to-end latency in milliseconds from a single message being read from Source WAL to being written to Target WAL and receiving an ack", Buckets: buckets, }, []string{ CDCLabelSourceChannelName, @@ -82,13 +82,12 @@ var CDCReplicateEndToEndLatency = prometheus.NewHistogramVec( }, ) -// TODO: sheep var CDCReplicateLag = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: milvusNamespace, Subsystem: typeutil.CDCRole, Name: CDCMetricReplicateLag, - Help: "Lag between the latest message in Source and the latest message in Target", + Help: "Lag in milliseconds between the latest synced Source message and the current time", }, []string{ CDCLabelSourceChannelName, CDCLabelTargetChannelName, diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index a9c7b34e38..684a084e49 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -78,7 +78,7 @@ type MutableMessage interface { // VChannel returns the virtual channel of current message. // Available only when the message's version greater than 0. - // Return "" if message is can be seen by all vchannels on the pchannel. + // Return "" or Pchannel if message is can be seen by all vchannels on the pchannel. VChannel() string // WithBarrierTimeTick sets the barrier time tick of current message. diff --git a/pkg/streaming/util/message/message_id.go b/pkg/streaming/util/message/message_id.go index 9497b2922d..02a69a7025 100644 --- a/pkg/streaming/util/message/message_id.go +++ b/pkg/streaming/util/message/message_id.go @@ -27,8 +27,19 @@ func RegisterMessageIDUnmsarshaler(walName WALName, unmarshaler MessageIDUnmarsh // MessageIDUnmarshaler is the unmarshaler for message id. type MessageIDUnmarshaler = func(b string) (MessageID, error) +// MustMarshalMessageID marshal the message id, panic if failed. +func MustMarshalMessageID(msgID MessageID) *commonpb.MessageID { + if msgID == nil { + return nil + } + return msgID.IntoProto() +} + // MustUnmarshalMessageID unmarshal the message id, panic if failed. func MustUnmarshalMessageID(msgID *commonpb.MessageID) MessageID { + if msgID == nil { + return nil + } id, err := UnmarshalMessageID(msgID) if err != nil { panic(fmt.Sprintf("unmarshal message id failed: %s, wal: %s, bytes: %s", err.Error(), msgID.WALName.String(), msgID.Id)) diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index a979548272..75886a3121 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -177,6 +177,13 @@ func (m *messageImpl) OverwriteReplicateVChannel(vchannel string, broadcastVChan panic("should not happen on broadcast header proto") } m.properties.Set(messageBroadcastHeader, bhVal) + + // overwrite the txn keepalive to infinite if it's a replicated message, + // because replicated message is already committed, so it should never be expired. + if txnCtx := m.TxnContext(); txnCtx != nil { + txnCtx.Keepalive = TxnKeepaliveInfinite + m.WithTxnContext(*txnCtx) + } } // OverwriteBroadcastHeader overwrites the broadcast header of the message.