From 44d915a43b6c8f7e99fffac310be09ead332bb2b Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Tue, 23 Dec 2025 21:11:19 +0800 Subject: [PATCH] fix: [2.5] Remove stale proxy clients on rewatch etcd (#46491) ### **User description** AddProxyClients now removes clients not in the new snapshot before adding new ones. This ensures proper cleanup when ProxyWatcher re-watche etcd. issue: https://github.com/milvus-io/milvus/issues/46397 pr: https://github.com/milvus-io/milvus/pull/46398 ___ ### **PR Type** Bug fix ___ ### **Description** - Rename `AddProxyClients` to `SetProxyClients` for clearer semantics - Implement stale client cleanup before adding new proxy clients - Remove proxy clients not present in new etcd snapshot - Update all callers in querycoord and rootcoord servers - Regenerate mock files with mockery v2.53.3 ___ ### Diagram Walkthrough ```mermaid flowchart LR A["ProxyWatcher detects
etcd change"] -->|calls| B["SetProxyClients
with new snapshot"] B -->|removes| C["Stale clients
not in snapshot"] C -->|closes| D["Cleanup resources"] B -->|adds| E["New proxy clients
from snapshot"] ```

File Walkthrough

Relevant files
Bug fix
3 files
proxy_client_manager.go
Rename AddProxyClients to SetProxyClients with cleanup     
+22/-2   
server.go
Update ProxyWatcher to use SetProxyClients                             
+1/-1     
root_coord.go
Update ProxyWatcher initialization to SetProxyClients       
+2/-2     
Tests
1 files
proxy_client_manager_test.go
Update test for SetProxyClients stale removal                       
+26/-10 
Miscellaneous
7 files
mock_proxy_client_manager.go
Regenerate mock with SetProxyClients method                           
+78/-38 
mock_proxy_watcher.go
Regenerate mock with mockery v2.53.3                                         
+9/-5     
mock_global_id_allocator.go
Regenerate mock with mockery v2.53.3                                         
+15/-3   
mock_grpc_client.go
Regenerate mock with mockery v2.53.3                                         
+33/-13 
allocator.go
Regenerate mock with mockery v2.53.3                                         
+26/-6   
mock_factory.go
Regenerate mock with mockery v2.53.3                                         
+18/-2   
mock_session.go
Regenerate mock with mockery v2.53.3                                         
+79/-19 
___ Signed-off-by: bigsheeper --- .../allocator/mock_global_id_allocator.go | 18 ++- internal/mocks/mock_grpc_client.go | 46 +++++-- internal/querycoordv2/server.go | 2 +- internal/rootcoord/root_coord.go | 4 +- internal/tso/mocks/allocator.go | 32 ++++- internal/util/dependency/mock_factory.go | 20 ++- .../proxyutil/mock_proxy_client_manager.go | 116 ++++++++++++------ internal/util/proxyutil/mock_proxy_watcher.go | 14 ++- .../util/proxyutil/proxy_client_manager.go | 24 +++- .../proxyutil/proxy_client_manager_test.go | 40 ++++-- internal/util/sessionutil/mock_session.go | 98 ++++++++++++--- 11 files changed, 311 insertions(+), 103 deletions(-) diff --git a/internal/allocator/mock_global_id_allocator.go b/internal/allocator/mock_global_id_allocator.go index b1f0b4d246..aedaee9710 100644 --- a/internal/allocator/mock_global_id_allocator.go +++ b/internal/allocator/mock_global_id_allocator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package allocator @@ -21,6 +21,10 @@ func (_m *MockGlobalIDAllocator) EXPECT() *MockGlobalIDAllocator_Expecter { func (_m *MockGlobalIDAllocator) Alloc(count uint32) (int64, int64, error) { ret := _m.Called(count) + if len(ret) == 0 { + panic("no return value specified for Alloc") + } + var r0 int64 var r1 int64 var r2 error @@ -76,10 +80,14 @@ func (_c *MockGlobalIDAllocator_Alloc_Call) RunAndReturn(run func(uint32) (int64 return _c } -// AllocOne provides a mock function with given fields: +// AllocOne provides a mock function with no fields func (_m *MockGlobalIDAllocator) AllocOne() (int64, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for AllocOne") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func() (int64, error)); ok { @@ -127,10 +135,14 @@ func (_c *MockGlobalIDAllocator_AllocOne_Call) RunAndReturn(run func() (int64, e return _c } -// Initialize provides a mock function with given fields: +// Initialize provides a mock function with no fields func (_m *MockGlobalIDAllocator) Initialize() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Initialize") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() diff --git a/internal/mocks/mock_grpc_client.go b/internal/mocks/mock_grpc_client.go index 383f8ff330..bcf4589f0e 100644 --- a/internal/mocks/mock_grpc_client.go +++ b/internal/mocks/mock_grpc_client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocks @@ -33,6 +33,10 @@ func (_m *MockGrpcClient[T]) EXPECT() *MockGrpcClient_Expecter[T] { func (_m *MockGrpcClient[T]) Call(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { ret := _m.Called(ctx, caller) + if len(ret) == 0 { + panic("no return value specified for Call") + } + var r0 interface{} var r1 error if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { @@ -84,10 +88,14 @@ func (_c *MockGrpcClient_Call_Call[T]) RunAndReturn(run func(context.Context, fu return _c } -// Close provides a mock function with given fields: +// Close provides a mock function with no fields func (_m *MockGrpcClient[T]) Close() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Close") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -125,7 +133,7 @@ func (_c *MockGrpcClient_Close_Call[T]) RunAndReturn(run func() error) *MockGrpc return _c } -// EnableEncryption provides a mock function with given fields: +// EnableEncryption provides a mock function with no fields func (_m *MockGrpcClient[T]) EnableEncryption() { _m.Called() } @@ -153,14 +161,18 @@ func (_c *MockGrpcClient_EnableEncryption_Call[T]) Return() *MockGrpcClient_Enab } func (_c *MockGrpcClient_EnableEncryption_Call[T]) RunAndReturn(run func()) *MockGrpcClient_EnableEncryption_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } -// GetNodeID provides a mock function with given fields: +// GetNodeID provides a mock function with no fields func (_m *MockGrpcClient[T]) GetNodeID() int64 { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetNodeID") + } + var r0 int64 if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() @@ -198,10 +210,14 @@ func (_c *MockGrpcClient_GetNodeID_Call[T]) RunAndReturn(run func() int64) *Mock return _c } -// GetRole provides a mock function with given fields: +// GetRole provides a mock function with no fields func (_m *MockGrpcClient[T]) GetRole() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetRole") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -243,6 +259,10 @@ func (_c *MockGrpcClient_GetRole_Call[T]) RunAndReturn(run func() string) *MockG func (_m *MockGrpcClient[T]) ReCall(ctx context.Context, caller func(T) (interface{}, error)) (interface{}, error) { ret := _m.Called(ctx, caller) + if len(ret) == 0 { + panic("no return value specified for ReCall") + } + var r0 interface{} var r1 error if rf, ok := ret.Get(0).(func(context.Context, func(T) (interface{}, error)) (interface{}, error)); ok { @@ -323,7 +343,7 @@ func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) Return() *MockGrpcClient_SetGet } func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (string, error))) *MockGrpcClient_SetGetAddrFunc_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -356,7 +376,7 @@ func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Return() *MockGrpcClien } func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -389,7 +409,7 @@ func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) Return() *MockGrpcCli } func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -422,7 +442,7 @@ func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) Return() *MockGrpcClient_ } func (_c *MockGrpcClient_SetNewGrpcClientFunc_Call[T]) RunAndReturn(run func(func(*grpc.ClientConn) T)) *MockGrpcClient_SetNewGrpcClientFunc_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -455,7 +475,7 @@ func (_c *MockGrpcClient_SetNodeID_Call[T]) Return() *MockGrpcClient_SetNodeID_C } func (_c *MockGrpcClient_SetNodeID_Call[T]) RunAndReturn(run func(int64)) *MockGrpcClient_SetNodeID_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -488,7 +508,7 @@ func (_c *MockGrpcClient_SetRole_Call[T]) Return() *MockGrpcClient_SetRole_Call[ } func (_c *MockGrpcClient_SetRole_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetRole_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -521,7 +541,7 @@ func (_c *MockGrpcClient_SetSession_Call[T]) Return() *MockGrpcClient_SetSession } func (_c *MockGrpcClient_SetSession_Call[T]) RunAndReturn(run func(*sessionutil.Session)) *MockGrpcClient_SetSession_Call[T] { - _c.Call.Return(run) + _c.Run(run) return _c } diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 756460916e..606a129a25 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -353,7 +353,7 @@ func (s *Server) initQueryCoord() error { s.proxyClientManager = proxyutil.NewProxyClientManager(proxyutil.DefaultProxyCreator) s.proxyWatcher = proxyutil.NewProxyWatcher( s.etcdCli, - s.proxyClientManager.AddProxyClients, + s.proxyClientManager.SetProxyClients, ) s.proxyWatcher.AddSessionFunc(s.proxyClientManager.AddProxyClient) s.proxyWatcher.DelSessionFunc(s.proxyClientManager.DelProxyClient) diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 0de5ca1884..2b806805ac 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -485,14 +485,14 @@ func (c *Core) initInternal() error { c.proxyWatcher = proxyutil.NewProxyWatcher( c.etcdCli, c.chanTimeTick.initSessions, - c.proxyClientManager.AddProxyClients, + c.proxyClientManager.SetProxyClients, ) c.proxyWatcher.AddSessionFunc(c.chanTimeTick.addSession, c.proxyClientManager.AddProxyClient) c.proxyWatcher.DelSessionFunc(c.chanTimeTick.delSession, c.proxyClientManager.DelProxyClient) } else { c.proxyWatcher = proxyutil.NewProxyWatcher( c.etcdCli, - c.proxyClientManager.AddProxyClients, + c.proxyClientManager.SetProxyClients, ) c.proxyWatcher.AddSessionFunc(c.proxyClientManager.AddProxyClient) c.proxyWatcher.DelSessionFunc(c.proxyClientManager.DelProxyClient) diff --git a/internal/tso/mocks/allocator.go b/internal/tso/mocks/allocator.go index 2d5d7308d9..f0a8a6244f 100644 --- a/internal/tso/mocks/allocator.go +++ b/internal/tso/mocks/allocator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocktso @@ -25,6 +25,10 @@ func (_m *Allocator) EXPECT() *Allocator_Expecter { func (_m *Allocator) GenerateTSO(count uint32) (uint64, error) { ret := _m.Called(count) + if len(ret) == 0 { + panic("no return value specified for GenerateTSO") + } + var r0 uint64 var r1 error if rf, ok := ret.Get(0).(func(uint32) (uint64, error)); ok { @@ -73,10 +77,14 @@ func (_c *Allocator_GenerateTSO_Call) RunAndReturn(run func(uint32) (uint64, err return _c } -// GetLastSavedTime provides a mock function with given fields: +// GetLastSavedTime provides a mock function with no fields func (_m *Allocator) GetLastSavedTime() time.Time { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetLastSavedTime") + } + var r0 time.Time if rf, ok := ret.Get(0).(func() time.Time); ok { r0 = rf() @@ -114,10 +122,14 @@ func (_c *Allocator_GetLastSavedTime_Call) RunAndReturn(run func() time.Time) *A return _c } -// Initialize provides a mock function with given fields: +// Initialize provides a mock function with no fields func (_m *Allocator) Initialize() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Initialize") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -155,7 +167,7 @@ func (_c *Allocator_Initialize_Call) RunAndReturn(run func() error) *Allocator_I return _c } -// Reset provides a mock function with given fields: +// Reset provides a mock function with no fields func (_m *Allocator) Reset() { _m.Called() } @@ -183,7 +195,7 @@ func (_c *Allocator_Reset_Call) Return() *Allocator_Reset_Call { } func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -191,6 +203,10 @@ func (_c *Allocator_Reset_Call) RunAndReturn(run func()) *Allocator_Reset_Call { func (_m *Allocator) SetTSO(_a0 uint64) error { ret := _m.Called(_a0) + if len(ret) == 0 { + panic("no return value specified for SetTSO") + } + var r0 error if rf, ok := ret.Get(0).(func(uint64) error); ok { r0 = rf(_a0) @@ -229,10 +245,14 @@ func (_c *Allocator_SetTSO_Call) RunAndReturn(run func(uint64) error) *Allocator return _c } -// UpdateTSO provides a mock function with given fields: +// UpdateTSO provides a mock function with no fields func (_m *Allocator) UpdateTSO() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for UpdateTSO") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() diff --git a/internal/util/dependency/mock_factory.go b/internal/util/dependency/mock_factory.go index 444020a203..bcb43b5708 100644 --- a/internal/util/dependency/mock_factory.go +++ b/internal/util/dependency/mock_factory.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package dependency @@ -55,7 +55,7 @@ func (_c *MockFactory_Init_Call) Return() *MockFactory_Init_Call { } func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentParam)) *MockFactory_Init_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -63,6 +63,10 @@ func (_c *MockFactory_Init_Call) RunAndReturn(run func(*paramtable.ComponentPara func (_m *MockFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewMsgStream") + } + var r0 msgstream.MsgStream var r1 error if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { @@ -117,6 +121,10 @@ func (_c *MockFactory_NewMsgStream_Call) RunAndReturn(run func(context.Context) func (_m *MockFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewMsgStreamDisposer") + } + var r0 func([]string, string) error if rf, ok := ret.Get(0).(func(context.Context) func([]string, string) error); ok { r0 = rf(ctx) @@ -161,6 +169,10 @@ func (_c *MockFactory_NewMsgStreamDisposer_Call) RunAndReturn(run func(context.C func (_m *MockFactory) NewPersistentStorageChunkManager(ctx context.Context) (storage.ChunkManager, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewPersistentStorageChunkManager") + } + var r0 storage.ChunkManager var r1 error if rf, ok := ret.Get(0).(func(context.Context) (storage.ChunkManager, error)); ok { @@ -215,6 +227,10 @@ func (_c *MockFactory_NewPersistentStorageChunkManager_Call) RunAndReturn(run fu func (_m *MockFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for NewTtMsgStream") + } + var r0 msgstream.MsgStream var r1 error if rf, ok := ret.Get(0).(func(context.Context) (msgstream.MsgStream, error)); ok { diff --git a/internal/util/proxyutil/mock_proxy_client_manager.go b/internal/util/proxyutil/mock_proxy_client_manager.go index e2244e6b76..81f368d81c 100644 --- a/internal/util/proxyutil/mock_proxy_client_manager.go +++ b/internal/util/proxyutil/mock_proxy_client_manager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package proxyutil @@ -59,40 +59,7 @@ func (_c *MockProxyClientManager_AddProxyClient_Call) Return() *MockProxyClientM } func (_c *MockProxyClientManager_AddProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call { - _c.Call.Return(run) - return _c -} - -// AddProxyClients provides a mock function with given fields: session -func (_m *MockProxyClientManager) AddProxyClients(session []*sessionutil.Session) { - _m.Called(session) -} - -// MockProxyClientManager_AddProxyClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddProxyClients' -type MockProxyClientManager_AddProxyClients_Call struct { - *mock.Call -} - -// AddProxyClients is a helper method to define mock.On call -// - session []*sessionutil.Session -func (_e *MockProxyClientManager_Expecter) AddProxyClients(session interface{}) *MockProxyClientManager_AddProxyClients_Call { - return &MockProxyClientManager_AddProxyClients_Call{Call: _e.mock.On("AddProxyClients", session)} -} - -func (_c *MockProxyClientManager_AddProxyClients_Call) Run(run func(session []*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*sessionutil.Session)) - }) - return _c -} - -func (_c *MockProxyClientManager_AddProxyClients_Call) Return() *MockProxyClientManager_AddProxyClients_Call { - _c.Call.Return() - return _c -} - -func (_c *MockProxyClientManager_AddProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -125,7 +92,7 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) Return() *MockProxyClientM } func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -133,6 +100,10 @@ func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*ses func (_m *MockProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetComponentStates") + } + var r0 map[int64]*milvuspb.ComponentStates var r1 error if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*milvuspb.ComponentStates, error)); ok { @@ -183,10 +154,14 @@ func (_c *MockProxyClientManager_GetComponentStates_Call) RunAndReturn(run func( return _c } -// GetProxyClients provides a mock function with given fields: +// GetProxyClients provides a mock function with no fields func (_m *MockProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetProxyClients") + } + var r0 *typeutil.ConcurrentMap[int64, types.ProxyClient] if rf, ok := ret.Get(0).(func() *typeutil.ConcurrentMap[int64, types.ProxyClient]); ok { r0 = rf() @@ -226,10 +201,14 @@ func (_c *MockProxyClientManager_GetProxyClients_Call) RunAndReturn(run func() * return _c } -// GetProxyCount provides a mock function with given fields: +// GetProxyCount provides a mock function with no fields func (_m *MockProxyClientManager) GetProxyCount() int { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetProxyCount") + } + var r0 int if rf, ok := ret.Get(0).(func() int); ok { r0 = rf() @@ -271,6 +250,10 @@ func (_c *MockProxyClientManager_GetProxyCount_Call) RunAndReturn(run func() int func (_m *MockProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetProxyMetrics") + } + var r0 []*milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context) ([]*milvuspb.GetMetricsResponse, error)); ok { @@ -332,6 +315,10 @@ func (_m *MockProxyClientManager) InvalidateCollectionMetaCache(ctx context.Cont _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for InvalidateCollectionMetaCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error); ok { r0 = rf(ctx, request, opts...) @@ -383,6 +370,10 @@ func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) RunAndRetur func (_m *MockProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for InvalidateCredentialCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { r0 = rf(ctx, request) @@ -426,6 +417,10 @@ func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(ru func (_m *MockProxyClientManager) InvalidateShardLeaderCache(ctx context.Context, request *proxypb.InvalidateShardLeaderCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for InvalidateShardLeaderCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateShardLeaderCacheRequest) error); ok { r0 = rf(ctx, request) @@ -469,6 +464,10 @@ func (_c *MockProxyClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(r func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for RefreshPolicyInfoCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { r0 = rf(ctx, req) @@ -508,10 +507,47 @@ func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) RunAndReturn(run f return _c } +// SetProxyClients provides a mock function with given fields: session +func (_m *MockProxyClientManager) SetProxyClients(session []*sessionutil.Session) { + _m.Called(session) +} + +// MockProxyClientManager_SetProxyClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetProxyClients' +type MockProxyClientManager_SetProxyClients_Call struct { + *mock.Call +} + +// SetProxyClients is a helper method to define mock.On call +// - session []*sessionutil.Session +func (_e *MockProxyClientManager_Expecter) SetProxyClients(session interface{}) *MockProxyClientManager_SetProxyClients_Call { + return &MockProxyClientManager_SetProxyClients_Call{Call: _e.mock.On("SetProxyClients", session)} +} + +func (_c *MockProxyClientManager_SetProxyClients_Call) Run(run func(session []*sessionutil.Session)) *MockProxyClientManager_SetProxyClients_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]*sessionutil.Session)) + }) + return _c +} + +func (_c *MockProxyClientManager_SetProxyClients_Call) Return() *MockProxyClientManager_SetProxyClients_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyClientManager_SetProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_SetProxyClients_Call { + _c.Run(run) + return _c +} + // SetRates provides a mock function with given fields: ctx, request func (_m *MockProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for SetRates") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) error); ok { r0 = rf(ctx, request) @@ -555,6 +591,10 @@ func (_c *MockProxyClientManager_SetRates_Call) RunAndReturn(run func(context.Co func (_m *MockProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { ret := _m.Called(ctx, request) + if len(ret) == 0 { + panic("no return value specified for UpdateCredentialCache") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { r0 = rf(ctx, request) diff --git a/internal/util/proxyutil/mock_proxy_watcher.go b/internal/util/proxyutil/mock_proxy_watcher.go index 0aed0cb5b6..4f79d65a8c 100644 --- a/internal/util/proxyutil/mock_proxy_watcher.go +++ b/internal/util/proxyutil/mock_proxy_watcher.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package proxyutil @@ -64,7 +64,7 @@ func (_c *MockProxyWatcher_AddSessionFunc_Call) Return() *MockProxyWatcher_AddSe } func (_c *MockProxyWatcher_AddSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -110,11 +110,11 @@ func (_c *MockProxyWatcher_DelSessionFunc_Call) Return() *MockProxyWatcher_DelSe } func (_c *MockProxyWatcher_DelSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Stop provides a mock function with given fields: +// Stop provides a mock function with no fields func (_m *MockProxyWatcher) Stop() { _m.Called() } @@ -142,7 +142,7 @@ func (_c *MockProxyWatcher_Stop_Call) Return() *MockProxyWatcher_Stop_Call { } func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher_Stop_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -150,6 +150,10 @@ func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher func (_m *MockProxyWatcher) WatchProxy(ctx context.Context) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for WatchProxy") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(ctx) diff --git a/internal/util/proxyutil/proxy_client_manager.go b/internal/util/proxyutil/proxy_client_manager.go index 7693683055..2f05dee2b2 100644 --- a/internal/util/proxyutil/proxy_client_manager.go +++ b/internal/util/proxyutil/proxy_client_manager.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -82,7 +83,7 @@ var defaultClientManagerHelper = ProxyClientManagerHelper{ type ProxyClientManagerInterface interface { AddProxyClient(session *sessionutil.Session) - AddProxyClients(session []*sessionutil.Session) + SetProxyClients(session []*sessionutil.Session) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] DelProxyClient(s *sessionutil.Session) GetProxyCount() int @@ -111,7 +112,26 @@ func NewProxyClientManager(creator ProxyCreator) *ProxyClientManager { } } -func (p *ProxyClientManager) AddProxyClients(sessions []*sessionutil.Session) { +// SetProxyClients sets proxy clients from a full snapshot of sessions. +// It removes stale clients not in the new snapshot and adds new ones. +// This is called during initial setup or when re-watching after etcd error. +func (p *ProxyClientManager) SetProxyClients(sessions []*sessionutil.Session) { + aliveSessions := lo.KeyBy(sessions, func(session *sessionutil.Session) int64 { + return session.ServerID + }) + + // Remove stale clients not in the alive sessions + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + if _, ok := aliveSessions[key]; !ok { + if cli, loaded := p.proxyClient.GetAndRemove(key); loaded { + cli.Close() + log.Info("remove stale proxy client", zap.Int64("serverID", key)) + } + } + return true + }) + + // Add new clients for _, session := range sessions { p.AddProxyClient(session) } diff --git a/internal/util/proxyutil/proxy_client_manager_test.go b/internal/util/proxyutil/proxy_client_manager_test.go index 60033dfaf5..382b6c5d91 100644 --- a/internal/util/proxyutil/proxy_client_manager_test.go +++ b/internal/util/proxyutil/proxy_client_manager_test.go @@ -103,22 +103,38 @@ func (p *proxyMock) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Ref return merr.Success(), nil } -func TestProxyClientManager_AddProxyClients(t *testing.T) { - proxyCreator := func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { - return nil, errors.New("failed") - } +func TestProxyClientManager_SetProxyClients(t *testing.T) { + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().Close().Return(nil).Once() + p2 := mocks.NewMockProxyClient(t) + p3 := mocks.NewMockProxyClient(t) + proxyCreator := func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { + return p3, nil + } pcm := NewProxyClientManager(proxyCreator) - session := &sessionutil.Session{ - SessionRaw: sessionutil.SessionRaw{ - ServerID: 100, - Address: "localhost", - }, - } + // Initial state: proxy 1, 2 + pcm.proxyClient.Insert(1, p1) + pcm.proxyClient.Insert(2, p2) + assert.Equal(t, 2, pcm.GetProxyCount()) - sessions := []*sessionutil.Session{session} - pcm.AddProxyClients(sessions) + // New snapshot: proxy 2, 3 + sessions := []*sessionutil.Session{ + {SessionRaw: sessionutil.SessionRaw{ServerID: 2, Address: "addr2"}}, + {SessionRaw: sessionutil.SessionRaw{ServerID: 4, Address: "addr4"}}, + } + pcm.SetProxyClients(sessions) + + // Verify: proxy 1 removed, proxy 2 kept, proxy 3 added + _, ok := pcm.proxyClient.Get(1) + assert.False(t, ok, "stale proxy 1 should be removed") + _, ok = pcm.proxyClient.Get(2) + assert.True(t, ok, "proxy 2 should still exist") + _, ok = pcm.proxyClient.Get(4) + assert.True(t, ok, "proxy 4 should be added") + + assert.Equal(t, 2, pcm.GetProxyCount()) } func TestProxyClientManager_AddProxyClient(t *testing.T) { diff --git a/internal/util/sessionutil/mock_session.go b/internal/util/sessionutil/mock_session.go index a4d47020ec..5af9c2c6be 100644 --- a/internal/util/sessionutil/mock_session.go +++ b/internal/util/sessionutil/mock_session.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package sessionutil @@ -24,10 +24,14 @@ func (_m *MockSession) EXPECT() *MockSession_Expecter { return &MockSession_Expecter{mock: &_m.Mock} } -// Disconnected provides a mock function with given fields: +// Disconnected provides a mock function with no fields func (_m *MockSession) Disconnected() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Disconnected") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -69,6 +73,10 @@ func (_c *MockSession_Disconnected_Call) RunAndReturn(run func() bool) *MockSess func (_m *MockSession) ForceActiveStandby(activateFunc func() error) error { ret := _m.Called(activateFunc) + if len(ret) == 0 { + panic("no return value specified for ForceActiveStandby") + } + var r0 error if rf, ok := ret.Get(0).(func(func() error) error); ok { r0 = rf(activateFunc) @@ -107,10 +115,14 @@ func (_c *MockSession_ForceActiveStandby_Call) RunAndReturn(run func(func() erro return _c } -// GetAddress provides a mock function with given fields: +// GetAddress provides a mock function with no fields func (_m *MockSession) GetAddress() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetAddress") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -148,10 +160,14 @@ func (_c *MockSession_GetAddress_Call) RunAndReturn(run func() string) *MockSess return _c } -// GetServerID provides a mock function with given fields: +// GetServerID provides a mock function with no fields func (_m *MockSession) GetServerID() int64 { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GetServerID") + } + var r0 int64 if rf, ok := ret.Get(0).(func() int64); ok { r0 = rf() @@ -193,6 +209,10 @@ func (_c *MockSession_GetServerID_Call) RunAndReturn(run func() int64) *MockSess func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) { ret := _m.Called(prefix) + if len(ret) == 0 { + panic("no return value specified for GetSessions") + } + var r0 map[string]*Session var r1 int64 var r2 error @@ -254,6 +274,10 @@ func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(string) (map[strin func (_m *MockSession) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) { ret := _m.Called(prefix, r) + if len(ret) == 0 { + panic("no return value specified for GetSessionsWithVersionRange") + } + var r0 map[string]*Session var r1 int64 var r2 error @@ -312,10 +336,14 @@ func (_c *MockSession_GetSessionsWithVersionRange_Call) RunAndReturn(run func(st return _c } -// GoingStop provides a mock function with given fields: +// GoingStop provides a mock function with no fields func (_m *MockSession) GoingStop() error { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for GoingStop") + } + var r0 error if rf, ok := ret.Get(0).(func() error); ok { r0 = rf() @@ -385,14 +413,18 @@ func (_c *MockSession_Init_Call) Return() *MockSession_Init_Call { } func (_c *MockSession_Init_Call) RunAndReturn(run func(string, string, bool, bool)) *MockSession_Init_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// IsTriggerKill provides a mock function with given fields: +// IsTriggerKill provides a mock function with no fields func (_m *MockSession) IsTriggerKill() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for IsTriggerKill") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -460,14 +492,18 @@ func (_c *MockSession_LivenessCheck_Call) Return() *MockSession_LivenessCheck_Ca } func (_c *MockSession_LivenessCheck_Call) RunAndReturn(run func(context.Context, func())) *MockSession_LivenessCheck_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// MarshalJSON provides a mock function with given fields: +// MarshalJSON provides a mock function with no fields func (_m *MockSession) MarshalJSON() ([]byte, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for MarshalJSON") + } + var r0 []byte var r1 error if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { @@ -521,6 +557,10 @@ func (_c *MockSession_MarshalJSON_Call) RunAndReturn(run func() ([]byte, error)) func (_m *MockSession) ProcessActiveStandBy(activateFunc func() error) error { ret := _m.Called(activateFunc) + if len(ret) == 0 { + panic("no return value specified for ProcessActiveStandBy") + } + var r0 error if rf, ok := ret.Get(0).(func(func() error) error); ok { r0 = rf(activateFunc) @@ -559,7 +599,7 @@ func (_c *MockSession_ProcessActiveStandBy_Call) RunAndReturn(run func(func() er return _c } -// Register provides a mock function with given fields: +// Register provides a mock function with no fields func (_m *MockSession) Register() { _m.Called() } @@ -587,14 +627,18 @@ func (_c *MockSession_Register_Call) Return() *MockSession_Register_Call { } func (_c *MockSession_Register_Call) RunAndReturn(run func()) *MockSession_Register_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Registered provides a mock function with given fields: +// Registered provides a mock function with no fields func (_m *MockSession) Registered() bool { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Registered") + } + var r0 bool if rf, ok := ret.Get(0).(func() bool); ok { r0 = rf() @@ -661,7 +705,7 @@ func (_c *MockSession_Revoke_Call) Return() *MockSession_Revoke_Call { } func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -694,7 +738,7 @@ func (_c *MockSession_SetDisconnected_Call) Return() *MockSession_SetDisconnecte } func (_c *MockSession_SetDisconnected_Call) RunAndReturn(run func(bool)) *MockSession_SetDisconnected_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -727,11 +771,11 @@ func (_c *MockSession_SetEnableActiveStandBy_Call) Return() *MockSession_SetEnab } func (_c *MockSession_SetEnableActiveStandBy_Call) RunAndReturn(run func(bool)) *MockSession_SetEnableActiveStandBy_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// Stop provides a mock function with given fields: +// Stop provides a mock function with no fields func (_m *MockSession) Stop() { _m.Called() } @@ -759,14 +803,18 @@ func (_c *MockSession_Stop_Call) Return() *MockSession_Stop_Call { } func (_c *MockSession_Stop_Call) RunAndReturn(run func()) *MockSession_Stop_Call { - _c.Call.Return(run) + _c.Run(run) return _c } -// String provides a mock function with given fields: +// String provides a mock function with no fields func (_m *MockSession) String() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for String") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -808,6 +856,10 @@ func (_c *MockSession_String_Call) RunAndReturn(run func() string) *MockSession_ func (_m *MockSession) UnmarshalJSON(data []byte) error { ret := _m.Called(data) + if len(ret) == 0 { + panic("no return value specified for UnmarshalJSON") + } + var r0 error if rf, ok := ret.Get(0).(func([]byte) error); ok { r0 = rf(data) @@ -875,7 +927,7 @@ func (_c *MockSession_UpdateRegistered_Call) Return() *MockSession_UpdateRegiste } func (_c *MockSession_UpdateRegistered_Call) RunAndReturn(run func(bool)) *MockSession_UpdateRegistered_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -883,6 +935,10 @@ func (_c *MockSession_UpdateRegistered_Call) RunAndReturn(run func(bool)) *MockS func (_m *MockSession) WatchServices(prefix string, revision int64, rewatch Rewatch) <-chan *SessionEvent { ret := _m.Called(prefix, revision, rewatch) + if len(ret) == 0 { + panic("no return value specified for WatchServices") + } + var r0 <-chan *SessionEvent if rf, ok := ret.Get(0).(func(string, int64, Rewatch) <-chan *SessionEvent); ok { r0 = rf(prefix, revision, rewatch) @@ -929,6 +985,10 @@ func (_c *MockSession_WatchServices_Call) RunAndReturn(run func(string, int64, R func (_m *MockSession) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) <-chan *SessionEvent { ret := _m.Called(prefix, r, revision, rewatch) + if len(ret) == 0 { + panic("no return value specified for WatchServicesWithVersionRange") + } + var r0 <-chan *SessionEvent if rf, ok := ret.Get(0).(func(string, semver.Range, int64, Rewatch) <-chan *SessionEvent); ok { r0 = rf(prefix, r, revision, rewatch)