fix: resolve SessionWatcher goroutine leak and unstable UT in querycoordv2 (#45627)

Related to #44620
Related to unstable ut "internal/querycoordv2 TestServer/TestNodeUp"

Introduce SessionWatcher interface to fix race condition and goroutine
leak that caused unstable unit test TestServer/TestNodeUp.

Changes:
- Add SessionWatcher interface with EventChannel() and Stop() methods
- Refactor WatchServices() to return SessionWatcher instead of raw
channel
- Fix cleanup order in QueryCoordV2: stop watcher before session
- Update DataCoord, ConnectionManager to use SessionWatcher
- Add MockSessionWatcher for testing

Fixes race condition between session context cancellation and internal
loop exit. Eliminates goroutine leak by providing explicit lifecycle
management.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2025-11-21 18:33:06 +08:00 committed by GitHub
parent 937fd99354
commit f51fcc09ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 433 additions and 102 deletions

View File

@ -533,6 +533,7 @@ generate-mockery-utils: getdeps
# tso.Allocator # tso.Allocator
$(INSTALL_PATH)/mockery --name=Allocator --dir=internal/tso --output=internal/tso/mocks --filename=allocator.go --with-expecter --structname=Allocator --outpkg=mocktso $(INSTALL_PATH)/mockery --name=Allocator --dir=internal/tso --output=internal/tso/mocks --filename=allocator.go --with-expecter --structname=Allocator --outpkg=mocktso
$(INSTALL_PATH)/mockery --name=SessionInterface --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session.go --with-expecter --structname=MockSession --inpackage $(INSTALL_PATH)/mockery --name=SessionInterface --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session.go --with-expecter --structname=MockSession --inpackage
$(INSTALL_PATH)/mockery --name=SessionWatcher --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session_watcher.go --with-expecter --structname=MockSessionWatcher --inpackage
$(INSTALL_PATH)/mockery --name=GrpcClient --dir=$(PWD)/internal/util/grpcclient --output=$(PWD)/internal/mocks --filename=mock_grpc_client.go --with-expecter --structname=MockGrpcClient $(INSTALL_PATH)/mockery --name=GrpcClient --dir=$(PWD)/internal/util/grpcclient --output=$(PWD)/internal/mocks --filename=mock_grpc_client.go --with-expecter --structname=MockGrpcClient
# proxy_client_manager.go # proxy_client_manager.go
$(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage $(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage

View File

@ -49,12 +49,12 @@ func (r *Runner) watchByPrefix(prefix string) {
_, revision, err := r.session.GetSessions(prefix) _, revision, err := r.session.GetSessions(prefix)
fn := func() { r.Stop() } fn := func() { r.Stop() }
console.AbnormalExitIf(err, r.backupFinished.Load(), console.AddCallbacks(fn)) console.AbnormalExitIf(err, r.backupFinished.Load(), console.AddCallbacks(fn))
eventCh := r.session.WatchServices(prefix, revision, nil) watcher := r.session.WatchServices(prefix, revision, nil)
for { for {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
return return
case event := <-eventCh: case event := <-watcher.EventChannel():
msg := fmt.Sprintf("session up/down, exit migration, event type: %s, session: %s", event.EventType.String(), event.Session.String()) msg := fmt.Sprintf("session up/down, exit migration, event type: %s, session: %s", event.EventType.String(), event.Session.String())
console.AbnormalExit(r.backupFinished.Load(), msg, console.AddCallbacks(fn)) console.AbnormalExit(r.backupFinished.Load(), msg, console.AddCallbacks(fn))
} }

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package allocator package allocator
@ -21,6 +21,10 @@ func (_m *MockGlobalIDAllocator) EXPECT() *MockGlobalIDAllocator_Expecter {
func (_m *MockGlobalIDAllocator) Alloc(count uint32) (int64, int64, error) { func (_m *MockGlobalIDAllocator) Alloc(count uint32) (int64, int64, error) {
ret := _m.Called(count) ret := _m.Called(count)
if len(ret) == 0 {
panic("no return value specified for Alloc")
}
var r0 int64 var r0 int64
var r1 int64 var r1 int64
var r2 error var r2 error
@ -76,10 +80,14 @@ func (_c *MockGlobalIDAllocator_Alloc_Call) RunAndReturn(run func(uint32) (int64
return _c return _c
} }
// AllocOne provides a mock function with given fields: // AllocOne provides a mock function with no fields
func (_m *MockGlobalIDAllocator) AllocOne() (int64, error) { func (_m *MockGlobalIDAllocator) AllocOne() (int64, error) {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for AllocOne")
}
var r0 int64 var r0 int64
var r1 error var r1 error
if rf, ok := ret.Get(0).(func() (int64, error)); ok { if rf, ok := ret.Get(0).(func() (int64, error)); ok {
@ -127,10 +135,14 @@ func (_c *MockGlobalIDAllocator_AllocOne_Call) RunAndReturn(run func() (int64, e
return _c return _c
} }
// Initialize provides a mock function with given fields: // Initialize provides a mock function with no fields
func (_m *MockGlobalIDAllocator) Initialize() error { func (_m *MockGlobalIDAllocator) Initialize() error {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Initialize")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func() error); ok { if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf() r0 = rf()

View File

@ -141,11 +141,10 @@ type Server struct {
notifyIndexChan chan UniqueID notifyIndexChan chan UniqueID
factory dependency.Factory factory dependency.Factory
session sessionutil.SessionInterface session sessionutil.SessionInterface
icSession sessionutil.SessionInterface icSession sessionutil.SessionInterface
dnEventCh <-chan *sessionutil.SessionEvent dnSessionWatcher sessionutil.SessionWatcher
// qcEventCh <-chan *sessionutil.SessionEvent qnSessionWatcher sessionutil.SessionWatcher
qnEventCh <-chan *sessionutil.SessionEvent
enableActiveStandBy bool enableActiveStandBy bool
activateFunc func() error activateFunc func() error
@ -532,7 +531,7 @@ func (s *Server) initServiceDiscovery() error {
} }
log.Info("DataCoord Cluster Manager start up successfully") log.Info("DataCoord Cluster Manager start up successfully")
s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes) s.dnSessionWatcher = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes)
} }
s.indexEngineVersionManager = newIndexEngineVersionManager() s.indexEngineVersionManager = newIndexEngineVersionManager()
@ -542,7 +541,7 @@ func (s *Server) initServiceDiscovery() error {
return err return err
} }
s.rewatchQueryNodes(qnSessions) s.rewatchQueryNodes(qnSessions)
s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes) s.qnSessionWatcher = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes)
return nil return nil
} }
@ -799,7 +798,7 @@ func (s *Server) watchService(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
log.Info("watch service shutdown") log.Info("watch service shutdown")
return return
case event, ok := <-s.dnEventCh: case event, ok := <-s.dnSessionWatcher.EventChannel():
if !ok { if !ok {
s.stopServiceWatch() s.stopServiceWatch()
return return
@ -812,7 +811,7 @@ func (s *Server) watchService(ctx context.Context) {
}() }()
return return
} }
case event, ok := <-s.qnEventCh: case event, ok := <-s.qnSessionWatcher.EventChannel():
if !ok { if !ok {
s.stopServiceWatch() s.stopServiceWatch()
return return
@ -1054,6 +1053,14 @@ func (s *Server) Stop() error {
s.analyzeInspector.Stop() s.analyzeInspector.Stop()
log.Info("datacoord analyze inspector stopped") log.Info("datacoord analyze inspector stopped")
if s.dnSessionWatcher != nil {
s.dnSessionWatcher.Stop()
}
if s.qnSessionWatcher != nil {
s.qnSessionWatcher.Stop()
}
if s.session != nil { if s.session != nil {
s.session.Stop() s.session.Stop()
} }

View File

@ -735,7 +735,12 @@ func TestService_WatchServices(t *testing.T) {
svr.serverLoopWg.Add(1) svr.serverLoopWg.Add(1)
ech := make(chan *sessionutil.SessionEvent) ech := make(chan *sessionutil.SessionEvent)
svr.dnEventCh = ech mockDnWatcher := sessionutil.NewMockSessionWatcher(t)
mockDnWatcher.EXPECT().EventChannel().Return(ech)
svr.dnSessionWatcher = mockDnWatcher
mockQnWatcher := sessionutil.NewMockSessionWatcher(t)
mockQnWatcher.EXPECT().EventChannel().Return(nil)
svr.qnSessionWatcher = mockQnWatcher
flag := false flag := false
closed := false closed := false
@ -762,7 +767,9 @@ func TestService_WatchServices(t *testing.T) {
ech = make(chan *sessionutil.SessionEvent) ech = make(chan *sessionutil.SessionEvent)
flag = false flag = false
svr.dnEventCh = ech mockDnWatcher = sessionutil.NewMockSessionWatcher(t)
mockDnWatcher.EXPECT().EventChannel().Return(ech)
svr.dnSessionWatcher = mockDnWatcher
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
svr.serverLoopWg.Add(1) svr.serverLoopWg.Add(1)

View File

@ -119,8 +119,8 @@ func (cm *ConnectionManager) AddDependency(roleName string) error {
} }
} }
eventChannel := cm.session.WatchServices(roleName, rev, nil) watcher := cm.session.WatchServices(roleName, rev, nil)
go cm.processEvent(eventChannel) go cm.processEvent(watcher.EventChannel())
return nil return nil
} }

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package mocks package mocks
@ -33,6 +33,10 @@ func (_m *MockGrpcClient[T]) EXPECT() *MockGrpcClient_Expecter[T] {
func (_m *MockGrpcClient[T]) Call(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { func (_m *MockGrpcClient[T]) Call(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) {
ret := _m.Called(ctx, caller) ret := _m.Called(ctx, caller)
if len(ret) == 0 {
panic("no return value specified for Call")
}
var r0 interface{} var r0 interface{}
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok {
@ -84,10 +88,14 @@ func (_c *MockGrpcClient_Call_Call[T]) RunAndReturn(run func(context.Context, fu
return _c return _c
} }
// Close provides a mock function with given fields: // Close provides a mock function with no fields
func (_m *MockGrpcClient[T]) Close() error { func (_m *MockGrpcClient[T]) Close() error {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func() error); ok { if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf() r0 = rf()
@ -125,7 +133,7 @@ func (_c *MockGrpcClient_Close_Call[T]) RunAndReturn(run func() error) *MockGrpc
return _c return _c
} }
// EnableEncryption provides a mock function with given fields: // EnableEncryption provides a mock function with no fields
func (_m *MockGrpcClient[T]) EnableEncryption() { func (_m *MockGrpcClient[T]) EnableEncryption() {
_m.Called() _m.Called()
} }
@ -153,14 +161,18 @@ func (_c *MockGrpcClient_EnableEncryption_Call[T]) Return() *MockGrpcClient_Enab
} }
func (_c *MockGrpcClient_EnableEncryption_Call[T]) RunAndReturn(run func()) *MockGrpcClient_EnableEncryption_Call[T] { func (_c *MockGrpcClient_EnableEncryption_Call[T]) RunAndReturn(run func()) *MockGrpcClient_EnableEncryption_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// GetNodeID provides a mock function with given fields: // GetNodeID provides a mock function with no fields
func (_m *MockGrpcClient[T]) GetNodeID() int64 { func (_m *MockGrpcClient[T]) GetNodeID() int64 {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetNodeID")
}
var r0 int64 var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok { if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf() r0 = rf()
@ -198,10 +210,14 @@ func (_c *MockGrpcClient_GetNodeID_Call[T]) RunAndReturn(run func() int64) *Mock
return _c return _c
} }
// GetRole provides a mock function with given fields: // GetRole provides a mock function with no fields
func (_m *MockGrpcClient[T]) GetRole() string { func (_m *MockGrpcClient[T]) GetRole() string {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetRole")
}
var r0 string var r0 string
if rf, ok := ret.Get(0).(func() string); ok { if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf() r0 = rf()
@ -243,6 +259,10 @@ func (_c *MockGrpcClient_GetRole_Call[T]) RunAndReturn(run func() string) *MockG
func (_m *MockGrpcClient[T]) ReCall(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { func (_m *MockGrpcClient[T]) ReCall(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) {
ret := _m.Called(ctx, caller) ret := _m.Called(ctx, caller)
if len(ret) == 0 {
panic("no return value specified for ReCall")
}
var r0 interface{} var r0 interface{}
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok {
@ -323,7 +343,7 @@ func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) Return() *MockGrpcClient_SetGet
} }
func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (string, error))) *MockGrpcClient_SetGetAddrFunc_Call[T] { func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (string, error))) *MockGrpcClient_SetGetAddrFunc_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -356,7 +376,7 @@ func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Return() *MockGrpcClien
} }
func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -389,7 +409,7 @@ func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) Return() *MockGrpcCli
} }
func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] { func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -422,7 +442,7 @@ func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) Return() *MockGrpcClient_
} }
func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) RunAndReturn(run func(func(*grpc.ClientConn) T)) *MockGrpcClient_SetNewGrpcClientFunc_Call[T] { func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) RunAndReturn(run func(func(*grpc.ClientConn) T)) *MockGrpcClient_SetNewGrpcClientFunc_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -455,7 +475,7 @@ func (_c *MockGrpcClient_SetNodeID_Call[T]) Return() *MockGrpcClient_SetNodeID_C
} }
func (_c *MockGrpcClient_SetNodeID_Call[T]) RunAndReturn(run func(int64)) *MockGrpcClient_SetNodeID_Call[T] { func (_c *MockGrpcClient_SetNodeID_Call[T]) RunAndReturn(run func(int64)) *MockGrpcClient_SetNodeID_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -488,7 +508,7 @@ func (_c *MockGrpcClient_SetRole_Call[T]) Return() *MockGrpcClient_SetRole_Call[
} }
func (_c *MockGrpcClient_SetRole_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetRole_Call[T] { func (_c *MockGrpcClient_SetRole_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetRole_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -521,7 +541,7 @@ func (_c *MockGrpcClient_SetSession_Call[T]) Return() *MockGrpcClient_SetSession
} }
func (_c *MockGrpcClient_SetSession_Call[T]) RunAndReturn(run func(*sessionutil.Session)) *MockGrpcClient_SetSession_Call[T] { func (_c *MockGrpcClient_SetSession_Call[T]) RunAndReturn(run func(*sessionutil.Session)) *MockGrpcClient_SetSession_Call[T] {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }

View File

@ -79,6 +79,8 @@ type Server struct {
tikvCli *txnkv.Client tikvCli *txnkv.Client
address string address string
session sessionutil.SessionInterface session sessionutil.SessionInterface
sessionWatcher sessionutil.SessionWatcher
sessionWatcherMu sync.Mutex
kv kv.MetaKv kv kv.MetaKv
idAllocator func() (int64, error) idAllocator func() (int64, error)
metricsCacheManager *metricsinfo.MetricsCacheManager metricsCacheManager *metricsinfo.MetricsCacheManager
@ -593,12 +595,19 @@ func (s *Server) Stop() error {
s.cluster.Stop() s.cluster.Stop()
} }
s.sessionWatcherMu.Lock()
if s.sessionWatcher != nil {
s.sessionWatcher.Stop()
}
s.sessionWatcherMu.Unlock()
s.cancel()
s.wg.Wait()
if s.session != nil { if s.session != nil {
s.session.Stop() s.session.Stop()
} }
s.cancel()
s.wg.Wait()
log.Info("QueryCoord stop successfully") log.Info("QueryCoord stop successfully")
return nil return nil
} }
@ -638,14 +647,16 @@ func (s *Server) watchNodes(revision int64) {
log := log.Ctx(s.ctx) log := log.Ctx(s.ctx)
defer s.wg.Done() defer s.wg.Done()
eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes) s.sessionWatcherMu.Lock()
s.sessionWatcher = s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes)
s.sessionWatcherMu.Unlock()
for { for {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
log.Info("stop watching nodes, QueryCoord stopped") log.Info("stop watching nodes, QueryCoord stopped")
return return
case event, ok := <-eventChan: case event, ok := <-s.sessionWatcher.EventChannel():
if !ok { if !ok {
// ErrCompacted is handled inside SessionWatcher // ErrCompacted is handled inside SessionWatcher
log.Warn("Session Watcher channel closed", zap.Int64("serverID", paramtable.GetNodeID())) log.Warn("Session Watcher channel closed", zap.Int64("serverID", paramtable.GetNodeID()))

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package mocktso package mocktso
@ -25,6 +25,10 @@ func (_m *Allocator) EXPECT() *Allocator_Expecter {
func (_m *Allocator) GenerateTSO(count uint32) (uint64, error) { func (_m *Allocator) GenerateTSO(count uint32) (uint64, error) {
ret := _m.Called(count) ret := _m.Called(count)
if len(ret) == 0 {
panic("no return value specified for GenerateTSO")
}
var r0 uint64 var r0 uint64
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(uint32) (uint64, error)); ok { if rf, ok := ret.Get(0).(func(uint32) (uint64, error)); ok {
@ -73,10 +77,14 @@ func (_c *Allocator_GenerateTSO_Call) RunAndReturn(run func(uint32) (uint64, err
return _c return _c
} }
// GetLastSavedTime provides a mock function with given fields: // GetLastSavedTime provides a mock function with no fields
func (_m *Allocator) GetLastSavedTime() time.Time { func (_m *Allocator) GetLastSavedTime() time.Time {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetLastSavedTime")
}
var r0 time.Time var r0 time.Time
if rf, ok := ret.Get(0).(func() time.Time); ok { if rf, ok := ret.Get(0).(func() time.Time); ok {
r0 = rf() r0 = rf()
@ -114,10 +122,14 @@ func (_c *Allocator_GetLastSavedTime_Call) RunAndReturn(run func() time.Time) *A
return _c return _c
} }
// Initialize provides a mock function with given fields: // Initialize provides a mock function with no fields
func (_m *Allocator) Initialize() error { func (_m *Allocator) Initialize() error {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Initialize")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func() error); ok { if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf() r0 = rf()
@ -155,7 +167,7 @@ func (_c *Allocator_Initialize_Call) RunAndReturn(run func() error) *Allocator_I
return _c return _c
} }
// Reset provides a mock function with given fields: // Reset provides a mock function with no fields
func (_m *Allocator) Reset() { func (_m *Allocator) Reset() {
_m.Called() _m.Called()
} }
@ -183,7 +195,7 @@ func (_c *Allocator_Reset_Call) Return() *Allocator_Reset_Call {
} }
func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -191,6 +203,10 @@ func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call {
func (_m *Allocator) SetTSO(_a0 uint64) error { func (_m *Allocator) SetTSO(_a0 uint64) error {
ret := _m.Called(_a0) ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for SetTSO")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(uint64) error); ok { if rf, ok := ret.Get(0).(func(uint64) error); ok {
r0 = rf(_a0) r0 = rf(_a0)
@ -229,10 +245,14 @@ func (_c *Allocator_SetTSO_Call) RunAndReturn(run func(uint64) error) *Allocator
return _c return _c
} }
// UpdateTSO provides a mock function with given fields: // UpdateTSO provides a mock function with no fields
func (_m *Allocator) UpdateTSO() error { func (_m *Allocator) UpdateTSO() error {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for UpdateTSO")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func() error); ok { if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf() r0 = rf()

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package dependency package dependency
@ -55,7 +55,7 @@ func (_c *MockFactory_Init_Call) Return() *MockFactory_Init_Call {
} }
func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentParam)) *MockFactory_Init_Call { func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentParam)) *MockFactory_Init_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -63,6 +63,10 @@ func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentPara
func (_m *MockFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { func (_m *MockFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for NewMsgStream")
}
var r0 msgstream.MsgStream var r0 msgstream.MsgStream
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok {
@ -117,6 +121,10 @@ func (_c *MockFactory_NewMsgStream_Call) RunAndReturn(run func(context.Context)
func (_m *MockFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error { func (_m *MockFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for NewMsgStreamDisposer")
}
var r0 func([]string, string) error var r0 func([]string, string) error
if rf, ok := ret.Get(0).(func(context.Context) func([]string, string) error); ok { if rf, ok := ret.Get(0).(func(context.Context) func([]string, string) error); ok {
r0 = rf(ctx) r0 = rf(ctx)
@ -161,6 +169,10 @@ func (_c *MockFactory_NewMsgStreamDisposer_Call) RunAndReturn(run func(context.C
func (_m *MockFactory) NewPersistentStorageChunkManager(ctx context.Context) (storage.ChunkManager, error) { func (_m *MockFactory) NewPersistentStorageChunkManager(ctx context.Context) (storage.ChunkManager, error) {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for NewPersistentStorageChunkManager")
}
var r0 storage.ChunkManager var r0 storage.ChunkManager
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (storage.ChunkManager, error)); ok { if rf, ok := ret.Get(0).(func(context.Context) (storage.ChunkManager, error)); ok {
@ -215,6 +227,10 @@ func (_c *MockFactory_NewPersistentStorageChunkManager_Call) RunAndReturn(run fu
func (_m *MockFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) { func (_m *MockFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for NewTtMsgStream")
}
var r0 msgstream.MsgStream var r0 msgstream.MsgStream
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok {

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package proxyutil package proxyutil
@ -59,7 +59,7 @@ func (_c *MockProxyClientManager_AddProxyClient_Call) Return() *MockProxyClientM
} }
func (_c *MockProxyClientManager_AddProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call { func (_c *MockProxyClientManager_AddProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -92,7 +92,7 @@ func (_c *MockProxyClientManager_AddProxyClients_Call) Return() *MockProxyClient
} }
func (_c *MockProxyClientManager_AddProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { func (_c *MockProxyClientManager_AddProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -125,7 +125,7 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) Return() *MockProxyClientM
} }
func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call { func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -133,6 +133,10 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*ses
func (_m *MockProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) { func (_m *MockProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetComponentStates")
}
var r0 map[int64]*milvuspb.ComponentStates var r0 map[int64]*milvuspb.ComponentStates
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*milvuspb.ComponentStates, error)); ok { if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*milvuspb.ComponentStates, error)); ok {
@ -183,10 +187,14 @@ func (_c *MockProxyClientManager_GetComponentStates_Call) RunAndReturn(run func(
return _c return _c
} }
// GetProxyClients provides a mock function with given fields: // GetProxyClients provides a mock function with no fields
func (_m *MockProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { func (_m *MockProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetProxyClients")
}
var r0 *typeutil.ConcurrentMap[int64, types.ProxyClient] var r0 *typeutil.ConcurrentMap[int64, types.ProxyClient]
if rf, ok := ret.Get(0).(func() *typeutil.ConcurrentMap[int64, types.ProxyClient]); ok { if rf, ok := ret.Get(0).(func() *typeutil.ConcurrentMap[int64, types.ProxyClient]); ok {
r0 = rf() r0 = rf()
@ -226,10 +234,14 @@ func (_c *MockProxyClientManager_GetProxyClients_Call) RunAndReturn(run func() *
return _c return _c
} }
// GetProxyCount provides a mock function with given fields: // GetProxyCount provides a mock function with no fields
func (_m *MockProxyClientManager) GetProxyCount() int { func (_m *MockProxyClientManager) GetProxyCount() int {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetProxyCount")
}
var r0 int var r0 int
if rf, ok := ret.Get(0).(func() int); ok { if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf() r0 = rf()
@ -271,6 +283,10 @@ func (_c *MockProxyClientManager_GetProxyCount_Call) RunAndReturn(run func() int
func (_m *MockProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { func (_m *MockProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetProxyMetrics")
}
var r0 []*milvuspb.GetMetricsResponse var r0 []*milvuspb.GetMetricsResponse
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context) ([]*milvuspb.GetMetricsResponse, error)); ok { if rf, ok := ret.Get(0).(func(context.Context) ([]*milvuspb.GetMetricsResponse, error)); ok {
@ -332,6 +348,10 @@ func (_m *MockProxyClientManager) InvalidateCollectionMetaCache(ctx context.Cont
_ca = append(_ca, _va...) _ca = append(_ca, _va...)
ret := _m.Called(_ca...) ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for InvalidateCollectionMetaCache")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error); ok {
r0 = rf(ctx, request, opts...) r0 = rf(ctx, request, opts...)
@ -383,6 +403,10 @@ func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) RunAndRetur
func (_m *MockProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { func (_m *MockProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error {
ret := _m.Called(ctx, request) ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for InvalidateCredentialCache")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok {
r0 = rf(ctx, request) r0 = rf(ctx, request)
@ -426,6 +450,10 @@ func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(ru
func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error { func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error {
ret := _m.Called(ctx, request) ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for InvalidateShardLeaderCache")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok {
r0 = rf(ctx, request) r0 = rf(ctx, request)
@ -469,6 +497,10 @@ func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(r
func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
if len(ret) == 0 {
panic("no return value specified for RefreshPolicyInfoCache")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok {
r0 = rf(ctx, req) r0 = rf(ctx, req)
@ -512,6 +544,10 @@ func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) RunAndReturn(run f
func (_m *MockProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { func (_m *MockProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error {
ret := _m.Called(ctx, request) ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for SetRates")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) error); ok {
r0 = rf(ctx, request) r0 = rf(ctx, request)
@ -555,6 +591,10 @@ func (_c *MockProxyClientManager_SetRates_Call) RunAndReturn(run func(context.Co
func (_m *MockProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { func (_m *MockProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error {
ret := _m.Called(ctx, request) ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for UpdateCredentialCache")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok {
r0 = rf(ctx, request) r0 = rf(ctx, request)

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package proxyutil package proxyutil
@ -64,7 +64,7 @@ func (_c *MockProxyWatcher_AddSessionFunc_Call) Return() *MockProxyWatcher_AddSe
} }
func (_c *MockProxyWatcher_AddSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call { func (_c *MockProxyWatcher_AddSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -110,11 +110,11 @@ func (_c *MockProxyWatcher_DelSessionFunc_Call) Return() *MockProxyWatcher_DelSe
} }
func (_c *MockProxyWatcher_DelSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call { func (_c *MockProxyWatcher_DelSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// Stop provides a mock function with given fields: // Stop provides a mock function with no fields
func (_m *MockProxyWatcher) Stop() { func (_m *MockProxyWatcher) Stop() {
_m.Called() _m.Called()
} }
@ -142,7 +142,7 @@ func (_c *MockProxyWatcher_Stop_Call) Return() *MockProxyWatcher_Stop_Call {
} }
func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher_Stop_Call { func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher_Stop_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -150,6 +150,10 @@ func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher
func (_m *MockProxyWatcher) WatchProxy(ctx context.Context) error { func (_m *MockProxyWatcher) WatchProxy(ctx context.Context) error {
ret := _m.Called(ctx) ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for WatchProxy")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok { if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx) r0 = rf(ctx)

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.53.3. DO NOT EDIT.
package sessionutil package sessionutil
@ -24,10 +24,14 @@ func (_m *MockSession) EXPECT() *MockSession_Expecter {
return &MockSession_Expecter{mock: &_m.Mock} return &MockSession_Expecter{mock: &_m.Mock}
} }
// Disconnected provides a mock function with given fields: // Disconnected provides a mock function with no fields
func (_m *MockSession) Disconnected() bool { func (_m *MockSession) Disconnected() bool {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Disconnected")
}
var r0 bool var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok { if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf() r0 = rf()
@ -69,6 +73,10 @@ func (_c *MockSession_Disconnected_Call) RunAndReturn(run func() bool) *MockSess
func (_m *MockSession) ForceActiveStandby(activateFunc func() error) error { func (_m *MockSession) ForceActiveStandby(activateFunc func() error) error {
ret := _m.Called(activateFunc) ret := _m.Called(activateFunc)
if len(ret) == 0 {
panic("no return value specified for ForceActiveStandby")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(func() error) error); ok { if rf, ok := ret.Get(0).(func(func() error) error); ok {
r0 = rf(activateFunc) r0 = rf(activateFunc)
@ -107,10 +115,14 @@ func (_c *MockSession_ForceActiveStandby_Call) RunAndReturn(run func(func() erro
return _c return _c
} }
// GetAddress provides a mock function with given fields: // GetAddress provides a mock function with no fields
func (_m *MockSession) GetAddress() string { func (_m *MockSession) GetAddress() string {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetAddress")
}
var r0 string var r0 string
if rf, ok := ret.Get(0).(func() string); ok { if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf() r0 = rf()
@ -148,10 +160,14 @@ func (_c *MockSession_GetAddress_Call) RunAndReturn(run func() string) *MockSess
return _c return _c
} }
// GetServerID provides a mock function with given fields: // GetServerID provides a mock function with no fields
func (_m *MockSession) GetServerID() int64 { func (_m *MockSession) GetServerID() int64 {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetServerID")
}
var r0 int64 var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok { if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf() r0 = rf()
@ -193,6 +209,10 @@ func (_c *MockSession_GetServerID_Call) RunAndReturn(run func() int64) *MockSess
func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) { func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) {
ret := _m.Called(prefix) ret := _m.Called(prefix)
if len(ret) == 0 {
panic("no return value specified for GetSessions")
}
var r0 map[string]*Session var r0 map[string]*Session
var r1 int64 var r1 int64
var r2 error var r2 error
@ -254,6 +274,10 @@ func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(string) (map[strin
func (_m *MockSession) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) { func (_m *MockSession) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) {
ret := _m.Called(prefix, r) ret := _m.Called(prefix, r)
if len(ret) == 0 {
panic("no return value specified for GetSessionsWithVersionRange")
}
var r0 map[string]*Session var r0 map[string]*Session
var r1 int64 var r1 int64
var r2 error var r2 error
@ -312,10 +336,14 @@ func (_c *MockSession_GetSessionsWithVersionRange_Call) RunAndReturn(run func(st
return _c return _c
} }
// GoingStop provides a mock function with given fields: // GoingStop provides a mock function with no fields
func (_m *MockSession) GoingStop() error { func (_m *MockSession) GoingStop() error {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GoingStop")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func() error); ok { if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf() r0 = rf()
@ -385,14 +413,18 @@ func (_c *MockSession_Init_Call) Return() *MockSession_Init_Call {
} }
func (_c *MockSession_Init_Call) RunAndReturn(run func(string, string, bool, bool)) *MockSession_Init_Call { func (_c *MockSession_Init_Call) RunAndReturn(run func(string, string, bool, bool)) *MockSession_Init_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// IsTriggerKill provides a mock function with given fields: // IsTriggerKill provides a mock function with no fields
func (_m *MockSession) IsTriggerKill() bool { func (_m *MockSession) IsTriggerKill() bool {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for IsTriggerKill")
}
var r0 bool var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok { if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf() r0 = rf()
@ -460,14 +492,18 @@ func (_c *MockSession_LivenessCheck_Call) Return() *MockSession_LivenessCheck_Ca
} }
func (_c *MockSession_LivenessCheck_Call) RunAndReturn(run func(context.Context, func())) *MockSession_LivenessCheck_Call { func (_c *MockSession_LivenessCheck_Call) RunAndReturn(run func(context.Context, func())) *MockSession_LivenessCheck_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// MarshalJSON provides a mock function with given fields: // MarshalJSON provides a mock function with no fields
func (_m *MockSession) MarshalJSON() ([]byte, error) { func (_m *MockSession) MarshalJSON() ([]byte, error) {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for MarshalJSON")
}
var r0 []byte var r0 []byte
var r1 error var r1 error
if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { if rf, ok := ret.Get(0).(func() ([]byte, error)); ok {
@ -521,6 +557,10 @@ func (_c *MockSession_MarshalJSON_Call) RunAndReturn(run func() ([]byte, error))
func (_m *MockSession) ProcessActiveStandBy(activateFunc func() error) error { func (_m *MockSession) ProcessActiveStandBy(activateFunc func() error) error {
ret := _m.Called(activateFunc) ret := _m.Called(activateFunc)
if len(ret) == 0 {
panic("no return value specified for ProcessActiveStandBy")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(func() error) error); ok { if rf, ok := ret.Get(0).(func(func() error) error); ok {
r0 = rf(activateFunc) r0 = rf(activateFunc)
@ -559,7 +599,7 @@ func (_c *MockSession_ProcessActiveStandBy_Call) RunAndReturn(run func(func() er
return _c return _c
} }
// Register provides a mock function with given fields: // Register provides a mock function with no fields
func (_m *MockSession) Register() { func (_m *MockSession) Register() {
_m.Called() _m.Called()
} }
@ -587,14 +627,18 @@ func (_c *MockSession_Register_Call) Return() *MockSession_Register_Call {
} }
func (_c *MockSession_Register_Call) RunAndReturn(run func()) *MockSession_Register_Call { func (_c *MockSession_Register_Call) RunAndReturn(run func()) *MockSession_Register_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// Registered provides a mock function with given fields: // Registered provides a mock function with no fields
func (_m *MockSession) Registered() bool { func (_m *MockSession) Registered() bool {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Registered")
}
var r0 bool var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok { if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf() r0 = rf()
@ -661,7 +705,7 @@ func (_c *MockSession_Revoke_Call) Return() *MockSession_Revoke_Call {
} }
func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call { func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -694,7 +738,7 @@ func (_c *MockSession_SetDisconnected_Call) Return() *MockSession_SetDisconnecte
} }
func (_c *MockSession_SetDisconnected_Call) RunAndReturn(run func(bool)) *MockSession_SetDisconnected_Call { func (_c *MockSession_SetDisconnected_Call) RunAndReturn(run func(bool)) *MockSession_SetDisconnected_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
@ -727,11 +771,11 @@ func (_c *MockSession_SetEnableActiveStandBy_Call) Return() *MockSession_SetEnab
} }
func (_c *MockSession_SetEnableActiveStandBy_Call) RunAndReturn(run func(bool)) *MockSession_SetEnableActiveStandBy_Call { func (_c *MockSession_SetEnableActiveStandBy_Call) RunAndReturn(run func(bool)) *MockSession_SetEnableActiveStandBy_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// Stop provides a mock function with given fields: // Stop provides a mock function with no fields
func (_m *MockSession) Stop() { func (_m *MockSession) Stop() {
_m.Called() _m.Called()
} }
@ -759,14 +803,18 @@ func (_c *MockSession_Stop_Call) Return() *MockSession_Stop_Call {
} }
func (_c *MockSession_Stop_Call) RunAndReturn(run func()) *MockSession_Stop_Call { func (_c *MockSession_Stop_Call) RunAndReturn(run func()) *MockSession_Stop_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// String provides a mock function with given fields: // String provides a mock function with no fields
func (_m *MockSession) String() string { func (_m *MockSession) String() string {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for String")
}
var r0 string var r0 string
if rf, ok := ret.Get(0).(func() string); ok { if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf() r0 = rf()
@ -808,6 +856,10 @@ func (_c *MockSession_String_Call) RunAndReturn(run func() string) *MockSession_
func (_m *MockSession) UnmarshalJSON(data []byte) error { func (_m *MockSession) UnmarshalJSON(data []byte) error {
ret := _m.Called(data) ret := _m.Called(data)
if len(ret) == 0 {
panic("no return value specified for UnmarshalJSON")
}
var r0 error var r0 error
if rf, ok := ret.Get(0).(func([]byte) error); ok { if rf, ok := ret.Get(0).(func([]byte) error); ok {
r0 = rf(data) r0 = rf(data)
@ -875,20 +927,24 @@ func (_c *MockSession_UpdateRegistered_Call) Return() *MockSession_UpdateRegiste
} }
func (_c *MockSession_UpdateRegistered_Call) RunAndReturn(run func(bool)) *MockSession_UpdateRegistered_Call { func (_c *MockSession_UpdateRegistered_Call) RunAndReturn(run func(bool)) *MockSession_UpdateRegistered_Call {
_c.Call.Return(run) _c.Run(run)
return _c return _c
} }
// WatchServices provides a mock function with given fields: prefix, revision, rewatch // WatchServices provides a mock function with given fields: prefix, revision, rewatch
func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) <-chan *SessionEvent { func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) SessionWatcher {
ret := _m.Called(prefix, revision, rewatch) ret := _m.Called(prefix, revision, rewatch)
var r0 <-chan *SessionEvent if len(ret) == 0 {
if rf, ok := ret.Get(0).(func(string, int64, Rewatch) <-chan *SessionEvent); ok { panic("no return value specified for WatchServices")
}
var r0 SessionWatcher
if rf, ok := ret.Get(0).(func(string, int64, Rewatch) SessionWatcher); ok {
r0 = rf(prefix, revision, rewatch) r0 = rf(prefix, revision, rewatch)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *SessionEvent) r0 = ret.Get(0).(SessionWatcher)
} }
} }
@ -915,26 +971,30 @@ func (_c *MockSession_WatchServices_Call) Run(run func(prefix string, revision i
return _c return _c
} }
func (_c *MockSession_WatchServices_Call) Return(eventChannel <-chan *SessionEvent) *MockSession_WatchServices_Call { func (_c *MockSession_WatchServices_Call) Return(watcher SessionWatcher) *MockSession_WatchServices_Call {
_c.Call.Return(eventChannel) _c.Call.Return(watcher)
return _c return _c
} }
func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, Rewatch) <-chan *SessionEvent) *MockSession_WatchServices_Call { func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, Rewatch) SessionWatcher) *MockSession_WatchServices_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
// WatchServicesWithVersionRange provides a mock function with given fields: prefix, r, revision, rewatch // WatchServicesWithVersionRange provides a mock function with given fields: prefix, r, revision, rewatch
func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) <-chan *SessionEvent { func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) SessionWatcher {
ret := _m.Called(prefix, r, revision, rewatch) ret := _m.Called(prefix, r, revision, rewatch)
var r0 <-chan *SessionEvent if len(ret) == 0 {
if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent); ok { panic("no return value specified for WatchServicesWithVersionRange")
}
var r0 SessionWatcher
if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) SessionWatcher); ok {
r0 = rf(prefix, r, revision, rewatch) r0 = rf(prefix, r, revision, rewatch)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *SessionEvent) r0 = ret.Get(0).(SessionWatcher)
} }
} }
@ -962,12 +1022,12 @@ func (_c *MockSession_WatchServicesWithVersionRange_Call) Run(run func(prefix st
return _c return _c
} }
func (_c *MockSession_WatchServicesWithVersionRange_Call) Return(eventChannel <-chan *SessionEvent) *MockSession_WatchServicesWithVersionRange_Call { func (_c *MockSession_WatchServicesWithVersionRange_Call) Return(watcher SessionWatcher) *MockSession_WatchServicesWithVersionRange_Call {
_c.Call.Return(eventChannel) _c.Call.Return(watcher)
return _c return _c
} }
func (_c *MockSession_WatchServicesWithVersionRange_Call) RunAndReturn(run func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent) *MockSession_WatchServicesWithVersionRange_Call { func (_c *MockSession_WatchServicesWithVersionRange_Call) RunAndReturn(run func(string, semver.Range, int64, Rewatch) SessionWatcher) *MockSession_WatchServicesWithVersionRange_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -0,0 +1,111 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package sessionutil
import mock "github.com/stretchr/testify/mock"
// MockSessionWatcher is an autogenerated mock type for the SessionWatcher type
type MockSessionWatcher struct {
mock.Mock
}
type MockSessionWatcher_Expecter struct {
mock *mock.Mock
}
func (_m *MockSessionWatcher) EXPECT() *MockSessionWatcher_Expecter {
return &MockSessionWatcher_Expecter{mock: &_m.Mock}
}
// EventChannel provides a mock function with no fields
func (_m *MockSessionWatcher) EventChannel() <-chan *SessionEvent {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for EventChannel")
}
var r0 <-chan *SessionEvent
if rf, ok := ret.Get(0).(func() <-chan *SessionEvent); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *SessionEvent)
}
}
return r0
}
// MockSessionWatcher_EventChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EventChannel'
type MockSessionWatcher_EventChannel_Call struct {
*mock.Call
}
// EventChannel is a helper method to define mock.On call
func (_e *MockSessionWatcher_Expecter) EventChannel() *MockSessionWatcher_EventChannel_Call {
return &MockSessionWatcher_EventChannel_Call{Call: _e.mock.On("EventChannel")}
}
func (_c *MockSessionWatcher_EventChannel_Call) Run(run func()) *MockSessionWatcher_EventChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSessionWatcher_EventChannel_Call) Return(_a0 <-chan *SessionEvent) *MockSessionWatcher_EventChannel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSessionWatcher_EventChannel_Call) RunAndReturn(run func() <-chan *SessionEvent) *MockSessionWatcher_EventChannel_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function with no fields
func (_m *MockSessionWatcher) Stop() {
_m.Called()
}
// MockSessionWatcher_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type MockSessionWatcher_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
func (_e *MockSessionWatcher_Expecter) Stop() *MockSessionWatcher_Stop_Call {
return &MockSessionWatcher_Stop_Call{Call: _e.mock.On("Stop")}
}
func (_c *MockSessionWatcher_Stop_Call) Run(run func()) *MockSessionWatcher_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSessionWatcher_Stop_Call) Return() *MockSessionWatcher_Stop_Call {
_c.Call.Return()
return _c
}
func (_c *MockSessionWatcher_Stop_Call) RunAndReturn(run func()) *MockSessionWatcher_Stop_Call {
_c.Run(run)
return _c
}
// NewMockSessionWatcher creates a new instance of MockSessionWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockSessionWatcher(t interface {
mock.TestingT
Cleanup(func())
}) *MockSessionWatcher {
mock := &MockSessionWatcher{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -34,8 +34,8 @@ type SessionInterface interface {
GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error)
GoingStop() error GoingStop() error
WatchServices(prefix string, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher)
WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher)
LivenessCheck(ctx context.Context, callback func()) LivenessCheck(ctx context.Context, callback func())
Stop() Stop()
Revoke(timeout time.Duration) Revoke(timeout time.Duration)
@ -51,3 +51,8 @@ type SessionInterface interface {
GetServerID() int64 GetServerID() int64
IsTriggerKill() bool IsTriggerKill() bool
} }
type SessionWatcher interface {
EventChannel() <-chan *SessionEvent
Stop()
}

View File

@ -765,11 +765,13 @@ type SessionEvent struct {
type sessionWatcher struct { type sessionWatcher struct {
s *Session s *Session
cancel context.CancelFunc
rch clientv3.WatchChan rch clientv3.WatchChan
eventCh chan *SessionEvent eventCh chan *SessionEvent
prefix string prefix string
rewatch Rewatch rewatch Rewatch
validate func(*Session) bool validate func(*Session) bool
wg sync.WaitGroup
closeOnce sync.Once closeOnce sync.Once
} }
@ -779,15 +781,17 @@ func (w *sessionWatcher) closeEventCh() {
}) })
} }
func (w *sessionWatcher) start() { func (w *sessionWatcher) start(ctx context.Context) {
w.wg.Add(1)
go func() { go func() {
defer w.closeEventCh() defer w.wg.Done()
for { for {
select { select {
case <-w.s.ctx.Done(): case <-ctx.Done():
return return
case wresp, ok := <-w.rch: case wresp, ok := <-w.rch:
if !ok { if !ok {
w.closeEventCh()
log.Warn("session watch channel closed") log.Warn("session watch channel closed")
return return
} }
@ -797,6 +801,11 @@ func (w *sessionWatcher) start() {
}() }()
} }
func (w *sessionWatcher) Stop() {
w.cancel()
w.wg.Wait()
}
// WatchServices watches the service's up and down in etcd, and sends event to // WatchServices watches the service's up and down in etcd, and sends event to
// eventChannel. // eventChannel.
// prefix is a parameter to know which service to watch and can be obtained in // prefix is a parameter to know which service to watch and can be obtained in
@ -805,17 +814,19 @@ func (w *sessionWatcher) start() {
// in GetSessions. // in GetSessions.
// If a server up, an event will be add to channel with eventType SessionAddType. // If a server up, an event will be add to channel with eventType SessionAddType.
// If a server down, an event will be add to channel with eventType SessionDelType. // If a server down, an event will be add to channel with eventType SessionDelType.
func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) { func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher) {
ctx, cancel := context.WithCancel(s.ctx)
w := &sessionWatcher{ w := &sessionWatcher{
s: s, s: s,
cancel: cancel,
eventCh: make(chan *SessionEvent, 100), eventCh: make(chan *SessionEvent, 100),
rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)), rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)),
prefix: prefix, prefix: prefix,
rewatch: rewatch, rewatch: rewatch,
validate: func(s *Session) bool { return true }, validate: func(s *Session) bool { return true },
} }
w.start() w.start(ctx)
return w.eventCh return w
} }
// WatchServicesWithVersionRange watches the service's up and down in etcd, and sends event to event Channel. // WatchServicesWithVersionRange watches the service's up and down in etcd, and sends event to event Channel.
@ -824,17 +835,19 @@ func (s *Session) WatchServices(prefix string, revision int64, rewatch Rewatch)
// revision is a etcd reversion to prevent missing key events and can be obtained in GetSessions. // revision is a etcd reversion to prevent missing key events and can be obtained in GetSessions.
// If a server up, an event will be add to channel with eventType SessionAddType. // If a server up, an event will be add to channel with eventType SessionAddType.
// If a server down, an event will be add to channel with eventType SessionDelType. // If a server down, an event will be add to channel with eventType SessionDelType.
func (s *Session) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (eventChannel <-chan *SessionEvent) { func (s *Session) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher) {
ctx, cancel := context.WithCancel(s.ctx)
w := &sessionWatcher{ w := &sessionWatcher{
s: s, s: s,
cancel: cancel,
eventCh: make(chan *SessionEvent, 100), eventCh: make(chan *SessionEvent, 100),
rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)), rch: s.etcdCli.Watch(s.ctx, path.Join(s.metaRoot, DefaultServiceRoot, prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)),
prefix: prefix, prefix: prefix,
rewatch: rewatch, rewatch: rewatch,
validate: func(s *Session) bool { return r(s.Version) }, validate: func(s *Session) bool { return r(s.Version) },
} }
w.start() w.start(ctx)
return w.eventCh return w
} }
func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) { func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) {
@ -919,6 +932,10 @@ func (w *sessionWatcher) handleWatchErr(err error) error {
return nil return nil
} }
func (w *sessionWatcher) EventChannel() <-chan *SessionEvent {
return w.eventCh
}
// LivenessCheck performs liveness check with provided context and channel // LivenessCheck performs liveness check with provided context and channel
// ctx controls the liveness check loop // ctx controls the liveness check loop
// ch is the liveness signal channel, ch is closed only when the session is expired // ch is the liveness signal channel, ch is closed only when the session is expired

View File

@ -158,7 +158,7 @@ func TestUpdateSessions(t *testing.T) {
sessions, rev, err := s.GetSessions("test") sessions, rev, err := s.GetSessions("test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, len(sessions), 0) assert.Equal(t, len(sessions), 0)
eventCh := s.WatchServices("test", rev, nil) watcher := s.WatchServices("test", rev, nil)
sList := []*Session{} sList := []*Session{}
@ -203,7 +203,7 @@ LOOP:
select { select {
case <-ch: case <-ch:
t.FailNow() t.FailNow()
case sessionEvent := <-eventCh: case sessionEvent := <-watcher.EventChannel():
if sessionEvent.EventType == SessionAddEvent { if sessionEvent.EventType == SessionAddEvent {
addEventLen++ addEventLen++
@ -616,7 +616,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() {
_, rev, err := s.GetSessionsWithVersionRange(suite.serverName, r) _, rev, err := s.GetSessionsWithVersionRange(suite.serverName, r)
suite.Require().NoError(err) suite.Require().NoError(err)
ch := s.WatchServicesWithVersionRange(suite.serverName, r, rev, nil) watcher := s.WatchServicesWithVersionRange(suite.serverName, r, rev, nil)
// remove all sessions // remove all sessions
go func() { go func() {
@ -626,7 +626,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() {
}() }()
select { select {
case evt := <-ch: case evt := <-watcher.EventChannel():
suite.Equal(suite.sessions[1].ServerID, evt.Session.ServerID) suite.Equal(suite.sessions[1].ServerID, evt.Session.ServerID)
case <-time.After(time.Second): case <-time.After(time.Second):
suite.Fail("no event received, failing") suite.Fail("no event received, failing")