diff --git a/cmd/tools/migration/migration/runner.go b/cmd/tools/migration/migration/runner.go index 15a86a3319..08b82cb543 100644 --- a/cmd/tools/migration/migration/runner.go +++ b/cmd/tools/migration/migration/runner.go @@ -46,7 +46,7 @@ func NewRunner(ctx context.Context, cfg *configs.Config) *Runner { func (r *Runner) watchByPrefix(prefix string) { defer r.wg.Done() - _, revision, err := r.session.GetSessions(prefix) + _, revision, err := r.session.GetSessions(r.ctx, prefix) fn := func() { r.Stop() } console.AbnormalExitIf(err, r.backupFinished.Load(), console.AddCallbacks(fn)) watcher := r.session.WatchServices(prefix, revision, nil) @@ -128,7 +128,7 @@ func (r *Runner) CheckCompatible() bool { } func (r *Runner) checkSessionsWithPrefix(prefix string) error { - sessions, _, err := r.session.GetSessions(prefix) + sessions, _, err := r.session.GetSessions(r.ctx, prefix) if err != nil { return err } @@ -139,7 +139,7 @@ func (r *Runner) checkSessionsWithPrefix(prefix string) error { } func (r *Runner) checkMySelf() error { - sessions, _, err := r.session.GetSessions(Role) + sessions, _, err := r.session.GetSessions(r.ctx, Role) if err != nil { return err } diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 535c62f3a7..ce60ba4411 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -235,7 +235,7 @@ func (s *Server) Register() error { } func (s *Server) ServerExist(serverID int64) bool { - sessions, _, err := s.session.GetSessions(typeutil.DataNodeRole) + sessions, _, err := s.session.GetSessions(s.ctx, typeutil.DataNodeRole) if err != nil { log.Ctx(s.ctx).Warn("failed to get sessions", zap.Error(err)) return false @@ -547,7 +547,7 @@ func (s *Server) initServiceDiscovery() error { } s.indexEngineVersionManager = newIndexEngineVersionManager() - qnSessions, qnRevision, err := s.session.GetSessions(typeutil.QueryNodeRole) + qnSessions, qnRevision, err := s.session.GetSessions(s.ctx, typeutil.QueryNodeRole) if err != nil { log.Warn("DataCoord get QueryNode sessions failed", zap.Error(err)) return err diff --git a/internal/distributed/connection_manager.go b/internal/distributed/connection_manager.go index 0730b498e7..05e5973c38 100644 --- a/internal/distributed/connection_manager.go +++ b/internal/distributed/connection_manager.go @@ -105,7 +105,7 @@ func (cm *ConnectionManager) AddDependency(roleName string) error { } cm.dependencies[roleName] = struct{}{} - msess, rev, err := cm.session.GetSessions(roleName) + msess, rev, err := cm.session.GetSessions(context.TODO(), roleName) if err != nil { log.Debug("ClientManager GetSessions failed", zap.String("roleName", roleName)) return err diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index bf667f2d69..bd7a80b0f6 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -50,10 +50,8 @@ type DataNodeClient struct { // Client is the grpc client for DataNode type Client struct { grpcClient grpcclient.GrpcClient[DataNodeClient] - sess *sessionutil.Session addr string serverID int64 - ctx context.Context } // NewClient creates a client for DataNode. @@ -61,7 +59,7 @@ func NewClient(ctx context.Context, addr string, serverID int64, encryption bool if addr == "" { return nil, errors.New("address is empty") } - sess := sessionutil.NewSession(ctx) + sess := sessionutil.NewSession(context.Background()) if sess == nil { err := errors.New("new session error, maybe can not connect to etcd") log.Ctx(ctx).Debug("DataNodeClient New Etcd Session failed", zap.Error(err)) @@ -72,9 +70,7 @@ func NewClient(ctx context.Context, addr string, serverID int64, encryption bool client := &Client{ addr: addr, grpcClient: grpcclient.NewClientBase[DataNodeClient](config, "milvus.proto.data.DataNode"), - sess: sess, serverID: serverID, - ctx: ctx, } // node shall specify node id client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, serverID)) diff --git a/internal/distributed/mixcoord/client/client.go b/internal/distributed/mixcoord/client/client.go index a97baa2d3f..f713291b83 100644 --- a/internal/distributed/mixcoord/client/client.go +++ b/internal/distributed/mixcoord/client/client.go @@ -68,7 +68,7 @@ type Client struct { // etcdEndpoints are the address list for etcd end points // timeout is default setting for each grpc call func NewClient(ctx context.Context) (types.MixCoordClient, error) { - sess := sessionutil.NewSession(ctx) + sess := sessionutil.NewSession(context.Background()) if sess == nil { err := errors.New("new session error, maybe can not connect to etcd") log.Ctx(ctx).Debug("New MixCoord Client failed", zap.Error(err)) @@ -110,7 +110,7 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) MixCoordClient { func (c *Client) getMixCoordAddr() (string, error) { log := log.Ctx(c.ctx) key := c.grpcClient.GetRole() - msess, _, err := c.sess.GetSessions(key) + msess, _, err := c.sess.GetSessions(c.ctx, key) if err != nil { log.Debug("MixCoordClient GetSessions failed", zap.Any("key", key)) return "", err @@ -135,7 +135,7 @@ func (c *Client) getMixCoordAddr() (string, error) { // compatible with standalone mode upgrade from 2.5, shoule be removed in 3.0 func (c *Client) getCompatibleMixCoordAddr() (string, error) { log := log.Ctx(c.ctx) - msess, _, err := c.sess.GetSessions(typeutil.RootCoordRole) + msess, _, err := c.sess.GetSessions(c.ctx, typeutil.RootCoordRole) if err != nil { log.Debug("mixCoordClient getSessions failed", zap.Any("key", typeutil.RootCoordRole), zap.Error(err)) return "", errors.New("find no available mixcoord, check mixcoord state") diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index 5f8a4f064b..85aeeb0ffb 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -45,7 +45,6 @@ var Params *paramtable.ComponentParam = paramtable.Get() type Client struct { grpcClient grpcclient.GrpcClient[proxypb.ProxyClient] addr string - sess *sessionutil.Session } // NewClient creates a new client instance @@ -53,7 +52,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien if addr == "" { return nil, errors.New("address is empty") } - sess := sessionutil.NewSession(ctx) + sess := sessionutil.NewSession(context.Background()) if sess == nil { err := errors.New("new session error, maybe can not connect to etcd") log.Ctx(ctx).Debug("Proxy client new session failed", zap.Error(err)) @@ -63,7 +62,6 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien client := &Client{ addr: addr, grpcClient: grpcclient.NewClientBase[proxypb.ProxyClient](config, "milvus.proto.proxy.Proxy"), - sess: sess, } // node shall specify node id client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.ProxyRole, nodeID)) diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index f05314b00c..2f1e43d99f 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -45,9 +45,7 @@ var Params *paramtable.ComponentParam = paramtable.Get() type Client struct { grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient] addr string - sess *sessionutil.Session nodeID int64 - ctx context.Context } // NewClient creates a new QueryNode client. @@ -55,7 +53,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC if addr == "" { return nil, errors.New("addr is empty") } - sess := sessionutil.NewSession(ctx) + sess := sessionutil.NewSession(context.Background()) if sess == nil { err := errors.New("new session error, maybe can not connect to etcd") log.Ctx(ctx).Debug("QueryNodeClient NewClient failed", zap.Error(err)) @@ -65,9 +63,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC client := &Client{ addr: addr, grpcClient: grpcclient.NewClientBase[querypb.QueryNodeClient](config, "milvus.proto.query.QueryNode"), - sess: sess, nodeID: nodeID, - ctx: ctx, } // node shall specify node id client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.QueryNodeRole, nodeID)) diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 37da8dc0fb..37ba453cbd 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -166,7 +166,7 @@ func (s *Server) SetSession(session sessionutil.SessionInterface) error { } func (s *Server) ServerExist(serverID int64) bool { - sessions, _, err := s.session.GetSessions(typeutil.QueryNodeRole) + sessions, _, err := s.session.GetSessions(s.ctx, typeutil.QueryNodeRole) if err != nil { log.Ctx(s.ctx).Warn("failed to get sessions", zap.Error(err)) return false @@ -486,7 +486,7 @@ func (s *Server) Start() error { func (s *Server) startQueryCoord() error { log.Ctx(s.ctx).Info("start watcher...") - sessions, revision, err := s.session.GetSessions(typeutil.QueryNodeRole) + sessions, revision, err := s.session.GetSessions(s.ctx, typeutil.QueryNodeRole) if err != nil { return err } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 47d8f4eb17..052ab3a20b 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -327,7 +327,7 @@ func (node *QueryNode) Init() error { return NewLocalWorker(node), nil } - sessions, _, err := node.session.GetSessions(typeutil.QueryNodeRole) + sessions, _, err := node.session.GetSessions(node.ctx, typeutil.QueryNodeRole) if err != nil { return nil, err } diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 61b28317b2..b4dafadfdf 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -379,7 +379,7 @@ func (c *ClientBase[T]) verifySession(ctx context.Context) error { } c.lastSessionCheck.Store(time.Now()) if c.sess != nil { - sessions, _, getSessionErr := c.sess.GetSessions(c.GetRole()) + sessions, _, getSessionErr := c.sess.GetSessions(ctx, c.GetRole()) if getSessionErr != nil { // Only log but not handle this error as it is an auxiliary logic log.Warn("fail to get session", zap.Error(getSessionErr)) diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go index 1b51c09217..6c08f98eed 100644 --- a/internal/util/grpcclient/client_test.go +++ b/internal/util/grpcclient/client_test.go @@ -107,7 +107,7 @@ func TestClientBase_NodeSessionNotExist(t *testing.T) { }) base.role = typeutil.QueryNodeRole mockSession := sessionutil.NewMockSession(t) - mockSession.EXPECT().GetSessions(mock.Anything).Return(nil, 0, nil) + mockSession.EXPECT().GetSessions(mock.Anything, mock.Anything).Return(nil, 0, nil) base.sess = mockSession base.grpcClientMtx.Lock() base.grpcClient = nil @@ -551,7 +551,7 @@ func TestVerifySession(t *testing.T) { base := ClientBase[*mockClient]{} mockSession := sessionutil.NewMockSession(t) expectedErr := errors.New("mocked") - mockSession.EXPECT().GetSessions(mock.Anything).Return(nil, 0, expectedErr) + mockSession.EXPECT().GetSessions(mock.Anything, mock.Anything).Return(nil, 0, expectedErr) base.sess = mockSession ctx := context.Background() @@ -562,7 +562,7 @@ func TestVerifySession(t *testing.T) { base.NodeID = *atomic.NewInt64(1) base.role = typeutil.RootCoordRole mockSession2 := sessionutil.NewMockSession(t) - mockSession2.EXPECT().GetSessions(mock.Anything).Return( + mockSession2.EXPECT().GetSessions(mock.Anything, mock.Anything).Return( map[string]*sessionutil.Session{ typeutil.RootCoordRole: { SessionRaw: sessionutil.SessionRaw{ diff --git a/internal/util/sessionutil/mock_session.go b/internal/util/sessionutil/mock_session.go index 3014ca5de8..adbe11a2d3 100644 --- a/internal/util/sessionutil/mock_session.go +++ b/internal/util/sessionutil/mock_session.go @@ -3,10 +3,10 @@ package sessionutil import ( + context "context" + semver "github.com/blang/semver/v4" mock "github.com/stretchr/testify/mock" - - time "time" ) // MockSession is an autogenerated mock type for the SessionInterface type @@ -157,9 +157,9 @@ func (_c *MockSession_GetServerID_Call) RunAndReturn(run func() int64) *MockSess return _c } -// GetSessions provides a mock function with given fields: prefix -func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) { - ret := _m.Called(prefix) +// GetSessions provides a mock function with given fields: ctx, prefix +func (_m *MockSession) GetSessions(ctx context.Context, prefix string) (map[string]*Session, int64, error) { + ret := _m.Called(ctx, prefix) if len(ret) == 0 { panic("no return value specified for GetSessions") @@ -168,25 +168,25 @@ func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, e var r0 map[string]*Session var r1 int64 var r2 error - if rf, ok := ret.Get(0).(func(string) (map[string]*Session, int64, error)); ok { - return rf(prefix) + if rf, ok := ret.Get(0).(func(context.Context, string) (map[string]*Session, int64, error)); ok { + return rf(ctx, prefix) } - if rf, ok := ret.Get(0).(func(string) map[string]*Session); ok { - r0 = rf(prefix) + if rf, ok := ret.Get(0).(func(context.Context, string) map[string]*Session); ok { + r0 = rf(ctx, prefix) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[string]*Session) } } - if rf, ok := ret.Get(1).(func(string) int64); ok { - r1 = rf(prefix) + if rf, ok := ret.Get(1).(func(context.Context, string) int64); ok { + r1 = rf(ctx, prefix) } else { r1 = ret.Get(1).(int64) } - if rf, ok := ret.Get(2).(func(string) error); ok { - r2 = rf(prefix) + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, prefix) } else { r2 = ret.Error(2) } @@ -200,14 +200,15 @@ type MockSession_GetSessions_Call struct { } // GetSessions is a helper method to define mock.On call +// - ctx context.Context // - prefix string -func (_e *MockSession_Expecter) GetSessions(prefix interface{}) *MockSession_GetSessions_Call { - return &MockSession_GetSessions_Call{Call: _e.mock.On("GetSessions", prefix)} +func (_e *MockSession_Expecter) GetSessions(ctx interface{}, prefix interface{}) *MockSession_GetSessions_Call { + return &MockSession_GetSessions_Call{Call: _e.mock.On("GetSessions", ctx, prefix)} } -func (_c *MockSession_GetSessions_Call) Run(run func(prefix string)) *MockSession_GetSessions_Call { +func (_c *MockSession_GetSessions_Call) Run(run func(ctx context.Context, prefix string)) *MockSession_GetSessions_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + run(args[0].(context.Context), args[1].(string)) }) return _c } @@ -217,7 +218,7 @@ func (_c *MockSession_GetSessions_Call) Return(_a0 map[string]*Session, _a1 int6 return _c } -func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(string) (map[string]*Session, int64, error)) *MockSession_GetSessions_Call { +func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(context.Context, string) (map[string]*Session, int64, error)) *MockSession_GetSessions_Call { _c.Call.Return(run) return _c } @@ -594,39 +595,6 @@ func (_c *MockSession_Registered_Call) RunAndReturn(run func() bool) *MockSessio return _c } -// Revoke provides a mock function with given fields: timeout -func (_m *MockSession) Revoke(timeout time.Duration) { - _m.Called(timeout) -} - -// MockSession_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' -type MockSession_Revoke_Call struct { - *mock.Call -} - -// Revoke is a helper method to define mock.On call -// - timeout time.Duration -func (_e *MockSession_Expecter) Revoke(timeout interface{}) *MockSession_Revoke_Call { - return &MockSession_Revoke_Call{Call: _e.mock.On("Revoke", timeout)} -} - -func (_c *MockSession_Revoke_Call) Run(run func(timeout time.Duration)) *MockSession_Revoke_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(time.Duration)) - }) - return _c -} - -func (_c *MockSession_Revoke_Call) Return() *MockSession_Revoke_Call { - _c.Call.Return() - return _c -} - -func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call { - _c.Run(run) - return _c -} - // SetDisconnected provides a mock function with given fields: b func (_m *MockSession) SetDisconnected(b bool) { _m.Called(b) diff --git a/internal/util/sessionutil/session.go b/internal/util/sessionutil/session.go index eba0590a6c..cdd0fd1865 100644 --- a/internal/util/sessionutil/session.go +++ b/internal/util/sessionutil/session.go @@ -16,6 +16,8 @@ package sessionutil import ( + "context" + "github.com/blang/semver/v4" ) @@ -27,7 +29,7 @@ type SessionInterface interface { String() string Register() - GetSessions(prefix string) (map[string]*Session, int64, error) + GetSessions(ctx context.Context, prefix string) (map[string]*Session, int64, error) GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error) GoingStop() error diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 674c10bc64..48970f2999 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -612,10 +612,10 @@ func (s *Session) startKeepAliveLoop() { // GetSessions will get all sessions registered in etcd. // Revision is returned for WatchServices to prevent key events from being missed. -func (s *Session) GetSessions(prefix string) (map[string]*Session, int64, error) { +func (s *Session) GetSessions(ctx context.Context, prefix string) (map[string]*Session, int64, error) { res := make(map[string]*Session) key := path.Join(s.metaRoot, DefaultServiceRoot, prefix) - resp, err := s.etcdCli.Get(s.ctx, key, clientv3.WithPrefix(), + resp, err := s.etcdCli.Get(ctx, key, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) if err != nil { return nil, 0, err @@ -868,7 +868,7 @@ func (w *sessionWatcher) handleWatchErr(err error) error { return err } - sessions, revision, err := w.s.GetSessions(w.prefix) + sessions, revision, err := w.s.GetSessions(w.s.ctx, w.prefix) if err != nil { log.Warn("GetSession before rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) w.closeEventCh() @@ -963,7 +963,7 @@ func (s *Session) ProcessActiveStandBy(activateFunc func() error) error { registerActiveFn := func() (bool, int64, error) { for _, role := range oldRoles { - sessions, _, err := s.GetSessions(role) + sessions, _, err := s.GetSessions(s.ctx, role) if err != nil { log.Debug("failed to get old sessions", zap.String("role", role), zap.Error(err)) continue diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index 3254be4ef3..27a0856548 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -88,7 +88,7 @@ func TestInit(t *testing.T) { assert.NotEqual(t, int64(0), s.LeaseID) assert.NotEqual(t, int64(0), s.ServerID) s.Register() - sessions, _, err := s.GetSessions("inittest") + sessions, _, err := s.GetSessions(ctx, "inittest") assert.NoError(t, err) assert.Contains(t, sessions, "inittest-"+strconv.FormatInt(s.ServerID, 10)) } @@ -111,7 +111,7 @@ func TestInitNoArgs(t *testing.T) { assert.NotEqual(t, int64(0), s.LeaseID) assert.NotEqual(t, int64(0), s.ServerID) s.Register() - sessions, _, err := s.GetSessions("inittest") + sessions, _, err := s.GetSessions(ctx, "inittest") assert.NoError(t, err) assert.Contains(t, sessions, "inittest-"+strconv.FormatInt(s.ServerID, 10)) } @@ -131,7 +131,7 @@ func TestUpdateSessions(t *testing.T) { s := NewSessionWithEtcd(ctx, metaRoot, etcdCli, WithResueNodeID(false)) - sessions, rev, err := s.GetSessions("test") + sessions, rev, err := s.GetSessions(ctx, "test") assert.NoError(t, err) assert.Equal(t, len(sessions), 0) watcher := s.WatchServices("test", rev, nil) @@ -155,15 +155,15 @@ func TestUpdateSessions(t *testing.T) { wg.Wait() assert.Eventually(t, func() bool { - sessions, _, _ := s.GetSessions("test") + sessions, _, _ := s.GetSessions(ctx, "test") return len(sessions) == 10 }, 10*time.Second, 100*time.Millisecond) - notExistSessions, _, _ := s.GetSessions("testt") + notExistSessions, _, _ := s.GetSessions(ctx, "testt") assert.Equal(t, len(notExistSessions), 0) etcdKV.RemoveWithPrefix(ctx, metaRoot) assert.Eventually(t, func() bool { - sessions, _, _ := s.GetSessions("test") + sessions, _, _ := s.GetSessions(ctx, "test") return len(sessions) == 0 }, 10*time.Second, 100*time.Millisecond)