diff --git a/cmd/tools/migration/migration/runner.go b/cmd/tools/migration/migration/runner.go index 89defd88de..15a86a3319 100644 --- a/cmd/tools/migration/migration/runner.go +++ b/cmd/tools/migration/migration/runner.go @@ -165,7 +165,6 @@ func (r *Runner) CheckSessions() error { func (r *Runner) RegisterSession() error { r.session.Register() - r.session.LivenessCheck(r.ctx, func() {}) return nil } @@ -246,7 +245,7 @@ func (r *Runner) waitUntilSessionExpired() { } func (r *Runner) Stop() { - r.session.Revoke(time.Second) + r.session.Stop() r.waitUntilSessionExpired() r.cancel() r.wg.Wait() diff --git a/internal/coordinator/mix_coord.go b/internal/coordinator/mix_coord.go index 1dbdca2f2d..742a1caac5 100644 --- a/internal/coordinator/mix_coord.go +++ b/internal/coordinator/mix_coord.go @@ -114,10 +114,6 @@ func (s *mixCoordImpl) Register() error { afterRegister := func() { metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.MixCoordRole).Inc() log.Info("MixCoord Register Finished") - s.session.LivenessCheck(s.ctx, func() { - log.Error("MixCoord disconnected from etcd, process will exit", zap.Int64("serverID", s.session.GetServerID())) - os.Exit(1) - }) } if s.enableActiveStandBy { go func() { diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index f37a21ea8d..810f20510f 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -24,7 +24,6 @@ import ( "fmt" "io" "math/rand" - "os" "sync" "time" @@ -157,12 +156,6 @@ func (node *DataNode) Register() error { metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.DataNodeRole).Inc() log.Info("DataNode Register Finished") - // Start liveness check - node.session.LivenessCheck(node.ctx, func() { - log.Error("Data Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.GetSession().ServerID)) - os.Exit(1) - }) - return nil } diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 22807c193a..b0a5e9718c 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -18,7 +18,6 @@ package streamingnode import ( "context" - "os" "strconv" "sync" "time" @@ -251,7 +250,7 @@ func (s *Server) start() (err error) { return errors.Wrap(err, "StreamingNode start gRPC server fail") } // Register current server to etcd. - s.registerSessionToETCD() + s.session.Register() s.componentState.OnInitialized(s.session.ServerID) return nil @@ -382,13 +381,3 @@ func (s *Server) startGPRCServer(ctx context.Context) error { funcutil.CheckGrpcReady(ctx, errCh) return <-errCh } - -// registerSessionToETCD registers current server to etcd. -func (s *Server) registerSessionToETCD() { - s.session.Register() - // start liveness check - s.session.LivenessCheck(context.Background(), func() { - log.Ctx(s.ctx).Error("StreamingNode disconnected from etcd, process will exit", zap.Int64("Server Id", paramtable.GetNodeID())) - os.Exit(1) - }) -} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 09407c531c..e3a00774ea 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "math/rand" - "os" "sync" "time" @@ -157,10 +156,6 @@ func (node *Proxy) Register() error { node.session.Register() metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.ProxyRole).Inc() log.Info("Proxy Register Finished") - node.session.LivenessCheck(node.ctx, func() { - log.Error("Proxy disconnected from etcd, process will exit", zap.Int64("Server Id", node.session.ServerID)) - os.Exit(1) - }) // TODO Reset the logger // Params.initLogCfg() return nil diff --git a/internal/querycoordv2/mocks/querynode.go b/internal/querycoordv2/mocks/querynode.go index e38ff0e4fb..0cb339bdb7 100644 --- a/internal/querycoordv2/mocks/querynode.go +++ b/internal/querycoordv2/mocks/querynode.go @@ -21,7 +21,6 @@ import ( "net" "sync" "testing" - "time" "github.com/stretchr/testify/mock" clientv3 "go.etcd.io/etcd/client/v3" @@ -149,7 +148,7 @@ func (node *MockQueryNode) Stopping() { func (node *MockQueryNode) Stop() { node.cancel() node.server.GracefulStop() - node.session.Revoke(time.Second) + node.session.Stop() } func (node *MockQueryNode) getAllChannels() []*querypb.ChannelVersionInfo { diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index f1a986d508..e76fafad81 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -32,7 +32,6 @@ import "C" import ( "context" "fmt" - "os" "plugin" "strings" "sync" @@ -182,10 +181,6 @@ func (node *QueryNode) Register() error { node.session.Register() // start liveness check metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.QueryNodeRole).Inc() - node.session.LivenessCheck(node.ctx, func() { - log.Ctx(node.ctx).Error("Query Node disconnected from etcd, process will exit", zap.Int64("Server Id", paramtable.GetNodeID())) - os.Exit(1) - }) return nil } diff --git a/internal/util/sessionutil/mock_session.go b/internal/util/sessionutil/mock_session.go index d060f766ba..3014ca5de8 100644 --- a/internal/util/sessionutil/mock_session.go +++ b/internal/util/sessionutil/mock_session.go @@ -3,8 +3,6 @@ package sessionutil import ( - context "context" - semver "github.com/blang/semver/v4" mock "github.com/stretchr/testify/mock" @@ -416,40 +414,6 @@ func (_c *MockSession_IsTriggerKill_Call) RunAndReturn(run func() bool) *MockSes return _c } -// LivenessCheck provides a mock function with given fields: ctx, callback -func (_m *MockSession) LivenessCheck(ctx context.Context, callback func()) { - _m.Called(ctx, callback) -} - -// MockSession_LivenessCheck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LivenessCheck' -type MockSession_LivenessCheck_Call struct { - *mock.Call -} - -// LivenessCheck is a helper method to define mock.On call -// - ctx context.Context -// - callback func() -func (_e *MockSession_Expecter) LivenessCheck(ctx interface{}, callback interface{}) *MockSession_LivenessCheck_Call { - return &MockSession_LivenessCheck_Call{Call: _e.mock.On("LivenessCheck", ctx, callback)} -} - -func (_c *MockSession_LivenessCheck_Call) Run(run func(ctx context.Context, callback func())) *MockSession_LivenessCheck_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(func())) - }) - return _c -} - -func (_c *MockSession_LivenessCheck_Call) Return() *MockSession_LivenessCheck_Call { - _c.Call.Return() - return _c -} - -func (_c *MockSession_LivenessCheck_Call) RunAndReturn(run func(context.Context, func())) *MockSession_LivenessCheck_Call { - _c.Run(run) - return _c -} - // MarshalJSON provides a mock function with no fields func (_m *MockSession) MarshalJSON() ([]byte, error) { ret := _m.Called() diff --git a/internal/util/sessionutil/session.go b/internal/util/sessionutil/session.go index 12c23aab38..eba0590a6c 100644 --- a/internal/util/sessionutil/session.go +++ b/internal/util/sessionutil/session.go @@ -16,9 +16,6 @@ package sessionutil import ( - "context" - "time" - "github.com/blang/semver/v4" ) @@ -36,9 +33,7 @@ type SessionInterface interface { GoingStop() error WatchServices(prefix string, revision int64, rewatch Rewatch) (watcher SessionWatcher) WatchServicesWithVersionRange(prefix string, r semver.Range, revision int64, rewatch Rewatch) (watcher SessionWatcher) - LivenessCheck(ctx context.Context, callback func()) Stop() - Revoke(timeout time.Duration) UpdateRegistered(b bool) Registered() bool SetDisconnected(b bool) diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 21f2c5554a..674c10bc64 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -28,6 +28,7 @@ import ( "time" "github.com/blang/semver/v4" + "github.com/cenkalti/backoff/v4" "github.com/cockroachdb/errors" "go.etcd.io/etcd/api/v3/mvccpb" v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" @@ -55,6 +56,7 @@ const ( LabelStreamingNodeEmbeddedQueryNode = "QUERYNODE_STREAMING-EMBEDDED" LabelStandalone = "STANDALONE" MilvusNodeIDForTesting = "MILVUS_NODE_ID_FOR_TESTING" + exitCodeSessionLeaseExpired = 1 ) // EnableEmbededQueryNodeLabel set server labels for embedded query node. @@ -141,21 +143,19 @@ func (s *SessionRaw) IsTriggerKill() bool { // Session is a struct to store service's session, including ServerID, ServerName, // Address. // Exclusive indicates that this server can only start one. +// TODO: it's a bad implementation to mix up the service registration and service diescovery into one struct. +// because the registration is used by server side, but the discovery is used by client side. +// we should split the service registration and service diescovery. type Session struct { - ctx context.Context - // When outside context done, Session cancels its goroutines first, then uses - // keepAliveCancel to cancel the etcd KeepAlive - keepAliveMu sync.Mutex - keepAliveCancel context.CancelFunc - keepAliveCtx context.Context + log.Binder + + ctx context.Context + cancel context.CancelFunc SessionRaw Version semver.Version `json:"Version,omitempty"` - liveChOnce sync.Once - liveCh chan struct{} - etcdCli *clientv3.Client watchSessionKeyCh clientv3.WatchChan watchCancel atomic.Pointer[context.CancelFunc] @@ -173,8 +173,6 @@ type Session struct { sessionTTL int64 sessionRetryTimes int64 reuseNodeID bool - - isStopped atomic.Bool // set to true if stop method is invoked } type SessionOption func(session *Session) @@ -258,8 +256,11 @@ func NewSessionWithEtcd(ctx context.Context, metaRoot string, client *clientv3.C log.Ctx(ctx).Error("get host name fail", zap.Error(hostNameErr)) } + ctx, cancel := context.WithCancel(ctx) session := &Session{ - ctx: ctx, + ctx: ctx, + cancel: cancel, + metaRoot: metaRoot, Version: common.Version, @@ -271,7 +272,6 @@ func NewSessionWithEtcd(ctx context.Context, metaRoot string, client *clientv3.C sessionTTL: paramtable.Get().CommonCfg.SessionTTL.GetAsInt64(), sessionRetryTimes: paramtable.Get().CommonCfg.SessionRetryTimes.GetAsInt64(), reuseNodeID: true, - isStopped: *atomic.NewBool(false), } // integration test create cluster with different nodeId in one process @@ -283,10 +283,6 @@ func NewSessionWithEtcd(ctx context.Context, metaRoot string, client *clientv3.C session.UpdateRegistered(false) session.etcdCli = client - log.Ctx(ctx).Info("Successfully connected to etcd for session", - zap.String("metaRoot", metaRoot), - zap.String("hostName", hostName), - ) return session } @@ -304,7 +300,13 @@ func (s *Session) Init(serverName, address string, exclusive bool, triggerKill b } s.ServerID = serverID s.ServerLabels = GetServerLabelsFromEnv(serverName) - log.Info("start server", zap.String("name", serverName), zap.String("address", address), zap.Int64("id", s.ServerID), zap.Any("server_labels", s.ServerLabels)) + + s.SetLogger(log.With( + log.FieldComponent("service-registration"), + zap.String("role", serverName), + zap.Int64("serverID", s.ServerID), + zap.String("address", address), + )) } // String makes Session struct able to be logged by zap @@ -314,14 +316,13 @@ func (s *Session) String() string { // Register will process keepAliveResponse to keep alive with etcd. func (s *Session) Register() { - ch, err := s.registerService() + err := s.registerService() if err != nil { - log.Error("Register failed", zap.Error(err)) + s.Logger().Error("register failed", zap.Error(err)) panic(err) } - s.liveCh = make(chan struct{}) - s.startKeepAliveLoop(ch) s.UpdateRegistered(true) + s.startKeepAliveLoop() } var serverIDMu sync.Mutex @@ -425,39 +426,6 @@ func (s *Session) getCompleteKey() string { return path.Join(s.metaRoot, DefaultServiceRoot, key) } -func (s *Session) getSessionKey() string { - key := s.ServerName - if !s.Exclusive { - key = fmt.Sprintf("%s-%d", key, s.ServerID) - } - return path.Join(s.metaRoot, DefaultServiceRoot, key) -} - -func (s *Session) initWatchSessionCh(ctx context.Context) error { - var ( - err error - getResp *clientv3.GetResponse - ) - - ctx, cancel := context.WithCancel(ctx) - if old := s.watchCancel.Load(); old != nil { - (*old)() - } - s.watchCancel.Store(&cancel) - - err = retry.Do(ctx, func() error { - getResp, err = s.etcdCli.Get(ctx, s.getSessionKey()) - return err - }, retry.Attempts(uint(s.sessionRetryTimes))) - if err != nil { - log.Warn("fail to get the session key from the etcd", zap.Error(err)) - cancel() - return err - } - s.watchSessionKeyCh = s.etcdCli.Watch(ctx, s.getSessionKey(), clientv3.WithRev(getResp.Header.Revision)) - return nil -} - // registerService registers the service to etcd so that other services // can find that the service is online and issue subsequent operations // RegisterService will save a key-value in etcd @@ -473,26 +441,24 @@ func (s *Session) initWatchSessionCh(ctx context.Context) error { // // Exclusive means whether this service can exist two at the same time, if so, // it is false. Otherwise, set it to true. -func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, error) { +func (s *Session) registerService() error { if s.enableActiveStandBy { s.updateStandby(true) } completeKey := s.getCompleteKey() - var ch <-chan *clientv3.LeaseKeepAliveResponse - log := log.Ctx(s.ctx) - log.Debug("service begin to register to etcd", zap.String("serverName", s.ServerName), zap.Int64("ServerID", s.ServerID)) + s.Logger().Info("service begin to register to etcd") registerFn := func() error { resp, err := s.etcdCli.Grant(s.ctx, s.sessionTTL) if err != nil { - log.Error("register service: failed to grant lease from etcd", zap.Error(err)) + s.Logger().Error("register service: failed to grant lease from etcd", zap.Error(err)) return err } s.LeaseID = &resp.ID sessionJSON, err := json.Marshal(s) if err != nil { - log.Error("register service: failed to marshal session", zap.Error(err)) + s.Logger().Error("register service: failed to marshal session", zap.Error(err)) return err } @@ -503,31 +469,17 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er 0)). Then(clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() if err != nil { - log.Warn("register on etcd error, check the availability of etcd", zap.Error(err)) + s.Logger().Warn("register on etcd error, check the availability of etcd", zap.Error(err)) return err } if txnResp != nil && !txnResp.Succeeded { s.handleRestart(completeKey) return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ServerName) } - log.Info("put session key into etcd", zap.String("key", completeKey), zap.String("value", string(sessionJSON))) - - ctx, cancel := context.WithCancel(s.ctx) - ch, err = s.etcdCli.KeepAlive(ctx, resp.ID) - if err != nil { - log.Warn("failed to keep alive with etcd", zap.Int64("lease ID", int64(resp.ID)), zap.Error(err)) - cancel() - return err - } - s.setKeepAliveContext(ctx, cancel) - log.Info("Service registered successfully", zap.String("ServerName", s.ServerName), zap.Int64("serverID", s.ServerID)) + s.Logger().Info("put session key into etcd, service registered successfully", zap.String("key", completeKey), zap.String("value", string(sessionJSON))) return nil } - err := retry.Do(s.ctx, registerFn, retry.Attempts(uint(s.sessionRetryTimes))) - if err != nil { - return nil, err - } - return ch, nil + return retry.Do(s.ctx, registerFn, retry.Attempts(uint(s.sessionRetryTimes))) } // Handle restart is fast path to handle node restart. @@ -562,108 +514,100 @@ func (s *Session) handleRestart(key string) { // processKeepAliveResponse processes the response of etcd keepAlive interface // If keepAlive fails for unexpected error, it will send a signal to the channel. -func (s *Session) processKeepAliveResponse(ch <-chan *clientv3.LeaseKeepAliveResponse) { - defer s.wg.Done() - for { - select { - case <-s.ctx.Done(): - log.Warn("session context canceled, stop keepalive") - s.cancelKeepAlive(true) - return - - case resp, ok := <-ch: - if !ok || resp == nil { - log.Warn("keepalive channel closed", - zap.Bool("stopped", s.isStopped.Load()), - zap.String("serverName", s.ServerName)) - - if s.isStopped.Load() { - s.safeCloseLiveCh() - return - } - - s.cancelKeepAlive(false) - - // this is just to make sure etcd is alived, and the lease can be keep alived ASAP - _, err := s.etcdCli.KeepAliveOnce(s.ctx, *s.LeaseID) - if err != nil { - log.Info("failed to keep alive", zap.String("serverName", s.ServerName), zap.Error(err)) - s.safeCloseLiveCh() - return - } - - var ( - newCh <-chan *clientv3.LeaseKeepAliveResponse - newCtx context.Context - newCancel context.CancelFunc - ) - - err = retry.Do(s.ctx, func() error { - ctx, cancel := context.WithCancel(s.ctx) - ch, err := s.etcdCli.KeepAlive(ctx, *s.LeaseID) - if err != nil { - cancel() - log.Warn("failed to keep alive with etcd", zap.Error(err)) - return err - } - newCh = ch - newCtx = ctx - newCancel = cancel - return nil - }, retry.Attempts(3)) - if err != nil { - log.Warn("failed to retry keepAlive", - zap.String("serverName", s.ServerName), - zap.Error(err)) - s.safeCloseLiveCh() - return - } - log.Info("retry keep alive success", zap.String("serverName", s.ServerName)) - s.setKeepAliveContext(newCtx, newCancel) - ch = newCh - continue - } - } - } -} - -func (s *Session) startKeepAliveLoop(ch <-chan *clientv3.LeaseKeepAliveResponse) { - s.wg.Add(1) - go s.processKeepAliveResponse(ch) -} - -func (s *Session) setKeepAliveContext(ctx context.Context, cancel context.CancelFunc) { - s.keepAliveMu.Lock() - s.keepAliveCtx = ctx - s.keepAliveCancel = cancel - s.keepAliveMu.Unlock() -} - -// cancelKeepAlive cancels the keepAlive context and sets a flag to control whether keepAlive retry is allowed. -func (s *Session) cancelKeepAlive(isStop bool) { - var cancel context.CancelFunc - s.keepAliveMu.Lock() +func (s *Session) processKeepAliveResponse() { defer func() { - s.keepAliveMu.Unlock() - if cancel != nil { - cancel() + s.Logger().Info("keep alive loop exited successfully, try to revoke lease right away...") + // here the s.ctx may be already done, so we use context.Background() with a timeout to revoke the lease. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if _, err := s.etcdCli.Revoke(ctx, *s.LeaseID); err != nil { + s.Logger().Error("failed to revoke lease", zap.Error(err), zap.Int64("leaseID", int64(*s.LeaseID))) } + s.Logger().Info("lease revoked successfully", zap.Int64("leaseID", int64(*s.LeaseID))) + s.wg.Done() }() - // only process the first time - if s.isStopped.Load() { - return - } + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = 10 * time.Millisecond + backoff.MaxInterval = 100 * time.Second + backoff.MaxElapsedTime = 0 + backoff.Reset() - // Add a variable to signal whether keepAlive retry is allowed. - // If isDone is true, disable keepAlive retry. - if isStop { - s.isStopped.Store(true) - } + var ch <-chan *clientv3.LeaseKeepAliveResponse + var lastErr error + nextKeepaliveInstant := time.Now().Add(time.Duration(s.sessionTTL) * time.Second) - cancel = s.keepAliveCancel - s.keepAliveCtx = nil - s.keepAliveCancel = nil + for { + if s.ctx.Err() != nil { + return + } + if lastErr != nil { + nextBackoffInterval := backoff.NextBackOff() + s.Logger().Warn("failed to start keep alive, wait for retry...", zap.Error(lastErr), zap.Duration("nextBackoffInterval", nextBackoffInterval)) + select { + case <-time.After(nextBackoffInterval): + case <-s.ctx.Done(): + return + } + } + + if ch == nil { + if err := s.checkKeepaliveTTL(nextKeepaliveInstant); err != nil { + lastErr = err + continue + } + newCH, err := s.etcdCli.KeepAlive(s.ctx, *s.LeaseID) + if err != nil { + s.Logger().Error("failed to keep alive with etcd", zap.Error(err)) + lastErr = errors.Wrap(err, "failed to keep alive") + continue + } + s.Logger().Info("keep alive...", zap.Int64("leaseID", int64(*s.LeaseID))) + ch = newCH + } + + // Block until the keep alive failure. + for range ch { + } + + // receive a keep alive response, continue the opeartion. + // the keep alive channel may be closed because of network error, we should retry the keep alive. + ch = nil + nextKeepaliveInstant = time.Now().Add(time.Duration(s.sessionTTL) * time.Second) + lastErr = nil + backoff.Reset() + } +} + +// checkKeepaliveTTL checks the TTL of the lease and returns the error if the lease is not found or expired. +func (s *Session) checkKeepaliveTTL(nextKeepaliveInstant time.Time) error { + errSessionExpiredAtClientSide := errors.New("session expired at client side") + ctx, cancel := context.WithDeadlineCause(s.ctx, nextKeepaliveInstant, errSessionExpiredAtClientSide) + defer cancel() + + ttlResp, err := s.etcdCli.TimeToLive(ctx, *s.LeaseID) + if err != nil { + if errors.Is(err, v3rpc.ErrLeaseNotFound) { + s.Logger().Error("confirm the lease is not found, the session is expired without activing closing", zap.Error(err)) + os.Exit(exitCodeSessionLeaseExpired) + } + if ctx.Err() != nil && errors.Is(context.Cause(ctx), errSessionExpiredAtClientSide) { + s.Logger().Error("session expired at client side, the session is expired without activing closing", zap.Error(err)) + os.Exit(exitCodeSessionLeaseExpired) + } + return errors.Wrap(err, "failed to check TTL") + } + if ttlResp.TTL <= 0 { + s.Logger().Error("confirm the lease is expired, the session is expired without activing closing", zap.Error(err)) + os.Exit(exitCodeSessionLeaseExpired) + } + s.Logger().Info("check TTL success, try to keep alive...", zap.Int64("ttl", ttlResp.TTL)) + return nil +} + +func (s *Session) startKeepAliveLoop() { + s.wg.Add(1) + go s.processKeepAliveResponse() } // GetSessions will get all sessions registered in etcd. @@ -734,7 +678,7 @@ func (s *Session) GoingStop() error { completeKey := s.getCompleteKey() resp, err := s.etcdCli.Get(s.ctx, completeKey, clientv3.WithCountOnly()) if err != nil { - log.Error("fail to get the session", zap.String("key", completeKey), zap.Error(err)) + s.Logger().Error("fail to get the session", zap.String("key", completeKey), zap.Error(err)) return err } if resp.Count == 0 { @@ -743,12 +687,12 @@ func (s *Session) GoingStop() error { s.Stopping = true sessionJSON, err := json.Marshal(s) if err != nil { - log.Error("fail to marshal the session", zap.String("key", completeKey)) + s.Logger().Error("fail to marshal the session", zap.String("key", completeKey)) return err } _, err = s.etcdCli.Put(s.ctx, completeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID)) if err != nil { - log.Error("fail to update the session to stopping state", zap.String("key", completeKey)) + s.Logger().Error("fail to update the session to stopping state", zap.String("key", completeKey)) return err } return nil @@ -950,121 +894,14 @@ func (w *sessionWatcher) EventChannel() <-chan *SessionEvent { return w.eventCh } -// LivenessCheck performs liveness check with provided context and channel -// ctx controls the liveness check loop -// ch is the liveness signal channel, ch is closed only when the session is expired -// callback must be called before liveness check exit, to close the session's owner component -func (s *Session) LivenessCheck(ctx context.Context, callback func()) { - err := s.initWatchSessionCh(ctx) - if err != nil { - log.Error("failed to get session for liveness check", zap.Error(err)) - s.cancelKeepAlive(true) - if callback != nil { - go callback() - } - return - } - - s.wg.Add(1) - go func() { - defer s.wg.Done() - if callback != nil { - // before exit liveness check, callback to exit the session owner - defer func() { - // the callback method will not be invoked if session is stopped. - if ctx.Err() == nil && !s.isStopped.Load() { - go callback() - } - }() - } - defer s.SetDisconnected(true) - for { - select { - case _, ok := <-s.liveCh: - // ok, still alive - if ok { - continue - } - // not ok, connection lost - log.Warn("connection lost detected, shuting down") - return - case <-ctx.Done(): - log.Warn("liveness exits due to context done") - // cancel the etcd keepAlive context - s.cancelKeepAlive(true) - return - case resp, ok := <-s.watchSessionKeyCh: - if !ok { - log.Warn("watch session key channel closed") - s.cancelKeepAlive(true) - return - } - if resp.Err() != nil { - // if not ErrCompacted, just close the channel - if resp.Err() != v3rpc.ErrCompacted { - // close event channel - log.Warn("Watch service found error", zap.Error(resp.Err())) - s.cancelKeepAlive(true) - return - } - log.Warn("Watch service found compacted error", zap.Error(resp.Err())) - err := s.initWatchSessionCh(ctx) - if err != nil { - log.Warn("failed to get session during reconnecting", zap.Error(err)) - s.cancelKeepAlive(true) - } - continue - } - for _, event := range resp.Events { - switch event.Type { - case mvccpb.PUT: - log.Info("register session success", zap.String("role", s.ServerName), zap.String("key", string(event.Kv.Key))) - case mvccpb.DELETE: - log.Info("session key is deleted, exit...", zap.String("role", s.ServerName), zap.String("key", string(event.Kv.Key))) - s.cancelKeepAlive(true) - } - } - } - } - }() -} - func (s *Session) Stop() { log.Info("session stopping", zap.String("serverName", s.ServerName)) - s.Revoke(time.Second) - s.cancelKeepAlive(true) + if s.cancel != nil { + s.cancel() + } s.wg.Wait() } -// Revoke revokes the internal LeaseID for the session key -func (s *Session) Revoke(timeout time.Duration) { - if s == nil { - return - } - log.Info("start to revoke session", zap.String("sessionKey", s.activeKey)) - if s.etcdCli == nil || s.LeaseID == nil { - log.Warn("skip remove session", - zap.String("sessionKey", s.activeKey), - zap.Bool("etcdCliIsNil", s.etcdCli == nil), - zap.Bool("LeaseIDIsNil", s.LeaseID == nil), - ) - return - } - if s.Disconnected() { - log.Warn("skip remove session, connection is disconnected", zap.String("sessionKey", s.activeKey)) - return - } - // can NOT use s.ctx, it may be Done here - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - // ignores resp & error, just do best effort to revoke - _, err := s.etcdCli.Revoke(ctx, *s.LeaseID) - if err != nil { - log.Warn("failed to revoke session", zap.String("sessionKey", s.activeKey), zap.Error(err)) - } - log.Info("revoke session successfully", zap.String("sessionKey", s.activeKey)) -} - // UpdateRegistered update the state of registered. func (s *Session) UpdateRegistered(b bool) { s.registered.Store(b) @@ -1099,15 +936,6 @@ func (s *Session) updateStandby(b bool) { s.isStandby.Store(b) } -func (s *Session) safeCloseLiveCh() { - s.liveChOnce.Do(func() { - close(s.liveCh) - if s.watchCancel.Load() != nil { - (*s.watchCancel.Load())() - } - }) -} - // ProcessActiveStandBy is used by coordinators to do active-standby mechanism. // coordinator enabled active-standby will first call Register and then call ProcessActiveStandBy. // steps: diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index ea0cc9392c..3254be4ef3 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -4,8 +4,8 @@ import ( "context" "fmt" "math/rand" - "net/url" "os" + "os/exec" "path" "strconv" "strings" @@ -20,13 +20,11 @@ import ( "github.com/stretchr/testify/suite" "go.etcd.io/etcd/api/v3/mvccpb" clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/server/v3/embed" - "go.etcd.io/etcd/server/v3/etcdserver/api/v3client" - "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/json" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/util/etcd" @@ -38,20 +36,14 @@ import ( func TestGetServerIDConcurrently(t *testing.T) { ctx := context.Background() paramtable.Init() - params := paramtable.Get() - endpoints := params.EtcdCfg.Endpoints.GetValue() metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) - defer etcdCli.Close() + etcdCli, _ := kvfactory.GetEtcdAndPath() etcdKV := etcdkv.NewEtcdKV(etcdCli, metaRoot) - err = etcdKV.RemoveWithPrefix(ctx, "") + err := etcdKV.RemoveWithPrefix(ctx, "") assert.NoError(t, err) - defer etcdKV.Close() defer etcdKV.RemoveWithPrefix(ctx, "") var wg sync.WaitGroup @@ -81,19 +73,14 @@ func TestGetServerIDConcurrently(t *testing.T) { func TestInit(t *testing.T) { ctx := context.Background() paramtable.Init() - params := paramtable.Get() - endpoints := params.EtcdCfg.Endpoints.GetValue() metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) + etcdCli, _ := kvfactory.GetEtcdAndPath() etcdKV := etcdkv.NewEtcdKV(etcdCli, metaRoot) - err = etcdKV.RemoveWithPrefix(ctx, "") + err := etcdKV.RemoveWithPrefix(ctx, "") assert.NoError(t, err) - defer etcdKV.Close() defer etcdKV.RemoveWithPrefix(ctx, "") s := NewSessionWithEtcd(ctx, metaRoot, etcdCli) @@ -109,19 +96,14 @@ func TestInit(t *testing.T) { func TestInitNoArgs(t *testing.T) { ctx := context.Background() paramtable.Init() - params := paramtable.Get() - endpoints := params.EtcdCfg.Endpoints.GetValue() metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) + etcdCli, _ := kvfactory.GetEtcdAndPath() etcdKV := etcdkv.NewEtcdKV(etcdCli, metaRoot) - err = etcdKV.RemoveWithPrefix(ctx, "") + err := etcdKV.RemoveWithPrefix(ctx, "") assert.NoError(t, err) - defer etcdKV.Close() defer etcdKV.RemoveWithPrefix(ctx, "") s := NewSession(ctx) @@ -137,17 +119,11 @@ func TestInitNoArgs(t *testing.T) { func TestUpdateSessions(t *testing.T) { ctx := context.Background() paramtable.Init() - params := paramtable.Get() - endpoints := params.EtcdCfg.Endpoints.GetValue() - etcdEndpoints := strings.Split(endpoints, ",") metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) - defer etcdCli.Close() + etcdCli, _ := kvfactory.GetEtcdAndPath() etcdKV := etcdkv.NewEtcdKV(etcdCli, "") - defer etcdKV.Close() defer etcdKV.RemoveWithPrefix(ctx, "") var wg sync.WaitGroup @@ -163,8 +139,6 @@ func TestUpdateSessions(t *testing.T) { sList := []*Session{} getIDFunc := func() { - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) singleS := NewSessionWithEtcd(ctx, metaRoot, etcdCli, WithResueNodeID(false)) singleS.Init("test", "testAddr", false, false) singleS.Register() @@ -221,86 +195,17 @@ LOOP: assert.Equal(t, delEventLen, 10) } -func TestSessionLivenessCheck(t *testing.T) { - paramtable.Init() - params := paramtable.Get() - - endpoints := params.EtcdCfg.Endpoints.GetValue() - metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) - s := NewSessionWithEtcd(context.Background(), metaRoot, etcdCli) - s.Register() - ch := make(chan struct{}) - s.liveCh = ch - signal := make(chan struct{}, 1) - - flag := atomic.NewBool(false) - s.LivenessCheck(context.Background(), func() { - flag.Store(true) - signal <- struct{}{} - }) - assert.False(t, flag.Load()) - - // test liveCh receive event, liveness won't exit, callback won't trigger - ch <- struct{}{} - assert.False(t, flag.Load()) - - // test close liveCh, liveness exit, callback should trigger - close(ch) - <-signal - assert.True(t, flag.Load()) - - // test context done, liveness exit, callback shouldn't trigger - metaRoot = fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - s1 := NewSessionWithEtcd(context.Background(), metaRoot, etcdCli) - s1.Register() - ctx, cancel := context.WithCancel(context.Background()) - flag.Store(false) - - s1.LivenessCheck(ctx, func() { - flag.Store(true) - signal <- struct{}{} - }) - cancel() - assert.False(t, flag.Load()) - - // test context done, liveness start failed, callback should trigger - metaRoot = fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - s2 := NewSessionWithEtcd(context.Background(), metaRoot, etcdCli) - s2.Register() - ctx, cancel = context.WithCancel(context.Background()) - signal = make(chan struct{}, 1) - flag.Store(false) - cancel() - s2.LivenessCheck(ctx, func() { - flag.Store(true) - signal <- struct{}{} - }) - <-signal - assert.True(t, flag.Load()) -} - func TestWatcherHandleWatchResp(t *testing.T) { ctx := context.Background() paramtable.Init() - params := paramtable.Get() - endpoints := params.EtcdCfg.Endpoints.GetValue() - etcdEndpoints := strings.Split(endpoints, ",") metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) - - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) - defer etcdCli.Close() + etcdCli, _ := kvfactory.GetEtcdAndPath() etcdKV := etcdkv.NewEtcdKV(etcdCli, "/by-dev/session-ut") - defer etcdKV.Close() defer etcdKV.RemoveWithPrefix(ctx, "/by-dev/session-ut") s := NewSessionWithEtcd(ctx, metaRoot, etcdCli) - defer s.Revoke(time.Second) + defer s.Stop() getWatcher := func(s *Session, rewatch Rewatch) *sessionWatcher { return &sessionWatcher{ @@ -405,21 +310,6 @@ func TestWatcherHandleWatchResp(t *testing.T) { w.handleWatchResponse(wresp) }) }) - - t.Run("err handled but list failed", func(t *testing.T) { - s := NewSessionWithEtcd(ctx, "/by-dev/session-ut", etcdCli) - s.etcdCli.Close() - w := getWatcher(s, func(sessions map[string]*Session) error { - return nil - }) - wresp := clientv3.WatchResponse{ - CompactRevision: 1, - } - - assert.Panics(t, func() { - w.handleWatchResponse(wresp) - }) - }) } func TestSession_Registered(t *testing.T) { @@ -473,8 +363,7 @@ func TestSessionUnmarshal(t *testing.T) { type SessionWithVersionSuite struct { suite.Suite - tmpDir string - etcdServer *embed.Etcd + tmpDir string metaRoot string serverName string @@ -482,63 +371,31 @@ type SessionWithVersionSuite struct { client *clientv3.Client } -// SetupSuite setup suite env func (suite *SessionWithVersionSuite) SetupSuite() { - dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut") - suite.Require().NoError(err) - suite.tmpDir = dir - suite.T().Log("using tmp dir:", dir) - - config := embed.NewConfig() - - config.Dir = os.TempDir() - config.LogLevel = "warn" - config.LogOutputs = []string{"default"} - u, err := url.Parse("http://localhost:0") - suite.Require().NoError(err) - - config.ListenClientUrls = []url.URL{*u} - u, err = url.Parse("http://localhost:0") - suite.Require().NoError(err) - config.ListenPeerUrls = []url.URL{*u} - - etcdServer, err := embed.StartEtcd(config) - suite.Require().NoError(err) - suite.etcdServer = etcdServer -} - -func (suite *SessionWithVersionSuite) TearDownSuite() { - if suite.etcdServer != nil { - suite.etcdServer.Close() - } - if suite.tmpDir != "" { - os.RemoveAll(suite.tmpDir) - } + client, _ := kvfactory.GetEtcdAndPath() + suite.client = client } func (suite *SessionWithVersionSuite) SetupTest() { - client := v3client.New(suite.etcdServer.Server) - suite.client = client - ctx := context.Background() suite.metaRoot = "sessionWithVersion" suite.serverName = "sessionComp" - s1 := NewSessionWithEtcd(ctx, suite.metaRoot, client, WithResueNodeID(false)) + s1 := NewSessionWithEtcd(ctx, suite.metaRoot, suite.client, WithResueNodeID(false)) s1.Version.Major, s1.Version.Minor, s1.Version.Patch = 0, 0, 0 s1.Init(suite.serverName, "s1", false, false) s1.Register() suite.sessions = append(suite.sessions, s1) - s2 := NewSessionWithEtcd(ctx, suite.metaRoot, client, WithResueNodeID(false)) + s2 := NewSessionWithEtcd(ctx, suite.metaRoot, suite.client, WithResueNodeID(false)) s2.Version.Major, s2.Version.Minor, s2.Version.Patch = 2, 1, 0 s2.Init(suite.serverName, "s2", false, false) s2.Register() suite.sessions = append(suite.sessions, s2) - s3 := NewSessionWithEtcd(ctx, suite.metaRoot, client, WithResueNodeID(false)) + s3 := NewSessionWithEtcd(ctx, suite.metaRoot, suite.client, WithResueNodeID(false)) s3.Version.Major, s3.Version.Minor, s3.Version.Patch = 2, 2, 0 s3.Version.Build = []string{"dev"} s3.Init(suite.serverName, "s3", false, false) @@ -549,17 +406,13 @@ func (suite *SessionWithVersionSuite) SetupTest() { func (suite *SessionWithVersionSuite) TearDownTest() { for _, s := range suite.sessions { - s.Revoke(time.Second) + s.Stop() } suite.sessions = nil - _, err := suite.client.Delete(context.Background(), suite.metaRoot, clientv3.WithPrefix()) + client, _ := kvfactory.GetEtcdAndPath() + _, err := client.Delete(context.Background(), suite.metaRoot, clientv3.WithPrefix()) suite.Require().NoError(err) - - if suite.client != nil { - suite.client.Close() - suite.client = nil - } } func (suite *SessionWithVersionSuite) TestGetSessionsWithRangeVersion() { @@ -621,7 +474,7 @@ func (suite *SessionWithVersionSuite) TestWatchServicesWithVersionRange() { // remove all sessions go func() { for _, s := range suite.sessions { - s.Revoke(time.Second) + s.Stop() } }() @@ -642,22 +495,16 @@ func TestSessionProcessActiveStandBy(t *testing.T) { ctx := context.TODO() // initial etcd paramtable.Init() - params := paramtable.Get() - endpoints := params.EtcdCfg.Endpoints.GetValue() metaRoot := fmt.Sprintf("%d/%s1", rand.Int(), DefaultServiceRoot) - etcdEndpoints := strings.Split(endpoints, ",") - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - require.NoError(t, err) + etcdCli, _ := kvfactory.GetEtcdAndPath() etcdKV := etcdkv.NewEtcdKV(etcdCli, metaRoot) - err = etcdKV.RemoveWithPrefix(ctx, "") + err := etcdKV.RemoveWithPrefix(ctx, "") assert.NoError(t, err) - defer etcdKV.Close() defer etcdKV.RemoveWithPrefix(ctx, "") var wg sync.WaitGroup - signal := make(chan struct{}) flag := false // register session 1, will be active @@ -673,12 +520,6 @@ func TestSessionProcessActiveStandBy(t *testing.T) { return nil }) wg.Wait() - s1.LivenessCheck(ctx1, func() { - log.Debug("Session 1 livenessCheck callback") - flag = true - close(signal) - s1.cancelKeepAlive(true) - }) assert.False(t, s1.isStandby.Load().(bool)) // register session 2, will be standby @@ -700,21 +541,7 @@ func TestSessionProcessActiveStandBy(t *testing.T) { log.Debug("Stop session 1, session 2 will take over primary service") assert.False(t, flag) - s1.safeCloseLiveCh() - { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, _ = s1.etcdCli.Revoke(ctx, *s1.LeaseID) - } - select { - case <-signal: - log.Debug("receive s1 signal") - case <-time.After(10 * time.Second): - log.Debug("wait to fail Liveness Check timeout") - t.FailNow() - } - assert.True(t, flag) - log.Debug("session s1 stop") + s1.Stop() wg.Wait() log.Debug("session s2 wait done") @@ -811,8 +638,7 @@ func TestIntegrationMode(t *testing.T) { type SessionSuite struct { suite.Suite - tmpDir string - etcdServer *embed.Etcd + tmpDir string metaRoot string serverName string @@ -821,53 +647,19 @@ type SessionSuite struct { func (s *SessionSuite) SetupSuite() { paramtable.Init() - dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut") - s.Require().NoError(err) - s.tmpDir = dir - s.T().Log("using tmp dir:", dir) - - config := embed.NewConfig() - - config.Dir = os.TempDir() - config.LogLevel = "warn" - config.LogOutputs = []string{"default"} - u, err := url.Parse("http://localhost:0") - s.Require().NoError(err) - - config.ListenClientUrls = []url.URL{*u} - u, err = url.Parse("http://localhost:0") - s.Require().NoError(err) - config.ListenPeerUrls = []url.URL{*u} - - etcdServer, err := embed.StartEtcd(config) - s.Require().NoError(err) - s.etcdServer = etcdServer } func (s *SessionSuite) TearDownSuite() { - if s.etcdServer != nil { - s.etcdServer.Close() - } - if s.tmpDir != "" { - os.RemoveAll(s.tmpDir) - } } func (s *SessionSuite) SetupTest() { - client := v3client.New(s.etcdServer.Server) - s.client = client - + s.client, _ = kvfactory.GetEtcdAndPath() s.metaRoot = fmt.Sprintf("milvus-ut/session-%s/", funcutil.GenRandomStr()) } func (s *SessionSuite) TearDownTest() { _, err := s.client.Delete(context.Background(), s.metaRoot, clientv3.WithPrefix()) s.Require().NoError(err) - - if s.client != nil { - s.client.Close() - s.client = nil - } } func (s *SessionSuite) TestDisconnected() { @@ -925,61 +717,19 @@ func (s *SessionSuite) TestGoingStop() { } } -func (s *SessionSuite) TestRevoke() { - ctx := context.Background() - disconnected := NewSessionWithEtcd(ctx, s.metaRoot, s.client, WithResueNodeID(false)) - disconnected.Init("test", "disconnected", false, false) - disconnected.Register() - disconnected.SetDisconnected(true) - - sess := NewSessionWithEtcd(ctx, s.metaRoot, s.client, WithResueNodeID(false)) - sess.Init("test", "normal", false, false) - sess.Register() - - cases := []struct { - tag string - input *Session - preExist bool - success bool - }{ - {"not_inited", &Session{}, false, true}, - {"disconnected", disconnected, true, false}, - {"normal", sess, false, true}, - } - - for _, c := range cases { - s.Run(c.tag, func() { - c.input.Revoke(time.Second) - resp, err := s.client.Get(ctx, c.input.getCompleteKey()) - s.Require().NoError(err) - if !c.preExist || c.success { - s.Equal(0, len(resp.Kvs)) - } - if c.preExist && !c.success { - s.Equal(1, len(resp.Kvs)) - } - }) - } -} - func (s *SessionSuite) TestKeepAliveRetryActiveCancel() { ctx := context.Background() session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) session.Init("test", "normal", false, false) // Register - ch, err := session.registerService() + err := session.registerService() s.Require().NoError(err) - session.liveCh = make(chan struct{}) - session.startKeepAliveLoop(ch) - session.LivenessCheck(ctx, nil) - // active cancel, should not retry connect - session.cancelKeepAlive(true) + session.startKeepAliveLoop() + session.Stop() // wait workers exit session.wg.Wait() - // expected Disconnected = true, means session is closed - assert.Equal(s.T(), true, session.Disconnected()) } func (s *SessionSuite) TestKeepAliveRetryChannelClose() { @@ -988,15 +738,12 @@ func (s *SessionSuite) TestKeepAliveRetryChannelClose() { session.Init("test", "normal", false, false) // Register - _, err := session.registerService() + err := session.registerService() if err != nil { panic(err) } - session.liveCh = make(chan struct{}) closeChan := make(chan *clientv3.LeaseKeepAliveResponse) - sendChan := (<-chan *clientv3.LeaseKeepAliveResponse)(closeChan) - session.startKeepAliveLoop(sendChan) - session.LivenessCheck(ctx, nil) + session.startKeepAliveLoop() // close channel, should retry connect close(closeChan) @@ -1009,17 +756,6 @@ func (s *SessionSuite) TestKeepAliveRetryChannelClose() { assert.Equal(s.T(), false, session.Disconnected()) } -func (s *SessionSuite) TestSafeCloseLiveCh() { - ctx := context.Background() - session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) - session.Init("test", "normal", false, false) - session.liveCh = make(chan struct{}) - session.safeCloseLiveCh() - assert.NotPanics(s.T(), func() { - session.safeCloseLiveCh() - }) -} - func (s *SessionSuite) TestGetSessions() { os.Setenv("MILVUS_SERVER_LABEL_key1", "value1") os.Setenv("MILVUS_SERVER_LABEL_key2", "value2") @@ -1035,45 +771,75 @@ func (s *SessionSuite) TestGetSessions() { assert.Equal(s.T(), "value2", ret["key2"]) } +func (s *SessionSuite) TestSessionLifetime() { + ctx := context.Background() + session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session.Init("test", "normal", false, false) + session.Register() + + resp, err := s.client.Get(ctx, session.getCompleteKey()) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + str, err := json.Marshal(session.SessionRaw) + s.Require().NoError(err) + s.Equal(string(resp.Kvs[0].Value), string(str)) + + ttlResp, err := s.client.Lease.TimeToLive(ctx, *session.LeaseID) + s.Require().NoError(err) + s.Greater(ttlResp.TTL, int64(0)) + + session.GoingStop() + resp, err = s.client.Get(ctx, session.getCompleteKey()) + s.Require().True(session.SessionRaw.Stopping) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + str, err = json.Marshal(session.SessionRaw) + s.Require().NoError(err) + s.Equal(string(resp.Kvs[0].Value), string(str)) + + session.Stop() + session.wg.Wait() + + resp, err = s.client.Get(ctx, session.getCompleteKey()) + s.Require().NoError(err) + s.Equal(0, len(resp.Kvs)) + + ttlResp, err = s.client.Lease.TimeToLive(ctx, *session.LeaseID) + s.Require().NoError(err) + s.Equal(int64(-1), ttlResp.TTL) +} + func TestSessionSuite(t *testing.T) { suite.Run(t, new(SessionSuite)) } -func (s *SessionSuite) TestKeepAliveCancelWithoutStop() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) - session.Init("test", "normal", false, false) - _, err := session.registerService() - assert.NoError(s.T(), err) - - // Override liveCh and LeaseKeepAliveResponse channel for testing - session.liveCh = make(chan struct{}) - kaCh := make(chan *clientv3.LeaseKeepAliveResponse) - session.startKeepAliveLoop(kaCh) - - session.keepAliveMu.Lock() - cancelOld := session.keepAliveCancel - session.keepAliveCancel = func() { - // only cancel, not setting isStopped, to simulate not "stop" - } - session.keepAliveMu.Unlock() - if cancelOld != nil { - cancelOld() +func TestForceKill(t *testing.T) { + if os.Getenv("TEST_EXIT") == "1" { + testForceKill("testForceKill") + return } - // send a nil (simulate closed keepalive channel) - go func() { - kaCh <- nil - }() + cmd := exec.Command(os.Args[0], "-test.run=TestForceKill") /* #nosec G204 */ + cmd.Env = append(os.Environ(), "TEST_EXIT=1") - // Give time for retry logic to trigger - time.Sleep(200 * time.Millisecond) + err := cmd.Run() - // should not be disconnected, session could recover - assert.False(s.T(), session.Disconnected()) - - // Routine clean up - session.Stop() + // 子进程退出码 + if e, ok := err.(*exec.ExitError); ok { + if e.ExitCode() != 1 { + t.Fatalf("expected exit 1, got %d", e.ExitCode()) + } + } else { + t.Fatalf("unexpected error: %#v", err) + } +} + +func testForceKill(serverName string) { + etcdCli, _ := kvfactory.GetEtcdAndPath() + session := NewSessionWithEtcd(context.Background(), "test", etcdCli) + session.Init(serverName, "normal", false, false) + session.Register() + + // trigger a force kill + etcdCli.Revoke(context.Background(), *session.LeaseID) } diff --git a/pkg/log/zap_async_buffered_write_core.go b/pkg/log/zap_async_buffered_write_core.go index f725859850..114de4cb7d 100644 --- a/pkg/log/zap_async_buffered_write_core.go +++ b/pkg/log/zap_async_buffered_write_core.go @@ -90,7 +90,7 @@ func (s *asyncTextIOCore) With(fields []zapcore.Field) zapcore.Core { return &asyncTextIOCore{ LevelEnabler: s.LevelEnabler, notifier: s.notifier, - enc: s.enc.Clone(), + enc: enc.Clone(), bws: s.bws, pending: s.pending, writeDroppedTimeout: s.writeDroppedTimeout,