enhance: support replicate message in wal. (#44456)

issue: #44123

- support replicate message  in wal of milvus.
- support CDC-replicate recovery from wal.
- fix some CDC replicator bugs

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-09-22 17:06:11 +08:00 committed by GitHub
parent edd250ffef
commit c171280f63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 2142 additions and 196 deletions

5
go.mod
View File

@ -86,6 +86,7 @@ require (
google.golang.org/protobuf v1.36.5 google.golang.org/protobuf v1.36.5
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
mosn.io/holmes v1.0.2 mosn.io/holmes v1.0.2
mosn.io/pkg v0.0.0-20211217101631-d914102d1baf
) )
require ( require (
@ -151,6 +152,8 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/dubbogo/getty v1.3.4 // indirect
github.com/dubbogo/gost v1.11.16 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/dvsekhvalnov/jose2go v1.6.0 // indirect github.com/dvsekhvalnov/jose2go v1.6.0 // indirect
github.com/ebitengine/purego v0.8.1 // indirect github.com/ebitengine/purego v0.8.1 // indirect
@ -194,6 +197,7 @@ require (
github.com/ianlancetaylor/cgosymbolizer v0.0.0-20221217025313-27d3c9f66b6a // indirect github.com/ianlancetaylor/cgosymbolizer v0.0.0-20221217025313-27d3c9f66b6a // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/k0kubun/pp v3.0.1+incompatible // indirect
github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
@ -298,7 +302,6 @@ require (
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20250321185631-1f6e0b77f77e // indirect k8s.io/utils v0.0.0-20250321185631-1f6e0b77f77e // indirect
mosn.io/api v0.0.0-20210204052134-5b9a826795fd // indirect mosn.io/api v0.0.0-20210204052134-5b9a826795fd // indirect
mosn.io/pkg v0.0.0-20211217101631-d914102d1baf // indirect
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect sigs.k8s.io/yaml v1.4.0 // indirect

5
go.sum
View File

@ -303,9 +303,11 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dubbogo/getty v1.3.4 h1:5TvH213pnSIKYzY7IK8TT/r6yr5uPTB/U6YNLT+GsU0=
github.com/dubbogo/getty v1.3.4/go.mod h1:36f+gH/ekaqcDWKbxNBQk9b9HXcGtaI6YHxp4YTntX8= github.com/dubbogo/getty v1.3.4/go.mod h1:36f+gH/ekaqcDWKbxNBQk9b9HXcGtaI6YHxp4YTntX8=
github.com/dubbogo/go-zookeeper v1.0.3/go.mod h1:fn6n2CAEer3novYgk9ULLwAjuV8/g4DdC2ENwRb6E+c= github.com/dubbogo/go-zookeeper v1.0.3/go.mod h1:fn6n2CAEer3novYgk9ULLwAjuV8/g4DdC2ENwRb6E+c=
github.com/dubbogo/gost v1.5.2/go.mod h1:pPTjVyoJan3aPxBPNUX0ADkXjPibLo+/Ib0/fADXSG8= github.com/dubbogo/gost v1.5.2/go.mod h1:pPTjVyoJan3aPxBPNUX0ADkXjPibLo+/Ib0/fADXSG8=
github.com/dubbogo/gost v1.11.16 h1:fvOw8aKQ0BuUYuD+MaXAYFvT7tg2l7WAS5SL5gZJpFs=
github.com/dubbogo/gost v1.11.16/go.mod h1:vIcP9rqz2KsXHPjsAwIUtfJIJjppQLQDcYaZTy/61jI= github.com/dubbogo/gost v1.11.16/go.mod h1:vIcP9rqz2KsXHPjsAwIUtfJIJjppQLQDcYaZTy/61jI=
github.com/dubbogo/jsonparser v1.0.1/go.mod h1:tYAtpctvSP/tWw4MeelsowSPgXQRVHHWbqL6ynps8jU= github.com/dubbogo/jsonparser v1.0.1/go.mod h1:tYAtpctvSP/tWw4MeelsowSPgXQRVHHWbqL6ynps8jU=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
@ -657,6 +659,7 @@ github.com/juju/cmd v0.0.0-20171107070456-e74f39857ca0/go.mod h1:yWJQHl73rdSX4DH
github.com/juju/collections v0.0.0-20200605021417-0d0ec82b7271/go.mod h1:5XgO71dV1JClcOJE+4dzdn4HrI5LiyKd7PlVG6eZYhY= github.com/juju/collections v0.0.0-20200605021417-0d0ec82b7271/go.mod h1:5XgO71dV1JClcOJE+4dzdn4HrI5LiyKd7PlVG6eZYhY=
github.com/juju/errors v0.0.0-20150916125642-1b5e39b83d18/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= github.com/juju/errors v0.0.0-20150916125642-1b5e39b83d18/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q=
github.com/juju/errors v0.0.0-20190930114154-d42613fe1ab9/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= github.com/juju/errors v0.0.0-20190930114154-d42613fe1ab9/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q=
github.com/juju/errors v0.0.0-20200330140219-3fe23663418f h1:MCOvExGLpaSIzLYB4iQXEHP4jYVU6vmzLNQPdMVrxnM=
github.com/juju/errors v0.0.0-20200330140219-3fe23663418f/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= github.com/juju/errors v0.0.0-20200330140219-3fe23663418f/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/juju/httpprof v0.0.0-20141217160036-14bf14c30767/go.mod h1:+MaLYz4PumRkkyHYeXJ2G5g5cIW0sli2bOfpmbaMV/g= github.com/juju/httpprof v0.0.0-20141217160036-14bf14c30767/go.mod h1:+MaLYz4PumRkkyHYeXJ2G5g5cIW0sli2bOfpmbaMV/g=
@ -681,7 +684,9 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM=
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k=
github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40=
github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg= github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg=
github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8= github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8=
github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE= github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE=

View File

@ -57,6 +57,10 @@ packages:
InterceptorWithReady: InterceptorWithReady:
InterceptorWithMetrics: InterceptorWithMetrics:
InterceptorBuilder: InterceptorBuilder:
github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates:
interfaces:
ReplicatesManager:
ReplicateAcker:
github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards: github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards:
interfaces: interfaces:
ShardManager: ShardManager:

View File

@ -64,20 +64,24 @@ func (c *controller) Start() {
func (c *controller) Stop() { func (c *controller) Stop() {
c.stopOnce.Do(func() { c.stopOnce.Do(func() {
log.Ctx(c.ctx).Info("CDC controller stopping...") log.Ctx(c.ctx).Info("CDC controller stopping...")
// TODO: sheep, gracefully stop the replicators
close(c.stopChan) close(c.stopChan)
c.wg.Wait() c.wg.Wait()
resource.Resource().ReplicateManagerClient().Close()
log.Ctx(c.ctx).Info("CDC controller stopped") log.Ctx(c.ctx).Info("CDC controller stopped")
}) })
} }
func (c *controller) run() { func (c *controller) run() {
replicatePChannels, err := resource.Resource().ReplicationCatalog().ListReplicatePChannels(c.ctx) targetReplicatePChannels, err := resource.Resource().ReplicationCatalog().ListReplicatePChannels(c.ctx)
if err != nil { if err != nil {
log.Ctx(c.ctx).Error("failed to get replicate pchannels", zap.Error(err)) log.Ctx(c.ctx).Error("failed to get replicate pchannels", zap.Error(err))
return return
} }
for _, replicatePChannel := range replicatePChannels { // create replicators for all replicate pchannels
for _, replicatePChannel := range targetReplicatePChannels {
resource.Resource().ReplicateManagerClient().CreateReplicator(replicatePChannel) resource.Resource().ReplicateManagerClient().CreateReplicator(replicatePChannel)
} }
// remove out of target replicators
resource.Resource().ReplicateManagerClient().RemoveOutOfTargetReplicators(targetReplicatePChannels)
} }

View File

@ -30,6 +30,7 @@ import (
func TestController_StartAndStop(t *testing.T) { func TestController_StartAndStop(t *testing.T) {
mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t) mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t)
mockReplicateManagerClient.EXPECT().Close().Return()
resource.InitForTest(t, resource.InitForTest(t,
resource.OptReplicateManagerClient(mockReplicateManagerClient), resource.OptReplicateManagerClient(mockReplicateManagerClient),
) )
@ -45,6 +46,7 @@ func TestController_StartAndStop(t *testing.T) {
func TestController_Run(t *testing.T) { func TestController_Run(t *testing.T) {
mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t) mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t)
mockReplicateManagerClient.EXPECT().Close().Return()
replicatePChannels := []*streamingpb.ReplicatePChannelMeta{ replicatePChannels := []*streamingpb.ReplicatePChannelMeta{
{ {
@ -55,6 +57,7 @@ func TestController_Run(t *testing.T) {
mockReplicationCatalog := mock_metastore.NewMockReplicationCatalog(t) mockReplicationCatalog := mock_metastore.NewMockReplicationCatalog(t)
mockReplicationCatalog.EXPECT().ListReplicatePChannels(mock.Anything).Return(replicatePChannels, nil) mockReplicationCatalog.EXPECT().ListReplicatePChannels(mock.Anything).Return(replicatePChannels, nil)
mockReplicateManagerClient.EXPECT().CreateReplicator(replicatePChannels[0]).Return() mockReplicateManagerClient.EXPECT().CreateReplicator(replicatePChannels[0]).Return()
mockReplicateManagerClient.EXPECT().RemoveOutOfTargetReplicators(replicatePChannels).Return()
resource.InitForTest(t, resource.InitForTest(t,
resource.OptReplicateManagerClient(mockReplicateManagerClient), resource.OptReplicateManagerClient(mockReplicateManagerClient),
resource.OptReplicationCatalog(mockReplicationCatalog), resource.OptReplicationCatalog(mockReplicationCatalog),
@ -68,6 +71,7 @@ func TestController_Run(t *testing.T) {
func TestController_RunError(t *testing.T) { func TestController_RunError(t *testing.T) {
mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t) mockReplicateManagerClient := replication.NewMockReplicateManagerClient(t)
mockReplicateManagerClient.EXPECT().Close().Return()
mockReplicationCatalog := mock_metastore.NewMockReplicationCatalog(t) mockReplicationCatalog := mock_metastore.NewMockReplicationCatalog(t)
mockReplicationCatalog.EXPECT().ListReplicatePChannels(mock.Anything).Return(nil, assert.AnError) mockReplicationCatalog.EXPECT().ListReplicatePChannels(mock.Anything).Return(nil, assert.AnError)

View File

@ -20,6 +20,38 @@ func (_m *MockReplicateManagerClient) EXPECT() *MockReplicateManagerClient_Expec
return &MockReplicateManagerClient_Expecter{mock: &_m.Mock} return &MockReplicateManagerClient_Expecter{mock: &_m.Mock}
} }
// Close provides a mock function with no fields
func (_m *MockReplicateManagerClient) Close() {
_m.Called()
}
// MockReplicateManagerClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockReplicateManagerClient_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockReplicateManagerClient_Expecter) Close() *MockReplicateManagerClient_Close_Call {
return &MockReplicateManagerClient_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockReplicateManagerClient_Close_Call) Run(run func()) *MockReplicateManagerClient_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockReplicateManagerClient_Close_Call) Return() *MockReplicateManagerClient_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockReplicateManagerClient_Close_Call) RunAndReturn(run func()) *MockReplicateManagerClient_Close_Call {
_c.Run(run)
return _c
}
// CreateReplicator provides a mock function with given fields: replicateInfo // CreateReplicator provides a mock function with given fields: replicateInfo
func (_m *MockReplicateManagerClient) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) { func (_m *MockReplicateManagerClient) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) {
_m.Called(replicateInfo) _m.Called(replicateInfo)
@ -53,6 +85,39 @@ func (_c *MockReplicateManagerClient_CreateReplicator_Call) RunAndReturn(run fun
return _c return _c
} }
// RemoveOutOfTargetReplicators provides a mock function with given fields: targetReplicatePChannels
func (_m *MockReplicateManagerClient) RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta) {
_m.Called(targetReplicatePChannels)
}
// MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveOutOfTargetReplicators'
type MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call struct {
*mock.Call
}
// RemoveOutOfTargetReplicators is a helper method to define mock.On call
// - targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta
func (_e *MockReplicateManagerClient_Expecter) RemoveOutOfTargetReplicators(targetReplicatePChannels interface{}) *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call {
return &MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call{Call: _e.mock.On("RemoveOutOfTargetReplicators", targetReplicatePChannels)}
}
func (_c *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call) Run(run func(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta)) *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*streamingpb.ReplicatePChannelMeta))
})
return _c
}
func (_c *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call) Return() *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call {
_c.Call.Return()
return _c
}
func (_c *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call) RunAndReturn(run func([]*streamingpb.ReplicatePChannelMeta)) *MockReplicateManagerClient_RemoveOutOfTargetReplicators_Call {
_c.Run(run)
return _c
}
// NewMockReplicateManagerClient creates a new instance of MockReplicateManagerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // NewMockReplicateManagerClient creates a new instance of MockReplicateManagerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value. // The first argument is typically a *testing.T value.
func NewMockReplicateManagerClient(t interface { func NewMockReplicateManagerClient(t interface {

View File

@ -22,4 +22,10 @@ import "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
type ReplicateManagerClient interface { type ReplicateManagerClient interface {
// CreateReplicator creates a new replicator for the replicate pchannel. // CreateReplicator creates a new replicator for the replicate pchannel.
CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta)
// RemoveOutOfTargetReplicators removes replicators that are not in the target replicate pchannels.
RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta)
// Close closes the replicate manager client.
Close()
} }

View File

@ -27,13 +27,13 @@ import (
"github.com/milvus-io/milvus/internal/cdc/replication/replicatestream" "github.com/milvus-io/milvus/internal/cdc/replication/replicatestream"
"github.com/milvus-io/milvus/internal/cdc/resource" "github.com/milvus-io/milvus/internal/cdc/resource"
"github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/options" "github.com/milvus-io/milvus/pkg/v2/streaming/util/options"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -108,23 +108,15 @@ func (r *channelReplicator) replicateLoop() error {
zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()),
zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()),
) )
startFrom, err := r.getReplicateStartMessageID() cp, err := r.getReplicateCheckpoint()
if err != nil { if err != nil {
return err return err
} }
ch := make(adaptor.ChanMessageHandler, scannerHandlerChanSize) ch := make(adaptor.ChanMessageHandler, scannerHandlerChanSize)
var deliverPolicy options.DeliverPolicy
if startFrom == nil {
// No checkpoint found, seek from the earliest position
deliverPolicy = options.DeliverPolicyAll()
} else {
// Seek from the checkpoint
deliverPolicy = options.DeliverPolicyStartFrom(startFrom)
}
scanner := streaming.WAL().Read(r.ctx, streaming.ReadOption{ scanner := streaming.WAL().Read(r.ctx, streaming.ReadOption{
PChannel: r.replicateInfo.GetSourceChannelName(), PChannel: r.replicateInfo.GetSourceChannelName(),
DeliverPolicy: deliverPolicy, DeliverPolicy: options.DeliverPolicyStartFrom(cp.MessageID),
DeliverFilters: []options.DeliverFilter{}, DeliverFilters: []options.DeliverFilter{options.DeliverFilterTimeTickGT(cp.TimeTick)},
MessageHandler: ch, MessageHandler: ch,
}) })
defer scanner.Close() defer scanner.Close()
@ -132,7 +124,7 @@ func (r *channelReplicator) replicateLoop() error {
rsc := r.createRscFunc(r.ctx, r.replicateInfo) rsc := r.createRscFunc(r.ctx, r.replicateInfo)
defer rsc.Close() defer rsc.Close()
logger.Info("start replicate channel loop", zap.Any("startFrom", startFrom)) logger.Info("start replicate channel loop", zap.Any("startFrom", cp))
for { for {
select { select {
@ -142,7 +134,9 @@ func (r *channelReplicator) replicateLoop() error {
case msg := <-ch: case msg := <-ch:
// TODO: Should be done at streamingnode. // TODO: Should be done at streamingnode.
if msg.MessageType().IsSelfControlled() { if msg.MessageType().IsSelfControlled() {
if msg.MessageType() != message.MessageTypeTimeTick {
logger.Debug("skip self-controlled message", log.FieldMessage(msg)) logger.Debug("skip self-controlled message", log.FieldMessage(msg))
}
continue continue
} }
err := rsc.Replicate(msg) err := rsc.Replicate(msg)
@ -150,18 +144,11 @@ func (r *channelReplicator) replicateLoop() error {
panic(fmt.Sprintf("replicate message failed due to unrecoverable error: %v", err)) panic(fmt.Sprintf("replicate message failed due to unrecoverable error: %v", err))
} }
logger.Debug("replicate message success", log.FieldMessage(msg)) logger.Debug("replicate message success", log.FieldMessage(msg))
if msg.MessageType() == message.MessageTypeAlterReplicateConfig {
roleChanged := r.handlePutReplicateConfigMessage(msg)
if roleChanged {
// Role changed, return and stop replicate.
return nil
}
}
} }
} }
} }
func (r *channelReplicator) getReplicateStartMessageID() (message.MessageID, error) { func (r *channelReplicator) getReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) {
logger := log.With( logger := log.With(
zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()),
zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()),
@ -189,41 +176,20 @@ func (r *channelReplicator) getReplicateStartMessageID() (message.MessageID, err
} }
} }
if checkpoint == nil || checkpoint.MessageId == nil { if checkpoint == nil || checkpoint.MessageId == nil {
logger.Info("channel not found in replicate info, will start from the beginning") initializedCheckpoint := utility.NewReplicateCheckpointFromProto(r.replicateInfo.InitializedCheckpoint)
return nil, nil logger.Info("channel not found in replicate info, will start from the beginning",
zap.Stringer("messageID", initializedCheckpoint.MessageID),
zap.Uint64("timeTick", initializedCheckpoint.TimeTick),
)
return initializedCheckpoint, nil
} }
startFrom := message.MustUnmarshalMessageID(checkpoint.GetMessageId()) cp := utility.NewReplicateCheckpointFromProto(checkpoint)
logger.Info("replicate messages from position", logger.Info("replicate messages from position",
zap.Any("checkpoint", checkpoint), zap.Stringer("messageID", cp.MessageID),
zap.Any("startFromMessageID", startFrom), zap.Uint64("timeTick", cp.TimeTick),
) )
return startFrom, nil return cp, nil
}
func (r *channelReplicator) handlePutReplicateConfigMessage(msg message.ImmutableMessage) (roleChanged bool) {
logger := log.With(
zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()),
zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()),
)
logger.Info("handle PutReplicateConfigMessage", log.FieldMessage(msg))
prcMsg := message.MustAsImmutableAlterReplicateConfigMessageV2(msg)
replicateConfig := prcMsg.Header().ReplicateConfiguration
currentClusterID := paramtable.Get().CommonCfg.ClusterPrefix.GetValue()
currentCluster := replicateutil.MustNewConfigHelper(currentClusterID, replicateConfig).GetCurrentCluster()
if currentCluster.Role() == replicateutil.RolePrimary {
logger.Info("primary cluster, skip handle PutReplicateConfigMessage")
return false
}
// Current cluster role changed, not primary cluster,
// we need to remove the replicate pchannel.
err := resource.Resource().ReplicationCatalog().RemoveReplicatePChannel(r.ctx,
r.replicateInfo.GetSourceChannelName(), r.replicateInfo.GetTargetChannelName())
if err != nil {
panic(fmt.Sprintf("failed to remove replicate pchannel: %v", err))
}
logger.Info("handle PutReplicateConfigMessage done, replicate pchannel removed")
return true
} }
func (r *channelReplicator) StopReplicate() { func (r *channelReplicator) StopReplicate() {

View File

@ -18,8 +18,10 @@ package replicatemanager
import ( import (
"context" "context"
"fmt"
"strings" "strings"
"github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
@ -33,15 +35,21 @@ type replicateManager struct {
// replicators is a map of replicate pchannel name to ChannelReplicator. // replicators is a map of replicate pchannel name to ChannelReplicator.
replicators map[string]Replicator replicators map[string]Replicator
replicatorPChannels map[string]*streamingpb.ReplicatePChannelMeta
} }
func NewReplicateManager() *replicateManager { func NewReplicateManager() *replicateManager {
return &replicateManager{ return &replicateManager{
ctx: context.Background(), ctx: context.Background(),
replicators: make(map[string]Replicator), replicators: make(map[string]Replicator),
replicatorPChannels: make(map[string]*streamingpb.ReplicatePChannelMeta),
} }
} }
func bindReplicatorKey(replicateInfo *streamingpb.ReplicatePChannelMeta) string {
return fmt.Sprintf("%s_%s", replicateInfo.GetSourceChannelName(), replicateInfo.GetTargetChannelName())
}
func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) { func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) {
logger := log.With( logger := log.With(
zap.String("sourceChannel", replicateInfo.GetSourceChannelName()), zap.String("sourceChannel", replicateInfo.GetSourceChannelName()),
@ -52,13 +60,36 @@ func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.Replicate
// current cluster is not source cluster, skip create replicator // current cluster is not source cluster, skip create replicator
return return
} }
_, ok := r.replicators[replicateInfo.GetSourceChannelName()] replicatorKey := bindReplicatorKey(replicateInfo)
_, ok := r.replicators[replicatorKey]
if ok { if ok {
logger.Debug("replicator already exists, skip create replicator") logger.Debug("replicator already exists, skip create replicator")
return return
} }
replicator := NewChannelReplicator(replicateInfo) replicator := NewChannelReplicator(replicateInfo)
replicator.StartReplicate() replicator.StartReplicate()
r.replicators[replicateInfo.GetSourceChannelName()] = replicator r.replicators[replicatorKey] = replicator
r.replicatorPChannels[replicatorKey] = replicateInfo
logger.Info("created replicator for replicate pchannel") logger.Info("created replicator for replicate pchannel")
} }
func (r *replicateManager) RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta) {
targets := lo.KeyBy(targetReplicatePChannels, bindReplicatorKey)
for replicatorKey, replicator := range r.replicators {
if pchannelMeta, ok := targets[replicatorKey]; !ok {
replicator.StopReplicate()
delete(r.replicators, replicatorKey)
delete(r.replicatorPChannels, replicatorKey)
log.Info("removed replicator due to out of target",
zap.String("sourceChannel", pchannelMeta.GetSourceChannelName()),
zap.String("targetChannel", pchannelMeta.GetTargetChannelName()),
)
}
}
}
func (r *replicateManager) Close() {
for _, replicator := range r.replicators {
replicator.StopReplicate()
}
}

View File

@ -60,7 +60,7 @@ func TestReplicateManager_CreateReplicator(t *testing.T) {
// Verify replicator was created // Verify replicator was created
assert.Equal(t, 1, len(manager.replicators)) assert.Equal(t, 1, len(manager.replicators))
replicator, exists := manager.replicators["test-source-channel-1"] replicator, exists := manager.replicators["test-source-channel-1_test-target-channel-1"]
assert.True(t, exists) assert.True(t, exists)
assert.NotNil(t, replicator) assert.NotNil(t, replicator)
@ -77,12 +77,12 @@ func TestReplicateManager_CreateReplicator(t *testing.T) {
// Verify second replicator was created // Verify second replicator was created
assert.Equal(t, 2, len(manager.replicators)) assert.Equal(t, 2, len(manager.replicators))
replicator2, exists := manager.replicators["test-source-channel-2"] replicator2, exists := manager.replicators["test-source-channel-2_test-target-channel-2"]
assert.True(t, exists) assert.True(t, exists)
assert.NotNil(t, replicator2) assert.NotNil(t, replicator2)
// Verify first replicator still exists // Verify first replicator still exists
replicator1, exists := manager.replicators["test-source-channel-1"] replicator1, exists := manager.replicators["test-source-channel-1_test-target-channel-1"]
assert.True(t, exists) assert.True(t, exists)
assert.NotNil(t, replicator1) assert.NotNil(t, replicator1)
} }

View File

@ -17,10 +17,13 @@
package replicatestream package replicatestream
import ( import (
"time"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord" "github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -83,6 +86,14 @@ func (m *replicateMetrics) OnConfirmed(msg message.ImmutableMessage) {
m.replicateInfo.GetSourceChannelName(), m.replicateInfo.GetSourceChannelName(),
m.replicateInfo.GetTargetChannelName(), m.replicateInfo.GetTargetChannelName(),
).Observe(float64(replicateDuration.Milliseconds())) ).Observe(float64(replicateDuration.Milliseconds()))
now := time.Now()
confirmedTime := tsoutil.PhysicalTime(msg.TimeTick())
lag := now.Sub(confirmedTime)
metrics.CDCReplicateLag.WithLabelValues(
m.replicateInfo.GetSourceChannelName(),
m.replicateInfo.GetTargetChannelName(),
).Set(float64(lag.Milliseconds()))
} }
func (m *replicateMetrics) OnConnect() { func (m *replicateMetrics) OnConnect() {
@ -93,29 +104,29 @@ func (m *replicateMetrics) OnConnect() {
} }
func (m *replicateMetrics) OnDisconnect() { func (m *replicateMetrics) OnDisconnect() {
clusterID := m.replicateInfo.GetTargetCluster().GetClusterId() targetClusterID := m.replicateInfo.GetTargetCluster().GetClusterId()
metrics.CDCStreamRPCConnections.WithLabelValues( metrics.CDCStreamRPCConnections.WithLabelValues(
clusterID, targetClusterID,
metrics.CDCStatusConnected, metrics.CDCStatusConnected,
).Dec() ).Dec()
metrics.CDCStreamRPCConnections.WithLabelValues( metrics.CDCStreamRPCConnections.WithLabelValues(
clusterID, targetClusterID,
metrics.CDCStatusDisconnected, metrics.CDCStatusDisconnected,
).Inc() ).Inc()
} }
func (m *replicateMetrics) OnReconnect() { func (m *replicateMetrics) OnReconnect() {
clusterID := m.replicateInfo.GetTargetCluster().GetClusterId() targetClusterID := m.replicateInfo.GetTargetCluster().GetClusterId()
metrics.CDCStreamRPCConnections.WithLabelValues( metrics.CDCStreamRPCConnections.WithLabelValues(
clusterID, targetClusterID,
metrics.CDCStatusDisconnected, metrics.CDCStatusDisconnected,
).Dec() ).Dec()
metrics.CDCStreamRPCConnections.WithLabelValues( metrics.CDCStreamRPCConnections.WithLabelValues(
clusterID, targetClusterID,
metrics.CDCStatusConnected, metrics.CDCStatusConnected,
).Inc() ).Inc()
metrics.CDCStreamRPCReconnectTimes.WithLabelValues( metrics.CDCStreamRPCReconnectTimes.WithLabelValues(
clusterID, targetClusterID,
).Inc() ).Inc()
} }

View File

@ -18,6 +18,7 @@ package replicatestream
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@ -32,6 +33,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
) )
const pendingMessageQueueLength = 128 const pendingMessageQueueLength = 128
@ -86,13 +88,14 @@ func (r *replicateStreamClient) startInternal() {
backoff.MaxElapsedTime = 0 backoff.MaxElapsedTime = 0
backoff.Reset() backoff.Reset()
disconnect := func(stopCh chan struct{}, err error) { disconnect := func(stopCh chan struct{}, err error) (reconnect bool) {
r.metrics.OnDisconnect() r.metrics.OnDisconnect()
close(stopCh) close(stopCh)
r.client.CloseSend() r.client.CloseSend()
r.wg.Wait() r.wg.Wait()
time.Sleep(backoff.NextBackOff()) time.Sleep(backoff.NextBackOff())
log.Warn("restart replicate stream client", zap.Error(err)) log.Warn("restart replicate stream client", zap.Error(err))
return err != nil
} }
for { for {
@ -131,9 +134,15 @@ func (r *replicateStreamClient) startInternal() {
r.wg.Wait() r.wg.Wait()
return return
case err := <-sendErrCh: case err := <-sendErrCh:
disconnect(stopCh, err) reconnect := disconnect(stopCh, err)
if !reconnect {
return
}
case err := <-recvErrCh: case err := <-recvErrCh:
disconnect(stopCh, err) reconnect := disconnect(stopCh, err)
if !reconnect {
return
}
} }
} }
} }
@ -280,6 +289,13 @@ func (r *replicateStreamClient) recvLoop(stopCh <-chan struct{}) error {
if lastConfirmedMessageInfo != nil { if lastConfirmedMessageInfo != nil {
messages := r.pendingMessages.CleanupConfirmedMessages(lastConfirmedMessageInfo.GetConfirmedTimeTick()) messages := r.pendingMessages.CleanupConfirmedMessages(lastConfirmedMessageInfo.GetConfirmedTimeTick())
for _, msg := range messages { for _, msg := range messages {
if msg.MessageType() == message.MessageTypeAlterReplicateConfig {
roleChanged := r.handleAlterReplicateConfigMessage(msg)
if roleChanged {
// Role changed, return and stop replicate.
return nil
}
}
r.metrics.OnConfirmed(msg) r.metrics.OnConfirmed(msg)
} }
} }
@ -287,6 +303,32 @@ func (r *replicateStreamClient) recvLoop(stopCh <-chan struct{}) error {
} }
} }
func (r *replicateStreamClient) handleAlterReplicateConfigMessage(msg message.ImmutableMessage) (roleChanged bool) {
logger := log.With(
zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()),
zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()),
)
logger.Info("handle AlterReplicateConfigMessage", log.FieldMessage(msg))
prcMsg := message.MustAsImmutableAlterReplicateConfigMessageV2(msg)
replicateConfig := prcMsg.Header().ReplicateConfiguration
currentClusterID := paramtable.Get().CommonCfg.ClusterPrefix.GetValue()
currentCluster := replicateutil.MustNewConfigHelper(currentClusterID, replicateConfig).GetCurrentCluster()
_, err := currentCluster.GetTargetChannel(r.replicateInfo.GetSourceChannelName(),
r.replicateInfo.GetTargetCluster().GetClusterId())
if err != nil {
// Cannot find the target channel, it means that the `current->target` topology edge is removed,
// so we need to remove the replicate pchannel and stop replicate.
err := resource.Resource().ReplicationCatalog().RemoveReplicatePChannel(r.ctx, r.replicateInfo)
if err != nil {
panic(fmt.Sprintf("failed to remove replicate pchannel: %v", err))
}
logger.Info("handle AlterReplicateConfigMessage done, replicate pchannel removed")
return true
}
logger.Info("target channel found, skip handle AlterReplicateConfigMessage")
return false
}
func (r *replicateStreamClient) Close() { func (r *replicateStreamClient) Close() {
r.cancel() r.cancel()
r.wg.Wait() r.wg.Wait()

View File

@ -119,6 +119,16 @@ func (s replicateService) overwriteReplicateMessage(ctx context.Context, msg mes
return nil, err return nil, err
} }
} }
if funcutil.IsControlChannel(msg.VChannel()) {
assignments, err := s.streamingCoordClient.Assignment().GetLatestAssignments(ctx)
if err != nil {
return nil, err
}
if !strings.HasPrefix(msg.VChannel(), assignments.PChannelOfCChannel()) {
return nil, status.NewReplicateViolation("invalid control channel %s, expected pchannel %s", msg.VChannel(), assignments.PChannelOfCChannel())
}
}
return msg, nil return msg, nil
} }

View File

@ -212,7 +212,7 @@ type QueryCoordCatalog interface {
type ReplicationCatalog interface { type ReplicationCatalog interface {
// RemoveReplicatePChannel removes the replicate pchannel from metastore. // RemoveReplicatePChannel removes the replicate pchannel from metastore.
// Remove the task of CDC replication task of current cluster, should be called when a CDC replication task is finished. // Remove the task of CDC replication task of current cluster, should be called when a CDC replication task is finished.
RemoveReplicatePChannel(ctx context.Context, sourceChannelName, targetChannelName string) error RemoveReplicatePChannel(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta) error
// ListReplicatePChannels lists all replicate pchannels from metastore. // ListReplicatePChannels lists all replicate pchannels from metastore.
// every ReplicatePChannelMeta is a task of CDC replication task of current cluster which is a source cluster in replication topology. // every ReplicatePChannelMeta is a task of CDC replication task of current cluster which is a source cluster in replication topology.

View File

@ -216,8 +216,8 @@ func (c *catalog) GetReplicateConfiguration(ctx context.Context) (*streamingpb.R
return config, nil return config, nil
} }
func (c *catalog) RemoveReplicatePChannel(ctx context.Context, targetClusterID, sourceChannelName string) error { func (c *catalog) RemoveReplicatePChannel(ctx context.Context, task *streamingpb.ReplicatePChannelMeta) error {
key := buildReplicatePChannelPath(targetClusterID, sourceChannelName) key := buildReplicatePChannelPath(task.GetTargetCluster().GetClusterId(), task.GetSourceChannelName())
return c.metaKV.Remove(ctx, key) return c.metaKV.Remove(ctx, key)
} }

View File

@ -242,7 +242,11 @@ func TestCatalog_ReplicationCatalog(t *testing.T) {
assert.Equal(t, infos[1].GetTargetChannelName(), "target-channel-2") assert.Equal(t, infos[1].GetTargetChannelName(), "target-channel-2")
assert.Equal(t, infos[1].GetTargetCluster().GetClusterId(), "target-cluster") assert.Equal(t, infos[1].GetTargetCluster().GetClusterId(), "target-cluster")
err = catalog.RemoveReplicatePChannel(context.Background(), "target-cluster", "source-channel-1") err = catalog.RemoveReplicatePChannel(context.Background(), &streamingpb.ReplicatePChannelMeta{
SourceChannelName: "source-channel-1",
TargetChannelName: "target-channel-1",
TargetCluster: &commonpb.MilvusCluster{ClusterId: "target-cluster"},
})
assert.NoError(t, err) assert.NoError(t, err)
infos, err = catalog.ListReplicatePChannels(context.Background()) infos, err = catalog.ListReplicatePChannels(context.Background())

View File

@ -81,17 +81,17 @@ func (_c *MockReplicationCatalog_ListReplicatePChannels_Call) RunAndReturn(run f
return _c return _c
} }
// RemoveReplicatePChannel provides a mock function with given fields: ctx, sourceChannelName, targetChannelName // RemoveReplicatePChannel provides a mock function with given fields: ctx, meta
func (_m *MockReplicationCatalog) RemoveReplicatePChannel(ctx context.Context, sourceChannelName string, targetChannelName string) error { func (_m *MockReplicationCatalog) RemoveReplicatePChannel(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta) error {
ret := _m.Called(ctx, sourceChannelName, targetChannelName) ret := _m.Called(ctx, meta)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for RemoveReplicatePChannel") panic("no return value specified for RemoveReplicatePChannel")
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.ReplicatePChannelMeta) error); ok {
r0 = rf(ctx, sourceChannelName, targetChannelName) r0 = rf(ctx, meta)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -106,15 +106,14 @@ type MockReplicationCatalog_RemoveReplicatePChannel_Call struct {
// RemoveReplicatePChannel is a helper method to define mock.On call // RemoveReplicatePChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - sourceChannelName string // - meta *streamingpb.ReplicatePChannelMeta
// - targetChannelName string func (_e *MockReplicationCatalog_Expecter) RemoveReplicatePChannel(ctx interface{}, meta interface{}) *MockReplicationCatalog_RemoveReplicatePChannel_Call {
func (_e *MockReplicationCatalog_Expecter) RemoveReplicatePChannel(ctx interface{}, sourceChannelName interface{}, targetChannelName interface{}) *MockReplicationCatalog_RemoveReplicatePChannel_Call { return &MockReplicationCatalog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, meta)}
return &MockReplicationCatalog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, sourceChannelName, targetChannelName)}
} }
func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, sourceChannelName string, targetChannelName string)) *MockReplicationCatalog_RemoveReplicatePChannel_Call { func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta)) *MockReplicationCatalog_RemoveReplicatePChannel_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string)) run(args[0].(context.Context), args[1].(*streamingpb.ReplicatePChannelMeta))
}) })
return _c return _c
} }
@ -124,7 +123,7 @@ func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) Return(_a0 error)
return _c return _c
} }
func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, string, string) error) *MockReplicationCatalog_RemoveReplicatePChannel_Call { func (_c *MockReplicationCatalog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, *streamingpb.ReplicatePChannelMeta) error) *MockReplicationCatalog_RemoveReplicatePChannel_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -371,17 +371,17 @@ func (_c *MockStreamingCoordCataLog_ListReplicatePChannels_Call) RunAndReturn(ru
return _c return _c
} }
// RemoveReplicatePChannel provides a mock function with given fields: ctx, sourceChannelName, targetChannelName // RemoveReplicatePChannel provides a mock function with given fields: ctx, meta
func (_m *MockStreamingCoordCataLog) RemoveReplicatePChannel(ctx context.Context, sourceChannelName string, targetChannelName string) error { func (_m *MockStreamingCoordCataLog) RemoveReplicatePChannel(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta) error {
ret := _m.Called(ctx, sourceChannelName, targetChannelName) ret := _m.Called(ctx, meta)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for RemoveReplicatePChannel") panic("no return value specified for RemoveReplicatePChannel")
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.ReplicatePChannelMeta) error); ok {
r0 = rf(ctx, sourceChannelName, targetChannelName) r0 = rf(ctx, meta)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -396,15 +396,14 @@ type MockStreamingCoordCataLog_RemoveReplicatePChannel_Call struct {
// RemoveReplicatePChannel is a helper method to define mock.On call // RemoveReplicatePChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - sourceChannelName string // - meta *streamingpb.ReplicatePChannelMeta
// - targetChannelName string func (_e *MockStreamingCoordCataLog_Expecter) RemoveReplicatePChannel(ctx interface{}, meta interface{}) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call {
func (_e *MockStreamingCoordCataLog_Expecter) RemoveReplicatePChannel(ctx interface{}, sourceChannelName interface{}, targetChannelName interface{}) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { return &MockStreamingCoordCataLog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, meta)}
return &MockStreamingCoordCataLog_RemoveReplicatePChannel_Call{Call: _e.mock.On("RemoveReplicatePChannel", ctx, sourceChannelName, targetChannelName)}
} }
func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, sourceChannelName string, targetChannelName string)) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) Run(run func(ctx context.Context, meta *streamingpb.ReplicatePChannelMeta)) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string)) run(args[0].(context.Context), args[1].(*streamingpb.ReplicatePChannelMeta))
}) })
return _c return _c
} }
@ -414,7 +413,7 @@ func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) Return(_a0 err
return _c return _c
} }
func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, string, string) error) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call { func (_c *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call) RunAndReturn(run func(context.Context, *streamingpb.ReplicatePChannelMeta) error) *MockStreamingCoordCataLog_RemoveReplicatePChannel_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -301,6 +301,63 @@ func (_c *MockWAL_GetLatestMVCCTimestamp_Call) RunAndReturn(run func(context.Con
return _c return _c
} }
// GetReplicateCheckpoint provides a mock function with no fields
func (_m *MockWAL) GetReplicateCheckpoint() (*wal.ReplicateCheckpoint, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetReplicateCheckpoint")
}
var r0 *wal.ReplicateCheckpoint
var r1 error
if rf, ok := ret.Get(0).(func() (*wal.ReplicateCheckpoint, error)); ok {
return rf()
}
if rf, ok := ret.Get(0).(func() *wal.ReplicateCheckpoint); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*wal.ReplicateCheckpoint)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockWAL_GetReplicateCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicateCheckpoint'
type MockWAL_GetReplicateCheckpoint_Call struct {
*mock.Call
}
// GetReplicateCheckpoint is a helper method to define mock.On call
func (_e *MockWAL_Expecter) GetReplicateCheckpoint() *MockWAL_GetReplicateCheckpoint_Call {
return &MockWAL_GetReplicateCheckpoint_Call{Call: _e.mock.On("GetReplicateCheckpoint")}
}
func (_c *MockWAL_GetReplicateCheckpoint_Call) Run(run func()) *MockWAL_GetReplicateCheckpoint_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWAL_GetReplicateCheckpoint_Call) Return(_a0 *wal.ReplicateCheckpoint, _a1 error) *MockWAL_GetReplicateCheckpoint_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockWAL_GetReplicateCheckpoint_Call) RunAndReturn(run func() (*wal.ReplicateCheckpoint, error)) *MockWAL_GetReplicateCheckpoint_Call {
_c.Call.Return(run)
return _c
}
// IsAvailable provides a mock function with no fields // IsAvailable provides a mock function with no fields
func (_m *MockWAL) IsAvailable() bool { func (_m *MockWAL) IsAvailable() bool {
ret := _m.Called() ret := _m.Called()

View File

@ -0,0 +1,65 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_replicates
import mock "github.com/stretchr/testify/mock"
// MockReplicateAcker is an autogenerated mock type for the ReplicateAcker type
type MockReplicateAcker struct {
mock.Mock
}
type MockReplicateAcker_Expecter struct {
mock *mock.Mock
}
func (_m *MockReplicateAcker) EXPECT() *MockReplicateAcker_Expecter {
return &MockReplicateAcker_Expecter{mock: &_m.Mock}
}
// Ack provides a mock function with given fields: err
func (_m *MockReplicateAcker) Ack(err error) {
_m.Called(err)
}
// MockReplicateAcker_Ack_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ack'
type MockReplicateAcker_Ack_Call struct {
*mock.Call
}
// Ack is a helper method to define mock.On call
// - err error
func (_e *MockReplicateAcker_Expecter) Ack(err interface{}) *MockReplicateAcker_Ack_Call {
return &MockReplicateAcker_Ack_Call{Call: _e.mock.On("Ack", err)}
}
func (_c *MockReplicateAcker_Ack_Call) Run(run func(err error)) *MockReplicateAcker_Ack_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(error))
})
return _c
}
func (_c *MockReplicateAcker_Ack_Call) Return() *MockReplicateAcker_Ack_Call {
_c.Call.Return()
return _c
}
func (_c *MockReplicateAcker_Ack_Call) RunAndReturn(run func(error)) *MockReplicateAcker_Ack_Call {
_c.Run(run)
return _c
}
// NewMockReplicateAcker creates a new instance of MockReplicateAcker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockReplicateAcker(t interface {
mock.TestingT
Cleanup(func())
}) *MockReplicateAcker {
mock := &MockReplicateAcker{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,252 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_replicates
import (
context "context"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
replicates "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates"
replicateutil "github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
utility "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
)
// MockReplicatesManager is an autogenerated mock type for the ReplicatesManager type
type MockReplicatesManager struct {
mock.Mock
}
type MockReplicatesManager_Expecter struct {
mock *mock.Mock
}
func (_m *MockReplicatesManager) EXPECT() *MockReplicatesManager_Expecter {
return &MockReplicatesManager_Expecter{mock: &_m.Mock}
}
// BeginReplicateMessage provides a mock function with given fields: ctx, msg
func (_m *MockReplicatesManager) BeginReplicateMessage(ctx context.Context, msg message.MutableMessage) (replicates.ReplicateAcker, error) {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for BeginReplicateMessage")
}
var r0 replicates.ReplicateAcker
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) (replicates.ReplicateAcker, error)); ok {
return rf(ctx, msg)
}
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) replicates.ReplicateAcker); ok {
r0 = rf(ctx, msg)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(replicates.ReplicateAcker)
}
}
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage) error); ok {
r1 = rf(ctx, msg)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockReplicatesManager_BeginReplicateMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeginReplicateMessage'
type MockReplicatesManager_BeginReplicateMessage_Call struct {
*mock.Call
}
// BeginReplicateMessage is a helper method to define mock.On call
// - ctx context.Context
// - msg message.MutableMessage
func (_e *MockReplicatesManager_Expecter) BeginReplicateMessage(ctx interface{}, msg interface{}) *MockReplicatesManager_BeginReplicateMessage_Call {
return &MockReplicatesManager_BeginReplicateMessage_Call{Call: _e.mock.On("BeginReplicateMessage", ctx, msg)}
}
func (_c *MockReplicatesManager_BeginReplicateMessage_Call) Run(run func(ctx context.Context, msg message.MutableMessage)) *MockReplicatesManager_BeginReplicateMessage_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.MutableMessage))
})
return _c
}
func (_c *MockReplicatesManager_BeginReplicateMessage_Call) Return(_a0 replicates.ReplicateAcker, _a1 error) *MockReplicatesManager_BeginReplicateMessage_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockReplicatesManager_BeginReplicateMessage_Call) RunAndReturn(run func(context.Context, message.MutableMessage) (replicates.ReplicateAcker, error)) *MockReplicatesManager_BeginReplicateMessage_Call {
_c.Call.Return(run)
return _c
}
// GetReplicateCheckpoint provides a mock function with no fields
func (_m *MockReplicatesManager) GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetReplicateCheckpoint")
}
var r0 *utility.ReplicateCheckpoint
var r1 error
if rf, ok := ret.Get(0).(func() (*utility.ReplicateCheckpoint, error)); ok {
return rf()
}
if rf, ok := ret.Get(0).(func() *utility.ReplicateCheckpoint); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*utility.ReplicateCheckpoint)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockReplicatesManager_GetReplicateCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicateCheckpoint'
type MockReplicatesManager_GetReplicateCheckpoint_Call struct {
*mock.Call
}
// GetReplicateCheckpoint is a helper method to define mock.On call
func (_e *MockReplicatesManager_Expecter) GetReplicateCheckpoint() *MockReplicatesManager_GetReplicateCheckpoint_Call {
return &MockReplicatesManager_GetReplicateCheckpoint_Call{Call: _e.mock.On("GetReplicateCheckpoint")}
}
func (_c *MockReplicatesManager_GetReplicateCheckpoint_Call) Run(run func()) *MockReplicatesManager_GetReplicateCheckpoint_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockReplicatesManager_GetReplicateCheckpoint_Call) Return(_a0 *utility.ReplicateCheckpoint, _a1 error) *MockReplicatesManager_GetReplicateCheckpoint_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockReplicatesManager_GetReplicateCheckpoint_Call) RunAndReturn(run func() (*utility.ReplicateCheckpoint, error)) *MockReplicatesManager_GetReplicateCheckpoint_Call {
_c.Call.Return(run)
return _c
}
// Role provides a mock function with no fields
func (_m *MockReplicatesManager) Role() replicateutil.Role {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Role")
}
var r0 replicateutil.Role
if rf, ok := ret.Get(0).(func() replicateutil.Role); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(replicateutil.Role)
}
return r0
}
// MockReplicatesManager_Role_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Role'
type MockReplicatesManager_Role_Call struct {
*mock.Call
}
// Role is a helper method to define mock.On call
func (_e *MockReplicatesManager_Expecter) Role() *MockReplicatesManager_Role_Call {
return &MockReplicatesManager_Role_Call{Call: _e.mock.On("Role")}
}
func (_c *MockReplicatesManager_Role_Call) Run(run func()) *MockReplicatesManager_Role_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockReplicatesManager_Role_Call) Return(_a0 replicateutil.Role) *MockReplicatesManager_Role_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockReplicatesManager_Role_Call) RunAndReturn(run func() replicateutil.Role) *MockReplicatesManager_Role_Call {
_c.Call.Return(run)
return _c
}
// SwitchReplicateMode provides a mock function with given fields: ctx, msg
func (_m *MockReplicatesManager) SwitchReplicateMode(ctx context.Context, msg message.MutableAlterReplicateConfigMessageV2) error {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for SwitchReplicateMode")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, message.MutableAlterReplicateConfigMessageV2) error); ok {
r0 = rf(ctx, msg)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockReplicatesManager_SwitchReplicateMode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SwitchReplicateMode'
type MockReplicatesManager_SwitchReplicateMode_Call struct {
*mock.Call
}
// SwitchReplicateMode is a helper method to define mock.On call
// - ctx context.Context
// - msg message.MutableAlterReplicateConfigMessageV2
func (_e *MockReplicatesManager_Expecter) SwitchReplicateMode(ctx interface{}, msg interface{}) *MockReplicatesManager_SwitchReplicateMode_Call {
return &MockReplicatesManager_SwitchReplicateMode_Call{Call: _e.mock.On("SwitchReplicateMode", ctx, msg)}
}
func (_c *MockReplicatesManager_SwitchReplicateMode_Call) Run(run func(ctx context.Context, msg message.MutableAlterReplicateConfigMessageV2)) *MockReplicatesManager_SwitchReplicateMode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.MutableAlterReplicateConfigMessageV2))
})
return _c
}
func (_c *MockReplicatesManager_SwitchReplicateMode_Call) Return(_a0 error) *MockReplicatesManager_SwitchReplicateMode_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockReplicatesManager_SwitchReplicateMode_Call) RunAndReturn(run func(context.Context, message.MutableAlterReplicateConfigMessageV2) error) *MockReplicatesManager_SwitchReplicateMode_Call {
_c.Call.Return(run)
return _c
}
// NewMockReplicatesManager creates a new instance of MockReplicatesManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockReplicatesManager(t interface {
mock.TestingT
Cleanup(func())
},
) *MockReplicatesManager {
mock := &MockReplicatesManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -34,7 +34,7 @@ func CreateReplicateServer(streamServer milvuspb.MilvusService_CreateReplicateSt
type ReplicateStreamServer struct { type ReplicateStreamServer struct {
clusterID string clusterID string
streamServer milvuspb.MilvusService_CreateReplicateStreamServer streamServer milvuspb.MilvusService_CreateReplicateStreamServer
replicateRespCh chan *milvuspb.ReplicateResponse // All processing messages result should sent from theses channel. replicateRespCh chan *milvuspb.ReplicateResponse
wg sync.WaitGroup wg sync.WaitGroup
} }
@ -111,7 +111,6 @@ func (p *ReplicateStreamServer) recvLoop() (err error) {
// handleReplicateMessage handles the replicate message request. // handleReplicateMessage handles the replicate message request.
func (p *ReplicateStreamServer) handleReplicateMessage(req *milvuspb.ReplicateRequest_ReplicateMessage) error { func (p *ReplicateStreamServer) handleReplicateMessage(req *milvuspb.ReplicateRequest_ReplicateMessage) error {
// TODO: sheep, update metrics.
p.wg.Add(1) p.wg.Add(1)
defer p.wg.Done() defer p.wg.Done()
reqMsg := req.ReplicateMessage.GetMessage() reqMsg := req.ReplicateMessage.GetMessage()

View File

@ -32,7 +32,7 @@ func NewAssignmentService(
listenerTotal: metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()), listenerTotal: metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()),
} }
// TODO: after recovering from wal, add it to here. // TODO: after recovering from wal, add it to here.
// registry.RegisterPutReplicateConfigV2AckCallback(assignmentService.putReplicateConfiguration) // registry.RegisterAlterReplicateConfigV2AckCallback(assignmentService.AlterReplicateConfiguration)
return assignmentService return assignmentService
} }
@ -83,7 +83,7 @@ func (s *assignmentServiceImpl) UpdateReplicateConfiguration(ctx context.Context
} }
// TODO: After recovering from wal, remove the operation here. // TODO: After recovering from wal, remove the operation here.
if err := s.putReplicateConfiguration(ctx, mockMessages...); err != nil { if err := s.AlterReplicateConfiguration(ctx, mockMessages...); err != nil {
return nil, err return nil, err
} }
return &streamingpb.UpdateReplicateConfigurationResponse{}, nil return &streamingpb.UpdateReplicateConfigurationResponse{}, nil
@ -130,9 +130,9 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte
return b, nil return b, nil
} }
// putReplicateConfiguration puts the replicate configuration into the balancer. // AlterReplicateConfiguration puts the replicate configuration into the balancer.
// It's a callback function of the broadcast service. // It's a callback function of the broadcast service.
func (s *assignmentServiceImpl) putReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error { func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error {
balancer, err := s.balancer.GetWithContext(ctx) balancer, err := s.balancer.GetWithContext(ctx)
if err != nil { if err != nil {
return err return err

View File

@ -13,6 +13,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer"
"github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
@ -65,16 +66,41 @@ func (hc *handlerClientImpl) GetLatestMVCCTimestampIfLocal(ctx context.Context,
return w.GetLatestMVCCTimestamp(ctx, vchannel) return w.GetLatestMVCCTimestamp(ctx, vchannel)
} }
// GetReplicateCheckpoint returns the WAL checkpoint that will be used to create scanner. // GetReplicateCheckpoint gets the replicate checkpoint of the wal.
func (hc *handlerClientImpl) GetReplicateCheckpoint(ctx context.Context, channelName string) (*wal.ReplicateCheckpoint, error) { func (hc *handlerClientImpl) GetReplicateCheckpoint(ctx context.Context, pchannel string) (*wal.ReplicateCheckpoint, error) {
if !hc.lifetime.Add(typeutil.LifetimeStateWorking) { if !hc.lifetime.Add(typeutil.LifetimeStateWorking) {
return nil, ErrClientClosed return nil, ErrClientClosed
} }
defer hc.lifetime.Done() defer hc.lifetime.Done()
return nil, nil logger := log.With(zap.String("pchannel", pchannel), zap.String("handler", "replicate checkpoint"))
cp, err := hc.createHandlerAfterStreamingNodeReady(ctx, logger, pchannel, func(ctx context.Context, assign *types.PChannelInfoAssigned) (any, error) {
// TODO: sheep, implement it. if assign.Channel.AccessMode != types.AccessModeRW {
return nil, errors.New("replicate checkpoint can only be read for RW channel")
}
localWAL, err := registry.GetLocalAvailableWAL(assign.Channel)
if err == nil {
return localWAL.GetReplicateCheckpoint()
}
if !shouldUseRemoteWAL(err) {
return nil, err
}
handlerService, err := hc.service.GetService(ctx)
if err != nil {
return nil, err
}
resp, err := handlerService.GetReplicateCheckpoint(ctx, &streamingpb.GetReplicateCheckpointRequest{
Pchannel: types.NewProtoFromPChannelInfo(assign.Channel),
})
if err != nil {
return nil, err
}
return utility.NewReplicateCheckpointFromProto(resp.Checkpoint), nil
})
if err != nil {
return nil, err
}
return cp.(*wal.ReplicateCheckpoint), nil
} }
// GetWALMetricsIfLocal gets the metrics of the local wal. // GetWALMetricsIfLocal gets the metrics of the local wal.

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_assignment" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_assignment"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_consumer" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_consumer"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_producer" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_producer"
@ -34,6 +35,14 @@ func TestHandlerClient(t *testing.T) {
service := mock_lazygrpc.NewMockService[streamingpb.StreamingNodeHandlerServiceClient](t) service := mock_lazygrpc.NewMockService[streamingpb.StreamingNodeHandlerServiceClient](t)
handlerServiceClient := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t) handlerServiceClient := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t)
handlerServiceClient.EXPECT().GetReplicateCheckpoint(mock.Anything, mock.Anything).Return(&streamingpb.GetReplicateCheckpointResponse{
Checkpoint: &commonpb.ReplicateCheckpoint{
ClusterId: "pchannel",
Pchannel: "pchannel",
MessageId: nil,
TimeTick: 0,
},
}, nil)
service.EXPECT().GetService(mock.Anything).Return(handlerServiceClient, nil) service.EXPECT().GetService(mock.Anything).Return(handlerServiceClient, nil)
rb := mock_resolver.NewMockBuilder(t) rb := mock_resolver.NewMockBuilder(t)
rb.EXPECT().Close().Run(func() {}) rb.EXPECT().Close().Run(func() {})
@ -91,6 +100,10 @@ func TestHandlerClient(t *testing.T) {
producer2.Close() producer2.Close()
producer3.Close() producer3.Close()
rcp, err := handler.GetReplicateCheckpoint(ctx, "pchannel")
assert.NoError(t, err)
assert.NotNil(t, rcp)
handler.GetLatestMVCCTimestampIfLocal(ctx, "pchannel") handler.GetLatestMVCCTimestampIfLocal(ctx, "pchannel")
producer4, err := handler.CreateProducer(ctx, &ProducerOptions{PChannel: "pchannel"}) producer4, err := handler.CreateProducer(ctx, &ProducerOptions{PChannel: "pchannel"})
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -13,6 +13,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
@ -160,7 +161,7 @@ func (impl *WALFlusherImpl) buildFlusherComponents(ctx context.Context, l wal.WA
impl.RecoveryStorage.UpdateFlusherCheckpoint(mp.ChannelName, &recovery.WALCheckpoint{ impl.RecoveryStorage.UpdateFlusherCheckpoint(mp.ChannelName, &recovery.WALCheckpoint{
MessageID: messageID, MessageID: messageID,
TimeTick: mp.Timestamp, TimeTick: mp.Timestamp,
Magic: recovery.RecoveryMagicStreamingInitialized, Magic: utility.RecoveryMagicStreamingInitialized,
}) })
}) })
go cpUpdater.Start() go cpUpdater.Start()

View File

@ -7,6 +7,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/service/handler/producer" "github.com/milvus-io/milvus/internal/streamingnode/server/service/handler/producer"
"github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
) )
var _ HandlerService = (*handlerServiceImpl)(nil) var _ HandlerService = (*handlerServiceImpl)(nil)
@ -27,12 +28,22 @@ type HandlerService = streamingpb.StreamingNodeHandlerServiceServer
// 2. wait wal handling result and transform it into grpc response (convert error into grpc error) // 2. wait wal handling result and transform it into grpc response (convert error into grpc error)
// 3. send response to client. // 3. send response to client.
type handlerServiceImpl struct { type handlerServiceImpl struct {
streamingpb.UnimplementedStreamingNodeHandlerServiceServer
walManager walmanager.Manager walManager walmanager.Manager
} }
// GetReplicateCheckpoint returns the WAL checkpoint that will be used to create scanner // GetReplicateCheckpoint returns the replicate checkpoint of the wal.
func (hs *handlerServiceImpl) GetReplicateCheckpoint(ctx context.Context, req *streamingpb.GetReplicateCheckpointRequest) (*streamingpb.GetReplicateCheckpointResponse, error) { func (hs *handlerServiceImpl) GetReplicateCheckpoint(ctx context.Context, req *streamingpb.GetReplicateCheckpointRequest) (*streamingpb.GetReplicateCheckpointResponse, error) {
panic("not implemented") // TODO: sheep, implement it. wal, err := hs.walManager.GetAvailableWAL(types.NewPChannelInfoFromProto(req.GetPchannel()))
if err != nil {
return nil, err
}
cp, err := wal.GetReplicateCheckpoint()
if err != nil {
return nil, err
}
return &streamingpb.GetReplicateCheckpointResponse{Checkpoint: cp.IntoProto()}, nil
} }
// Produce creates a new producer for the channel on this log node. // Produce creates a new producer for the channel on this log node.

View File

@ -10,6 +10,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery"
@ -17,6 +18,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -107,6 +109,15 @@ func (o *openerAdaptorImpl) openRWWAL(ctx context.Context, l walimpls.WALImpls,
InitialRecoverSnapshot: snapshot, InitialRecoverSnapshot: snapshot,
TxnManager: param.TxnManager, TxnManager: param.TxnManager,
}) })
if param.ReplicateManager, err = replicates.RecoverReplicateManager(
&replicates.ReplicateManagerRecoverParam{
ChannelInfo: param.ChannelInfo,
CurrentClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(),
InitialRecoverSnapshot: snapshot,
},
); err != nil {
return nil, err
}
// start the flusher to flush and generate recovery info. // start the flusher to flush and generate recovery info.
var flusher *flusherimpl.WALFlusherImpl var flusher *flusherimpl.WALFlusherImpl

View File

@ -50,6 +50,10 @@ func (w *roWALAdaptorImpl) GetLatestMVCCTimestamp(ctx context.Context, vchannel
panic("we cannot acquire lastest mvcc timestamp from a read only wal") panic("we cannot acquire lastest mvcc timestamp from a read only wal")
} }
func (w *roWALAdaptorImpl) GetReplicateCheckpoint() (*wal.ReplicateCheckpoint, error) {
panic("we cannot get replicate checkpoint from a read only wal")
}
// Append writes a record to the log. // Append writes a record to the log.
func (w *roWALAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) { func (w *roWALAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) {
panic("we cannot append message into a read only wal") panic("we cannot append message into a read only wal")

View File

@ -121,6 +121,16 @@ func (w *walAdaptorImpl) GetLatestMVCCTimestamp(ctx context.Context, vchannel st
return currentMVCC.Timetick, nil return currentMVCC.Timetick, nil
} }
// GetReplicateCheckpoint returns the replicate checkpoint of the wal.
func (w *walAdaptorImpl) GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) {
if !w.lifetime.Add(typeutil.LifetimeStateWorking) {
return nil, status.NewOnShutdownError("wal is on shutdown")
}
defer w.lifetime.Done()
return w.param.ReplicateManager.GetReplicateCheckpoint()
}
// Append writes a record to the log. // Append writes a record to the log.
func (w *walAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) { func (w *walAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) {
if !w.lifetime.Add(typeutil.LifetimeStateWorking) { if !w.lifetime.Add(typeutil.LifetimeStateWorking) {

View File

@ -22,6 +22,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/lock" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/lock"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/redo" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/redo"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry"
@ -57,6 +58,7 @@ func TestWAL(t *testing.T) {
b := registry.MustGetBuilder(message.WALNameTest, b := registry.MustGetBuilder(message.WALNameTest,
redo.NewInterceptorBuilder(), redo.NewInterceptorBuilder(),
lock.NewInterceptorBuilder(), lock.NewInterceptorBuilder(),
replicate.NewInterceptorBuilder(),
timetick.NewInterceptorBuilder(), timetick.NewInterceptorBuilder(),
shard.NewInterceptorBuilder(), shard.NewInterceptorBuilder(),
) )
@ -181,6 +183,10 @@ func (f *testOneWALFramework) Run() {
} }
func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, rwWAL wal.WAL, roWAL wal.ROWAL) { func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, rwWAL wal.WAL, roWAL wal.ROWAL) {
cp, err := rwWAL.GetReplicateCheckpoint()
assert.True(f.t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(f.t, cp)
f.testSendCreateCollection(ctx, rwWAL) f.testSendCreateCollection(ctx, rwWAL)
defer f.testSendDropCollection(ctx, rwWAL) defer f.testSendDropCollection(ctx, rwWAL)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard/shards"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/mvcc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/mvcc"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn"
@ -29,6 +30,7 @@ type InterceptorBuildParam struct {
InitialRecoverSnapshot *recovery.RecoverySnapshot // The initial recover snapshot for the wal, used to recover the wal state. InitialRecoverSnapshot *recovery.RecoverySnapshot // The initial recover snapshot for the wal, used to recover the wal state.
TxnManager *txn.TxnManager // The transaction manager for the wal, used to manage the transactions. TxnManager *txn.TxnManager // The transaction manager for the wal, used to manage the transactions.
ShardManager shards.ShardManager // The shard manager for the wal, used to manage the shards, segment assignment, partition. ShardManager shards.ShardManager // The shard manager for the wal, used to manage the shards, segment assignment, partition.
ReplicateManager replicates.ReplicateManager // The replicates manager for the wal, used to manage the replicates.
} }
// Clear release the resources in the interceptor build param. // Clear release the resources in the interceptor build param.

View File

@ -16,6 +16,7 @@ type interceptorBuilder struct{}
// Build creates a new redo interceptor. // Build creates a new redo interceptor.
func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor { func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor {
return &lockAppendInterceptor{ return &lockAppendInterceptor{
channel: param.ChannelInfo,
vchannelLocker: lock.NewKeyLock[string](), vchannelLocker: lock.NewKeyLock[string](),
txnManager: param.TxnManager, txnManager: param.TxnManager,
} }

View File

@ -2,14 +2,18 @@ package lock
import ( import (
"context" "context"
"sync"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/lock" "github.com/milvus-io/milvus/pkg/v2/util/lock"
) )
type lockAppendInterceptor struct { type lockAppendInterceptor struct {
channel types.PChannelInfo
glock sync.RWMutex // glock is a wal level lock, it will acquire a highest level lock for wal.
vchannelLocker *lock.KeyLock[string] vchannelLocker *lock.KeyLock[string]
txnManager *txn.TxnManager txnManager *txn.TxnManager
} }
@ -26,6 +30,14 @@ func (r *lockAppendInterceptor) acquireLockGuard(_ context.Context, msg message.
// Acquire the write lock for the vchannel. // Acquire the write lock for the vchannel.
vchannel := msg.VChannel() vchannel := msg.VChannel()
if msg.MessageType().IsExclusiveRequired() { if msg.MessageType().IsExclusiveRequired() {
if vchannel == "" || vchannel == r.channel.Name {
r.glock.Lock()
return func() {
// fail all transactions at all vchannels.
r.txnManager.FailTxnAtVChannel("")
r.glock.Unlock()
}
} else {
r.vchannelLocker.Lock(vchannel) r.vchannelLocker.Lock(vchannel)
return func() { return func() {
// For exclusive messages, we need to fail all transactions at the vchannel. // For exclusive messages, we need to fail all transactions at the vchannel.
@ -40,9 +52,12 @@ func (r *lockAppendInterceptor) acquireLockGuard(_ context.Context, msg message.
r.vchannelLocker.Unlock(vchannel) r.vchannelLocker.Unlock(vchannel)
} }
} }
}
r.glock.RLock()
r.vchannelLocker.RLock(vchannel) r.vchannelLocker.RLock(vchannel)
return func() { return func() {
r.vchannelLocker.RUnlock(vchannel) r.vchannelLocker.RUnlock(vchannel)
r.glock.RUnlock()
} }
} }

View File

@ -0,0 +1,15 @@
package replicate
import "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
func NewInterceptorBuilder() interceptors.InterceptorBuilder {
return &interceptorBuilder{}
}
type interceptorBuilder struct{}
func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor {
return &replicateInterceptor{
replicateManager: param.ReplicateManager,
}
}

View File

@ -0,0 +1,51 @@
package replicate
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
const interceptorName = "replicate"
type replicateInterceptor struct {
replicateManager replicates.ReplicateManager
}
func (impl *replicateInterceptor) Name() string {
return interceptorName
}
func (impl *replicateInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (msgID message.MessageID, err error) {
if msg.MessageType() == message.MessageTypeAlterReplicateConfig {
// A AlterReplicateConfig message is protected by wal level lock, so it's safe to switch replicate mode.
// switch replicate mode if the message is put replicate config.
alterReplicateConfig := message.MustAsMutableAlterReplicateConfigMessageV2(msg)
if err := impl.replicateManager.SwitchReplicateMode(ctx, alterReplicateConfig); err != nil {
return nil, err
}
return appendOp(ctx, msg)
}
// Begin to replicate the message.
acker, err := impl.replicateManager.BeginReplicateMessage(ctx, msg)
if errors.Is(err, replicates.ErrNotHandledByReplicateManager) {
// the message is not handled by replicate manager, write it into wal directly.
return appendOp(ctx, msg)
}
if err != nil {
return nil, err
}
defer func() {
acker.Ack(err)
}()
return appendOp(ctx, msg)
}
func (impl *replicateInterceptor) Close() {
}

View File

@ -0,0 +1,71 @@
package replicate
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/interceptors/replicate/mock_replicates"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate/replicates"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
)
func TestReplicateInterceptor(t *testing.T) {
manager := mock_replicates.NewMockReplicatesManager(t)
acker := mock_replicates.NewMockReplicateAcker(t)
manager.EXPECT().SwitchReplicateMode(mock.Anything, mock.Anything).Return(nil)
manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Return(acker, nil)
acker.EXPECT().Ack(mock.Anything).Return()
interceptor := NewInterceptorBuilder().Build(&interceptors.InterceptorBuildParam{
ReplicateManager: manager,
})
mutableMsg := message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithAllVChannel().
MustBuildMutable()
msgID, err := interceptor.DoAppend(context.Background(), mutableMsg, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) {
return walimplstest.NewTestMessageID(1), nil
})
assert.NoError(t, err)
assert.NotNil(t, msgID)
mutableMsg2 := message.NewCreateDatabaseMessageBuilderV2().
WithHeader(&message.CreateDatabaseMessageHeader{}).
WithBody(&message.CreateDatabaseMessageBody{}).
WithVChannel("test").
MustBuildMutable()
msgID2, err := interceptor.DoAppend(context.Background(), mutableMsg2, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) {
return walimplstest.NewTestMessageID(2), nil
})
assert.NoError(t, err)
assert.NotNil(t, msgID2)
manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Unset()
manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Return(nil, replicates.ErrNotHandledByReplicateManager)
msgID3, err := interceptor.DoAppend(context.Background(), mutableMsg2, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) {
return walimplstest.NewTestMessageID(3), nil
})
assert.NoError(t, err)
assert.NotNil(t, msgID3)
manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Unset()
manager.EXPECT().BeginReplicateMessage(mock.Anything, mock.Anything).Return(nil, errors.New("test"))
msgID4, err := interceptor.DoAppend(context.Background(), mutableMsg2, func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) {
return walimplstest.NewTestMessageID(4), nil
})
assert.Error(t, err)
assert.Nil(t, msgID4)
interceptor.Close()
}

View File

@ -0,0 +1,240 @@
package replicates
import (
"context"
"sync"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
)
// ErrNotHandledByReplicateManager is a special error to indicate that the message should not be handled by the replicate manager.
var ErrNotHandledByReplicateManager = errors.New("not handled by replicate manager")
// ReplicateManagerRecoverParam is the parameter for recovering the replicate manager.
type ReplicateManagerRecoverParam struct {
ChannelInfo types.PChannelInfo
CurrentClusterID string
InitialRecoverSnapshot *recovery.RecoverySnapshot // the initial recover snapshot of the replicate manager.
}
// RecoverReplicateManager recovers the replicate manager from the initial recover snapshot.
// It will recover the replicate manager from the initial recover snapshot.
// If the wal is on replicating mode, it will recover the replicate state.
func RecoverReplicateManager(param *ReplicateManagerRecoverParam) (ReplicateManager, error) {
replicateConfigHelper, err := replicateutil.NewConfigHelper(param.CurrentClusterID, param.InitialRecoverSnapshot.Checkpoint.ReplicateConfig)
if err != nil {
return nil, newReplicateViolationErrorForConfig(param.InitialRecoverSnapshot.Checkpoint.ReplicateConfig, err)
}
rm := &replicatesManagerImpl{
mu: sync.Mutex{},
currentClusterID: param.CurrentClusterID,
pchannel: param.ChannelInfo,
replicateConfigHelper: replicateConfigHelper,
}
if !rm.isPrimaryRole() {
// if current cluster is not the primary role,
// recover the secondary state for it.
if rm.secondaryState, err = recoverSecondaryState(param); err != nil {
return nil, err
}
}
return rm, nil
}
// replicatesManagerImpl is the implementation of the replicates manager.
type replicatesManagerImpl struct {
mu sync.Mutex
pchannel types.PChannelInfo
currentClusterID string
replicateConfigHelper *replicateutil.ConfigHelper
secondaryState *secondaryState // if the current cluster is not the primary role, it will have secondaryState.
}
// SwitchReplicateMode switches the replicates manager between replicating mode and non-replicating mode.
func (impl *replicatesManagerImpl) SwitchReplicateMode(_ context.Context, msg message.MutableAlterReplicateConfigMessageV2) error {
impl.mu.Lock()
defer impl.mu.Unlock()
newCfg := msg.Header().ReplicateConfiguration
newGraph, err := replicateutil.NewConfigHelper(impl.currentClusterID, newCfg)
if err != nil {
return newReplicateViolationErrorForConfig(newCfg, err)
}
incomingCurrentClusterConfig := newGraph.GetCurrentCluster()
switch incomingCurrentClusterConfig.Role() {
case replicateutil.RolePrimary:
// drop the replicating state if the current cluster is switched to primary.
impl.secondaryState = nil
case replicateutil.RoleSecondary:
if impl.isPrimaryRole() || impl.secondaryState.SourceClusterID() != incomingCurrentClusterConfig.SourceCluster().GetClusterId() {
// Only update the replicating state when the current cluster switch from primary to secondary,
// or the source cluster is changed.
impl.secondaryState = newSecondaryState(
incomingCurrentClusterConfig.SourceCluster().GetClusterId(),
incomingCurrentClusterConfig.MustGetSourceChannel(impl.pchannel.Name),
)
}
}
impl.replicateConfigHelper = newGraph
return nil
}
func (impl *replicatesManagerImpl) BeginReplicateMessage(ctx context.Context, msg message.MutableMessage) (g ReplicateAcker, err error) {
rh := msg.ReplicateHeader()
// some message type like timetick, create segment, flush are generated by wal itself.
// it should never be handled by the replicates manager.
if msg.MessageType().IsSelfControlled() {
if rh != nil {
return nil, status.NewIgnoreOperation("wal self-controlled message cannot be replicated")
}
return nil, ErrNotHandledByReplicateManager
}
impl.mu.Lock()
defer func() {
if err != nil {
impl.mu.Unlock()
}
}()
switch impl.getRole() {
case replicateutil.RolePrimary:
if rh != nil {
return nil, status.NewReplicateViolation("replicate message cannot be received in primary role")
}
return nil, ErrNotHandledByReplicateManager
case replicateutil.RoleSecondary:
if rh == nil {
return nil, status.NewReplicateViolation("non-replicate message cannot be received in secondary role")
}
return impl.beginReplicateMessage(ctx, msg)
default:
panic("unreachable: invalid role")
}
}
// GetReplicateCheckpoint gets the replicate checkpoint.
func (impl *replicatesManagerImpl) GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error) {
impl.mu.Lock()
defer impl.mu.Unlock()
if impl.isPrimaryRole() {
return nil, status.NewReplicateViolation("wal is not a secondary cluster in replicating topology")
}
return impl.secondaryState.GetCheckpoint(), nil
}
// beginReplicateMessage begins the replicate message operation.
func (impl *replicatesManagerImpl) beginReplicateMessage(ctx context.Context, msg message.MutableMessage) (ReplicateAcker, error) {
rh := msg.ReplicateHeader()
if rh.ClusterID != impl.secondaryState.SourceClusterID() {
return nil, status.NewReplicateViolation("cluster id mismatch, current: %s, expected: %s", rh.ClusterID, impl.secondaryState.SourceClusterID())
}
// if the incoming message's time tick is less than the checkpoint's time tick,
// it means that the message has been written to the wal, so it can be ignored.
// txn message will share same time tick, so we only filter with <, it will be deduplicated by the txnHelper.
isTxnBody := msg.TxnContext() != nil && msg.MessageType() != message.MessageTypeBeginTxn
if (isTxnBody && rh.TimeTick < impl.secondaryState.GetCheckpoint().TimeTick) || (!isTxnBody && rh.TimeTick <= impl.secondaryState.GetCheckpoint().TimeTick) {
return nil, status.NewIgnoreOperation("message is too old, message_id: %s, time_tick: %d, txn: %t, current time tick: %d",
rh.MessageID, rh.TimeTick, isTxnBody, impl.secondaryState.GetCheckpoint().TimeTick)
}
if msg.TxnContext() != nil {
return impl.startReplicateTxnMessage(ctx, msg, rh)
}
return impl.startReplicateNonTxnMessage(ctx, msg, rh)
}
// startReplicateTxnMessage starts the replicate txn message operation.
func (impl *replicatesManagerImpl) startReplicateTxnMessage(_ context.Context, msg message.MutableMessage, rh *message.ReplicateHeader) (ReplicateAcker, error) {
txn := msg.TxnContext()
switch msg.MessageType() {
case message.MessageTypeBeginTxn:
if err := impl.secondaryState.StartBegin(txn, rh); err != nil {
return nil, err
}
return replicateAckerImpl(func(err error) {
if err == nil {
impl.secondaryState.BeginDone(txn)
}
impl.mu.Unlock()
}), nil
case message.MessageTypeCommitTxn:
if err := impl.secondaryState.StartCommit(txn); err != nil {
return nil, err
}
// only update the checkpoint when the txn is committed.
return replicateAckerImpl(func(err error) {
if err == nil {
impl.secondaryState.CommitDone(txn)
impl.secondaryState.PushForwardCheckpoint(rh.TimeTick, rh.LastConfirmedMessageID)
}
impl.mu.Unlock()
}), nil
case message.MessageTypeRollbackTxn:
panic("unreachable: rollback txn message should never be replicated when wal is on replicating mode")
default:
if err := impl.secondaryState.AddNewMessage(txn, rh); err != nil {
return nil, err
}
return replicateAckerImpl(func(err error) {
if err == nil {
impl.secondaryState.AddNewMessageDone(rh)
}
impl.mu.Unlock()
}), nil
}
}
// startReplicateNonTxnMessage starts the replicate non-txn message operation.
func (impl *replicatesManagerImpl) startReplicateNonTxnMessage(_ context.Context, _ message.MutableMessage, rh *message.ReplicateHeader) (ReplicateAcker, error) {
if impl.secondaryState.CurrentTxn() != nil {
return nil, status.NewReplicateViolation(
"txn is in progress, so the incoming message must be txn message, current txn: %d",
impl.secondaryState.CurrentTxn().TxnID,
)
}
return replicateAckerImpl(func(err error) {
if err == nil {
impl.secondaryState.PushForwardCheckpoint(rh.TimeTick, rh.LastConfirmedMessageID)
}
impl.mu.Unlock()
}), nil
}
// Role returns the role of the current cluster in the replicate topology.
func (impl *replicatesManagerImpl) Role() replicateutil.Role {
impl.mu.Lock()
defer impl.mu.Unlock()
return impl.getRole()
}
// getRole returns the role of the current cluster in the replicate topology.
func (impl *replicatesManagerImpl) getRole() replicateutil.Role {
if impl.replicateConfigHelper == nil {
return replicateutil.RolePrimary
}
return impl.replicateConfigHelper.MustGetCluster(impl.currentClusterID).Role()
}
// isPrimaryRole checks if the current cluster is the primary role.
func (impl *replicatesManagerImpl) isPrimaryRole() bool {
return impl.getRole() == replicateutil.RolePrimary
}
// newReplicateViolationErrorForConfig creates a new replicate violation error for the given configuration and error.
func newReplicateViolationErrorForConfig(cfg *commonpb.ReplicateConfiguration, err error) error {
bytes, _ := protojson.Marshal(cfg)
return status.NewReplicateViolation("when greating replciate graph, %s, %s", string(bytes), err.Error())
}

View File

@ -0,0 +1,48 @@
package replicates
import (
"context"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
)
type replicateAckerImpl func(err error)
func (r replicateAckerImpl) Ack(err error) {
r(err)
}
// ReplicateAcker is a guard for replicate message.
type ReplicateAcker interface {
// Ack acknowledges the replicate message operation is done.
// It will push forward the in-memory checkpoint if the err is nil.
Ack(err error)
}
// ReplicateManager manages the replicate operation on one wal.
// There are two states:
// 1. primary: wal will only receive the non-replicate message.
// 2. secondary: wal will only receive the replicate message.
type ReplicateManager interface {
// Role returns the role of the replicate manager.
Role() replicateutil.Role
// SwitchReplicateMode switches the replicate mode.
// following cases will happens:
// 1. primary->secondary: will transit into replicating mode, the message without replicate header will be rejected.
// 2. primary->primary: nothing happens,
// 3. secondary->primary: will transit into non-replicating mode, the secondary replica state (remote cluster replicating checkpoint...) will be dropped.
// 4. secondary->secondary with the source cluster is changed: the previous remote cluster replicating checkpoint will be dropped.
// 5. secondary->secondary without the source cluster is changed: nothing happens.
SwitchReplicateMode(ctx context.Context, msg message.MutableAlterReplicateConfigMessageV2) error
// BeginReplicateMessage begins the replicate one-replicated-message operation.
// ReplicateAcker's Ack method should be called if returned without error.
BeginReplicateMessage(ctx context.Context, msg message.MutableMessage) (ReplicateAcker, error)
// GetReplicateCheckpoint gets current replicate checkpoint.
// return ReplicateViolationError if the replicate mode is not replicating.
GetReplicateCheckpoint() (*utility.ReplicateCheckpoint, error)
}

View File

@ -0,0 +1,495 @@
package replicates
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/metricsutil"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/recovery"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
)
func TestNonReplicateManager(t *testing.T) {
rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{
ChannelInfo: types.PChannelInfo{
Name: "test1-rootcoord-dml_0",
Term: 1,
},
CurrentClusterID: "test1",
InitialRecoverSnapshot: &recovery.RecoverySnapshot{
Checkpoint: &utility.WALCheckpoint{
MessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
ReplicateCheckpoint: nil,
ReplicateConfig: nil,
},
},
})
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
testSwitchReplicateMode(t, rm, "test1", "test2")
testMessageOnPrimary(t, rm)
testMessageOnSecondary(t, rm)
}
func TestPrimaryReplicateManager(t *testing.T) {
rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{
ChannelInfo: types.PChannelInfo{
Name: "test1-rootcoord-dml_0",
Term: 1,
},
CurrentClusterID: "test1",
InitialRecoverSnapshot: &recovery.RecoverySnapshot{
Checkpoint: &utility.WALCheckpoint{
MessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
ReplicateCheckpoint: nil,
ReplicateConfig: newReplicateConfiguration("test1", "test2"),
},
},
})
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
testSwitchReplicateMode(t, rm, "test1", "test2")
testMessageOnPrimary(t, rm)
testMessageOnSecondary(t, rm)
}
func TestSecondaryReplicateManager(t *testing.T) {
rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{
ChannelInfo: types.PChannelInfo{
Name: "test1-rootcoord-dml_0",
Term: 1,
},
CurrentClusterID: "test1",
InitialRecoverSnapshot: &recovery.RecoverySnapshot{
Checkpoint: &utility.WALCheckpoint{
MessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
ReplicateCheckpoint: &utility.ReplicateCheckpoint{
ClusterID: "test2",
PChannel: "test2-rootcoord-dml_0",
MessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
ReplicateConfig: newReplicateConfiguration("test2", "test1"),
},
TxnBuffer: utility.NewTxnBuffer(log.With(), metricsutil.NewScanMetrics(types.PChannelInfo{}).NewScannerMetrics()),
},
})
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
testSwitchReplicateMode(t, rm, "test1", "test2")
testMessageOnPrimary(t, rm)
testMessageOnSecondary(t, rm)
}
func TestSecondaryReplicateManagerWithTxn(t *testing.T) {
txnBuffer := utility.NewTxnBuffer(log.With(), metricsutil.NewScanMetrics(types.PChannelInfo{}).NewScannerMetrics())
txnMsgs := newReplicateTxnMessage("test1", "test2", 2)
for _, msg := range txnMsgs[0:3] {
immutableMsg := msg.WithTimeTick(3).IntoImmutableMessage(walimplstest.NewTestMessageID(1))
txnBuffer.HandleImmutableMessages([]message.ImmutableMessage{immutableMsg}, msg.TimeTick())
}
rm, err := RecoverReplicateManager(&ReplicateManagerRecoverParam{
ChannelInfo: types.PChannelInfo{
Name: "test1-rootcoord-dml_0",
Term: 1,
},
CurrentClusterID: "test1",
InitialRecoverSnapshot: &recovery.RecoverySnapshot{
Checkpoint: &utility.WALCheckpoint{
MessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
ReplicateCheckpoint: &utility.ReplicateCheckpoint{
ClusterID: "test2",
PChannel: "test2-rootcoord-dml_0",
MessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
ReplicateConfig: newReplicateConfiguration("test2", "test1"),
},
TxnBuffer: txnBuffer,
},
})
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
committed := false
for _, msg := range newReplicateTxnMessage("test1", "test2", 2) {
g, err := rm.BeginReplicateMessage(context.Background(), msg)
if msg.MessageType() == message.MessageTypeCommitTxn && !committed {
assert.NoError(t, err)
assert.NotNil(t, g)
g.Ack(nil)
committed = true
} else {
assert.True(t, status.AsStreamingError(err).IsIgnoredOperation())
assert.Nil(t, g)
}
}
}
func testSwitchReplicateMode(t *testing.T, rm ReplicateManager, primaryClusterID, secondaryClusterID string) {
ctx := context.Background()
// switch to primary
err := rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
cp, err := rm.GetReplicateCheckpoint()
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, cp)
// idempotent switch to primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
cp, err = rm.GetReplicateCheckpoint()
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, cp)
// switch to secondary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, secondaryClusterID)
assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// idempotent switch to secondary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, secondaryClusterID)
assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// switch back to primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
cp, err = rm.GetReplicateCheckpoint()
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, cp)
// idempotent switch back to primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
cp, err = rm.GetReplicateCheckpoint()
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, cp)
// switch back to secondary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, secondaryClusterID)
assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// idempotent switch back to secondary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, secondaryClusterID)
assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// add a new cluster and switch to primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID, "test3"))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
cp, err = rm.GetReplicateCheckpoint()
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, cp)
// idempotent add a new cluster and switch to primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(primaryClusterID, secondaryClusterID, "test3"))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RolePrimary)
cp, err = rm.GetReplicateCheckpoint()
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, cp)
// add a new cluster and switch to secondary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID, "test3"))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, secondaryClusterID)
assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// idempotent add a new cluster and switch to secondary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage(secondaryClusterID, primaryClusterID, "test3"))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, secondaryClusterID)
assert.Equal(t, cp.PChannel, secondaryClusterID+"-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// switch the primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage("test3", primaryClusterID, secondaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, "test3")
assert.Equal(t, cp.PChannel, "test3-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
// idempotent switch the primary
err = rm.SwitchReplicateMode(ctx, newAlterReplicateConfigMessage("test3", primaryClusterID, secondaryClusterID))
assert.NoError(t, err)
assert.Equal(t, rm.Role(), replicateutil.RoleSecondary)
cp, err = rm.GetReplicateCheckpoint()
assert.NoError(t, err)
assert.Equal(t, cp.ClusterID, "test3")
assert.Equal(t, cp.PChannel, "test3-rootcoord-dml_0")
assert.Nil(t, cp.MessageID)
assert.Equal(t, cp.TimeTick, uint64(0))
}
func testMessageOnPrimary(t *testing.T, rm ReplicateManager) {
// switch to primary
err := rm.SwitchReplicateMode(context.Background(), newAlterReplicateConfigMessage("test1", "test2"))
assert.NoError(t, err)
// Test self-controlled message
g, err := rm.BeginReplicateMessage(context.Background(), message.NewCreateSegmentMessageBuilderV2().
WithHeader(&message.CreateSegmentMessageHeader{}).
WithBody(&message.CreateSegmentMessageBody{}).
WithVChannel("test1-rootcoord-dml_0").
MustBuildMutable())
assert.ErrorIs(t, err, ErrNotHandledByReplicateManager)
assert.Nil(t, g)
// Test non-replicate message
msg := newNonReplicateMessage("test1")
g, err = rm.BeginReplicateMessage(context.Background(), msg)
assert.ErrorIs(t, err, ErrNotHandledByReplicateManager)
assert.Nil(t, g)
// Test replicate message
msg = newReplicateMessage("test1", "test2")
g, err = rm.BeginReplicateMessage(context.Background(), msg)
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, g)
}
func testMessageOnSecondary(t *testing.T, rm ReplicateManager) {
// switch to secondary
err := rm.SwitchReplicateMode(context.Background(), newAlterReplicateConfigMessage("test2", "test1"))
assert.NoError(t, err)
// Test wrong cluster replicates
msg := newReplicateMessage("test1", "test3")
g, err := rm.BeginReplicateMessage(context.Background(), msg)
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, g)
// Test self-controlled message
g, err = rm.BeginReplicateMessage(context.Background(), message.NewCreateSegmentMessageBuilderV2().
WithHeader(&message.CreateSegmentMessageHeader{}).
WithBody(&message.CreateSegmentMessageBody{}).
WithVChannel("test1-rootcoord-dml_0").
MustBuildMutable())
assert.ErrorIs(t, err, ErrNotHandledByReplicateManager)
assert.Nil(t, g)
// Test non-replicate message
msg = newNonReplicateMessage("test1")
g, err = rm.BeginReplicateMessage(context.Background(), msg)
assert.True(t, status.AsStreamingError(err).IsReplicateViolation())
assert.Nil(t, g)
// Test replicate message
msg = newReplicateMessage("test1", "test2")
g, err = rm.BeginReplicateMessage(context.Background(), msg)
assert.NoError(t, err)
assert.NotNil(t, g)
g.Ack(nil)
// Test replicate message
msg = newReplicateMessage("test1", "test2")
g, err = rm.BeginReplicateMessage(context.Background(), msg)
assert.True(t, status.AsStreamingError(err).IsIgnoredOperation())
assert.Nil(t, g)
for idx, msg := range newReplicateTxnMessage("test1", "test2", 2) {
g, err = rm.BeginReplicateMessage(context.Background(), msg)
if idx%2 == 0 {
assert.NoError(t, err)
assert.NotNil(t, g)
g.Ack(nil)
} else {
assert.True(t, status.AsStreamingError(err).IsIgnoredOperation())
assert.Nil(t, g)
}
}
msg = newReplicateMessage("test1", "test2", 2)
g, err = rm.BeginReplicateMessage(context.Background(), msg)
assert.True(t, status.AsStreamingError(err).IsIgnoredOperation())
assert.Nil(t, g)
g, err = rm.BeginReplicateMessage(context.Background(), newReplicateTxnMessage("test1", "test2", 2)[0])
assert.True(t, status.AsStreamingError(err).IsIgnoredOperation())
assert.Nil(t, g)
}
// newReplicateConfiguration creates a valid replicate configuration for testing
func newReplicateConfiguration(primaryClusterID string, secondaryClusterID ...string) *commonpb.ReplicateConfiguration {
clusters := []*commonpb.MilvusCluster{
{ClusterId: primaryClusterID, Pchannels: []string{primaryClusterID + "-rootcoord-dml_0", primaryClusterID + "-rootcoord-dml_1"}},
}
crossClusterTopology := []*commonpb.CrossClusterTopology{}
for _, secondaryClusterID := range secondaryClusterID {
clusters = append(clusters, &commonpb.MilvusCluster{ClusterId: secondaryClusterID, Pchannels: []string{secondaryClusterID + "-rootcoord-dml_0", secondaryClusterID + "-rootcoord-dml_1"}})
crossClusterTopology = append(crossClusterTopology, &commonpb.CrossClusterTopology{SourceClusterId: primaryClusterID, TargetClusterId: secondaryClusterID})
}
return &commonpb.ReplicateConfiguration{
Clusters: clusters,
CrossClusterTopology: crossClusterTopology,
}
}
func newAlterReplicateConfigMessage(primaryClusterID string, secondaryClusterID ...string) message.MutableAlterReplicateConfigMessageV2 {
return message.MustAsMutableAlterReplicateConfigMessageV2(message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{
ReplicateConfiguration: newReplicateConfiguration(primaryClusterID, secondaryClusterID...),
}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithVChannel(primaryClusterID + "-rootcoord-dml_0").
MustBuildMutable())
}
func newNonReplicateMessage(clusterID string) message.MutableMessage {
return message.NewCreateDatabaseMessageBuilderV2().
WithHeader(&message.CreateDatabaseMessageHeader{}).
WithBody(&message.CreateDatabaseMessageBody{}).
WithVChannel(clusterID + "-rootcoord-dml_0").
MustBuildMutable()
}
func newReplicateMessage(clusterID string, sourceClusterID string, timetick ...uint64) message.MutableMessage {
tt := uint64(1)
if len(timetick) > 0 {
tt = timetick[0]
}
msg := message.NewCreateDatabaseMessageBuilderV2().
WithHeader(&message.CreateDatabaseMessageHeader{}).
WithBody(&message.CreateDatabaseMessageBody{}).
WithVChannel(sourceClusterID + "-rootcoord-dml_0").
MustBuildMutable().
WithTimeTick(tt).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
IntoImmutableMessage(walimplstest.NewTestMessageID(1))
replicateMsg := message.NewReplicateMessage(
sourceClusterID,
msg.IntoImmutableMessageProto(),
)
replicateMsg.OverwriteReplicateVChannel(
clusterID + "-rootcoord-dml_0",
)
return replicateMsg
}
func newImmutableTxnMessage(clusterID string, timetick ...uint64) []message.ImmutableMessage {
tt := uint64(1)
if len(timetick) > 0 {
tt = timetick[0]
}
immutables := []message.ImmutableMessage{
message.NewBeginTxnMessageBuilderV2().
WithHeader(&message.BeginTxnMessageHeader{}).
WithBody(&message.BeginTxnMessageBody{}).
WithVChannel(clusterID + "-rootcoord-dml_0").
MustBuildMutable().
WithTxnContext(message.TxnContext{
TxnID: message.TxnID(1),
Keepalive: message.TxnKeepaliveInfinite,
}).
WithTimeTick(tt).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
IntoImmutableMessage(walimplstest.NewTestMessageID(1)),
message.NewCreateDatabaseMessageBuilderV2().
WithHeader(&message.CreateDatabaseMessageHeader{}).
WithBody(&message.CreateDatabaseMessageBody{}).
WithVChannel(clusterID + "-rootcoord-dml_0").
MustBuildMutable().
WithTxnContext(message.TxnContext{
TxnID: message.TxnID(1),
Keepalive: message.TxnKeepaliveInfinite,
}).
WithTimeTick(tt).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
IntoImmutableMessage(walimplstest.NewTestMessageID(1)),
message.NewCommitTxnMessageBuilderV2().
WithHeader(&message.CommitTxnMessageHeader{}).
WithBody(&message.CommitTxnMessageBody{}).
WithVChannel(clusterID + "-rootcoord-dml_0").
MustBuildMutable().
WithTxnContext(message.TxnContext{
TxnID: message.TxnID(1),
Keepalive: message.TxnKeepaliveInfinite,
}).
WithTimeTick(tt).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
IntoImmutableMessage(walimplstest.NewTestMessageID(1)),
}
return immutables
}
func newReplicateTxnMessage(clusterID string, sourceClusterID string, timetick ...uint64) []message.MutableMessage {
immutables := newImmutableTxnMessage(sourceClusterID, timetick...)
replicateMsgs := []message.MutableMessage{}
for _, immutable := range immutables {
replicateMsg := message.NewReplicateMessage(
sourceClusterID,
immutable.IntoImmutableMessageProto(),
)
replicateMsg.OverwriteReplicateVChannel(
clusterID + "-rootcoord-dml_0",
)
replicateMsgs = append(replicateMsgs, replicateMsg)
// test the idempotency
replicateMsgs = append(replicateMsgs, replicateMsg)
}
return replicateMsgs
}

View File

@ -0,0 +1,76 @@
package replicates
import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// newSecondaryState creates a new secondary state.
func newSecondaryState(sourceClusterID string, sourcePChannel string) *secondaryState {
return &secondaryState{
checkpoint: &utility.ReplicateCheckpoint{
ClusterID: sourceClusterID,
PChannel: sourcePChannel,
MessageID: nil,
TimeTick: 0,
},
replicateTxnHelper: newReplicateTxnHelper(),
}
}
// recoverSecondaryState recovers the secondary state from the recover param.
func recoverSecondaryState(param *ReplicateManagerRecoverParam) (*secondaryState, error) {
txnHelper := newReplicateTxnHelper()
sourceClusterID := param.InitialRecoverSnapshot.Checkpoint.ReplicateCheckpoint.ClusterID
// recover the txn helper.
uncommittedTxnBuilders := param.InitialRecoverSnapshot.TxnBuffer.GetUncommittedMessageBuilder()
for _, builder := range uncommittedTxnBuilders {
begin, body := builder.Messages()
replicateHeader := begin.ReplicateHeader()
// filter out the txn builders that are replicated from other cluster or not replicated.
if replicateHeader == nil || replicateHeader.ClusterID != sourceClusterID {
continue
}
// there will be only one uncommitted txn builder.
if err := txnHelper.StartBegin(begin.TxnContext(), begin.ReplicateHeader()); err != nil {
return nil, err
}
txnHelper.BeginDone(begin.TxnContext())
for _, msg := range body {
if err := txnHelper.AddNewMessage(msg.TxnContext(), msg.ReplicateHeader()); err != nil {
return nil, err
}
txnHelper.AddNewMessageDone(msg.ReplicateHeader())
}
}
return &secondaryState{
checkpoint: param.InitialRecoverSnapshot.Checkpoint.ReplicateCheckpoint,
replicateTxnHelper: txnHelper,
}, nil
}
// secondaryState describes the state of the secondary role.
type secondaryState struct {
checkpoint *utility.ReplicateCheckpoint
*replicateTxnHelper // if not nil, the txn replicating operation is in progress.
}
// SourceClusterID returns the source cluster id of the secondary state.
func (s *secondaryState) SourceClusterID() string {
return s.checkpoint.ClusterID
}
// GetCheckpoint returns the checkpoint of the secondary state.
func (s *secondaryState) GetCheckpoint() *utility.ReplicateCheckpoint {
return s.checkpoint
}
// PushForwardCheckpoint pushes forward the checkpoint.
func (s *secondaryState) PushForwardCheckpoint(timetick uint64, lastConfirmedMessageID message.MessageID) error {
if timetick <= s.checkpoint.TimeTick {
return nil
}
s.checkpoint.TimeTick = timetick
s.checkpoint.MessageID = lastConfirmedMessageID
return nil
}

View File

@ -0,0 +1,76 @@
package replicates
import (
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// newReplicateTxnHelper creates a new replicate txn helper.
func newReplicateTxnHelper() *replicateTxnHelper {
return &replicateTxnHelper{
currentTxn: nil,
messageIDs: typeutil.NewSet[string](),
}
}
// replicateTxnHelper is a helper for replicating a txn.
// It is used to handle and deduplicate the txn messages.
type replicateTxnHelper struct {
currentTxn *message.TxnContext
messageIDs typeutil.Set[string]
}
// CurrentTxn returns the current txn context.
func (s *replicateTxnHelper) CurrentTxn() *message.TxnContext {
return s.currentTxn
}
func (s *replicateTxnHelper) StartBegin(txn *message.TxnContext, replicateHeader *message.ReplicateHeader) error {
if s.currentTxn != nil {
if s.currentTxn.TxnID == txn.TxnID {
return status.NewIgnoreOperation("txn message is already in progress, txnID: %d", s.currentTxn.TxnID)
}
return status.NewReplicateViolation("begin txn violation, txnID: %d, incoming: %d", s.currentTxn.TxnID, txn.TxnID)
}
return nil
}
func (s *replicateTxnHelper) BeginDone(txn *message.TxnContext) {
s.currentTxn = txn
s.messageIDs = typeutil.NewSet[string]()
}
func (s *replicateTxnHelper) AddNewMessage(txn *message.TxnContext, replicateHeader *message.ReplicateHeader) error {
if s.currentTxn == nil {
return status.NewReplicateViolation("add new txn message without new txn, incoming: %d", s.currentTxn.TxnID, txn.TxnID)
}
if s.currentTxn.TxnID != txn.TxnID {
return status.NewReplicateViolation("add new txn message with different txn, current: %d, incoming: %d", s.currentTxn.TxnID, txn.TxnID)
}
if s.messageIDs.Contain(replicateHeader.MessageID.Marshal()) {
return status.NewIgnoreOperation("txn message is already in progress, txnID: %d, messageID: %d", s.currentTxn.TxnID, replicateHeader.MessageID)
}
return nil
}
func (s *replicateTxnHelper) AddNewMessageDone(replicateHeader *message.ReplicateHeader) {
s.messageIDs.Insert(replicateHeader.MessageID.Marshal())
}
func (s *replicateTxnHelper) StartCommit(txn *message.TxnContext) error {
if s.currentTxn == nil {
return status.NewIgnoreOperation("commit txn without txn, maybe already committed, txnID: %d", txn.TxnID)
}
if s.currentTxn.TxnID != txn.TxnID {
return status.NewReplicateViolation("commit txn with different txn, current: %d, incoming: %d", s.currentTxn.TxnID, txn.TxnID)
}
s.currentTxn = nil
s.messageIDs = nil
return nil
}
func (s *replicateTxnHelper) CommitDone(txn *message.TxnContext) {
s.currentTxn = nil
s.messageIDs = nil
}

View File

@ -76,7 +76,7 @@ func (w *segmentFlushWorker) do() {
} }
nextInterval := backoff.NextBackOff() nextInterval := backoff.NextBackOff()
w.Logger().Info("failed to allocate new growing segment, retrying", zap.Duration("nextInterval", nextInterval), zap.Error(err)) w.Logger().Info("failed to flush new growing segment, retrying", zap.Duration("nextInterval", nextInterval), zap.Error(err))
select { select {
case <-w.ctx.Done(): case <-w.ctx.Done():
w.Logger().Info("flush segment canceled", zap.Error(w.ctx.Err())) w.Logger().Info("flush segment canceled", zap.Error(w.ctx.Err()))

View File

@ -16,6 +16,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
) )
@ -211,6 +212,32 @@ func TestWithContext(t *testing.T) {
assert.NotNil(t, session) assert.NotNil(t, session)
} }
func TestManagerFromReplcateMessage(t *testing.T) {
resource.InitForTest(t)
manager := NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
immutableMsg := message.NewBeginTxnMessageBuilderV2().
WithVChannel("v1").
WithHeader(&message.BeginTxnMessageHeader{
KeepaliveMilliseconds: 10 * time.Millisecond.Milliseconds(),
}).
WithBody(&message.BeginTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(1).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
WithTxnContext(message.TxnContext{
TxnID: 18,
Keepalive: 10 * time.Millisecond,
}).
IntoImmutableMessage(walimplstest.NewTestMessageID(1))
replicateMsg := message.NewReplicateMessage("test2", immutableMsg.IntoImmutableMessageProto()).WithTimeTick(2)
session, err := manager.BeginNewTxn(context.Background(), message.MustAsMutableBeginTxnMessageV2(replicateMsg))
assert.NoError(t, err)
assert.NotNil(t, session)
assert.Equal(t, message.TxnID(18), session.TxnContext().TxnID)
assert.Equal(t, message.TxnKeepaliveInfinite, session.TxnContext().Keepalive)
}
func newBeginTxnMessage(timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 { func newBeginTxnMessage(timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 {
return newBeginTxnMessageWithVChannel("v1", timetick, keepalive) return newBeginTxnMessageWithVChannel("v1", timetick, keepalive)
} }

View File

@ -79,16 +79,8 @@ func (m *TxnManager) RecoverDone() <-chan struct{} {
func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*TxnSession, error) { func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*TxnSession, error) {
timetick := msg.TimeTick() timetick := msg.TimeTick()
vchannel := msg.VChannel() vchannel := msg.VChannel()
keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond
if keepalive == 0 { txnCtx, err := m.buildTxnContext(ctx, msg)
// If keepalive is 0, the txn set the keepalive with default keepalive.
keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse()
}
if keepalive < 1*time.Millisecond {
return nil, status.NewInvaildArgument("keepalive must be greater than 1ms")
}
id, err := resource.Resource().IDAllocator().Allocate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -100,23 +92,49 @@ func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTx
if m.closed != nil { if m.closed != nil {
return nil, status.NewTransactionExpired("manager closed") return nil, status.NewTransactionExpired("manager closed")
} }
txnCtx := message.TxnContext{ session := newTxnSession(vchannel, *txnCtx, timetick, m.metrics.BeginTxn())
TxnID: message.TxnID(id),
Keepalive: keepalive,
}
session := newTxnSession(vchannel, txnCtx, timetick, m.metrics.BeginTxn())
m.sessions[session.TxnContext().TxnID] = session m.sessions[session.TxnContext().TxnID] = session
return session, nil return session, nil
} }
// buildTxnContext builds the txn context from the message.
func (m *TxnManager) buildTxnContext(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*message.TxnContext, error) {
if msg.ReplicateHeader() != nil {
// reuse the txn context if replicated.
// If the message is replicated, it should never be expired, so we set the keepalive to infinite.
return &message.TxnContext{
TxnID: msg.TxnContext().TxnID,
Keepalive: message.TxnKeepaliveInfinite,
}, nil
}
keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond
if keepalive == 0 {
// If keepalive is 0, the txn set the keepalive with default keepalive.
keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse()
}
if keepalive < 1*time.Millisecond {
return nil, status.NewInvaildArgument("keepalive must be greater than 1ms")
}
id, err := resource.Resource().IDAllocator().Allocate(ctx)
if err != nil {
return nil, err
}
return &message.TxnContext{
TxnID: message.TxnID(id),
Keepalive: keepalive,
}, nil
}
// FailTxnAtVChannel fails all transactions at the specified vchannel. // FailTxnAtVChannel fails all transactions at the specified vchannel.
// If the vchannel is empty, it will fail all transactions.
func (m *TxnManager) FailTxnAtVChannel(vchannel string) { func (m *TxnManager) FailTxnAtVChannel(vchannel string) {
// avoid the txn to be committed. // avoid the txn to be committed.
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
ids := make([]int64, 0, len(m.sessions)) ids := make([]int64, 0, len(m.sessions))
for id, session := range m.sessions { for id, session := range m.sessions {
if session.VChannel() == vchannel { if vchannel == "" || session.VChannel() == vchannel {
session.Cleanup() session.Cleanup()
delete(m.sessions, id) delete(m.sessions, id)
delete(m.recoveredSessions, id) delete(m.recoveredSessions, id)

View File

@ -158,7 +158,7 @@ func (r *recoveryStorageImpl) initializeRecoverInfo(ctx context.Context, channel
checkpoint := &streamingpb.WALCheckpoint{ checkpoint := &streamingpb.WALCheckpoint{
MessageId: untilMessage.LastConfirmedMessageID().IntoProto(), MessageId: untilMessage.LastConfirmedMessageID().IntoProto(),
TimeTick: untilMessage.TimeTick(), TimeTick: untilMessage.TimeTick(),
RecoveryMagic: RecoveryMagicStreamingInitialized, RecoveryMagic: utility.RecoveryMagicStreamingInitialized,
} }
if err := resource.Resource().StreamingNodeCatalog().SaveConsumeCheckpoint(ctx, channelInfo.Name, checkpoint); err != nil { if err := resource.Resource().StreamingNodeCatalog().SaveConsumeCheckpoint(ctx, channelInfo.Name, checkpoint); err != nil {
return nil, errors.Wrap(err, "failed to save checkpoint to catalog") return nil, errors.Wrap(err, "failed to save checkpoint to catalog")

View File

@ -16,6 +16,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore" "github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
internaltypes "github.com/milvus-io/milvus/internal/types" internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb" "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
@ -49,7 +50,7 @@ func TestInitRecoveryInfoFromMeta(t *testing.T) {
&streamingpb.WALCheckpoint{ &streamingpb.WALCheckpoint{
MessageId: rmq.NewRmqID(1).IntoProto(), MessageId: rmq.NewRmqID(1).IntoProto(),
TimeTick: 1, TimeTick: 1,
RecoveryMagic: RecoveryMagicStreamingInitialized, RecoveryMagic: utility.RecoveryMagicStreamingInitialized,
}, nil) }, nil)
resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog)) resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog))
channel := types.PChannelInfo{Name: "test_channel"} channel := types.PChannelInfo{Name: "test_channel"}
@ -60,7 +61,7 @@ func TestInitRecoveryInfoFromMeta(t *testing.T) {
err := rs.recoverRecoveryInfoFromMeta(context.Background(), channel, lastConfirmed.IntoImmutableMessage(rmq.NewRmqID(1))) err := rs.recoverRecoveryInfoFromMeta(context.Background(), channel, lastConfirmed.IntoImmutableMessage(rmq.NewRmqID(1)))
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, rs.checkpoint) assert.NotNil(t, rs.checkpoint)
assert.Equal(t, RecoveryMagicStreamingInitialized, rs.checkpoint.Magic) assert.Equal(t, utility.RecoveryMagicStreamingInitialized, rs.checkpoint.Magic)
assert.True(t, rs.checkpoint.MessageID.EQ(rmq.NewRmqID(1))) assert.True(t, rs.checkpoint.MessageID.EQ(rmq.NewRmqID(1)))
} }

View File

@ -11,12 +11,14 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
) )
@ -69,6 +71,7 @@ func newRecoveryStorage(channel types.PChannelInfo) *recoveryStorageImpl {
backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
cfg: cfg, cfg: cfg,
mu: sync.Mutex{}, mu: sync.Mutex{},
currentClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(),
channel: channel, channel: channel,
dirtyCounter: 0, dirtyCounter: 0,
persistNotifier: make(chan struct{}, 1), persistNotifier: make(chan struct{}, 1),
@ -84,6 +87,7 @@ type recoveryStorageImpl struct {
backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}]
cfg *config cfg *config
mu sync.Mutex mu sync.Mutex
currentClusterID string
channel types.PChannelInfo channel types.PChannelInfo
segments map[int64]*segmentRecoveryInfo segments map[int64]*segmentRecoveryInfo
vchannels map[string]*vchannelRecoveryInfo vchannels map[string]*vchannelRecoveryInfo
@ -225,8 +229,7 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) {
} }
r.handleMessage(msg) r.handleMessage(msg)
r.checkpoint.TimeTick = msg.TimeTick() r.updateCheckpoint(msg)
r.checkpoint.MessageID = msg.LastConfirmedMessageID()
r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick) r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick)
if !msg.IsPersisted() { if !msg.IsPersisted() {
@ -239,6 +242,52 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) {
} }
} }
// updateCheckpoint updates the checkpoint of the recovery storage.
func (r *recoveryStorageImpl) updateCheckpoint(msg message.ImmutableMessage) {
if msg.MessageType() == message.MessageTypeAlterReplicateConfig {
cfg := message.MustAsImmutableAlterReplicateConfigMessageV2(msg)
r.checkpoint.ReplicateConfig = cfg.Header().ReplicateConfiguration
clusterRole := replicateutil.MustNewConfigHelper(r.currentClusterID, cfg.Header().ReplicateConfiguration).GetCurrentCluster()
switch clusterRole.Role() {
case replicateutil.RolePrimary:
r.checkpoint.ReplicateCheckpoint = nil
case replicateutil.RoleSecondary:
// Update the replicate checkpoint if the cluster role is secondary.
sourceClusterID := clusterRole.SourceCluster().GetClusterId()
sourcePChannel := clusterRole.MustGetSourceChannel(r.channel.Name)
if r.checkpoint.ReplicateCheckpoint == nil || r.checkpoint.ReplicateCheckpoint.ClusterID != sourceClusterID {
r.checkpoint.ReplicateCheckpoint = &utility.ReplicateCheckpoint{
ClusterID: sourceClusterID,
PChannel: sourcePChannel,
MessageID: nil,
TimeTick: 0,
}
}
}
}
r.checkpoint.MessageID = msg.LastConfirmedMessageID()
r.checkpoint.TimeTick = msg.TimeTick()
// update the replicate checkpoint.
replicateHeader := msg.ReplicateHeader()
if replicateHeader == nil {
return
}
if r.checkpoint.ReplicateCheckpoint == nil {
r.detectInconsistency(msg, "replicate checkpoint is nil when incoming replicate message")
return
}
if replicateHeader.ClusterID != r.checkpoint.ReplicateCheckpoint.ClusterID {
r.detectInconsistency(msg,
"replicate header cluster id mismatch",
zap.String("expected", r.checkpoint.ReplicateCheckpoint.ClusterID),
zap.String("actual", replicateHeader.ClusterID))
return
}
r.checkpoint.ReplicateCheckpoint.MessageID = replicateHeader.LastConfirmedMessageID
r.checkpoint.ReplicateCheckpoint.TimeTick = replicateHeader.TimeTick
}
// The incoming message id is always sorted with timetick. // The incoming message id is always sorted with timetick.
func (r *recoveryStorageImpl) handleMessage(msg message.ImmutableMessage) { func (r *recoveryStorageImpl) handleMessage(msg message.ImmutableMessage) {
if msg.VChannel() != "" && msg.MessageType() != message.MessageTypeCreateCollection && if msg.VChannel() != "" && msg.MessageType() != message.MessageTypeCreateCollection &&

View File

@ -0,0 +1,105 @@
package recovery
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
)
func TestUpdateCheckpoint(t *testing.T) {
rs := &recoveryStorageImpl{
currentClusterID: "test1",
channel: types.PChannelInfo{Name: "test1-rootcoord-dml_0"},
checkpoint: &WALCheckpoint{},
metrics: newRecoveryStorageMetrics(types.PChannelInfo{Name: "test1-rootcoord-dml_0"}),
}
rs.updateCheckpoint(newAlterReplicateConfigMessage("test1", []string{"test2"}, 1, walimplstest.NewTestMessageID(1)))
assert.Nil(t, rs.checkpoint.ReplicateCheckpoint)
assert.Equal(t, rs.checkpoint.MessageID, walimplstest.NewTestMessageID(1))
assert.Equal(t, rs.checkpoint.TimeTick, uint64(1))
rs.updateCheckpoint(newAlterReplicateConfigMessage("test2", []string{"test1"}, 1, walimplstest.NewTestMessageID(1)))
assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint)
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test2")
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test2-rootcoord-dml_0")
assert.Nil(t, rs.checkpoint.ReplicateCheckpoint.MessageID)
assert.Zero(t, rs.checkpoint.ReplicateCheckpoint.TimeTick)
replicateMsg := message.NewReplicateMessage("test3", message.NewCreateDatabaseMessageBuilderV2().
WithHeader(&message.CreateDatabaseMessageHeader{}).
WithBody(&message.CreateDatabaseMessageBody{}).
WithVChannel("test3-rootcoord-dml_0").
MustBuildMutable().
WithTimeTick(3).
WithLastConfirmed(walimplstest.NewTestMessageID(10)).
IntoImmutableMessage(walimplstest.NewTestMessageID(20)).IntoImmutableMessageProto())
replicateMsg.OverwriteReplicateVChannel("test1-rootcoord-dml_0")
immutableReplicateMsg := replicateMsg.WithTimeTick(4).
WithLastConfirmed(walimplstest.NewTestMessageID(11)).
IntoImmutableMessage(walimplstest.NewTestMessageID(22))
rs.updateCheckpoint(immutableReplicateMsg)
// update with wrong clusterID.
rs.updateCheckpoint(immutableReplicateMsg)
assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint)
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test2")
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test2-rootcoord-dml_0")
assert.Nil(t, rs.checkpoint.ReplicateCheckpoint.MessageID)
assert.Zero(t, rs.checkpoint.ReplicateCheckpoint.TimeTick)
rs.updateCheckpoint(newAlterReplicateConfigMessage("test3", []string{"test2", "test1"}, 1, walimplstest.NewTestMessageID(1)))
assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint)
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test3")
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test3-rootcoord-dml_0")
assert.Nil(t, rs.checkpoint.ReplicateCheckpoint.MessageID)
assert.Zero(t, rs.checkpoint.ReplicateCheckpoint.TimeTick)
// update with right clusterID.
rs.updateCheckpoint(immutableReplicateMsg)
assert.NotNil(t, rs.checkpoint.ReplicateCheckpoint)
assert.Equal(t, rs.checkpoint.MessageID, walimplstest.NewTestMessageID(11))
assert.Equal(t, rs.checkpoint.TimeTick, uint64(4))
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.ClusterID, "test3")
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.PChannel, "test3-rootcoord-dml_0")
assert.True(t, rs.checkpoint.ReplicateCheckpoint.MessageID.EQ(walimplstest.NewTestMessageID(10)))
assert.Equal(t, rs.checkpoint.ReplicateCheckpoint.TimeTick, uint64(3))
rs.updateCheckpoint(newAlterReplicateConfigMessage("test1", []string{"test2"}, 1, walimplstest.NewTestMessageID(1)))
assert.Nil(t, rs.checkpoint.ReplicateCheckpoint)
rs.updateCheckpoint(immutableReplicateMsg)
}
// newAlterReplicateConfigMessage creates a new alter replicate config message.
func newAlterReplicateConfigMessage(primaryClusterID string, secondaryClusterID []string, timetick uint64, messageID message.MessageID) message.ImmutableMessage {
return message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{
ReplicateConfiguration: newReplicateConfiguration(primaryClusterID, secondaryClusterID...),
}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithVChannel("test1-rootcoord-dml_0").
MustBuildMutable().
WithTimeTick(timetick).
WithLastConfirmed(messageID).
IntoImmutableMessage(walimplstest.NewTestMessageID(10086))
}
// newReplicateConfiguration creates a valid replicate configuration for testing
func newReplicateConfiguration(primaryClusterID string, secondaryClusterID ...string) *commonpb.ReplicateConfiguration {
clusters := []*commonpb.MilvusCluster{
{ClusterId: primaryClusterID, Pchannels: []string{primaryClusterID + "-rootcoord-dml_0", primaryClusterID + "-rootcoord-dml_1"}},
}
crossClusterTopology := []*commonpb.CrossClusterTopology{}
for _, secondaryClusterID := range secondaryClusterID {
clusters = append(clusters, &commonpb.MilvusCluster{ClusterId: secondaryClusterID, Pchannels: []string{secondaryClusterID + "-rootcoord-dml_0", secondaryClusterID + "-rootcoord-dml_1"}})
crossClusterTopology = append(crossClusterTopology, &commonpb.CrossClusterTopology{SourceClusterId: primaryClusterID, TargetClusterId: secondaryClusterID})
}
return &commonpb.ReplicateConfiguration{
Clusters: clusters,
CrossClusterTopology: crossClusterTopology,
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/mock_walimpls" "github.com/milvus-io/milvus/pkg/v2/mocks/streaming/mock_walimpls"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
@ -23,7 +24,7 @@ func TestTruncator(t *testing.T) {
truncator := newSamplingTruncator(&WALCheckpoint{ truncator := newSamplingTruncator(&WALCheckpoint{
MessageID: rmq.NewRmqID(1), MessageID: rmq.NewRmqID(1),
TimeTick: 1, TimeTick: 1,
Magic: RecoveryMagicStreamingInitialized, Magic: utility.RecoveryMagicStreamingInitialized,
}, w, newRecoveryStorageMetrics(types.PChannelInfo{Name: "test", Term: 1})) }, w, newRecoveryStorageMetrics(types.PChannelInfo{Name: "test", Term: 1}))
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@ -32,7 +33,7 @@ func TestTruncator(t *testing.T) {
truncator.SampleCheckpoint(&WALCheckpoint{ truncator.SampleCheckpoint(&WALCheckpoint{
MessageID: rmq.NewRmqID(int64(i)), MessageID: rmq.NewRmqID(int64(i)),
TimeTick: tsoutil.ComposeTSByTime(time.Now(), 0), TimeTick: tsoutil.ComposeTSByTime(time.Now(), 0),
Magic: RecoveryMagicStreamingInitialized, Magic: utility.RecoveryMagicStreamingInitialized,
}) })
} }
} }

View File

@ -13,30 +13,21 @@ const (
// NewWALCheckpointFromProto creates a new WALCheckpoint from a protobuf message. // NewWALCheckpointFromProto creates a new WALCheckpoint from a protobuf message.
func NewWALCheckpointFromProto(cp *streamingpb.WALCheckpoint) *WALCheckpoint { func NewWALCheckpointFromProto(cp *streamingpb.WALCheckpoint) *WALCheckpoint {
wcp := &WALCheckpoint{ if cp == nil {
return nil
}
return &WALCheckpoint{
MessageID: message.MustUnmarshalMessageID(cp.MessageId), MessageID: message.MustUnmarshalMessageID(cp.MessageId),
TimeTick: cp.TimeTick, TimeTick: cp.TimeTick,
Magic: cp.RecoveryMagic, Magic: cp.RecoveryMagic,
ReplicateConfig: cp.ReplicateConfig, ReplicateConfig: cp.ReplicateConfig,
ReplicateCheckpoint: NewReplicateCheckpointFromProto(cp.ReplicateCheckpoint),
} }
if cp.ReplicateCheckpoint != nil {
var messageID message.MessageID
if cp.ReplicateCheckpoint.MessageId != nil {
messageID = message.MustUnmarshalMessageID(cp.ReplicateCheckpoint.MessageId)
}
wcp.ReplicateCheckpoint = &ReplicateCheckpoint{
ClusterID: cp.ReplicateCheckpoint.ClusterId,
PChannel: cp.ReplicateCheckpoint.Pchannel,
MessageID: messageID,
TimeTick: cp.ReplicateCheckpoint.TimeTick,
}
}
return wcp
} }
// WALCheckpoint represents a consume checkpoint in the Write-Ahead Log (WAL). // WALCheckpoint represents a consume checkpoint in the Write-Ahead Log (WAL).
type WALCheckpoint struct { type WALCheckpoint struct {
MessageID message.MessageID MessageID message.MessageID // should always be not nil.
TimeTick uint64 TimeTick uint64
Magic int64 Magic int64
ReplicateCheckpoint *ReplicateCheckpoint ReplicateCheckpoint *ReplicateCheckpoint
@ -45,15 +36,16 @@ type WALCheckpoint struct {
// IntoProto converts the WALCheckpoint to a protobuf message. // IntoProto converts the WALCheckpoint to a protobuf message.
func (c *WALCheckpoint) IntoProto() *streamingpb.WALCheckpoint { func (c *WALCheckpoint) IntoProto() *streamingpb.WALCheckpoint {
cp := &streamingpb.WALCheckpoint{ if c == nil {
MessageId: c.MessageID.IntoProto(), return nil
}
return &streamingpb.WALCheckpoint{
MessageId: message.MustMarshalMessageID(c.MessageID),
TimeTick: c.TimeTick, TimeTick: c.TimeTick,
RecoveryMagic: c.Magic, RecoveryMagic: c.Magic,
ReplicateConfig: c.ReplicateConfig,
ReplicateCheckpoint: c.ReplicateCheckpoint.IntoProto(),
} }
if c.ReplicateCheckpoint != nil {
cp.ReplicateCheckpoint = c.ReplicateCheckpoint.IntoProto()
}
return cp
} }
// Clone creates a new WALCheckpoint with the same values as the original. // Clone creates a new WALCheckpoint with the same values as the original.
@ -62,16 +54,20 @@ func (c *WALCheckpoint) Clone() *WALCheckpoint {
MessageID: c.MessageID, MessageID: c.MessageID,
TimeTick: c.TimeTick, TimeTick: c.TimeTick,
Magic: c.Magic, Magic: c.Magic,
ReplicateConfig: c.ReplicateConfig,
ReplicateCheckpoint: c.ReplicateCheckpoint.Clone(), ReplicateCheckpoint: c.ReplicateCheckpoint.Clone(),
} }
} }
// NewReplicateCheckpointFromProto creates a new ReplicateCheckpoint from a protobuf message. // NewReplicateCheckpointFromProto creates a new ReplicateCheckpoint from a protobuf message.
func NewReplicateCheckpointFromProto(cp *commonpb.ReplicateCheckpoint) *ReplicateCheckpoint { func NewReplicateCheckpointFromProto(cp *commonpb.ReplicateCheckpoint) *ReplicateCheckpoint {
if cp == nil {
return nil
}
return &ReplicateCheckpoint{ return &ReplicateCheckpoint{
MessageID: message.MustUnmarshalMessageID(cp.MessageId),
ClusterID: cp.ClusterId, ClusterID: cp.ClusterId,
PChannel: cp.Pchannel, PChannel: cp.Pchannel,
MessageID: message.MustUnmarshalMessageID(cp.MessageId),
TimeTick: cp.TimeTick, TimeTick: cp.TimeTick,
} }
} }
@ -81,7 +77,7 @@ func NewReplicateCheckpointFromProto(cp *commonpb.ReplicateCheckpoint) *Replicat
type ReplicateCheckpoint struct { type ReplicateCheckpoint struct {
ClusterID string // the cluster id of the source cluster. ClusterID string // the cluster id of the source cluster.
PChannel string // the pchannel of the source cluster. PChannel string // the pchannel of the source cluster.
MessageID message.MessageID // the last confirmed message id of the last replicated message. MessageID message.MessageID // the last confirmed message id of the last replicated message, may be nil when initializing.
TimeTick uint64 // the time tick of the last replicated message. TimeTick uint64 // the time tick of the last replicated message.
} }
@ -90,14 +86,10 @@ func (c *ReplicateCheckpoint) IntoProto() *commonpb.ReplicateCheckpoint {
if c == nil { if c == nil {
return nil return nil
} }
var messageID *commonpb.MessageID
if c.MessageID != nil {
messageID = c.MessageID.IntoProto()
}
return &commonpb.ReplicateCheckpoint{ return &commonpb.ReplicateCheckpoint{
ClusterId: c.ClusterID, ClusterId: c.ClusterID,
Pchannel: c.PChannel, Pchannel: c.PChannel,
MessageId: messageID, MessageId: message.MustMarshalMessageID(c.MessageID),
TimeTick: c.TimeTick, TimeTick: c.TimeTick,
} }
} }

View File

@ -11,6 +11,9 @@ import (
) )
func TestNewWALCheckpointFromProto(t *testing.T) { func TestNewWALCheckpointFromProto(t *testing.T) {
assert.Nil(t, NewWALCheckpointFromProto(nil))
assert.Nil(t, NewWALCheckpointFromProto(nil).IntoProto())
messageID := rmq.NewRmqID(1) messageID := rmq.NewRmqID(1)
timeTick := uint64(12345) timeTick := uint64(12345)
recoveryMagic := int64(1) recoveryMagic := int64(1)
@ -59,4 +62,25 @@ func TestNewWALCheckpointFromProto(t *testing.T) {
assert.Equal(t, uint64(123456), newCheckpoint.ReplicateCheckpoint.TimeTick) assert.Equal(t, uint64(123456), newCheckpoint.ReplicateCheckpoint.TimeTick)
assert.True(t, rmq.NewRmqID(2).EQ(newCheckpoint.ReplicateCheckpoint.MessageID)) assert.True(t, rmq.NewRmqID(2).EQ(newCheckpoint.ReplicateCheckpoint.MessageID))
assert.NotNil(t, newCheckpoint.ReplicateConfig) assert.NotNil(t, newCheckpoint.ReplicateConfig)
proto = newCheckpoint.IntoProto()
checkpoint2 = NewWALCheckpointFromProto(proto)
assert.True(t, messageID.EQ(checkpoint2.MessageID))
assert.Equal(t, timeTick, checkpoint2.TimeTick)
assert.Equal(t, recoveryMagic, checkpoint2.Magic)
assert.Equal(t, "by-dev", checkpoint2.ReplicateCheckpoint.ClusterID)
assert.Equal(t, "p1", checkpoint2.ReplicateCheckpoint.PChannel)
assert.Equal(t, uint64(123456), checkpoint2.ReplicateCheckpoint.TimeTick)
assert.True(t, rmq.NewRmqID(2).EQ(checkpoint2.ReplicateCheckpoint.MessageID))
assert.NotNil(t, checkpoint2.ReplicateConfig)
checkpoint2 = newCheckpoint.Clone()
assert.True(t, messageID.EQ(checkpoint2.MessageID))
assert.Equal(t, timeTick, checkpoint2.TimeTick)
assert.Equal(t, recoveryMagic, checkpoint2.Magic)
assert.Equal(t, "by-dev", checkpoint2.ReplicateCheckpoint.ClusterID)
assert.Equal(t, "p1", checkpoint2.ReplicateCheckpoint.PChannel)
assert.Equal(t, uint64(123456), checkpoint2.ReplicateCheckpoint.TimeTick)
assert.True(t, rmq.NewRmqID(2).EQ(checkpoint2.ReplicateCheckpoint.MessageID))
assert.NotNil(t, checkpoint2.ReplicateConfig)
} }

View File

@ -21,6 +21,13 @@ type WAL interface {
// GetLatestMVCCTimestamp get the latest mvcc timestamp of the wal at vchannel. // GetLatestMVCCTimestamp get the latest mvcc timestamp of the wal at vchannel.
GetLatestMVCCTimestamp(ctx context.Context, vchannel string) (uint64, error) GetLatestMVCCTimestamp(ctx context.Context, vchannel string) (uint64, error)
// GetReplicateCheckpoint returns the replicate checkpoint of the wal.
// If the wal is not on replicating mode, it will return ReplicateViolationError.
// If the wal is on replicating mode, it will return the replicate checkpoint of the wal.
// If the wal is initialized into replica mode, not replicate any message,
// the message id of the replicate checkpoint will be 0.
GetReplicateCheckpoint() (*ReplicateCheckpoint, error)
// Append writes a record to the log. // Append writes a record to the log.
Append(ctx context.Context, msg message.MutableMessage) (*AppendResult, error) Append(ctx context.Context, msg message.MutableMessage) (*AppendResult, error)

View File

@ -9,6 +9,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/lock" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/lock"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/redo" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/redo"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/replicate"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/shard"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry"
@ -28,6 +29,7 @@ func OpenManager() (Manager, error) {
opener, err := registry.MustGetBuilder(walName, opener, err := registry.MustGetBuilder(walName,
redo.NewInterceptorBuilder(), redo.NewInterceptorBuilder(),
lock.NewInterceptorBuilder(), lock.NewInterceptorBuilder(),
replicate.NewInterceptorBuilder(),
timetick.NewInterceptorBuilder(), timetick.NewInterceptorBuilder(),
shard.NewInterceptorBuilder(), shard.NewInterceptorBuilder(),
).Build() ).Build()

View File

@ -56,10 +56,15 @@ func (e *StreamingError) IsSkippedOperation() bool {
// Stop resuming retry and report to user. // Stop resuming retry and report to user.
func (e *StreamingError) IsUnrecoverable() bool { func (e *StreamingError) IsUnrecoverable() bool {
return e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNRECOVERABLE || return e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNRECOVERABLE ||
e.Code == streamingpb.StreamingCode_STREAMING_CODE_REPLICATE_VIOLATION || e.IsReplicateViolation() ||
e.IsTxnUnavilable() e.IsTxnUnavilable()
} }
// IsReplicateViolation returns true if the error is caused by replicate violation.
func (e *StreamingError) IsReplicateViolation() bool {
return e.Code == streamingpb.StreamingCode_STREAMING_CODE_REPLICATE_VIOLATION
}
// IsTxnUnavilable returns true if the transaction is unavailable. // IsTxnUnavilable returns true if the transaction is unavailable.
func (e *StreamingError) IsTxnUnavilable() bool { func (e *StreamingError) IsTxnUnavilable() bool {
return e.Code == streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED || return e.Code == streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED ||

View File

@ -74,7 +74,7 @@ var CDCReplicateEndToEndLatency = prometheus.NewHistogramVec(
Namespace: milvusNamespace, Namespace: milvusNamespace,
Subsystem: typeutil.CDCRole, Subsystem: typeutil.CDCRole,
Name: CDCMetricReplicateEndToEndLatency, Name: CDCMetricReplicateEndToEndLatency,
Help: "End-to-end latency from a single message being read from Source WAL to being written to Target WAL and receiving an ack", Help: "End-to-end latency in milliseconds from a single message being read from Source WAL to being written to Target WAL and receiving an ack",
Buckets: buckets, Buckets: buckets,
}, []string{ }, []string{
CDCLabelSourceChannelName, CDCLabelSourceChannelName,
@ -82,13 +82,12 @@ var CDCReplicateEndToEndLatency = prometheus.NewHistogramVec(
}, },
) )
// TODO: sheep
var CDCReplicateLag = prometheus.NewGaugeVec( var CDCReplicateLag = prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: milvusNamespace, Namespace: milvusNamespace,
Subsystem: typeutil.CDCRole, Subsystem: typeutil.CDCRole,
Name: CDCMetricReplicateLag, Name: CDCMetricReplicateLag,
Help: "Lag between the latest message in Source and the latest message in Target", Help: "Lag in milliseconds between the latest synced Source message and the current time",
}, []string{ }, []string{
CDCLabelSourceChannelName, CDCLabelSourceChannelName,
CDCLabelTargetChannelName, CDCLabelTargetChannelName,

View File

@ -78,7 +78,7 @@ type MutableMessage interface {
// VChannel returns the virtual channel of current message. // VChannel returns the virtual channel of current message.
// Available only when the message's version greater than 0. // Available only when the message's version greater than 0.
// Return "" if message is can be seen by all vchannels on the pchannel. // Return "" or Pchannel if message is can be seen by all vchannels on the pchannel.
VChannel() string VChannel() string
// WithBarrierTimeTick sets the barrier time tick of current message. // WithBarrierTimeTick sets the barrier time tick of current message.

View File

@ -27,8 +27,19 @@ func RegisterMessageIDUnmsarshaler(walName WALName, unmarshaler MessageIDUnmarsh
// MessageIDUnmarshaler is the unmarshaler for message id. // MessageIDUnmarshaler is the unmarshaler for message id.
type MessageIDUnmarshaler = func(b string) (MessageID, error) type MessageIDUnmarshaler = func(b string) (MessageID, error)
// MustMarshalMessageID marshal the message id, panic if failed.
func MustMarshalMessageID(msgID MessageID) *commonpb.MessageID {
if msgID == nil {
return nil
}
return msgID.IntoProto()
}
// MustUnmarshalMessageID unmarshal the message id, panic if failed. // MustUnmarshalMessageID unmarshal the message id, panic if failed.
func MustUnmarshalMessageID(msgID *commonpb.MessageID) MessageID { func MustUnmarshalMessageID(msgID *commonpb.MessageID) MessageID {
if msgID == nil {
return nil
}
id, err := UnmarshalMessageID(msgID) id, err := UnmarshalMessageID(msgID)
if err != nil { if err != nil {
panic(fmt.Sprintf("unmarshal message id failed: %s, wal: %s, bytes: %s", err.Error(), msgID.WALName.String(), msgID.Id)) panic(fmt.Sprintf("unmarshal message id failed: %s, wal: %s, bytes: %s", err.Error(), msgID.WALName.String(), msgID.Id))

View File

@ -177,6 +177,13 @@ func (m *messageImpl) OverwriteReplicateVChannel(vchannel string, broadcastVChan
panic("should not happen on broadcast header proto") panic("should not happen on broadcast header proto")
} }
m.properties.Set(messageBroadcastHeader, bhVal) m.properties.Set(messageBroadcastHeader, bhVal)
// overwrite the txn keepalive to infinite if it's a replicated message,
// because replicated message is already committed, so it should never be expired.
if txnCtx := m.TxnContext(); txnCtx != nil {
txnCtx.Keepalive = TxnKeepaliveInfinite
m.WithTxnContext(*txnCtx)
}
} }
// OverwriteBroadcastHeader overwrites the broadcast header of the message. // OverwriteBroadcastHeader overwrites the broadcast header of the message.