diff --git a/Makefile b/Makefile index b93df0012f..1e7947555d 100644 --- a/Makefile +++ b/Makefile @@ -533,6 +533,7 @@ generate-mockery-utils: getdeps # tso.Allocator $(INSTALL_PATH)/mockery --name=Allocator --dir=internal/tso --output=internal/tso/mocks --filename=allocator.go --with-expecter --structname=Allocator --outpkg=mocktso $(INSTALL_PATH)/mockery --name=SessionInterface --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session.go --with-expecter --structname=MockSession --inpackage + $(INSTALL_PATH)/mockery --name=SessionWatcher --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session_watcher.go --with-expecter --structname=MockSessionWatcher --inpackage $(INSTALL_PATH)/mockery --name=GrpcClient --dir=$(PWD)/internal/util/grpcclient --output=$(PWD)/internal/mocks --filename=mock_grpc_client.go --with-expecter --structname=MockGrpcClient # proxy_client_manager.go $(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage diff --git a/cmd/tools/migration/migration/runner.go b/cmd/tools/migration/migration/runner.go index 23bc708fcd..89defd88de 100644 --- a/cmd/tools/migration/migration/runner.go +++ b/cmd/tools/migration/migration/runner.go @@ -49,12 +49,12 @@ func (r *Runner) watchByPrefix(prefix string) { _, revision, err := r.session.GetSessions(prefix) fn := func() { r.Stop() } console.AbnormalExitIf(err, r.backupFinished.Load(), console.AddCallbacks(fn)) - eventCh := r.session.WatchServices(prefix, revision, nil) + watcher := r.session.WatchServices(prefix, revision, nil) for { select { case <-r.ctx.Done(): return - case event := <-eventCh: + case event := <-watcher.EventChannel(): msg := fmt.Sprintf("session up/down, exit migration, event type: %s, session: %s", event.EventType.String(), event.Session.String()) console.AbnormalExit(r.backupFinished.Load(), msg, console.AddCallbacks(fn)) } diff --git a/internal/allocator/mock_global_id_allocator.go b/internal/allocator/mock_global_id_allocator.go index b1f0b4d246..aedaee9710 100644 --- a/internal/allocator/mock_global_id_allocator.go +++ b/internal/allocator/mock_global_id_allocator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package allocator @@ -21,6 +21,10 @@ func (_m *MockGlobalIDAllocator) EXPECT() *MockGlobalIDAllocator_Expecter { func (_m *MockGlobalIDAllocator) Alloc(count uint32) (int64, int64, error) { ret := _m.Called(count) + if len(ret) == 0 { + panic("no return value specified for Alloc") + } + var r0 int64 var r1 int64 var r2 error @@ -76,10 +80,14 @@ func (_c *MockGlobalIDAllocator_Alloc_Call) RunAndReturn(run func(uint32) (int64 return _c } -// AllocOne provides a mock function with given fields: +// AllocOne provides a mock function with no fields func (_m *MockGlobalIDAllocator) AllocOne() (int64, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for AllocOne") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func() (int64, error)); ok { @@ -127,10 +135,14 @@ func (_c *MockGlobalIDAllocator_AllocOne_Call) RunAndReturn(run func() (int64, e return _c } -// Initialize provides a mock function with given fields: +// Initialize provides a mock function with no fields func (_m *MockGlobalIDAllocator) Initialize() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Initialize") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 96bfe59ba9..1f8a18de00 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -141,11 +141,10 @@ type Server struct { notifyIndexChan chan UniqueID factory dependency.Factory - session sessionutil.SessionInterface - icSession sessionutil.SessionInterface - dnEventCh <-chan *sessionutil.SessionEvent - // qcEventCh <-chan *sessionutil.SessionEvent - qnEventCh <-chan *sessionutil.SessionEvent + session sessionutil.SessionInterface + icSession sessionutil.SessionInterface + dnSessionWatcher sessionutil.SessionWatcher + qnSessionWatcher sessionutil.SessionWatcher enableActiveStandBy bool activateFunc func() error @@ -532,7 +531,7 @@ func (s *Server) initServiceDiscovery() error { } log.Info("DataCoord Cluster Manager start up successfully") - s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes) + s.dnSessionWatcher = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes) } s.indexEngineVersionManager = newIndexEngineVersionManager() @@ -542,7 +541,7 @@ func (s *Server) initServiceDiscovery() error { return err } s.rewatchQueryNodes(qnSessions) - s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes) + s.qnSessionWatcher = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes) return nil } @@ -799,7 +798,7 @@ func (s *Server) watchService(ctx context.Context) { case <-ctx.Done(): log.Info("watch service shutdown") return - case event, ok := <-s.dnEventCh: + case event, ok := <-s.dnSessionWatcher.EventChannel(): if !ok { s.stopServiceWatch() return @@ -812,7 +811,7 @@ func (s *Server) watchService(ctx context.Context) { }() return } - case event, ok := <-s.qnEventCh: + case event, ok := <-s.qnSessionWatcher.EventChannel(): if !ok { s.stopServiceWatch() return @@ -1054,6 +1053,14 @@ func (s *Server) Stop() error { s.analyzeInspector.Stop() log.Info("datacoord analyze inspector stopped") + if s.dnSessionWatcher != nil { + s.dnSessionWatcher.Stop() + } + + if s.qnSessionWatcher != nil { + s.qnSessionWatcher.Stop() + } + if s.session != nil { s.session.Stop() } diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index ceb65b258c..610f108c52 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -735,7 +735,12 @@ func TestService_WatchServices(t *testing.T) { svr.serverLoopWg.Add(1) ech := make(chan *sessionutil.SessionEvent) - svr.dnEventCh = ech + mockDnWatcher := sessionutil.NewMockSessionWatcher(t) + mockDnWatcher.EXPECT().EventChannel().Return(ech) + svr.dnSessionWatcher = mockDnWatcher + mockQnWatcher := sessionutil.NewMockSessionWatcher(t) + mockQnWatcher.EXPECT().EventChannel().Return(nil) + svr.qnSessionWatcher = mockQnWatcher flag := false closed := false @@ -762,7 +767,9 @@ func TestService_WatchServices(t *testing.T) { ech = make(chan *sessionutil.SessionEvent) flag = false - svr.dnEventCh = ech + mockDnWatcher = sessionutil.NewMockSessionWatcher(t) + mockDnWatcher.EXPECT().EventChannel().Return(ech) + svr.dnSessionWatcher = mockDnWatcher ctx, cancel := context.WithCancel(context.Background()) svr.serverLoopWg.Add(1) diff --git a/internal/distributed/connection_manager.go b/internal/distributed/connection_manager.go index 18c568c79a..0730b498e7 100644 --- a/internal/distributed/connection_manager.go +++ b/internal/distributed/connection_manager.go @@ -119,8 +119,8 @@ func (cm *ConnectionManager) AddDependency(roleName string) error { } } - eventChannel := cm.session.WatchServices(roleName, rev, nil) - go cm.processEvent(eventChannel) + watcher := cm.session.WatchServices(roleName, rev, nil) + go cm.processEvent(watcher.EventChannel()) return nil } diff --git a/internal/mocks/mock_grpc_client.go b/internal/mocks/mock_grpc_client.go index 383f8ff330..bcf4589f0e 100644 --- a/internal/mocks/mock_grpc_client.go +++ b/internal/mocks/mock_grpc_client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocks @@ -33,6 +33,10 @@ func (_m *MockGrpcClient[T]) EXPECT() *MockGrpcClient_Expecter[T] { func (_m *MockGrpcClient[T]) Call(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { ret := _m.Called(ctx, caller) + if len(ret) == 0 { + panic("no return value specified for Call") + } + var r0 interface{} var r1 error if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { @@ -84,10 +88,14 @@ func (_c *MockGrpcClient_Call_Call[T]) RunAndReturn(run func(context.Context, fu return _c } -// Close provides a mock function with given fields: +// Close provides a mock function with no fields func (_m *MockGrpcClient[T]) Close() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Close") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -125,7 +133,7 @@ func (_c *MockGrpcClient_Close_Call[T]) RunAndReturn(run func() error) *MockGrpc return _c } -// EnableEncryption provides a mock function with given fields: +// EnableEncryption provides a mock function with no fields func (_m *MockGrpcClient[T]) EnableEncryption() { _m.Called() } @@ -153,14 +161,18 @@ func (_c *MockGrpcClient_EnableEncryption_Call[T]) Return() *MockGrpcClient_Enab } func (_c *MockGrpcClient_EnableEncryption_Call[T]) RunAndReturn(run func()) *MockGrpcClient_EnableEncryption_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } -// GetNodeID provides a mock function with given fields: +// GetNodeID provides a mock function with no fields func (_m *MockGrpcClient[T]) GetNodeID() int64 { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetNodeID") + } + var r0 int64 if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() @@ -198,10 +210,14 @@ func (_c *MockGrpcClient_GetNodeID_Call[T]) RunAndReturn(run func() int64) *Mock return _c } -// GetRole provides a mock function with given fields: +// GetRole provides a mock function with no fields func (_m *MockGrpcClient[T]) GetRole() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetRole") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -243,6 +259,10 @@ func (_c *MockGrpcClient_GetRole_Call[T]) RunAndReturn(run func() string) *MockG func (_m *MockGrpcClient[T]) ReCall(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { ret := _m.Called(ctx, caller) + if len(ret) == 0 { + panic("no return value specified for ReCall") + } + var r0 interface{} var r1 error if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { @@ -323,7 +343,7 @@ func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) Return() *MockGrpcClient_SetGet } func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (string, error))) *MockGrpcClient_SetGetAddrFunc_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -356,7 +376,7 @@ func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Return() *MockGrpcClien } func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -389,7 +409,7 @@ func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) Return() *MockGrpcCli } func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -422,7 +442,7 @@ func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) Return() *MockGrpcClient_ } func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) RunAndReturn(run func(func(*grpc.ClientConn) T)) *MockGrpcClient_SetNewGrpcClientFunc_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -455,7 +475,7 @@ func (_c *MockGrpcClient_SetNodeID_Call[T]) Return() *MockGrpcClient_SetNodeID_C } func (_c *MockGrpcClient_SetNodeID_Call[T]) RunAndReturn(run func(int64)) *MockGrpcClient_SetNodeID_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -488,7 +508,7 @@ func (_c *MockGrpcClient_SetRole_Call[T]) Return() *MockGrpcClient_SetRole_Call[ } func (_c *MockGrpcClient_SetRole_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetRole_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -521,7 +541,7 @@ func (_c *MockGrpcClient_SetSession_Call[T]) Return() *MockGrpcClient_SetSession } func (_c *MockGrpcClient_SetSession_Call[T]) RunAndReturn(run func(*sessionutil.Session)) *MockGrpcClient_SetSession_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index f2abdddacd..51567ac5d4 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -79,6 +79,8 @@ type Server struct { tikvCli *txnkv.Client address string session sessionutil.SessionInterface + sessionWatcher sessionutil.SessionWatcher + sessionWatcherMu sync.Mutex kv kv.MetaKv idAllocator func() (int64, error) metricsCacheManager *metricsinfo.MetricsCacheManager @@ -593,12 +595,19 @@ func (s *Server) Stop() error { s.cluster.Stop() } + s.sessionWatcherMu.Lock() + if s.sessionWatcher != nil { + s.sessionWatcher.Stop() + } + s.sessionWatcherMu.Unlock() + + s.cancel() + s.wg.Wait() + if s.session != nil { s.session.Stop() } - s.cancel() - s.wg.Wait() log.Info("QueryCoord stop successfully") return nil } @@ -638,14 +647,16 @@ func (s *Server) watchNodes(revision int64) { log := log.Ctx(s.ctx) defer s.wg.Done() - eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes) + s.sessionWatcherMu.Lock() + s.sessionWatcher = s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes) + s.sessionWatcherMu.Unlock() for { select { case <-s.ctx.Done(): log.Info("stop watching nodes, QueryCoord stopped") return - case event, ok := <-eventChan: + case event, ok := <-s.sessionWatcher.EventChannel(): if !ok { // ErrCompacted is handled inside SessionWatcher log.Warn("Session Watcher channel closed", zap.Int64("serverID", paramtable.GetNodeID())) diff --git a/internal/tso/mocks/allocator.go b/internal/tso/mocks/allocator.go index 2d5d7308d9..f0a8a6244f 100644 --- a/internal/tso/mocks/allocator.go +++ b/internal/tso/mocks/allocator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocktso @@ -25,6 +25,10 @@ func (_m *Allocator) EXPECT() *Allocator_Expecter { func (_m *Allocator) GenerateTSO(count uint32) (uint64, error) { ret := _m.Called(count) + if len(ret) == 0 { + panic("no return value specified for GenerateTSO") + } + var r0 uint64 var r1 error if rf, ok := ret.Get(0).(func(uint32) (uint64, error)); ok { @@ -73,10 +77,14 @@ func (_c *Allocator_GenerateTSO_Call) RunAndReturn(run func(uint32) (uint64, err return _c } -// GetLastSavedTime provides a mock function with given fields: +// GetLastSavedTime provides a mock function with no fields func (_m *Allocator) GetLastSavedTime() time.Time { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetLastSavedTime") + } + var r0 time.Time if rf, ok := ret.Get(0).(func() time.Time); ok { r0 = rf() @@ -114,10 +122,14 @@ func (_c *Allocator_GetLastSavedTime_Call) RunAndReturn(run func() time.Time) *A return _c } -// Initialize provides a mock function with given fields: +// Initialize provides a mock function with no fields func (_m *Allocator) Initialize() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Initialize") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -155,7 +167,7 @@ func (_c *Allocator_Initialize_Call) RunAndReturn(run func() error) *Allocator_I return _c } -// Reset provides a mock function with given fields: +// Reset provides a mock function with no fields func (_m *Allocator) Reset() { _m.Called() } @@ -183,7 +195,7 @@ func (_c *Allocator_Reset_Call) Return() *Allocator_Reset_Call { } func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -191,6 +203,10 @@ func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { func (_m *Allocator) SetTSO(_a0 uint64) error { ret := _m.Called(_a0) + if len(ret) == 0 { + panic("no return value specified for SetTSO") + } + var r0 error if rf, ok := ret.Get(0).(func(uint64) error); ok { r0 = rf(_a0) @@ -229,10 +245,14 @@ func (_c *Allocator_SetTSO_Call) RunAndReturn(run func(uint64) error) *Allocator return _c } -// UpdateTSO provides a mock function with given fields: +// UpdateTSO provides a mock function with no fields func (_m *Allocator) UpdateTSO() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for UpdateTSO") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() diff --git a/internal/util/dependency/mock_factory.go b/internal/util/dependency/mock_factory.go index 444020a203..bcb43b5708 100644 --- a/internal/util/dependency/mock_factory.go +++ b/internal/util/dependency/mock_factory.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package dependency @@ -55,7 +55,7 @@ func (_c *MockFactory_Init_Call) Return() *MockFactory_Init_Call { } func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentParam)) *MockFactory_Init_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -63,6 +63,10 @@ func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentPara func (_m *MockFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewMsgStream") + } + var r0 msgstream.MsgStream var r1 error if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { @@ -117,6 +121,10 @@ func (_c *MockFactory_NewMsgStream_Call) RunAndReturn(run func(context.Context) func (_m *MockFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewMsgStreamDisposer") + } + var r0 func([]string, string) error if rf, ok := ret.Get(0).(func(context.Context) func([]string, string) error); ok { r0 = rf(ctx) @@ -161,6 +169,10 @@ func (_c *MockFactory_NewMsgStreamDisposer_Call) RunAndReturn(run func(context.C func (_m *MockFactory) NewPersistentStorageChunkManager(ctx context.Context) (storage.ChunkManager, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewPersistentStorageChunkManager") + } + var r0 storage.ChunkManager var r1 error if rf, ok := ret.Get(0).(func(context.Context) (storage.ChunkManager, error)); ok { @@ -215,6 +227,10 @@ func (_c *MockFactory_NewPersistentStorageChunkManager_Call) RunAndReturn(run fu func (_m *MockFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewTtMsgStream") + } + var r0 msgstream.MsgStream var r1 error if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { diff --git a/internal/util/proxyutil/mock_proxy_client_manager.go b/internal/util/proxyutil/mock_proxy_client_manager.go index e2244e6b76..1e729da365 100644 --- a/internal/util/proxyutil/mock_proxy_client_manager.go +++ b/internal/util/proxyutil/mock_proxy_client_manager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package proxyutil @@ -59,7 +59,7 @@ func (_c *MockProxyClientManager_AddProxyClient_Call) Return() *MockProxyClientM } func (_c *MockProxyClientManager_AddProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -92,7 +92,7 @@ func (_c *MockProxyClientManager_AddProxyClients_Call) Return() *MockProxyClient } func (_c *MockProxyClientManager_AddProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -125,7 +125,7 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) Return() *MockProxyClientM } func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -133,6 +133,10 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*ses func (_m *MockProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetComponentStates") + } + var r0 map[int64]*milvuspb.ComponentStates var r1 error if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*milvuspb.ComponentStates, error)); ok { @@ -183,10 +187,14 @@ func (_c *MockProxyClientManager_GetComponentStates_Call) RunAndReturn(run func( return _c } -// GetProxyClients provides a mock function with given fields: +// GetProxyClients provides a mock function with no fields func (_m *MockProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetProxyClients") + } + var r0 *typeutil.ConcurrentMap[int64, types.ProxyClient] if rf, ok := ret.Get(0).(func() *typeutil.ConcurrentMap[int64, types.ProxyClient]); ok { r0 = rf() @@ -226,10 +234,14 @@ func (_c *MockProxyClientManager_GetProxyClients_Call) RunAndReturn(run func() * return _c } -// GetProxyCount provides a mock function with given fields: +// GetProxyCount provides a mock function with no fields func (_m *MockProxyClientManager) GetProxyCount() int { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetProxyCount") + } + var r0 int if rf, ok := ret.Get(0).(func() int); ok { r0 = rf() @@ -271,6 +283,10 @@ func (_c *MockProxyClientManager_GetProxyCount_Call) RunAndReturn(run func() int func (_m *MockProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetProxyMetrics") + } + var r0 []*milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context) ([]*milvuspb.GetMetricsResponse, error)); ok { @@ -332,6 +348,10 @@ func (_m *MockProxyClientManager) InvalidateCollectionMetaCache(ctx context.Cont _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for InvalidateCollectionMetaCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error); ok { r0 = rf(ctx, request, opts...) @@ -383,6 +403,10 @@ func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) RunAndRetur func (_m *MockProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for InvalidateCredentialCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { r0 = rf(ctx, request) @@ -426,6 +450,10 @@ func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(ru func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for InvalidateShardLeaderCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok { r0 = rf(ctx, request) @@ -469,6 +497,10 @@ func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(r func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for RefreshPolicyInfoCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { r0 = rf(ctx, req) @@ -512,6 +544,10 @@ func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) RunAndReturn(run f func (_m *MockProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for SetRates") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) error); ok { r0 = rf(ctx, request) @@ -555,6 +591,10 @@ func (_c *MockProxyClientManager_SetRates_Call) RunAndReturn(run func(context.Co func (_m *MockProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for UpdateCredentialCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { r0 = rf(ctx, request) diff --git a/internal/util/proxyutil/mock_proxy_watcher.go b/internal/util/proxyutil/mock_proxy_watcher.go index 0aed0cb5b6..4f79d65a8c 100644 --- a/internal/util/proxyutil/mock_proxy_watcher.go +++ b/internal/util/proxyutil/mock_proxy_watcher.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package proxyutil @@ -64,7 +64,7 @@ func (_c *MockProxyWatcher_AddSessionFunc_Call) Return() *MockProxyWatcher_AddSe } func (_c *MockProxyWatcher_AddSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -110,11 +110,11 @@ func (_c *MockProxyWatcher_DelSessionFunc_Call) Return() *MockProxyWatcher_DelSe } func (_c *MockProxyWatcher_DelSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Stop provides a mock function with given fields: +// Stop provides a mock function with no fields func (_m *MockProxyWatcher) Stop() { _m.Called() } @@ -142,7 +142,7 @@ func (_c *MockProxyWatcher_Stop_Call) Return() *MockProxyWatcher_Stop_Call { } func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher_Stop_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -150,6 +150,10 @@ func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher func (_m *MockProxyWatcher) WatchProxy(ctx context.Context) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for WatchProxy") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(ctx) diff --git a/internal/util/sessionutil/mock_session.go b/internal/util/sessionutil/mock_session.go index a4d47020ec..d6bedd3402 100644 --- a/internal/util/sessionutil/mock_session.go +++ b/internal/util/sessionutil/mock_session.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package sessionutil @@ -24,10 +24,14 @@ func (_m *MockSession) EXPECT() *MockSession_Expecter { return &MockSession_Expecter{mock: &_m.Mock} } -// Disconnected provides a mock function with given fields: +// Disconnected provides a mock function with no fields func (_m *MockSession) Disconnected() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Disconnected") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -69,6 +73,10 @@ func (_c *MockSession_Disconnected_Call) RunAndReturn(run func() bool) *MockSess func (_m *MockSession) ForceActiveStandby(activateFunc func() error) error { ret := _m.Called(activateFunc) + if len(ret) == 0 { + panic("no return value specified for ForceActiveStandby") + } + var r0 error if rf, ok := ret.Get(0).(func(func() error) error); ok { r0 = rf(activateFunc) @@ -107,10 +115,14 @@ func (_c *MockSession_ForceActiveStandby_Call) RunAndReturn(run func(func() erro return _c } -// GetAddress provides a mock function with given fields: +// GetAddress provides a mock function with no fields func (_m *MockSession) GetAddress() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetAddress") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -148,10 +160,14 @@ func (_c *MockSession_GetAddress_Call) RunAndReturn(run func() string) *MockSess return _c } -// GetServerID provides a mock function with given fields: +// GetServerID provides a mock function with no fields func (_m *MockSession) GetServerID() int64 { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetServerID") + } + var r0 int64 if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() @@ -193,6 +209,10 @@ func (_c *MockSession_GetServerID_Call) RunAndReturn(run func() int64) *MockSess func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) { ret := _m.Called(prefix) + if len(ret) == 0 { + panic("no return value specified for GetSessions") + } + var r0 map[string]*Session var r1 int64 var r2 error @@ -254,6 +274,10 @@ func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(string) (map[strin func (_m *MockSession) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) { ret := _m.Called(prefix, r) + if len(ret) == 0 { + panic("no return value specified for GetSessionsWithVersionRange") + } + var r0 map[string]*Session var r1 int64 var r2 error @@ -312,10 +336,14 @@ func (_c *MockSession_GetSessionsWithVersionRange_Call) RunAndReturn(run func(st return _c } -// GoingStop provides a mock function with given fields: +// GoingStop provides a mock function with no fields func (_m *MockSession) GoingStop() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GoingStop") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -385,14 +413,18 @@ func (_c *MockSession_Init_Call) Return() *MockSession_Init_Call { } func (_c *MockSession_Init_Call) RunAndReturn(run func(string, string, bool, bool)) *MockSession_Init_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// IsTriggerKill provides a mock function with given fields: +// IsTriggerKill provides a mock function with no fields func (_m *MockSession) IsTriggerKill() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for IsTriggerKill") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -460,14 +492,18 @@ func (_c *MockSession_LivenessCheck_Call) Return() *MockSession_LivenessCheck_Ca } func (_c *MockSession_LivenessCheck_Call) RunAndReturn(run func(context.Context, func())) *MockSession_LivenessCheck_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// MarshalJSON provides a mock function with given fields: +// MarshalJSON provides a mock function with no fields func (_m *MockSession) MarshalJSON() ([]byte, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for MarshalJSON") + } + var r0 []byte var r1 error if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { @@ -521,6 +557,10 @@ func (_c *MockSession_MarshalJSON_Call) RunAndReturn(run func() ([]byte, error)) func (_m *MockSession) ProcessActiveStandBy(activateFunc func() error) error { ret := _m.Called(activateFunc) + if len(ret) == 0 { + panic("no return value specified for ProcessActiveStandBy") + } + var r0 error if rf, ok := ret.Get(0).(func(func() error) error); ok { r0 = rf(activateFunc) @@ -559,7 +599,7 @@ func (_c *MockSession_ProcessActiveStandBy_Call) RunAndReturn(run func(func() er return _c } -// Register provides a mock function with given fields: +// Register provides a mock function with no fields func (_m *MockSession) Register() { _m.Called() } @@ -587,14 +627,18 @@ func (_c *MockSession_Register_Call) Return() *MockSession_Register_Call { } func (_c *MockSession_Register_Call) RunAndReturn(run func()) *MockSession_Register_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Registered provides a mock function with given fields: +// Registered provides a mock function with no fields func (_m *MockSession) Registered() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Registered") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -661,7 +705,7 @@ func (_c *MockSession_Revoke_Call) Return() *MockSession_Revoke_Call { } func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -694,7 +738,7 @@ func (_c *MockSession_SetDisconnected_Call) Return() *MockSession_SetDisconnecte } func (_c *MockSession_SetDisconnected_Call) RunAndReturn(run func(bool)) *MockSession_SetDisconnected_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -727,11 +771,11 @@ func (_c *MockSession_SetEnableActiveStandBy_Call) Return() *MockSession_SetEnab } func (_c *MockSession_SetEnableActiveStandBy_Call) RunAndReturn(run func(bool)) *MockSession_SetEnableActiveStandBy_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Stop provides a mock function with given fields: +// Stop provides a mock function with no fields func (_m *MockSession) Stop() { _m.Called() } @@ -759,14 +803,18 @@ func (_c *MockSession_Stop_Call) Return() *MockSession_Stop_Call { } func (_c *MockSession_Stop_Call) RunAndReturn(run func()) *MockSession_Stop_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// String provides a mock function with given fields: +// String provides a mock function with no fields func (_m *MockSession) String() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for String") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -808,6 +856,10 @@ func (_c *MockSession_String_Call) RunAndReturn(run func() string) *MockSession_ func (_m *MockSession) UnmarshalJSON(data []byte) error { ret := _m.Called(data) + if len(ret) == 0 { + panic("no return value specified for UnmarshalJSON") + } + var r0 error if rf, ok := ret.Get(0).(func([]byte) error); ok { r0 = rf(data) @@ -875,20 +927,24 @@ func (_c *MockSession_UpdateRegistered_Call) Return() *MockSession_UpdateRegiste } func (_c *MockSession_UpdateRegistered_Call) RunAndReturn(run func(bool)) *MockSession_UpdateRegistered_Call { - _c.Call.Return(run) + _c.Run(run) return _c } // WatchServices provides a mock function with given fields: prefix, revision, rewatch -func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) <-chan *SessionEvent { +func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) SessionWatcher { ret := _m.Called(prefix, revision, rewatch) - var r0 <-chan *SessionEvent - if rf, ok := ret.Get(0).(func(string, int64, Rewatch) <-chan *SessionEvent); ok { + if len(ret) == 0 { + panic("no return value specified for WatchServices") + } + + var r0 SessionWatcher + if rf, ok := ret.Get(0).(func(string, int64, Rewatch) SessionWatcher); ok { r0 = rf(prefix, revision, rewatch) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan *SessionEvent) + r0 = ret.Get(0).(SessionWatcher) } } @@ -915,26 +971,30 @@ func (_c *MockSession_WatchServices_Call) Run(run func(prefix string, revision i return _c } -func (_c *MockSession_WatchServices_Call) Return(eventChannel <-chan *SessionEvent) *MockSession_WatchServices_Call { - _c.Call.Return(eventChannel) +func (_c *MockSession_WatchServices_Call) Return(watcher SessionWatcher) *MockSession_WatchServices_Call { + _c.Call.Return(watcher) return _c } -func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, Rewatch) <-chan *SessionEvent) *MockSession_WatchServices_Call { +func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, Rewatch) SessionWatcher) *MockSession_WatchServices_Call { _c.Call.Return(run) return _c } // WatchServicesWithVersionRange provides a mock function with given fields: prefix, r, revision, rewatch -func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) <-chan *SessionEvent { +func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) SessionWatcher { ret := _m.Called(prefix, r, revision, rewatch) - var r0 <-chan *SessionEvent - if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent); ok { + if len(ret) == 0 { + panic("no return value specified for WatchServicesWithVersionRange") + } + + var r0 SessionWatcher + if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) SessionWatcher); ok { r0 = rf(prefix, r, revision, rewatch) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan *SessionEvent) + r0 = ret.Get(0).(SessionWatcher) } } @@ -962,12 +1022,12 @@ func (_c *MockSession_WatchServicesWithVersionRange_Call) Run(run func(prefix st return _c } -func (_c *MockSession_WatchServicesWithVersionRange_Call) Return(eventChannel <-chan *SessionEvent) *MockSession_WatchServicesWithVersionRange_Call { - _c.Call.Return(eventChannel) +func (_c *MockSession_WatchServicesWithVersionRange_Call) Return(watcher SessionWatcher) *MockSession_WatchServicesWithVersionRange_Call { + _c.Call.Return(watcher) return _c } -func (_c *MockSession_WatchServicesWithVersionRange_Call) RunAndReturn(run func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent) *MockSession_WatchServicesWithVersionRange_Call { +func (_c *MockSession_WatchServicesWithVersionRange_Call) RunAndReturn(run func(string, semver.Range, int64, Rewatch) SessionWatcher) *MockSession_WatchServicesWithVersionRange_Call { _c.Call.Return(run) return _c } diff --git a/internal/util/sessionutil/mock_session_watcher.go b/internal/util/sessionutil/mock_session_watcher.go new file mode 100644 index 0000000000..668a7bf4a4 --- /dev/null +++ b/internal/util/sessionutil/mock_session_watcher.go @@ -0,0 +1,111 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package sessionutil + +import mock "github.com/stretchr/testify/mock" + +// MockSessionWatcher is an autogenerated mock type for the SessionWatcher type +type MockSessionWatcher struct { + mock.Mock +} + +type MockSessionWatcher_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSessionWatcher) EXPECT() *MockSessionWatcher_Expecter { + return &MockSessionWatcher_Expecter{mock: &_m.Mock} +} + +// EventChannel provides a mock function with no fields +func (_m *MockSessionWatcher) EventChannel() <-chan *SessionEvent { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EventChannel") + } + + var r0 <-chan *SessionEvent + if rf, ok := ret.Get(0).(func() <-chan *SessionEvent); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan *SessionEvent) + } + } + + return r0 +} + +// MockSessionWatcher_EventChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EventChannel' +type MockSessionWatcher_EventChannel_Call struct { + *mock.Call +} + +// EventChannel is a helper method to define mock.On call +func (_e *MockSessionWatcher_Expecter) EventChannel() *MockSessionWatcher_EventChannel_Call { + return &MockSessionWatcher_EventChannel_Call{Call: _e.mock.On("EventChannel")} +} + +func (_c *MockSessionWatcher_EventChannel_Call) Run(run func()) *MockSessionWatcher_EventChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSessionWatcher_EventChannel_Call) Return(_a0 <-chan *SessionEvent) *MockSessionWatcher_EventChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSessionWatcher_EventChannel_Call) RunAndReturn(run func() <-chan *SessionEvent) *MockSessionWatcher_EventChannel_Call { + _c.Call.Return(run) + return _c +} + +// Stop provides a mock function with no fields +func (_m *MockSessionWatcher) Stop() { + _m.Called() +} + +// MockSessionWatcher_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockSessionWatcher_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *MockSessionWatcher_Expecter) Stop() *MockSessionWatcher_Stop_Call { + return &MockSessionWatcher_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *MockSessionWatcher_Stop_Call) Run(run func()) *MockSessionWatcher_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSessionWatcher_Stop_Call) Return() *MockSessionWatcher_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSessionWatcher_Stop_Call) RunAndReturn(run func()) *MockSessionWatcher_Stop_Call { + _c.Run(run) + return _c +} + +// NewMockSessionWatcher creates a new instance of MockSessionWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSessionWatcher(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSessionWatcher { + mock := &MockSessionWatcher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/sessionutil/session.go b/internal/util/sessionutil/session.go index eae304d583..83908a2cd2 100644 --- a/internal/util/sessionutil/session.go +++ b/internal/util/sessionutil/session.go @@ -34,8 +34,8 @@ type SessionInterface interface { GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) GoingStop() error - WatchServices(prefix string, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) - WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) + WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher) + WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher) LivenessCheck(ctx context.Context, callback func()) Stop() Revoke(timeout time.Duration) @@ -51,3 +51,8 @@ type SessionInterface interface { GetServerID() int64 IsTriggerKill() bool } + +type SessionWatcher interface { + EventChannel() <-chan *SessionEvent + Stop() +} diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index f1b3cef7ca..d6c5a42705 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -765,11 +765,13 @@ type SessionEvent struct { type sessionWatcher struct { s *Session + cancel context.CancelFunc rch clientv3.WatchChan eventCh chan *SessionEvent prefix string rewatch Rewatch validate func(*Session) bool + wg sync.WaitGroup closeOnce sync.Once } @@ -779,15 +781,17 @@ func (w *sessionWatcher) closeEventCh() { }) } -func (w *sessionWatcher) start() { +func (w *sessionWatcher) start(ctx context.Context) { + w.wg.Add(1) go func() { - defer w.closeEventCh() + defer w.wg.Done() for { select { - case <-w.s.ctx.Done(): + case <-ctx.Done(): return case wresp, ok := <-w.rch: if !ok { + w.closeEventCh() log.Warn("session watch channel closed") return } @@ -797,6 +801,11 @@ func (w *sessionWatcher) start() { }() } +func (w *sessionWatcher) Stop() { + w.cancel() + w.wg.Wait() +} + // WatchServices watches the service's up and down in etcd, and sends event to // eventChannel. // prefix is a parameter to know which service to watch and can be obtained in @@ -805,17 +814,19 @@ func (w *sessionWatcher) start() { // in GetSessions. // If a server up, an event will be add to channel with eventType SessionAddType. // If a server down, an event will be add to channel with eventType SessionDelType. -func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) { +func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher) { + ctx, cancel := context.WithCancel(s.ctx) w := &sessionWatcher{ s: s, + cancel: cancel, eventCh: make(chan *SessionEvent, 100), rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)), prefix: prefix, rewatch: rewatch, validate: func(s *Session) bool { return true }, } - w.start() - return w.eventCh + w.start(ctx) + return w } // WatchServicesWithVersionRange watches the service's up and down in etcd, and sends event to event Channel. @@ -824,17 +835,19 @@ func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) // revision is a etcd reversion to prevent missing key events and can be obtained in GetSessions. // If a server up, an event will be add to channel with eventType SessionAddType. // If a server down, an event will be add to channel with eventType SessionDelType. -func (s *Session) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) { +func (s *Session) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher) { + ctx, cancel := context.WithCancel(s.ctx) w := &sessionWatcher{ s: s, + cancel: cancel, eventCh: make(chan *SessionEvent, 100), rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)), prefix: prefix, rewatch: rewatch, validate: func(s *Session) bool { return r(s.Version) }, } - w.start() - return w.eventCh + w.start(ctx) + return w } func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) { @@ -919,6 +932,10 @@ func (w *sessionWatcher) handleWatchErr(err error) error { return nil } +func (w *sessionWatcher) EventChannel() <-chan *SessionEvent { + return w.eventCh +} + // LivenessCheck performs liveness check with provided context and channel // ctx controls the liveness check loop // ch is the liveness signal channel, ch is closed only when the session is expired diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index 95700d5f66..01436e1db9 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -158,7 +158,7 @@ func TestUpdateSessions(t *testing.T) { sessions, rev, err := s.GetSessions("test") assert.NoError(t, err) assert.Equal(t, len(sessions), 0) - eventCh := s.WatchServices("test", rev, nil) + watcher := s.WatchServices("test", rev, nil) sList := []*Session{} @@ -203,7 +203,7 @@ LOOP: select { case <-ch: t.FailNow() - case sessionEvent := <-eventCh: + case sessionEvent := <-watcher.EventChannel(): if sessionEvent.EventType == SessionAddEvent { addEventLen++ @@ -616,7 +616,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() { _, rev, err := s.GetSessionsWithVersionRange(suite.serverName, r) suite.Require().NoError(err) - ch := s.WatchServicesWithVersionRange(suite.serverName, r, rev, nil) + watcher := s.WatchServicesWithVersionRange(suite.serverName, r, rev, nil) // remove all sessions go func() { @@ -626,7 +626,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() { }() select { - case evt := <-ch: + case evt := <-watcher.EventChannel(): suite.Equal(suite.sessions[1].ServerID, evt.Session.ServerID) case <-time.After(time.Second): suite.Fail("no event received, failing")