From cbb79dae8c81fcd8353af3c201ee502eabd9c428 Mon Sep 17 00:00:00 2001 From: Xiaofan <83447078+xiaofan-luan@users.noreply.github.com> Date: Thu, 27 Nov 2025 07:45:08 +0800 Subject: [PATCH] enhance: [2.6] add robust handle etcd servercrash (#45633) issue: #45303 pr: #45304 fix milvus pod may restart when etcd pod start --------- Signed-off-by: xiaofanluan --- .gitignore | 3 + Makefile | 1 + cmd/milvus/mck.go | 5 +- cmd/tools/migration/backend/etcd.go | 3 +- cmd/tools/migration/migration/runner.go | 7 +- cmd/tools/migration/mmap/tool/main.go | 6 +- configs/milvus.yaml | 4 +- .../allocator/mock_global_id_allocator.go | 18 +- internal/datacoord/server.go | 25 +- internal/datacoord/server_test.go | 11 +- internal/distributed/connection_manager.go | 4 +- internal/distributed/datanode/service.go | 3 +- internal/distributed/mixcoord/service.go | 3 +- internal/distributed/proxy/service.go | 3 +- internal/distributed/querynode/service.go | 3 +- internal/kv/etcd/metakv_factory.go | 3 +- internal/mocks/mock_grpc_client.go | 46 ++- internal/querycoordv2/server.go | 19 +- internal/tso/mocks/allocator.go | 32 +- .../util/dependency/kv/kv_client_handler.go | 4 +- internal/util/dependency/mock_factory.go | 20 +- .../proxyutil/mock_proxy_client_manager.go | 52 +++- internal/util/proxyutil/mock_proxy_watcher.go | 14 +- internal/util/sessionutil/mock_session.go | 126 ++++++-- .../util/sessionutil/mock_session_watcher.go | 111 +++++++ internal/util/sessionutil/session.go | 9 +- internal/util/sessionutil/session_util.go | 287 ++++++++++-------- .../util/sessionutil/session_util_test.go | 55 +++- pkg/streaming/walimpls/impls/wp/builder.go | 3 +- pkg/util/etcd/etcd_util.go | 57 +++- pkg/util/paramtable/component_param.go | 2 +- pkg/util/paramtable/service_param.go | 86 ++++-- 32 files changed, 744 insertions(+), 281 deletions(-) create mode 100644 internal/util/sessionutil/mock_session_watcher.go diff --git a/.gitignore b/.gitignore index 7838d8d279..7a59f15b32 100644 --- a/.gitignore +++ b/.gitignore @@ -117,3 +117,6 @@ WARP.md # Antlr .antlr + +# Gocache +**/.gocache/ diff --git a/Makefile b/Makefile index 4d144696aa..ae13823f1c 100644 --- a/Makefile +++ b/Makefile @@ -518,6 +518,7 @@ generate-mockery-utils: getdeps # tso.Allocator $(INSTALL_PATH)/mockery --name=Allocator --dir=internal/tso --output=internal/tso/mocks --filename=allocator.go --with-expecter --structname=Allocator --outpkg=mocktso $(INSTALL_PATH)/mockery --name=SessionInterface --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session.go --with-expecter --structname=MockSession --inpackage + $(INSTALL_PATH)/mockery --name=SessionWatcher --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session_watcher.go --with-expecter --structname=MockSessionWatcher --inpackage $(INSTALL_PATH)/mockery --name=GrpcClient --dir=$(PWD)/internal/util/grpcclient --output=$(PWD)/internal/mocks --filename=mock_grpc_client.go --with-expecter --structname=MockGrpcClient # proxy_client_manager.go $(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage diff --git a/cmd/milvus/mck.go b/cmd/milvus/mck.go index 87416f5169..a8af1e079d 100644 --- a/cmd/milvus/mck.go +++ b/cmd/milvus/mck.go @@ -216,7 +216,7 @@ func (c *mck) connectEctd() { var err error log := log.Ctx(context.TODO()) if c.etcdIP != "" { - etcdCli, err = etcd.GetRemoteEtcdClient([]string{c.etcdIP}) + etcdCli, err = etcd.GetRemoteEtcdClient([]string{c.etcdIP}, c.params.EtcdCfg.ClientOptions()...) } else { etcdCli, err = etcd.CreateEtcdClient( c.params.EtcdCfg.UseEmbedEtcd.GetAsBool(), @@ -228,7 +228,8 @@ func (c *mck) connectEctd() { c.params.EtcdCfg.EtcdTLSCert.GetValue(), c.params.EtcdCfg.EtcdTLSKey.GetValue(), c.params.EtcdCfg.EtcdTLSCACert.GetValue(), - c.params.EtcdCfg.EtcdTLSMinVersion.GetValue()) + c.params.EtcdCfg.EtcdTLSMinVersion.GetValue(), + c.params.EtcdCfg.ClientOptions()...) } if err != nil { log.Fatal("failed to connect to etcd", zap.Error(err)) diff --git a/cmd/tools/migration/backend/etcd.go b/cmd/tools/migration/backend/etcd.go index a370a02c31..3764978e7c 100644 --- a/cmd/tools/migration/backend/etcd.go +++ b/cmd/tools/migration/backend/etcd.go @@ -32,7 +32,8 @@ func newEtcdBasedBackend(cfg *configs.MilvusConfig) (*etcdBasedBackend, error) { cfg.EtcdCfg.EtcdTLSCert.GetValue(), cfg.EtcdCfg.EtcdTLSKey.GetValue(), cfg.EtcdCfg.EtcdTLSCACert.GetValue(), - cfg.EtcdCfg.EtcdTLSMinVersion.GetValue()) + cfg.EtcdCfg.EtcdTLSMinVersion.GetValue(), + cfg.EtcdCfg.ClientOptions()...) if err != nil { return nil, err } diff --git a/cmd/tools/migration/migration/runner.go b/cmd/tools/migration/migration/runner.go index 555799f17a..89defd88de 100644 --- a/cmd/tools/migration/migration/runner.go +++ b/cmd/tools/migration/migration/runner.go @@ -49,12 +49,12 @@ func (r *Runner) watchByPrefix(prefix string) { _, revision, err := r.session.GetSessions(prefix) fn := func() { r.Stop() } console.AbnormalExitIf(err, r.backupFinished.Load(), console.AddCallbacks(fn)) - eventCh := r.session.WatchServices(prefix, revision, nil) + watcher := r.session.WatchServices(prefix, revision, nil) for { select { case <-r.ctx.Done(): return - case event := <-eventCh: + case event := <-watcher.EventChannel(): msg := fmt.Sprintf("session up/down, exit migration, event type: %s, session: %s", event.EventType.String(), event.Session.String()) console.AbnormalExit(r.backupFinished.Load(), msg, console.AddCallbacks(fn)) } @@ -79,7 +79,8 @@ func (r *Runner) initEtcdCli() { r.cfg.EtcdCfg.EtcdTLSCert.GetValue(), r.cfg.EtcdCfg.EtcdTLSKey.GetValue(), r.cfg.EtcdCfg.EtcdTLSCACert.GetValue(), - r.cfg.EtcdCfg.EtcdTLSMinVersion.GetValue()) + r.cfg.EtcdCfg.EtcdTLSMinVersion.GetValue(), + r.cfg.EtcdCfg.ClientOptions()...) console.AbnormalExitIf(err, r.backupFinished.Load()) r.etcdCli = cli } diff --git a/cmd/tools/migration/mmap/tool/main.go b/cmd/tools/migration/mmap/tool/main.go index edf04df0df..180bc5fa69 100644 --- a/cmd/tools/migration/mmap/tool/main.go +++ b/cmd/tools/migration/mmap/tool/main.go @@ -75,7 +75,8 @@ func prepareTsoAllocator() tso.Allocator { etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { panic(err) } @@ -109,7 +110,8 @@ func metaKVCreator() (kv.MetaKv, error) { etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { panic(err) } diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 004515ebfb..512dca90ed 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -53,6 +53,8 @@ etcd: # We recommend using version 1.2 and above. tlsMinVersion: 1.3 requestTimeout: 10000 # Etcd operation timeout in milliseconds + dialKeepAliveTime: 3000 # Interval in milliseconds for gRPC dial keepalive pings sent to etcd endpoints. + dialKeepAliveTimeout: 2000 # Timeout in milliseconds waiting for keepalive responses before marking the connection as unhealthy. use: embed: false # Whether to enable embedded Etcd (an in-process EtcdServer). data: @@ -988,7 +990,7 @@ common: internaltlsEnabled: false tlsMode: 0 session: - ttl: 30 # ttl value when session granting a lease to register service + ttl: 15 # ttl value when session granting a lease to register service retryTimes: 30 # retry times when session sending etcd requests locks: metrics: diff --git a/internal/allocator/mock_global_id_allocator.go b/internal/allocator/mock_global_id_allocator.go index b1f0b4d246..aedaee9710 100644 --- a/internal/allocator/mock_global_id_allocator.go +++ b/internal/allocator/mock_global_id_allocator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package allocator @@ -21,6 +21,10 @@ func (_m *MockGlobalIDAllocator) EXPECT() *MockGlobalIDAllocator_Expecter { func (_m *MockGlobalIDAllocator) Alloc(count uint32) (int64, int64, error) { ret := _m.Called(count) + if len(ret) == 0 { + panic("no return value specified for Alloc") + } + var r0 int64 var r1 int64 var r2 error @@ -76,10 +80,14 @@ func (_c *MockGlobalIDAllocator_Alloc_Call) RunAndReturn(run func(uint32) (int64 return _c } -// AllocOne provides a mock function with given fields: +// AllocOne provides a mock function with no fields func (_m *MockGlobalIDAllocator) AllocOne() (int64, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for AllocOne") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func() (int64, error)); ok { @@ -127,10 +135,14 @@ func (_c *MockGlobalIDAllocator_AllocOne_Call) RunAndReturn(run func() (int64, e return _c } -// Initialize provides a mock function with given fields: +// Initialize provides a mock function with no fields func (_m *MockGlobalIDAllocator) Initialize() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Initialize") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 04753bab67..e52c1f5673 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -141,11 +141,10 @@ type Server struct { notifyIndexChan chan UniqueID factory dependency.Factory - session sessionutil.SessionInterface - icSession sessionutil.SessionInterface - dnEventCh <-chan *sessionutil.SessionEvent - // qcEventCh <-chan *sessionutil.SessionEvent - qnEventCh <-chan *sessionutil.SessionEvent + session sessionutil.SessionInterface + icSession sessionutil.SessionInterface + dnSessionWatcher sessionutil.SessionWatcher + qnSessionWatcher sessionutil.SessionWatcher enableActiveStandBy bool activateFunc func() error @@ -529,7 +528,7 @@ func (s *Server) initServiceDiscovery() error { } log.Info("DataCoord Cluster Manager start up successfully") - s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes) + s.dnSessionWatcher = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes) } s.indexEngineVersionManager = newIndexEngineVersionManager() @@ -539,7 +538,7 @@ func (s *Server) initServiceDiscovery() error { return err } s.rewatchQueryNodes(qnSessions) - s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes) + s.qnSessionWatcher = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes) return nil } @@ -796,7 +795,7 @@ func (s *Server) watchService(ctx context.Context) { case <-ctx.Done(): log.Info("watch service shutdown") return - case event, ok := <-s.dnEventCh: + case event, ok := <-s.dnSessionWatcher.EventChannel(): if !ok { s.stopServiceWatch() return @@ -809,7 +808,7 @@ func (s *Server) watchService(ctx context.Context) { }() return } - case event, ok := <-s.qnEventCh: + case event, ok := <-s.qnSessionWatcher.EventChannel(): if !ok { s.stopServiceWatch() return @@ -1051,6 +1050,14 @@ func (s *Server) Stop() error { s.analyzeInspector.Stop() log.Info("datacoord analyze inspector stopped") + if s.dnSessionWatcher != nil { + s.dnSessionWatcher.Stop() + } + + if s.qnSessionWatcher != nil { + s.qnSessionWatcher.Stop() + } + if s.session != nil { s.session.Stop() } diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index aa448b33dd..ed75c66d41 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -735,7 +735,12 @@ func TestService_WatchServices(t *testing.T) { svr.serverLoopWg.Add(1) ech := make(chan *sessionutil.SessionEvent) - svr.dnEventCh = ech + mockDnWatcher := sessionutil.NewMockSessionWatcher(t) + mockDnWatcher.EXPECT().EventChannel().Return(ech) + svr.dnSessionWatcher = mockDnWatcher + mockQnWatcher := sessionutil.NewMockSessionWatcher(t) + mockQnWatcher.EXPECT().EventChannel().Return(nil) + svr.qnSessionWatcher = mockQnWatcher flag := false closed := false @@ -762,7 +767,9 @@ func TestService_WatchServices(t *testing.T) { ech = make(chan *sessionutil.SessionEvent) flag = false - svr.dnEventCh = ech + mockDnWatcher = sessionutil.NewMockSessionWatcher(t) + mockDnWatcher.EXPECT().EventChannel().Return(ech) + svr.dnSessionWatcher = mockDnWatcher ctx, cancel := context.WithCancel(context.Background()) svr.serverLoopWg.Add(1) diff --git a/internal/distributed/connection_manager.go b/internal/distributed/connection_manager.go index 18c568c79a..0730b498e7 100644 --- a/internal/distributed/connection_manager.go +++ b/internal/distributed/connection_manager.go @@ -119,8 +119,8 @@ func (cm *ConnectionManager) AddDependency(roleName string) error { } } - eventChannel := cm.session.WatchServices(roleName, rev, nil) - go cm.processEvent(eventChannel) + watcher := cm.session.WatchServices(roleName, rev, nil) + go cm.processEvent(watcher.EventChannel()) return nil } diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index d7f236861b..f82f83df30 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -228,7 +228,8 @@ func (s *Server) init() error { etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { log.Error("failed to connect to etcd", zap.Error(err)) return err diff --git a/internal/distributed/mixcoord/service.go b/internal/distributed/mixcoord/service.go index 7495c40c0f..eda85490ee 100644 --- a/internal/distributed/mixcoord/service.go +++ b/internal/distributed/mixcoord/service.go @@ -141,7 +141,8 @@ func (s *Server) init() error { etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { log.Warn("MixCoord connect to etcd failed", zap.Error(err)) return err diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index f4a4f5e63a..41691bac29 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -445,7 +445,8 @@ func (s *Server) init() error { etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { log.Debug("Proxy connect to etcd failed", zap.Error(err)) return err diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 07b9f9d98a..d2c7906c00 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -122,7 +122,8 @@ func (s *Server) init() error { etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { log.Debug("QueryNode connect to etcd failed", zap.Error(err)) return err diff --git a/internal/kv/etcd/metakv_factory.go b/internal/kv/etcd/metakv_factory.go index 9cad46de11..55cc7d46a8 100644 --- a/internal/kv/etcd/metakv_factory.go +++ b/internal/kv/etcd/metakv_factory.go @@ -63,7 +63,8 @@ func NewWatchKVFactory(rootPath string, etcdCfg *paramtable.EtcdConfig) (kv.Watc etcdCfg.EtcdTLSCert.GetValue(), etcdCfg.EtcdTLSKey.GetValue(), etcdCfg.EtcdTLSCACert.GetValue(), - etcdCfg.EtcdTLSMinVersion.GetValue()) + etcdCfg.EtcdTLSMinVersion.GetValue(), + etcdCfg.ClientOptions()...) if err != nil { return nil, err } diff --git a/internal/mocks/mock_grpc_client.go b/internal/mocks/mock_grpc_client.go index 383f8ff330..bcf4589f0e 100644 --- a/internal/mocks/mock_grpc_client.go +++ b/internal/mocks/mock_grpc_client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocks @@ -33,6 +33,10 @@ func (_m *MockGrpcClient[T]) EXPECT() *MockGrpcClient_Expecter[T] { func (_m *MockGrpcClient[T]) Call(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { ret := _m.Called(ctx, caller) + if len(ret) == 0 { + panic("no return value specified for Call") + } + var r0 interface{} var r1 error if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { @@ -84,10 +88,14 @@ func (_c *MockGrpcClient_Call_Call[T]) RunAndReturn(run func(context.Context, fu return _c } -// Close provides a mock function with given fields: +// Close provides a mock function with no fields func (_m *MockGrpcClient[T]) Close() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Close") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -125,7 +133,7 @@ func (_c *MockGrpcClient_Close_Call[T]) RunAndReturn(run func() error) *MockGrpc return _c } -// EnableEncryption provides a mock function with given fields: +// EnableEncryption provides a mock function with no fields func (_m *MockGrpcClient[T]) EnableEncryption() { _m.Called() } @@ -153,14 +161,18 @@ func (_c *MockGrpcClient_EnableEncryption_Call[T]) Return() *MockGrpcClient_Enab } func (_c *MockGrpcClient_EnableEncryption_Call[T]) RunAndReturn(run func()) *MockGrpcClient_EnableEncryption_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } -// GetNodeID provides a mock function with given fields: +// GetNodeID provides a mock function with no fields func (_m *MockGrpcClient[T]) GetNodeID() int64 { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetNodeID") + } + var r0 int64 if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() @@ -198,10 +210,14 @@ func (_c *MockGrpcClient_GetNodeID_Call[T]) RunAndReturn(run func() int64) *Mock return _c } -// GetRole provides a mock function with given fields: +// GetRole provides a mock function with no fields func (_m *MockGrpcClient[T]) GetRole() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetRole") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -243,6 +259,10 @@ func (_c *MockGrpcClient_GetRole_Call[T]) RunAndReturn(run func() string) *MockG func (_m *MockGrpcClient[T]) ReCall(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { ret := _m.Called(ctx, caller) + if len(ret) == 0 { + panic("no return value specified for ReCall") + } + var r0 interface{} var r1 error if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { @@ -323,7 +343,7 @@ func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) Return() *MockGrpcClient_SetGet } func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (string, error))) *MockGrpcClient_SetGetAddrFunc_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -356,7 +376,7 @@ func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Return() *MockGrpcClien } func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -389,7 +409,7 @@ func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) Return() *MockGrpcCli } func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -422,7 +442,7 @@ func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) Return() *MockGrpcClient_ } func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) RunAndReturn(run func(func(*grpc.ClientConn) T)) *MockGrpcClient_SetNewGrpcClientFunc_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -455,7 +475,7 @@ func (_c *MockGrpcClient_SetNodeID_Call[T]) Return() *MockGrpcClient_SetNodeID_C } func (_c *MockGrpcClient_SetNodeID_Call[T]) RunAndReturn(run func(int64)) *MockGrpcClient_SetNodeID_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -488,7 +508,7 @@ func (_c *MockGrpcClient_SetRole_Call[T]) Return() *MockGrpcClient_SetRole_Call[ } func (_c *MockGrpcClient_SetRole_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetRole_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -521,7 +541,7 @@ func (_c *MockGrpcClient_SetSession_Call[T]) Return() *MockGrpcClient_SetSession } func (_c *MockGrpcClient_SetSession_Call[T]) RunAndReturn(run func(*sessionutil.Session)) *MockGrpcClient_SetSession_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 41f9caf115..53682a4440 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -79,6 +79,8 @@ type Server struct { tikvCli *txnkv.Client address string session sessionutil.SessionInterface + sessionWatcher sessionutil.SessionWatcher + sessionWatcherMu sync.Mutex kv kv.MetaKv idAllocator func() (int64, error) metricsCacheManager *metricsinfo.MetricsCacheManager @@ -589,12 +591,19 @@ func (s *Server) Stop() error { s.cluster.Stop() } + s.sessionWatcherMu.Lock() + if s.sessionWatcher != nil { + s.sessionWatcher.Stop() + } + s.sessionWatcherMu.Unlock() + + s.cancel() + s.wg.Wait() + if s.session != nil { s.session.Stop() } - s.cancel() - s.wg.Wait() log.Info("QueryCoord stop successfully") return nil } @@ -634,14 +643,16 @@ func (s *Server) watchNodes(revision int64) { log := log.Ctx(s.ctx) defer s.wg.Done() - eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes) + s.sessionWatcherMu.Lock() + s.sessionWatcher = s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes) + s.sessionWatcherMu.Unlock() for { select { case <-s.ctx.Done(): log.Info("stop watching nodes, QueryCoord stopped") return - case event, ok := <-eventChan: + case event, ok := <-s.sessionWatcher.EventChannel(): if !ok { // ErrCompacted is handled inside SessionWatcher log.Warn("Session Watcher channel closed", zap.Int64("serverID", paramtable.GetNodeID())) diff --git a/internal/tso/mocks/allocator.go b/internal/tso/mocks/allocator.go index 2d5d7308d9..f0a8a6244f 100644 --- a/internal/tso/mocks/allocator.go +++ b/internal/tso/mocks/allocator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocktso @@ -25,6 +25,10 @@ func (_m *Allocator) EXPECT() *Allocator_Expecter { func (_m *Allocator) GenerateTSO(count uint32) (uint64, error) { ret := _m.Called(count) + if len(ret) == 0 { + panic("no return value specified for GenerateTSO") + } + var r0 uint64 var r1 error if rf, ok := ret.Get(0).(func(uint32) (uint64, error)); ok { @@ -73,10 +77,14 @@ func (_c *Allocator_GenerateTSO_Call) RunAndReturn(run func(uint32) (uint64, err return _c } -// GetLastSavedTime provides a mock function with given fields: +// GetLastSavedTime provides a mock function with no fields func (_m *Allocator) GetLastSavedTime() time.Time { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetLastSavedTime") + } + var r0 time.Time if rf, ok := ret.Get(0).(func() time.Time); ok { r0 = rf() @@ -114,10 +122,14 @@ func (_c *Allocator_GetLastSavedTime_Call) RunAndReturn(run func() time.Time) *A return _c } -// Initialize provides a mock function with given fields: +// Initialize provides a mock function with no fields func (_m *Allocator) Initialize() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Initialize") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -155,7 +167,7 @@ func (_c *Allocator_Initialize_Call) RunAndReturn(run func() error) *Allocator_I return _c } -// Reset provides a mock function with given fields: +// Reset provides a mock function with no fields func (_m *Allocator) Reset() { _m.Called() } @@ -183,7 +195,7 @@ func (_c *Allocator_Reset_Call) Return() *Allocator_Reset_Call { } func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -191,6 +203,10 @@ func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { func (_m *Allocator) SetTSO(_a0 uint64) error { ret := _m.Called(_a0) + if len(ret) == 0 { + panic("no return value specified for SetTSO") + } + var r0 error if rf, ok := ret.Get(0).(func(uint64) error); ok { r0 = rf(_a0) @@ -229,10 +245,14 @@ func (_c *Allocator_SetTSO_Call) RunAndReturn(run func(uint64) error) *Allocator return _c } -// UpdateTSO provides a mock function with given fields: +// UpdateTSO provides a mock function with no fields func (_m *Allocator) UpdateTSO() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for UpdateTSO") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() diff --git a/internal/util/dependency/kv/kv_client_handler.go b/internal/util/dependency/kv/kv_client_handler.go index 267a8e3092..43e0d09698 100644 --- a/internal/util/dependency/kv/kv_client_handler.go +++ b/internal/util/dependency/kv/kv_client_handler.go @@ -62,6 +62,7 @@ func getEtcdAndPath() (*clientv3.Client, string) { // Function that calls the Etcd constructor func createEtcdClient() (*clientv3.Client, error) { cfg := ¶mtable.Get().ServiceParam + options := cfg.EtcdCfg.ClientOptions() return etcd.CreateEtcdClient( cfg.EtcdCfg.UseEmbedEtcd.GetAsBool(), cfg.EtcdCfg.EtcdEnableAuth.GetAsBool(), @@ -72,5 +73,6 @@ func createEtcdClient() (*clientv3.Client, error) { cfg.EtcdCfg.EtcdTLSCert.GetValue(), cfg.EtcdCfg.EtcdTLSKey.GetValue(), cfg.EtcdCfg.EtcdTLSCACert.GetValue(), - cfg.EtcdCfg.EtcdTLSMinVersion.GetValue()) + cfg.EtcdCfg.EtcdTLSMinVersion.GetValue(), + options...) } diff --git a/internal/util/dependency/mock_factory.go b/internal/util/dependency/mock_factory.go index 444020a203..bcb43b5708 100644 --- a/internal/util/dependency/mock_factory.go +++ b/internal/util/dependency/mock_factory.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package dependency @@ -55,7 +55,7 @@ func (_c *MockFactory_Init_Call) Return() *MockFactory_Init_Call { } func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentParam)) *MockFactory_Init_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -63,6 +63,10 @@ func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentPara func (_m *MockFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewMsgStream") + } + var r0 msgstream.MsgStream var r1 error if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { @@ -117,6 +121,10 @@ func (_c *MockFactory_NewMsgStream_Call) RunAndReturn(run func(context.Context) func (_m *MockFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewMsgStreamDisposer") + } + var r0 func([]string, string) error if rf, ok := ret.Get(0).(func(context.Context) func([]string, string) error); ok { r0 = rf(ctx) @@ -161,6 +169,10 @@ func (_c *MockFactory_NewMsgStreamDisposer_Call) RunAndReturn(run func(context.C func (_m *MockFactory) NewPersistentStorageChunkManager(ctx context.Context) (storage.ChunkManager, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewPersistentStorageChunkManager") + } + var r0 storage.ChunkManager var r1 error if rf, ok := ret.Get(0).(func(context.Context) (storage.ChunkManager, error)); ok { @@ -215,6 +227,10 @@ func (_c *MockFactory_NewPersistentStorageChunkManager_Call) RunAndReturn(run fu func (_m *MockFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewTtMsgStream") + } + var r0 msgstream.MsgStream var r1 error if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { diff --git a/internal/util/proxyutil/mock_proxy_client_manager.go b/internal/util/proxyutil/mock_proxy_client_manager.go index e2244e6b76..1e729da365 100644 --- a/internal/util/proxyutil/mock_proxy_client_manager.go +++ b/internal/util/proxyutil/mock_proxy_client_manager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package proxyutil @@ -59,7 +59,7 @@ func (_c *MockProxyClientManager_AddProxyClient_Call) Return() *MockProxyClientM } func (_c *MockProxyClientManager_AddProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -92,7 +92,7 @@ func (_c *MockProxyClientManager_AddProxyClients_Call) Return() *MockProxyClient } func (_c *MockProxyClientManager_AddProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -125,7 +125,7 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) Return() *MockProxyClientM } func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -133,6 +133,10 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*ses func (_m *MockProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetComponentStates") + } + var r0 map[int64]*milvuspb.ComponentStates var r1 error if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*milvuspb.ComponentStates, error)); ok { @@ -183,10 +187,14 @@ func (_c *MockProxyClientManager_GetComponentStates_Call) RunAndReturn(run func( return _c } -// GetProxyClients provides a mock function with given fields: +// GetProxyClients provides a mock function with no fields func (_m *MockProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetProxyClients") + } + var r0 *typeutil.ConcurrentMap[int64, types.ProxyClient] if rf, ok := ret.Get(0).(func() *typeutil.ConcurrentMap[int64, types.ProxyClient]); ok { r0 = rf() @@ -226,10 +234,14 @@ func (_c *MockProxyClientManager_GetProxyClients_Call) RunAndReturn(run func() * return _c } -// GetProxyCount provides a mock function with given fields: +// GetProxyCount provides a mock function with no fields func (_m *MockProxyClientManager) GetProxyCount() int { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetProxyCount") + } + var r0 int if rf, ok := ret.Get(0).(func() int); ok { r0 = rf() @@ -271,6 +283,10 @@ func (_c *MockProxyClientManager_GetProxyCount_Call) RunAndReturn(run func() int func (_m *MockProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetProxyMetrics") + } + var r0 []*milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context) ([]*milvuspb.GetMetricsResponse, error)); ok { @@ -332,6 +348,10 @@ func (_m *MockProxyClientManager) InvalidateCollectionMetaCache(ctx context.Cont _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for InvalidateCollectionMetaCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error); ok { r0 = rf(ctx, request, opts...) @@ -383,6 +403,10 @@ func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) RunAndRetur func (_m *MockProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for InvalidateCredentialCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { r0 = rf(ctx, request) @@ -426,6 +450,10 @@ func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(ru func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for InvalidateShardLeaderCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok { r0 = rf(ctx, request) @@ -469,6 +497,10 @@ func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(r func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for RefreshPolicyInfoCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { r0 = rf(ctx, req) @@ -512,6 +544,10 @@ func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) RunAndReturn(run f func (_m *MockProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for SetRates") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) error); ok { r0 = rf(ctx, request) @@ -555,6 +591,10 @@ func (_c *MockProxyClientManager_SetRates_Call) RunAndReturn(run func(context.Co func (_m *MockProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for UpdateCredentialCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { r0 = rf(ctx, request) diff --git a/internal/util/proxyutil/mock_proxy_watcher.go b/internal/util/proxyutil/mock_proxy_watcher.go index 0aed0cb5b6..4f79d65a8c 100644 --- a/internal/util/proxyutil/mock_proxy_watcher.go +++ b/internal/util/proxyutil/mock_proxy_watcher.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package proxyutil @@ -64,7 +64,7 @@ func (_c *MockProxyWatcher_AddSessionFunc_Call) Return() *MockProxyWatcher_AddSe } func (_c *MockProxyWatcher_AddSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -110,11 +110,11 @@ func (_c *MockProxyWatcher_DelSessionFunc_Call) Return() *MockProxyWatcher_DelSe } func (_c *MockProxyWatcher_DelSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Stop provides a mock function with given fields: +// Stop provides a mock function with no fields func (_m *MockProxyWatcher) Stop() { _m.Called() } @@ -142,7 +142,7 @@ func (_c *MockProxyWatcher_Stop_Call) Return() *MockProxyWatcher_Stop_Call { } func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher_Stop_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -150,6 +150,10 @@ func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher func (_m *MockProxyWatcher) WatchProxy(ctx context.Context) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for WatchProxy") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(ctx) diff --git a/internal/util/sessionutil/mock_session.go b/internal/util/sessionutil/mock_session.go index a4d47020ec..d6bedd3402 100644 --- a/internal/util/sessionutil/mock_session.go +++ b/internal/util/sessionutil/mock_session.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package sessionutil @@ -24,10 +24,14 @@ func (_m *MockSession) EXPECT() *MockSession_Expecter { return &MockSession_Expecter{mock: &_m.Mock} } -// Disconnected provides a mock function with given fields: +// Disconnected provides a mock function with no fields func (_m *MockSession) Disconnected() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Disconnected") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -69,6 +73,10 @@ func (_c *MockSession_Disconnected_Call) RunAndReturn(run func() bool) *MockSess func (_m *MockSession) ForceActiveStandby(activateFunc func() error) error { ret := _m.Called(activateFunc) + if len(ret) == 0 { + panic("no return value specified for ForceActiveStandby") + } + var r0 error if rf, ok := ret.Get(0).(func(func() error) error); ok { r0 = rf(activateFunc) @@ -107,10 +115,14 @@ func (_c *MockSession_ForceActiveStandby_Call) RunAndReturn(run func(func() erro return _c } -// GetAddress provides a mock function with given fields: +// GetAddress provides a mock function with no fields func (_m *MockSession) GetAddress() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetAddress") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -148,10 +160,14 @@ func (_c *MockSession_GetAddress_Call) RunAndReturn(run func() string) *MockSess return _c } -// GetServerID provides a mock function with given fields: +// GetServerID provides a mock function with no fields func (_m *MockSession) GetServerID() int64 { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetServerID") + } + var r0 int64 if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() @@ -193,6 +209,10 @@ func (_c *MockSession_GetServerID_Call) RunAndReturn(run func() int64) *MockSess func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) { ret := _m.Called(prefix) + if len(ret) == 0 { + panic("no return value specified for GetSessions") + } + var r0 map[string]*Session var r1 int64 var r2 error @@ -254,6 +274,10 @@ func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(string) (map[strin func (_m *MockSession) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) { ret := _m.Called(prefix, r) + if len(ret) == 0 { + panic("no return value specified for GetSessionsWithVersionRange") + } + var r0 map[string]*Session var r1 int64 var r2 error @@ -312,10 +336,14 @@ func (_c *MockSession_GetSessionsWithVersionRange_Call) RunAndReturn(run func(st return _c } -// GoingStop provides a mock function with given fields: +// GoingStop provides a mock function with no fields func (_m *MockSession) GoingStop() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GoingStop") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -385,14 +413,18 @@ func (_c *MockSession_Init_Call) Return() *MockSession_Init_Call { } func (_c *MockSession_Init_Call) RunAndReturn(run func(string, string, bool, bool)) *MockSession_Init_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// IsTriggerKill provides a mock function with given fields: +// IsTriggerKill provides a mock function with no fields func (_m *MockSession) IsTriggerKill() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for IsTriggerKill") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -460,14 +492,18 @@ func (_c *MockSession_LivenessCheck_Call) Return() *MockSession_LivenessCheck_Ca } func (_c *MockSession_LivenessCheck_Call) RunAndReturn(run func(context.Context, func())) *MockSession_LivenessCheck_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// MarshalJSON provides a mock function with given fields: +// MarshalJSON provides a mock function with no fields func (_m *MockSession) MarshalJSON() ([]byte, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for MarshalJSON") + } + var r0 []byte var r1 error if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { @@ -521,6 +557,10 @@ func (_c *MockSession_MarshalJSON_Call) RunAndReturn(run func() ([]byte, error)) func (_m *MockSession) ProcessActiveStandBy(activateFunc func() error) error { ret := _m.Called(activateFunc) + if len(ret) == 0 { + panic("no return value specified for ProcessActiveStandBy") + } + var r0 error if rf, ok := ret.Get(0).(func(func() error) error); ok { r0 = rf(activateFunc) @@ -559,7 +599,7 @@ func (_c *MockSession_ProcessActiveStandBy_Call) RunAndReturn(run func(func() er return _c } -// Register provides a mock function with given fields: +// Register provides a mock function with no fields func (_m *MockSession) Register() { _m.Called() } @@ -587,14 +627,18 @@ func (_c *MockSession_Register_Call) Return() *MockSession_Register_Call { } func (_c *MockSession_Register_Call) RunAndReturn(run func()) *MockSession_Register_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Registered provides a mock function with given fields: +// Registered provides a mock function with no fields func (_m *MockSession) Registered() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Registered") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -661,7 +705,7 @@ func (_c *MockSession_Revoke_Call) Return() *MockSession_Revoke_Call { } func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -694,7 +738,7 @@ func (_c *MockSession_SetDisconnected_Call) Return() *MockSession_SetDisconnecte } func (_c *MockSession_SetDisconnected_Call) RunAndReturn(run func(bool)) *MockSession_SetDisconnected_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -727,11 +771,11 @@ func (_c *MockSession_SetEnableActiveStandBy_Call) Return() *MockSession_SetEnab } func (_c *MockSession_SetEnableActiveStandBy_Call) RunAndReturn(run func(bool)) *MockSession_SetEnableActiveStandBy_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Stop provides a mock function with given fields: +// Stop provides a mock function with no fields func (_m *MockSession) Stop() { _m.Called() } @@ -759,14 +803,18 @@ func (_c *MockSession_Stop_Call) Return() *MockSession_Stop_Call { } func (_c *MockSession_Stop_Call) RunAndReturn(run func()) *MockSession_Stop_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// String provides a mock function with given fields: +// String provides a mock function with no fields func (_m *MockSession) String() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for String") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -808,6 +856,10 @@ func (_c *MockSession_String_Call) RunAndReturn(run func() string) *MockSession_ func (_m *MockSession) UnmarshalJSON(data []byte) error { ret := _m.Called(data) + if len(ret) == 0 { + panic("no return value specified for UnmarshalJSON") + } + var r0 error if rf, ok := ret.Get(0).(func([]byte) error); ok { r0 = rf(data) @@ -875,20 +927,24 @@ func (_c *MockSession_UpdateRegistered_Call) Return() *MockSession_UpdateRegiste } func (_c *MockSession_UpdateRegistered_Call) RunAndReturn(run func(bool)) *MockSession_UpdateRegistered_Call { - _c.Call.Return(run) + _c.Run(run) return _c } // WatchServices provides a mock function with given fields: prefix, revision, rewatch -func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) <-chan *SessionEvent { +func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) SessionWatcher { ret := _m.Called(prefix, revision, rewatch) - var r0 <-chan *SessionEvent - if rf, ok := ret.Get(0).(func(string, int64, Rewatch) <-chan *SessionEvent); ok { + if len(ret) == 0 { + panic("no return value specified for WatchServices") + } + + var r0 SessionWatcher + if rf, ok := ret.Get(0).(func(string, int64, Rewatch) SessionWatcher); ok { r0 = rf(prefix, revision, rewatch) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan *SessionEvent) + r0 = ret.Get(0).(SessionWatcher) } } @@ -915,26 +971,30 @@ func (_c *MockSession_WatchServices_Call) Run(run func(prefix string, revision i return _c } -func (_c *MockSession_WatchServices_Call) Return(eventChannel <-chan *SessionEvent) *MockSession_WatchServices_Call { - _c.Call.Return(eventChannel) +func (_c *MockSession_WatchServices_Call) Return(watcher SessionWatcher) *MockSession_WatchServices_Call { + _c.Call.Return(watcher) return _c } -func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, Rewatch) <-chan *SessionEvent) *MockSession_WatchServices_Call { +func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, Rewatch) SessionWatcher) *MockSession_WatchServices_Call { _c.Call.Return(run) return _c } // WatchServicesWithVersionRange provides a mock function with given fields: prefix, r, revision, rewatch -func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) <-chan *SessionEvent { +func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) SessionWatcher { ret := _m.Called(prefix, r, revision, rewatch) - var r0 <-chan *SessionEvent - if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent); ok { + if len(ret) == 0 { + panic("no return value specified for WatchServicesWithVersionRange") + } + + var r0 SessionWatcher + if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) SessionWatcher); ok { r0 = rf(prefix, r, revision, rewatch) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan *SessionEvent) + r0 = ret.Get(0).(SessionWatcher) } } @@ -962,12 +1022,12 @@ func (_c *MockSession_WatchServicesWithVersionRange_Call) Run(run func(prefix st return _c } -func (_c *MockSession_WatchServicesWithVersionRange_Call) Return(eventChannel <-chan *SessionEvent) *MockSession_WatchServicesWithVersionRange_Call { - _c.Call.Return(eventChannel) +func (_c *MockSession_WatchServicesWithVersionRange_Call) Return(watcher SessionWatcher) *MockSession_WatchServicesWithVersionRange_Call { + _c.Call.Return(watcher) return _c } -func (_c *MockSession_WatchServicesWithVersionRange_Call) RunAndReturn(run func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent) *MockSession_WatchServicesWithVersionRange_Call { +func (_c *MockSession_WatchServicesWithVersionRange_Call) RunAndReturn(run func(string, semver.Range, int64, Rewatch) SessionWatcher) *MockSession_WatchServicesWithVersionRange_Call { _c.Call.Return(run) return _c } diff --git a/internal/util/sessionutil/mock_session_watcher.go b/internal/util/sessionutil/mock_session_watcher.go new file mode 100644 index 0000000000..668a7bf4a4 --- /dev/null +++ b/internal/util/sessionutil/mock_session_watcher.go @@ -0,0 +1,111 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package sessionutil + +import mock "github.com/stretchr/testify/mock" + +// MockSessionWatcher is an autogenerated mock type for the SessionWatcher type +type MockSessionWatcher struct { + mock.Mock +} + +type MockSessionWatcher_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSessionWatcher) EXPECT() *MockSessionWatcher_Expecter { + return &MockSessionWatcher_Expecter{mock: &_m.Mock} +} + +// EventChannel provides a mock function with no fields +func (_m *MockSessionWatcher) EventChannel() <-chan *SessionEvent { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EventChannel") + } + + var r0 <-chan *SessionEvent + if rf, ok := ret.Get(0).(func() <-chan *SessionEvent); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan *SessionEvent) + } + } + + return r0 +} + +// MockSessionWatcher_EventChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EventChannel' +type MockSessionWatcher_EventChannel_Call struct { + *mock.Call +} + +// EventChannel is a helper method to define mock.On call +func (_e *MockSessionWatcher_Expecter) EventChannel() *MockSessionWatcher_EventChannel_Call { + return &MockSessionWatcher_EventChannel_Call{Call: _e.mock.On("EventChannel")} +} + +func (_c *MockSessionWatcher_EventChannel_Call) Run(run func()) *MockSessionWatcher_EventChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSessionWatcher_EventChannel_Call) Return(_a0 <-chan *SessionEvent) *MockSessionWatcher_EventChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSessionWatcher_EventChannel_Call) RunAndReturn(run func() <-chan *SessionEvent) *MockSessionWatcher_EventChannel_Call { + _c.Call.Return(run) + return _c +} + +// Stop provides a mock function with no fields +func (_m *MockSessionWatcher) Stop() { + _m.Called() +} + +// MockSessionWatcher_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockSessionWatcher_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *MockSessionWatcher_Expecter) Stop() *MockSessionWatcher_Stop_Call { + return &MockSessionWatcher_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *MockSessionWatcher_Stop_Call) Run(run func()) *MockSessionWatcher_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSessionWatcher_Stop_Call) Return() *MockSessionWatcher_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSessionWatcher_Stop_Call) RunAndReturn(run func()) *MockSessionWatcher_Stop_Call { + _c.Run(run) + return _c +} + +// NewMockSessionWatcher creates a new instance of MockSessionWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSessionWatcher(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSessionWatcher { + mock := &MockSessionWatcher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/sessionutil/session.go b/internal/util/sessionutil/session.go index eae304d583..83908a2cd2 100644 --- a/internal/util/sessionutil/session.go +++ b/internal/util/sessionutil/session.go @@ -34,8 +34,8 @@ type SessionInterface interface { GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) GoingStop() error - WatchServices(prefix string, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) - WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) + WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher) + WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher) LivenessCheck(ctx context.Context, callback func()) Stop() Revoke(timeout time.Duration) @@ -51,3 +51,8 @@ type SessionInterface interface { GetServerID() int64 IsTriggerKill() bool } + +type SessionWatcher interface { + EventChannel() <-chan *SessionEvent + Stop() +} diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index d640a21fec..d6c5a42705 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -145,7 +145,7 @@ type Session struct { ctx context.Context // When outside context done, Session cancels its goroutines first, then uses // keepAliveCancel to cancel the etcd KeepAlive - keepAliveLock sync.Mutex + keepAliveMu sync.Mutex keepAliveCancel context.CancelFunc keepAliveCtx context.Context @@ -282,24 +282,11 @@ func NewSessionWithEtcd(ctx context.Context, metaRoot string, client *clientv3.C session.apply(opts...) session.UpdateRegistered(false) - - connectEtcdFn := func() error { - log.Ctx(ctx).Debug("Session try to connect to etcd") - ctx2, cancel2 := context.WithTimeout(session.ctx, 5*time.Second) - defer cancel2() - if _, err := client.Get(ctx2, "health"); err != nil { - return err - } - session.etcdCli = client - return nil - } - err := retry.Do(ctx, connectEtcdFn, retry.Attempts(100)) - if err != nil { - log.Ctx(ctx).Warn("failed to initialize session", - zap.Error(err)) - return nil - } - log.Ctx(ctx).Debug("Session connect to etcd success") + session.etcdCli = client + log.Ctx(ctx).Info("Successfully connected to etcd for session", + zap.String("metaRoot", metaRoot), + zap.String("hostName", hostName), + ) return session } @@ -333,7 +320,7 @@ func (s *Session) Register() { panic(err) } s.liveCh = make(chan struct{}) - s.processKeepAliveResponse(ch) + s.startKeepAliveLoop(ch) s.UpdateRegistered(true) } @@ -453,6 +440,9 @@ func (s *Session) initWatchSessionCh(ctx context.Context) error { ) ctx, cancel := context.WithCancel(ctx) + if old := s.watchCancel.Load(); old != nil { + (*old)() + } s.watchCancel.Store(&cancel) err = retry.Do(ctx, func() error { @@ -461,6 +451,7 @@ func (s *Session) initWatchSessionCh(ctx context.Context) error { }, retry.Attempts(uint(s.sessionRetryTimes))) if err != nil { log.Warn("fail to get the session key from the etcd", zap.Error(err)) + cancel() return err } s.watchSessionKeyCh = s.etcdCli.Watch(ctx, s.getSessionKey(), clientv3.WithRev(getResp.Header.Revision)) @@ -494,13 +485,14 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er registerFn := func() error { resp, err := s.etcdCli.Grant(s.ctx, s.sessionTTL) if err != nil { - log.Error("register service", zap.Error(err)) + log.Error("register service: failed to grant lease from etcd", zap.Error(err)) return err } s.LeaseID = &resp.ID sessionJSON, err := json.Marshal(s) if err != nil { + log.Error("register service: failed to marshal session", zap.Error(err)) return err } @@ -511,24 +503,23 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er 0)). Then(clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() if err != nil { - log.Warn("register on etcd error, check the availability of etcd ", zap.Error(err)) + log.Warn("register on etcd error, check the availability of etcd", zap.Error(err)) return err } - if txnResp != nil && !txnResp.Succeeded { s.handleRestart(completeKey) return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ServerName) } log.Info("put session key into etcd", zap.String("key", completeKey), zap.String("value", string(sessionJSON))) - keepAliveCtx, keepAliveCancel := context.WithCancel(context.Background()) - s.keepAliveCtx = keepAliveCtx - s.keepAliveCancel = keepAliveCancel - ch, err = s.etcdCli.KeepAlive(keepAliveCtx, resp.ID) + ctx, cancel := context.WithCancel(s.ctx) + ch, err = s.etcdCli.KeepAlive(ctx, resp.ID) if err != nil { - log.Warn("go error during keeping alive with etcd", zap.Error(err)) + log.Warn("failed to keep alive with etcd", zap.Int64("lease ID", int64(resp.ID)), zap.Error(err)) + cancel() return err } + s.setKeepAliveContext(ctx, cancel) log.Info("Service registered successfully", zap.String("ServerName", s.ServerName), zap.Int64("serverID", s.ServerID)) return nil } @@ -572,88 +563,107 @@ func (s *Session) handleRestart(key string) { // processKeepAliveResponse processes the response of etcd keepAlive interface // If keepAlive fails for unexpected error, it will send a signal to the channel. func (s *Session) processKeepAliveResponse(ch <-chan *clientv3.LeaseKeepAliveResponse) { - s.wg.Add(1) - go func() { - defer s.wg.Done() - for { - select { - case <-s.ctx.Done(): - log.Warn("keep alive", zap.Error(errors.New("context done"))) - s.cancelKeepAlive() - return - case resp, ok := <-ch: - if !ok { - log.Warn("session keepalive channel closed") + defer s.wg.Done() + for { + select { + case <-s.ctx.Done(): + log.Warn("session context canceled, stop keepalive") + s.cancelKeepAlive(true) + return - // if keep alive is canceled, keepAliveCtx.Err() will return a non-nil error - if s.keepAliveCtx.Err() != nil { - s.safeCloseLiveCh() - return - } + case resp, ok := <-ch: + if !ok || resp == nil { + log.Warn("keepalive channel closed", + zap.Bool("stopped", s.isStopped.Load()), + zap.String("serverName", s.ServerName)) - log.Info("keepAlive channel close caused by etcd, try to KeepAliveOnce", zap.String("serverName", s.ServerName)) - s.keepAliveLock.Lock() - defer s.keepAliveLock.Unlock() - // have to KeepAliveOnce before KeepAlive because KeepAlive won't throw error even when lease OT - var keepAliveOnceResp *clientv3.LeaseKeepAliveResponse - s.keepAliveCancel() - s.keepAliveCtx, s.keepAliveCancel = context.WithCancel(context.Background()) - err := retry.Do(s.ctx, func() error { - ctx, cancel := context.WithTimeout(s.keepAliveCtx, time.Second*10) - defer cancel() - resp, err := s.etcdCli.KeepAliveOnce(ctx, *s.LeaseID) - keepAliveOnceResp = resp - return err - }, retry.Attempts(3)) - if err != nil { - log.Warn("fail to retry keepAliveOnce", zap.String("serverName", s.ServerName), zap.Int64("LeaseID", int64(*s.LeaseID)), zap.Error(err)) - s.safeCloseLiveCh() - return - } - log.Info("succeed to KeepAliveOnce", zap.String("serverName", s.ServerName), zap.Int64("LeaseID", int64(*s.LeaseID)), zap.Any("resp", keepAliveOnceResp)) - - var chNew <-chan *clientv3.LeaseKeepAliveResponse - keepAliveFunc := func() error { - var err1 error - chNew, err1 = s.etcdCli.KeepAlive(s.keepAliveCtx, *s.LeaseID) - return err1 - } - err = fnWithTimeout(keepAliveFunc, time.Second*10) - if err != nil { - log.Warn("fail to retry keepAlive", zap.Error(err)) - s.safeCloseLiveCh() - return - } - go s.processKeepAliveResponse(chNew) + if s.isStopped.Load() { + s.safeCloseLiveCh() return } - if resp == nil { - log.Warn("session keepalive response failed") + + s.cancelKeepAlive(false) + + // this is just to make sure etcd is alived, and the lease can be keep alived ASAP + _, err := s.etcdCli.KeepAliveOnce(s.ctx, *s.LeaseID) + if err != nil { + log.Info("failed to keep alive", zap.String("serverName", s.ServerName), zap.Error(err)) s.safeCloseLiveCh() + return } + + var ( + newCh <-chan *clientv3.LeaseKeepAliveResponse + newCtx context.Context + newCancel context.CancelFunc + ) + + err = retry.Do(s.ctx, func() error { + ctx, cancel := context.WithCancel(s.ctx) + ch, err := s.etcdCli.KeepAlive(ctx, *s.LeaseID) + if err != nil { + cancel() + log.Warn("failed to keep alive with etcd", zap.Error(err)) + return err + } + newCh = ch + newCtx = ctx + newCancel = cancel + return nil + }, retry.Attempts(3)) + if err != nil { + log.Warn("failed to retry keepAlive", + zap.String("serverName", s.ServerName), + zap.Error(err)) + s.safeCloseLiveCh() + return + } + log.Info("retry keep alive success", zap.String("serverName", s.ServerName)) + s.setKeepAliveContext(newCtx, newCancel) + ch = newCh + continue } } - }() + } } -func fnWithTimeout(fn func() error, d time.Duration) error { - if d != 0 { - resultChan := make(chan bool) - var err1 error - go func() { - err1 = fn() - resultChan <- true - }() +func (s *Session) startKeepAliveLoop(ch <-chan *clientv3.LeaseKeepAliveResponse) { + s.wg.Add(1) + go s.processKeepAliveResponse(ch) +} - select { - case <-resultChan: - log.Ctx(context.TODO()).Debug("retry func success") - case <-time.After(d): - return errors.New("func timed out") +func (s *Session) setKeepAliveContext(ctx context.Context, cancel context.CancelFunc) { + s.keepAliveMu.Lock() + s.keepAliveCtx = ctx + s.keepAliveCancel = cancel + s.keepAliveMu.Unlock() +} + +// cancelKeepAlive cancels the keepAlive context and sets a flag to control whether keepAlive retry is allowed. +func (s *Session) cancelKeepAlive(isStop bool) { + var cancel context.CancelFunc + s.keepAliveMu.Lock() + defer func() { + s.keepAliveMu.Unlock() + if cancel != nil { + cancel() } - return err1 + }() + + // only process the first time + if s.isStopped.Load() { + return } - return fn() + + // Add a variable to signal whether keepAlive retry is allowed. + // If isDone is true, disable keepAlive retry. + if isStop { + s.isStopped.Store(true) + } + + cancel = s.keepAliveCancel + s.keepAliveCtx = nil + s.keepAliveCancel = nil } // GetSessions will get all sessions registered in etcd. @@ -754,22 +764,34 @@ type SessionEvent struct { } type sessionWatcher struct { - s *Session - rch clientv3.WatchChan - eventCh chan *SessionEvent - prefix string - rewatch Rewatch - validate func(*Session) bool + s *Session + cancel context.CancelFunc + rch clientv3.WatchChan + eventCh chan *SessionEvent + prefix string + rewatch Rewatch + validate func(*Session) bool + wg sync.WaitGroup + closeOnce sync.Once } -func (w *sessionWatcher) start() { +func (w *sessionWatcher) closeEventCh() { + w.closeOnce.Do(func() { + close(w.eventCh) + }) +} + +func (w *sessionWatcher) start(ctx context.Context) { + w.wg.Add(1) go func() { + defer w.wg.Done() for { select { - case <-w.s.ctx.Done(): + case <-ctx.Done(): return case wresp, ok := <-w.rch: if !ok { + w.closeEventCh() log.Warn("session watch channel closed") return } @@ -779,6 +801,11 @@ func (w *sessionWatcher) start() { }() } +func (w *sessionWatcher) Stop() { + w.cancel() + w.wg.Wait() +} + // WatchServices watches the service's up and down in etcd, and sends event to // eventChannel. // prefix is a parameter to know which service to watch and can be obtained in @@ -787,17 +814,19 @@ func (w *sessionWatcher) start() { // in GetSessions. // If a server up, an event will be add to channel with eventType SessionAddType. // If a server down, an event will be add to channel with eventType SessionDelType. -func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) { +func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher) { + ctx, cancel := context.WithCancel(s.ctx) w := &sessionWatcher{ s: s, + cancel: cancel, eventCh: make(chan *SessionEvent, 100), rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)), prefix: prefix, rewatch: rewatch, validate: func(s *Session) bool { return true }, } - w.start() - return w.eventCh + w.start(ctx) + return w } // WatchServicesWithVersionRange watches the service's up and down in etcd, and sends event to event Channel. @@ -806,17 +835,19 @@ func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) // revision is a etcd reversion to prevent missing key events and can be obtained in GetSessions. // If a server up, an event will be add to channel with eventType SessionAddType. // If a server down, an event will be add to channel with eventType SessionDelType. -func (s *Session) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) { +func (s *Session) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher) { + ctx, cancel := context.WithCancel(s.ctx) w := &sessionWatcher{ s: s, + cancel: cancel, eventCh: make(chan *SessionEvent, 100), rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)), prefix: prefix, rewatch: rewatch, validate: func(s *Session) bool { return r(s.Version) }, } - w.start() - return w.eventCh + w.start(ctx) + return w } func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) { @@ -875,14 +906,14 @@ func (w *sessionWatcher) handleWatchErr(err error) error { if err != v3rpc.ErrCompacted { // close event channel log.Warn("Watch service found error", zap.Error(err)) - close(w.eventCh) + w.closeEventCh() return err } sessions, revision, err := w.s.GetSessions(w.prefix) if err != nil { log.Warn("GetSession before rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) - close(w.eventCh) + w.closeEventCh() return err } // rewatch is nil, no logic to handle @@ -893,7 +924,7 @@ func (w *sessionWatcher) handleWatchErr(err error) error { } if err != nil { log.Warn("WatchServices rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) - close(w.eventCh) + w.closeEventCh() return err } @@ -901,6 +932,10 @@ func (w *sessionWatcher) handleWatchErr(err error) error { return nil } +func (w *sessionWatcher) EventChannel() <-chan *SessionEvent { + return w.eventCh +} + // LivenessCheck performs liveness check with provided context and channel // ctx controls the liveness check loop // ch is the liveness signal channel, ch is closed only when the session is expired @@ -909,7 +944,7 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { err := s.initWatchSessionCh(ctx) if err != nil { log.Error("failed to get session for liveness check", zap.Error(err)) - s.cancelKeepAlive() + s.cancelKeepAlive(true) if callback != nil { go callback() } @@ -942,12 +977,12 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { case <-ctx.Done(): log.Warn("liveness exits due to context done") // cancel the etcd keepAlive context - s.cancelKeepAlive() + s.cancelKeepAlive(true) return case resp, ok := <-s.watchSessionKeyCh: if !ok { log.Warn("watch session key channel closed") - s.cancelKeepAlive() + s.cancelKeepAlive(true) return } if resp.Err() != nil { @@ -955,14 +990,14 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { if resp.Err() != v3rpc.ErrCompacted { // close event channel log.Warn("Watch service found error", zap.Error(resp.Err())) - s.cancelKeepAlive() + s.cancelKeepAlive(true) return } log.Warn("Watch service found compacted error", zap.Error(resp.Err())) err := s.initWatchSessionCh(ctx) if err != nil { log.Warn("failed to get session during reconnecting", zap.Error(err)) - s.cancelKeepAlive() + s.cancelKeepAlive(true) } continue } @@ -972,7 +1007,7 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { log.Info("register session success", zap.String("role", s.ServerName), zap.String("key", string(event.Kv.Key))) case mvccpb.DELETE: log.Info("session key is deleted, exit...", zap.String("role", s.ServerName), zap.String("key", string(event.Kv.Key))) - s.cancelKeepAlive() + s.cancelKeepAlive(true) } } } @@ -980,18 +1015,10 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { }() } -func (s *Session) cancelKeepAlive() { - s.keepAliveLock.Lock() - defer s.keepAliveLock.Unlock() - if s.keepAliveCancel != nil { - s.keepAliveCancel() - } -} - func (s *Session) Stop() { - s.isStopped.Store(true) + log.Info("session stopping", zap.String("serverName", s.ServerName)) s.Revoke(time.Second) - s.cancelKeepAlive() + s.cancelKeepAlive(true) s.wg.Wait() } diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index d43fea206b..29d7c30ddd 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -158,7 +158,7 @@ func TestUpdateSessions(t *testing.T) { sessions, rev, err := s.GetSessions("test") assert.NoError(t, err) assert.Equal(t, len(sessions), 0) - eventCh := s.WatchServices("test", rev, nil) + watcher := s.WatchServices("test", rev, nil) sList := []*Session{} @@ -203,7 +203,7 @@ LOOP: select { case <-ch: t.FailNow() - case sessionEvent := <-eventCh: + case sessionEvent := <-watcher.EventChannel(): if sessionEvent.EventType == SessionAddEvent { addEventLen++ @@ -616,7 +616,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() { _, rev, err := s.GetSessionsWithVersionRange(suite.serverName, r) suite.Require().NoError(err) - ch := s.WatchServicesWithVersionRange(suite.serverName, r, rev, nil) + watcher := s.WatchServicesWithVersionRange(suite.serverName, r, rev, nil) // remove all sessions go func() { @@ -626,7 +626,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() { }() select { - case evt := <-ch: + case evt := <-watcher.EventChannel(): suite.Equal(suite.sessions[1].ServerID, evt.Session.ServerID) case <-time.After(time.Second): suite.Fail("no event received, failing") @@ -677,7 +677,7 @@ func TestSessionProcessActiveStandBy(t *testing.T) { log.Debug("Session 1 livenessCheck callback") flag = true close(signal) - s1.cancelKeepAlive() + s1.cancelKeepAlive(true) }) assert.False(t, s1.isStandby.Load().(bool)) @@ -1025,10 +1025,10 @@ func (s *SessionSuite) TestKeepAliveRetryActiveCancel() { ch, err := session.registerService() s.Require().NoError(err) session.liveCh = make(chan struct{}) - session.processKeepAliveResponse(ch) + session.startKeepAliveLoop(ch) session.LivenessCheck(ctx, nil) // active cancel, should not retry connect - session.cancelKeepAlive() + session.cancelKeepAlive(true) // wait workers exit session.wg.Wait() @@ -1049,7 +1049,7 @@ func (s *SessionSuite) TestKeepAliveRetryChannelClose() { session.liveCh = make(chan struct{}) closeChan := make(chan *clientv3.LeaseKeepAliveResponse) sendChan := (<-chan *clientv3.LeaseKeepAliveResponse)(closeChan) - session.processKeepAliveResponse(sendChan) + session.startKeepAliveLoop(sendChan) session.LivenessCheck(ctx, nil) // close channel, should retry connect close(closeChan) @@ -1092,3 +1092,42 @@ func (s *SessionSuite) TestGetSessions() { func TestSessionSuite(t *testing.T) { suite.Run(t, new(SessionSuite)) } + +func (s *SessionSuite) TestKeepAliveCancelWithoutStop() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session.Init("test", "normal", false, false) + _, err := session.registerService() + assert.NoError(s.T(), err) + + // Override liveCh and LeaseKeepAliveResponse channel for testing + session.liveCh = make(chan struct{}) + kaCh := make(chan *clientv3.LeaseKeepAliveResponse) + session.startKeepAliveLoop(kaCh) + + session.keepAliveMu.Lock() + cancelOld := session.keepAliveCancel + session.keepAliveCancel = func() { + // only cancel, not setting isStopped, to simulate not "stop" + } + session.keepAliveMu.Unlock() + if cancelOld != nil { + cancelOld() + } + + // send a nil (simulate closed keepalive channel) + go func() { + kaCh <- nil + }() + + // Give time for retry logic to trigger + time.Sleep(200 * time.Millisecond) + + // should not be disconnected, session could recover + assert.False(s.T(), session.Disconnected()) + + // Routine clean up + session.Stop() +} diff --git a/pkg/streaming/walimpls/impls/wp/builder.go b/pkg/streaming/walimpls/impls/wp/builder.go index c9b85a8874..0b0d89556a 100644 --- a/pkg/streaming/walimpls/impls/wp/builder.go +++ b/pkg/streaming/walimpls/impls/wp/builder.go @@ -167,7 +167,8 @@ func (b *builderImpl) getEtcdClient(ctx context.Context) (*clientv3.Client, erro etcdConfig.EtcdTLSCert.GetValue(), etcdConfig.EtcdTLSKey.GetValue(), etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) + etcdConfig.EtcdTLSMinVersion.GetValue(), + etcdConfig.ClientOptions()...) if err != nil { log.Warn("Woodpecker create connection to etcd failed", zap.Error(err)) return nil, err diff --git a/pkg/util/etcd/etcd_util.go b/pkg/util/etcd/etcd_util.go index dac7695ca7..5ad72c7b6d 100644 --- a/pkg/util/etcd/etcd_util.go +++ b/pkg/util/etcd/etcd_util.go @@ -38,6 +38,28 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util" ) +type ClientOption func(*clientv3.Config) + +// WithDialKeepAlive configures gRPC keepalive and autosync behaviors for the etcd client. +func WithDialKeepAlive(dialKeepAliveTime, dialKeepAliveTimeout time.Duration) ClientOption { + return func(cfg *clientv3.Config) { + if dialKeepAliveTime > 0 { + cfg.DialKeepAliveTime = dialKeepAliveTime + } + if dialKeepAliveTimeout > 0 { + cfg.DialKeepAliveTimeout = dialKeepAliveTimeout + } + } +} + +func applyClientOptions(cfg *clientv3.Config, opts ...ClientOption) { + for _, opt := range opts { + if opt != nil { + opt(cfg) + } + } +} + // GetEtcdClient returns etcd client // should only used for test func GetEtcdClient( @@ -48,6 +70,7 @@ func GetEtcdClient( keyFile string, caCertFile string, minVersion string, + opts ...ClientOption, ) (*clientv3.Client, error) { log.Info("create etcd client", zap.Bool("useEmbedEtcd", useEmbedEtcd), @@ -58,24 +81,26 @@ func GetEtcdClient( return GetEmbedEtcdClient() } if useSSL { - return GetRemoteEtcdSSLClient(endpoints, certFile, keyFile, caCertFile, minVersion) + return GetRemoteEtcdSSLClient(endpoints, certFile, keyFile, caCertFile, minVersion, opts...) } - return GetRemoteEtcdClient(endpoints) + return GetRemoteEtcdClient(endpoints, opts...) } // GetRemoteEtcdClient returns client of remote etcd by given endpoints -func GetRemoteEtcdClient(endpoints []string) (*clientv3.Client, error) { - return clientv3.New(clientv3.Config{ +func GetRemoteEtcdClient(endpoints []string, opts ...ClientOption) (*clientv3.Client, error) { + cfg := clientv3.Config{ Endpoints: endpoints, DialTimeout: 5 * time.Second, DialOptions: []grpc.DialOption{ grpc.WithBlock(), }, - }) + } + applyClientOptions(&cfg, opts...) + return clientv3.New(cfg) } -func GetRemoteEtcdClientWithAuth(endpoints []string, userName, password string) (*clientv3.Client, error) { - return clientv3.New(clientv3.Config{ +func GetRemoteEtcdClientWithAuth(endpoints []string, userName, password string, opts ...ClientOption) (*clientv3.Client, error) { + cfg := clientv3.Config{ Endpoints: endpoints, DialTimeout: 5 * time.Second, Username: userName, @@ -83,15 +108,17 @@ func GetRemoteEtcdClientWithAuth(endpoints []string, userName, password string) DialOptions: []grpc.DialOption{ grpc.WithBlock(), }, - }) + } + applyClientOptions(&cfg, opts...) + return clientv3.New(cfg) } -func GetRemoteEtcdSSLClient(endpoints []string, certFile string, keyFile string, caCertFile string, minVersion string) (*clientv3.Client, error) { +func GetRemoteEtcdSSLClient(endpoints []string, certFile string, keyFile string, caCertFile string, minVersion string, opts ...ClientOption) (*clientv3.Client, error) { var cfg clientv3.Config - return GetRemoteEtcdSSLClientWithCfg(endpoints, certFile, keyFile, caCertFile, minVersion, cfg) + return GetRemoteEtcdSSLClientWithCfg(endpoints, certFile, keyFile, caCertFile, minVersion, cfg, opts...) } -func GetRemoteEtcdSSLClientWithCfg(endpoints []string, certFile string, keyFile string, caCertFile string, minVersion string, cfg clientv3.Config) (*clientv3.Client, error) { +func GetRemoteEtcdSSLClientWithCfg(endpoints []string, certFile string, keyFile string, caCertFile string, minVersion string, cfg clientv3.Config, opts ...ClientOption) (*clientv3.Client, error) { cfg.Endpoints = endpoints cfg.DialTimeout = 5 * time.Second cert, err := tls.LoadX509KeyPair(certFile, keyFile) @@ -130,6 +157,7 @@ func GetRemoteEtcdSSLClientWithCfg(endpoints []string, certFile string, keyFile } cfg.DialOptions = append(cfg.DialOptions, grpc.WithBlock()) + applyClientOptions(&cfg, opts...) return clientv3.New(cfg) } @@ -145,18 +173,19 @@ func CreateEtcdClient( keyFile string, caCertFile string, minVersion string, + opts ...ClientOption, ) (*clientv3.Client, error) { if !enableAuth || useEmbedEtcd { - return GetEtcdClient(useEmbedEtcd, useSSL, endpoints, certFile, keyFile, caCertFile, minVersion) + return GetEtcdClient(useEmbedEtcd, useSSL, endpoints, certFile, keyFile, caCertFile, minVersion, opts...) } log.Info("create etcd client(enable auth)", zap.Bool("useSSL", useSSL), zap.Any("endpoints", endpoints), zap.String("minVersion", minVersion)) if useSSL { - return GetRemoteEtcdSSLClientWithCfg(endpoints, certFile, keyFile, caCertFile, minVersion, clientv3.Config{Username: userName, Password: password}) + return GetRemoteEtcdSSLClientWithCfg(endpoints, certFile, keyFile, caCertFile, minVersion, clientv3.Config{Username: userName, Password: password}, opts...) } - return GetRemoteEtcdClientWithAuth(endpoints, userName, password) + return GetRemoteEtcdClientWithAuth(endpoints, userName, password, opts...) } func min(a, b int) int { diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index ca5108c0af..0331b25013 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -47,7 +47,7 @@ const ( DefaultMiddlePriorityThreadCoreCoefficient = 5 DefaultLowPriorityThreadCoreCoefficient = 1 - DefaultSessionTTL = 30 // s + DefaultSessionTTL = 10 // s DefaultSessionRetryTimes = 30 DefaultMaxDegree = 56 diff --git a/pkg/util/paramtable/service_param.go b/pkg/util/paramtable/service_param.go index ee3f28579b..c556860ae9 100644 --- a/pkg/util/paramtable/service_param.go +++ b/pkg/util/paramtable/service_param.go @@ -23,12 +23,14 @@ import ( "path" "strconv" "strings" + "time" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/util" + "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" ) @@ -90,20 +92,22 @@ func (p *ServiceParam) WoodpeckerEnable() bool { // --- etcd --- type EtcdConfig struct { // --- ETCD --- - Endpoints ParamItem `refreshable:"false"` - RootPath ParamItem `refreshable:"false"` - MetaSubPath ParamItem `refreshable:"false"` - KvSubPath ParamItem `refreshable:"false"` - MetaRootPath CompositeParamItem `refreshable:"false"` - KvRootPath CompositeParamItem `refreshable:"false"` - EtcdLogLevel ParamItem `refreshable:"false"` - EtcdLogPath ParamItem `refreshable:"false"` - EtcdUseSSL ParamItem `refreshable:"false"` - EtcdTLSCert ParamItem `refreshable:"false"` - EtcdTLSKey ParamItem `refreshable:"false"` - EtcdTLSCACert ParamItem `refreshable:"false"` - EtcdTLSMinVersion ParamItem `refreshable:"false"` - RequestTimeout ParamItem `refreshable:"false"` + Endpoints ParamItem `refreshable:"false"` + RootPath ParamItem `refreshable:"false"` + MetaSubPath ParamItem `refreshable:"false"` + KvSubPath ParamItem `refreshable:"false"` + MetaRootPath CompositeParamItem `refreshable:"false"` + KvRootPath CompositeParamItem `refreshable:"false"` + EtcdLogLevel ParamItem `refreshable:"false"` + EtcdLogPath ParamItem `refreshable:"false"` + EtcdUseSSL ParamItem `refreshable:"false"` + EtcdTLSCert ParamItem `refreshable:"false"` + EtcdTLSKey ParamItem `refreshable:"false"` + EtcdTLSCACert ParamItem `refreshable:"false"` + EtcdTLSMinVersion ParamItem `refreshable:"false"` + RequestTimeout ParamItem `refreshable:"false"` + DialKeepAliveTime ParamItem `refreshable:"false"` + DialKeepAliveTimeout ParamItem `refreshable:"false"` // --- Embed ETCD --- UseEmbedEtcd ParamItem `refreshable:"false"` @@ -286,6 +290,24 @@ We recommend using version 1.2 and above.`, } p.RequestTimeout.Init(base.mgr) + p.DialKeepAliveTime = ParamItem{ + Key: "etcd.dialKeepAliveTime", + DefaultValue: "3000", + Version: "2.6.6", + Doc: `Interval in milliseconds for gRPC dial keepalive pings sent to etcd endpoints.`, + Export: true, + } + p.DialKeepAliveTime.Init(base.mgr) + + p.DialKeepAliveTimeout = ParamItem{ + Key: "etcd.dialKeepAliveTimeout", + DefaultValue: "2000", + Version: "2.6.6", + Doc: `Timeout in milliseconds waiting for keepalive responses before marking the connection as unhealthy.`, + Export: true, + } + p.DialKeepAliveTimeout.Init(base.mgr) + p.EtcdEnableAuth = ParamItem{ Key: "etcd.auth.enabled", DefaultValue: "false", @@ -318,17 +340,31 @@ We recommend using version 1.2 and above.`, func (p *EtcdConfig) GetAll() map[string]string { return map[string]string{ - "etcd.endpoints": p.Endpoints.GetValue(), - "etcd.metaRootPath": p.MetaRootPath.GetValue(), - "etcd.ssl.enabled": p.EtcdUseSSL.GetValue(), - "etcd.ssl.tlsCert": p.EtcdTLSCert.GetValue(), - "etcd.ssl.tlsKey": p.EtcdTLSKey.GetValue(), - "etcd.ssl.tlsCACert": p.EtcdTLSCACert.GetValue(), - "etcd.ssl.tlsMinVersion": p.EtcdTLSMinVersion.GetValue(), - "etcd.requestTimeout": p.RequestTimeout.GetValue(), - "etcd.auth.enabled": p.EtcdEnableAuth.GetValue(), - "etcd.auth.userName": p.EtcdAuthUserName.GetValue(), - "etcd.auth.password": p.EtcdAuthPassword.GetValue(), + "etcd.endpoints": p.Endpoints.GetValue(), + "etcd.metaRootPath": p.MetaRootPath.GetValue(), + "etcd.ssl.enabled": p.EtcdUseSSL.GetValue(), + "etcd.ssl.tlsCert": p.EtcdTLSCert.GetValue(), + "etcd.ssl.tlsKey": p.EtcdTLSKey.GetValue(), + "etcd.ssl.tlsCACert": p.EtcdTLSCACert.GetValue(), + "etcd.ssl.tlsMinVersion": p.EtcdTLSMinVersion.GetValue(), + "etcd.requestTimeout": p.RequestTimeout.GetValue(), + "etcd.dialKeepAliveTime": p.DialKeepAliveTime.GetValue(), + "etcd.dialKeepAliveTimeout": p.DialKeepAliveTimeout.GetValue(), + "etcd.auth.enabled": p.EtcdEnableAuth.GetValue(), + "etcd.auth.userName": p.EtcdAuthUserName.GetValue(), + "etcd.auth.password": p.EtcdAuthPassword.GetValue(), + } +} + +func (p *EtcdConfig) ClientOptions() []etcd.ClientOption { + dialKeepAliveTime := p.DialKeepAliveTime.GetAsDuration(time.Millisecond) + dialKeepAliveTimeout := p.DialKeepAliveTimeout.GetAsDuration(time.Millisecond) + + if dialKeepAliveTime <= 0 && dialKeepAliveTimeout <= 0 { + return nil + } + return []etcd.ClientOption{ + etcd.WithDialKeepAlive(dialKeepAliveTime, dialKeepAliveTimeout), } }