diff --git a/internal/coordinator/mix_coord.go b/internal/coordinator/mix_coord.go index 4b3472bd41..6cf9499fa5 100644 --- a/internal/coordinator/mix_coord.go +++ b/internal/coordinator/mix_coord.go @@ -75,9 +75,6 @@ type mixCoordImpl struct { factory dependency.Factory - enableActiveStandBy bool - activateFunc func() error - metricsRequest *metricsinfo.MetricsRequest metaKVCreator func() kv.MetaKv @@ -97,13 +94,12 @@ func NewMixCoordServer(c context.Context, factory dependency.Factory) (*mixCoord dataCoordServer := datacoord.CreateServer(c, factory) return &mixCoordImpl{ - ctx: ctx, - cancel: cancel, - rootcoordServer: rootCoordServer, - queryCoordServer: queryCoordServer, - datacoordServer: dataCoordServer, - enableActiveStandBy: Params.MixCoordCfg.EnableActiveStandby.GetAsBool(), - factory: factory, + ctx: ctx, + cancel: cancel, + rootcoordServer: rootCoordServer, + queryCoordServer: queryCoordServer, + datacoordServer: dataCoordServer, + factory: factory, }, nil } @@ -115,21 +111,17 @@ func (s *mixCoordImpl) Register() error { metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.MixCoordRole).Inc() log.Info("MixCoord Register Finished") } - if s.enableActiveStandBy { - go func() { - if err := s.session.ProcessActiveStandBy(s.activateFunc); err != nil { - if s.ctx.Err() == context.Canceled { - log.Info("standby process canceled due to server shutdown") - return - } - log.Error("failed to activate standby server", zap.Error(err)) - panic(err) + go func() { + if err := s.session.ProcessActiveStandBy(s.activateFunc); err != nil { + if s.ctx.Err() == context.Canceled { + log.Info("standby process canceled due to server shutdown") + return } - afterRegister() - }() - } else { + log.Error("failed to activate standby server", zap.Error(err)) + panic(err) + } afterRegister() - } + }() return nil } @@ -142,33 +134,25 @@ func (s *mixCoordImpl) Init() error { s.factory.Init(Params) s.initKVCreator() s.initStreamingCoord() - if s.enableActiveStandBy { - s.activateFunc = func() error { - log.Info("mixCoord switch from standby to active, activating") + s.UpdateStateCode(commonpb.StateCode_StandBy) + log.Info("MixCoord enter standby mode successfully") + return nil +} - var err error - s.initOnce.Do(func() { - if err = s.initInternal(); err != nil { - log.Error("mixCoord init failed", zap.Error(err)) - } - }) - if err != nil { - return err - } - log.Info("mixCoord startup success", zap.String("address", s.session.GetAddress())) - s.startAndUpdateHealthy() - return err +func (s *mixCoordImpl) activateFunc() error { + log.Info("mixCoord switch from standby to active, activating") + var err error + s.initOnce.Do(func() { + if err = s.initInternal(); err != nil { + log.Error("mixCoord init failed", zap.Error(err)) } - s.UpdateStateCode(commonpb.StateCode_StandBy) - log.Info("MixCoord enter standby mode successfully") - } else { - s.initOnce.Do(func() { - if initErr = s.initInternal(); initErr != nil { - log.Error("mixCoord init failed", zap.Error(initErr)) - } - }) + }) + if err != nil { + return err } - return initErr + log.Info("mixCoord startup success", zap.String("address", s.session.GetAddress())) + s.startAndUpdateHealthy() + return err } func (s *mixCoordImpl) initInternal() error { @@ -235,9 +219,6 @@ func (s *mixCoordImpl) initKVCreator() { } func (s *mixCoordImpl) Start() error { - if !s.enableActiveStandBy { - s.startAndUpdateHealthy() - } return nil } @@ -380,7 +361,7 @@ func (s *mixCoordImpl) initStreamingCoord() { func (s *mixCoordImpl) initSession() error { s.session = sessionutil.NewSession(s.ctx) s.session.Init(typeutil.MixCoordRole, s.address, true, true) - s.session.SetEnableActiveStandBy(s.enableActiveStandBy) + s.session.SetEnableActiveStandBy(true) s.rootcoordServer.SetSession(s.session) s.datacoordServer.SetSession(s.session) s.queryCoordServer.SetSession(s.session) @@ -388,9 +369,6 @@ func (s *mixCoordImpl) initSession() error { return nil } -func (s *mixCoordImpl) startHealthCheck() { -} - func (s *mixCoordImpl) SetAddress(address string) { s.address = address s.rootcoordServer.SetAddress(address) diff --git a/internal/coordinator/mix_coord_test.go b/internal/coordinator/mix_coord_test.go index bfc3722074..4bf203f096 100644 --- a/internal/coordinator/mix_coord_test.go +++ b/internal/coordinator/mix_coord_test.go @@ -103,61 +103,6 @@ func TestMixcoord_EnableActiveStandby(t *testing.T) { assert.NoError(t, err) } -// make sure the main functions work well when EnableActiveStandby=false -func TestMixcoord_DisableActiveStandby(t *testing.T) { - randVal := rand.Int() - paramtable.Init() - testutil.ResetEnvironment() - Params.Save("etcd.rootPath", fmt.Sprintf("/%d", randVal)) - // Need to reset global etcd to follow new path - kvfactory.CloseEtcdClient() - - paramtable.Get().Save(Params.MixCoordCfg.EnableActiveStandby.Key, "false") - paramtable.Get().Save(Params.CommonCfg.RootCoordTimeTick.Key, fmt.Sprintf("rootcoord-time-tick-%d", randVal)) - paramtable.Get().Save(Params.CommonCfg.RootCoordStatistics.Key, fmt.Sprintf("rootcoord-statistics-%d", randVal)) - paramtable.Get().Save(Params.CommonCfg.RootCoordDml.Key, fmt.Sprintf("rootcoord-dml-test-%d", randVal)) - - ctx := context.Background() - coreFactory := dependency.NewDefaultFactory(true) - etcdCli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - assert.NoError(t, err) - defer etcdCli.Close() - core, err := NewMixCoordServer(ctx, coreFactory) - core.SetEtcdClient(etcdCli) - assert.NoError(t, err) - core.SetTiKVClient(tikv.SetupLocalTxn()) - - err = core.Init() - assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Initializing, core.GetStateCode()) - err = core.Start() - assert.NoError(t, err) - core.session.TriggerKill = false - err = core.Register() - assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Healthy, core.GetStateCode()) - resp, err := core.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_DescribeCollection, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, - CollectionName: "unexist", - }) - assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - err = core.Stop() - assert.NoError(t, err) -} - func TestMixCoord_FlushAll(t *testing.T) { t.Run("success", func(t *testing.T) { mockey.PatchConvey("test flush all success", t, func() { diff --git a/internal/distributed/connection_manager.go b/internal/distributed/connection_manager.go deleted file mode 100644 index 05e5973c38..0000000000 --- a/internal/distributed/connection_manager.go +++ /dev/null @@ -1,455 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package distributed - -import ( - "context" - "os" - "sync" - "syscall" - "time" - - "github.com/cockroachdb/errors" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - "github.com/samber/lo" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials/insecure" - - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/querypb" - "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" - "github.com/milvus-io/milvus/pkg/v2/proto/workerpb" - "github.com/milvus-io/milvus/pkg/v2/tracer" - "github.com/milvus-io/milvus/pkg/v2/util/retry" - "github.com/milvus-io/milvus/pkg/v2/util/typeutil" -) - -// ConnectionManager handles connection to other components of the system -type ConnectionManager struct { - session *sessionutil.Session - - dependencies map[string]struct{} - - rootCoord rootcoordpb.RootCoordClient - rootCoordMu sync.RWMutex - queryCoord querypb.QueryCoordClient - queryCoordMu sync.RWMutex - dataCoord datapb.DataCoordClient - dataCoordMu sync.RWMutex - queryNodes map[int64]querypb.QueryNodeClient - queryNodesMu sync.RWMutex - dataNodes map[int64]datapb.DataNodeClient - dataNodesMu sync.RWMutex - indexNodes map[int64]workerpb.IndexNodeClient - indexNodesMu sync.RWMutex - - taskMu sync.RWMutex - buildTasks map[int64]*buildClientTask - notify chan int64 - - connMu sync.RWMutex - connections map[int64]*grpc.ClientConn - - closeCh chan struct{} -} - -// NewConnectionManager creates a new connection manager. -func NewConnectionManager(session *sessionutil.Session) *ConnectionManager { - return &ConnectionManager{ - session: session, - - dependencies: make(map[string]struct{}), - - queryNodes: make(map[int64]querypb.QueryNodeClient), - dataNodes: make(map[int64]datapb.DataNodeClient), - indexNodes: make(map[int64]workerpb.IndexNodeClient), - - buildTasks: make(map[int64]*buildClientTask), - notify: make(chan int64), - - connections: make(map[int64]*grpc.ClientConn), - } -} - -// AddDependency add a dependency by role name. -func (cm *ConnectionManager) AddDependency(roleName string) error { - if !cm.checkroleName(roleName) { - return errors.New("roleName is illegal") - } - - log := log.Ctx(context.TODO()) - _, ok := cm.dependencies[roleName] - if ok { - log.Warn("Dependency is already added", zap.String("roleName", roleName)) - return nil - } - cm.dependencies[roleName] = struct{}{} - - msess, rev, err := cm.session.GetSessions(context.TODO(), roleName) - if err != nil { - log.Debug("ClientManager GetSessions failed", zap.String("roleName", roleName)) - return err - } - - if len(msess) == 0 { - log.Debug("No nodes are currently alive", zap.String("roleName", roleName)) - } else { - for _, value := range msess { - cm.buildConnections(value) - } - } - - watcher := cm.session.WatchServices(roleName, rev, nil) - go cm.processEvent(watcher.EventChannel()) - - return nil -} - -func (cm *ConnectionManager) Start() { - go cm.receiveFinishTask() -} - -func (cm *ConnectionManager) GetRootCoordClient() (rootcoordpb.RootCoordClient, bool) { - cm.rootCoordMu.RLock() - defer cm.rootCoordMu.RUnlock() - _, ok := cm.dependencies[typeutil.RootCoordRole] - if !ok { - log.Ctx(context.TODO()).Error("RootCoord dependency has not been added yet") - return nil, false - } - - return cm.rootCoord, true -} - -func (cm *ConnectionManager) GetQueryCoordClient() (querypb.QueryCoordClient, bool) { - cm.queryCoordMu.RLock() - defer cm.queryCoordMu.RUnlock() - _, ok := cm.dependencies[typeutil.QueryCoordRole] - if !ok { - log.Ctx(context.TODO()).Error("QueryCoord dependency has not been added yet") - return nil, false - } - - return cm.queryCoord, true -} - -func (cm *ConnectionManager) GetDataCoordClient() (datapb.DataCoordClient, bool) { - cm.dataCoordMu.RLock() - defer cm.dataCoordMu.RUnlock() - _, ok := cm.dependencies[typeutil.DataCoordRole] - if !ok { - log.Ctx(context.TODO()).Error("DataCoord dependency has not been added yet") - return nil, false - } - - return cm.dataCoord, true -} - -func (cm *ConnectionManager) GetQueryNodeClients() ([]lo.Tuple2[int64, querypb.QueryNodeClient], bool) { - cm.queryNodesMu.RLock() - defer cm.queryNodesMu.RUnlock() - _, ok := cm.dependencies[typeutil.QueryNodeRole] - if !ok { - log.Ctx(context.TODO()).Error("QueryNode dependency has not been added yet") - return nil, false - } - - nodes := lo.MapToSlice(cm.queryNodes, func(id int64, client querypb.QueryNodeClient) lo.Tuple2[int64, querypb.QueryNodeClient] { - return lo.Tuple2[int64, querypb.QueryNodeClient]{A: id, B: client} - }) - - return nodes, true -} - -func (cm *ConnectionManager) GetDataNodeClients() ([]lo.Tuple2[int64, datapb.DataNodeClient], bool) { - cm.dataNodesMu.RLock() - defer cm.dataNodesMu.RUnlock() - _, ok := cm.dependencies[typeutil.DataNodeRole] - if !ok { - log.Ctx(context.TODO()).Error("DataNode dependency has not been added yet") - return nil, false - } - - return lo.MapToSlice(cm.dataNodes, func(id int64, client datapb.DataNodeClient) lo.Tuple2[int64, datapb.DataNodeClient] { - return lo.Tuple2[int64, datapb.DataNodeClient]{A: id, B: client} - }), true -} - -func (cm *ConnectionManager) GetIndexNodeClients() ([]lo.Tuple2[int64, workerpb.IndexNodeClient], bool) { - cm.indexNodesMu.RLock() - defer cm.indexNodesMu.RUnlock() - _, ok := cm.dependencies[typeutil.IndexNodeRole] - if !ok { - log.Ctx(context.TODO()).Error("IndexNode dependency has not been added yet") - return nil, false - } - - return lo.MapToSlice(cm.indexNodes, func(id int64, client workerpb.IndexNodeClient) lo.Tuple2[int64, workerpb.IndexNodeClient] { - return lo.Tuple2[int64, workerpb.IndexNodeClient]{A: id, B: client} - }), true -} - -func (cm *ConnectionManager) Stop() { - for _, task := range cm.buildTasks { - task.Stop() - } - close(cm.closeCh) - for _, conn := range cm.connections { - conn.Close() - } -} - -// fix datarace in unittest -// startWatchService will only be invoked at start procedure -// otherwise, remove the annotation and add atomic protection -// -//go:norace -func (cm *ConnectionManager) processEvent(channel <-chan *sessionutil.SessionEvent) { - for { - select { - case _, ok := <-cm.closeCh: - if !ok { - return - } - case ev, ok := <-channel: - if !ok { - log.Ctx(context.TODO()).Error("watch service channel closed", zap.Int64("serverID", cm.session.ServerID)) - go cm.Stop() - if cm.session.TriggerKill { - if p, err := os.FindProcess(os.Getpid()); err == nil { - p.Signal(syscall.SIGINT) - } - } - return - } - switch ev.EventType { - case sessionutil.SessionAddEvent: - log.Ctx(context.TODO()).Debug("ConnectionManager", zap.Any("add event", ev.Session)) - cm.buildConnections(ev.Session) - case sessionutil.SessionDelEvent: - cm.removeTask(ev.Session.ServerID) - cm.removeConnection(ev.Session.ServerID) - } - } - } -} - -func (cm *ConnectionManager) receiveFinishTask() { - log := log.Ctx(context.TODO()) - for { - select { - case _, ok := <-cm.closeCh: - if !ok { - return - } - case serverID := <-cm.notify: - cm.taskMu.Lock() - task, ok := cm.buildTasks[serverID] - log.Debug("ConnectionManager", zap.Int64("receive finish", serverID)) - if ok { - log.Debug("ConnectionManager", zap.Int64("get task ok", serverID)) - log.Debug("ConnectionManager", zap.Any("task state", task.state)) - if task.state == buildClientSuccess { - log.Debug("ConnectionManager", zap.Int64("build success", serverID)) - cm.addConnection(task.sess.ServerID, task.result) - cm.buildClients(task.sess, task.result) - } - delete(cm.buildTasks, serverID) - } - cm.taskMu.Unlock() - } - } -} - -func (cm *ConnectionManager) buildClients(session *sessionutil.Session, connection *grpc.ClientConn) { - switch session.ServerName { - case typeutil.RootCoordRole: - cm.rootCoordMu.Lock() - defer cm.rootCoordMu.Unlock() - cm.rootCoord = rootcoordpb.NewRootCoordClient(connection) - case typeutil.DataCoordRole: - cm.dataCoordMu.Lock() - defer cm.dataCoordMu.Unlock() - cm.dataCoord = datapb.NewDataCoordClient(connection) - case typeutil.QueryCoordRole: - cm.queryCoordMu.Lock() - defer cm.queryCoordMu.Unlock() - cm.queryCoord = querypb.NewQueryCoordClient(connection) - case typeutil.QueryNodeRole: - cm.queryNodesMu.Lock() - defer cm.queryNodesMu.Unlock() - cm.queryNodes[session.ServerID] = querypb.NewQueryNodeClient(connection) - case typeutil.DataNodeRole: - cm.dataNodesMu.Lock() - defer cm.dataNodesMu.Unlock() - cm.dataNodes[session.ServerID] = datapb.NewDataNodeClient(connection) - case typeutil.IndexNodeRole: - cm.indexNodesMu.Lock() - defer cm.indexNodesMu.Unlock() - cm.indexNodes[session.ServerID] = workerpb.NewIndexNodeClient(connection) - } -} - -func (cm *ConnectionManager) buildConnections(session *sessionutil.Session) { - task := newBuildClientTask(session, cm.notify) - cm.addTask(session.ServerID, task) - task.Run() -} - -func (cm *ConnectionManager) addConnection(id int64, conn *grpc.ClientConn) { - cm.connMu.Lock() - cm.connections[id] = conn - cm.connMu.Unlock() -} - -func (cm *ConnectionManager) removeConnection(id int64) { - cm.connMu.Lock() - conn, ok := cm.connections[id] - if ok { - conn.Close() - delete(cm.connections, id) - } - cm.connMu.Unlock() -} - -func (cm *ConnectionManager) addTask(id int64, task *buildClientTask) { - cm.taskMu.Lock() - cm.buildTasks[id] = task - cm.taskMu.Unlock() -} - -func (cm *ConnectionManager) removeTask(id int64) { - cm.taskMu.Lock() - task, ok := cm.buildTasks[id] - if ok { - task.Stop() - delete(cm.buildTasks, id) - } - cm.taskMu.Unlock() -} - -type buildConnectionstate int - -const ( - buildConnectionstart buildConnectionstate = iota - buildClientRunning - buildClientSuccess - buildClientFailed -) - -type buildClientTask struct { - ctx context.Context - cancel context.CancelFunc - - sess *sessionutil.Session - state buildConnectionstate - retryOptions []retry.Option - - result *grpc.ClientConn - notify chan int64 -} - -func newBuildClientTask(session *sessionutil.Session, notify chan int64, retryOptions ...retry.Option) *buildClientTask { - ctx, cancel := context.WithCancel(context.Background()) - return &buildClientTask{ - ctx: ctx, - cancel: cancel, - - sess: session, - retryOptions: retryOptions, - - notify: notify, - } -} - -func (bct *buildClientTask) Run() { - bct.state = buildClientRunning - go func() { - defer bct.finish() - connectGrpcFunc := func() error { - opts := tracer.GetInterceptorOpts() - log.Ctx(bct.ctx).Debug("Grpc connect", zap.String("Address", bct.sess.Address)) - ctx, cancel := context.WithTimeout(bct.ctx, 30*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, bct.sess.Address, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - grpc.WithDisableRetry(), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - otelgrpc.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - otelgrpc.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - bct.result = conn - bct.state = buildClientSuccess - return nil - } - - err := retry.Do(bct.ctx, connectGrpcFunc, bct.retryOptions...) - log.Ctx(bct.ctx).Debug("ConnectionManager", zap.Int64("build connection finish", bct.sess.ServerID)) - if err != nil { - log.Ctx(bct.ctx).Debug("BuildClientTask try connect failed", - zap.String("roleName", bct.sess.ServerName), zap.Error(err)) - bct.state = buildClientFailed - return - } - }() -} - -func (bct *buildClientTask) Stop() { - bct.cancel() -} - -func (bct *buildClientTask) finish() { - log.Ctx(bct.ctx).Debug("ConnectionManager", zap.Int64("notify connection finish", bct.sess.ServerID)) - bct.notify <- bct.sess.ServerID -} - -var roles = map[string]struct{}{ - typeutil.RootCoordRole: {}, - typeutil.QueryCoordRole: {}, - typeutil.DataCoordRole: {}, - typeutil.QueryNodeRole: {}, - typeutil.DataNodeRole: {}, - typeutil.IndexNodeRole: {}, -} - -func (cm *ConnectionManager) checkroleName(roleName string) bool { - _, ok := roles[roleName] - return ok -} diff --git a/internal/distributed/connection_manager_test.go b/internal/distributed/connection_manager_test.go deleted file mode 100644 index 24e8572dc5..0000000000 --- a/internal/distributed/connection_manager_test.go +++ /dev/null @@ -1,296 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package distributed - -import ( - "context" - "fmt" - "net" - "os" - "os/signal" - "strings" - "syscall" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.uber.org/zap" - "google.golang.org/grpc" - - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/querypb" - "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" - "github.com/milvus-io/milvus/pkg/v2/proto/workerpb" - "github.com/milvus-io/milvus/pkg/v2/util/etcd" - "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/typeutil" -) - -func TestMain(t *testing.M) { - // init embed etcd - embedetcdServer, tempDir, err := etcd.StartTestEmbedEtcdServer() - if err != nil { - log.Fatal("failed to start embed etcd server for unittest", zap.Error(err)) - } - - defer os.RemoveAll(tempDir) - defer embedetcdServer.Server.Stop() - - addrs := etcd.GetEmbedEtcdEndpoints(embedetcdServer) - - paramtable.Init() - paramtable.Get().Save(paramtable.Get().EtcdCfg.Endpoints.Key, strings.Join(addrs, ",")) - - os.Exit(t.Run()) -} - -func TestConnectionManager(t *testing.T) { - ctx := context.Background() - - testPath := fmt.Sprintf("TestConnectionManager-%d", time.Now().Unix()) - paramtable.Get().Save(paramtable.Get().EtcdCfg.RootPath.Key, testPath) - - session := initSession(ctx) - cm := NewConnectionManager(session) - cm.AddDependency(typeutil.RootCoordRole) - cm.AddDependency(typeutil.QueryCoordRole) - cm.AddDependency(typeutil.DataCoordRole) - cm.AddDependency(typeutil.QueryNodeRole) - cm.AddDependency(typeutil.DataNodeRole) - cm.AddDependency(typeutil.IndexNodeRole) - cm.Start() - - t.Run("rootCoord", func(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:") - assert.NoError(t, err) - defer lis.Close() - rootCoord := &testRootCoord{} - grpcServer := grpc.NewServer() - defer grpcServer.Stop() - rootcoordpb.RegisterRootCoordServer(grpcServer, rootCoord) - go grpcServer.Serve(lis) - session.Init(typeutil.RootCoordRole, lis.Addr().String(), true, false) - session.Register() - assert.Eventually(t, func() bool { - rootCoord, ok := cm.GetRootCoordClient() - return rootCoord != nil && ok - }, 10*time.Second, 100*time.Millisecond) - }) - - t.Run("queryCoord", func(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:") - assert.NoError(t, err) - defer lis.Close() - queryCoord := &testQueryCoord{} - grpcServer := grpc.NewServer() - defer grpcServer.Stop() - querypb.RegisterQueryCoordServer(grpcServer, queryCoord) - go grpcServer.Serve(lis) - session.Init(typeutil.QueryCoordRole, lis.Addr().String(), true, false) - session.Register() - assert.Eventually(t, func() bool { - queryCoord, ok := cm.GetQueryCoordClient() - return queryCoord != nil && ok - }, 10*time.Second, 100*time.Millisecond) - }) - - t.Run("dataCoord", func(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:") - assert.NoError(t, err) - defer lis.Close() - dataCoord := &testDataCoord{} - grpcServer := grpc.NewServer() - defer grpcServer.Stop() - datapb.RegisterDataCoordServer(grpcServer, dataCoord) - go grpcServer.Serve(lis) - session.Init(typeutil.DataCoordRole, lis.Addr().String(), true, false) - session.Register() - assert.Eventually(t, func() bool { - dataCoord, ok := cm.GetDataCoordClient() - return dataCoord != nil && ok - }, 10*time.Second, 100*time.Millisecond) - }) - - t.Run("queryNode", func(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:") - assert.NoError(t, err) - defer lis.Close() - queryNode := &testQueryNode{} - grpcServer := grpc.NewServer() - defer grpcServer.Stop() - querypb.RegisterQueryNodeServer(grpcServer, queryNode) - go grpcServer.Serve(lis) - session.Init(typeutil.QueryNodeRole, lis.Addr().String(), true, false) - session.Register() - assert.Eventually(t, func() bool { - queryNodes, ok := cm.GetQueryNodeClients() - return len(queryNodes) == 1 && ok - }, 10*time.Second, 100*time.Millisecond) - }) - - t.Run("dataNode", func(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:") - assert.NoError(t, err) - defer lis.Close() - dataNode := &testDataNode{} - grpcServer := grpc.NewServer() - defer grpcServer.Stop() - datapb.RegisterDataNodeServer(grpcServer, dataNode) - go grpcServer.Serve(lis) - session.Init(typeutil.DataNodeRole, lis.Addr().String(), true, false) - session.Register() - assert.Eventually(t, func() bool { - dataNodes, ok := cm.GetDataNodeClients() - return len(dataNodes) == 1 && ok - }, 10*time.Second, 100*time.Millisecond) - }) - - t.Run("indexNode", func(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:") - assert.NoError(t, err) - defer lis.Close() - indexNode := &testIndexNode{} - grpcServer := grpc.NewServer() - defer grpcServer.Stop() - workerpb.RegisterIndexNodeServer(grpcServer, indexNode) - go grpcServer.Serve(lis) - session.Init(typeutil.IndexNodeRole, lis.Addr().String(), true, false) - session.Register() - assert.Eventually(t, func() bool { - indexNodes, ok := cm.GetIndexNodeClients() - return len(indexNodes) == 1 && ok - }, 10*time.Second, 100*time.Millisecond) - }) -} - -func TestConnectionManager_processEvent(t *testing.T) { - t.Run("close closeCh", func(t *testing.T) { - cm := &ConnectionManager{ - closeCh: make(chan struct{}), - } - - ech := make(chan *sessionutil.SessionEvent) - flag := false - signal := make(chan struct{}, 1) - go func() { - assert.Panics(t, func() { - cm.processEvent(ech) - }) - - flag = true - signal <- struct{}{} - }() - - close(ech) - <-signal - assert.True(t, flag) - - ech = make(chan *sessionutil.SessionEvent) - flag = false - go func() { - cm.processEvent(ech) - flag = true - signal <- struct{}{} - }() - close(cm.closeCh) - <-signal - assert.True(t, flag) - }) - - t.Run("close watch chan", func(t *testing.T) { - sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT) - defer signal.Reset(syscall.SIGINT) - sigQuit := make(chan struct{}, 1) - - cm := &ConnectionManager{ - closeCh: make(chan struct{}), - session: &sessionutil.Session{ - SessionRaw: sessionutil.SessionRaw{ - ServerID: 1, - TriggerKill: true, - }, - }, - } - - ech := make(chan *sessionutil.SessionEvent) - - go func() { - <-sc - sigQuit <- struct{}{} - }() - - go func() { - cm.processEvent(ech) - }() - - close(ech) - - <-sigQuit - }) -} - -type testRootCoord struct { - rootcoordpb.RootCoordServer -} - -type testQueryCoord struct { - querypb.QueryCoordServer -} -type testDataCoord struct { - datapb.DataCoordServer -} - -type testQueryNode struct { - querypb.QueryNodeServer -} - -type testDataNode struct { - datapb.DataNodeServer -} - -type testIndexNode struct { - workerpb.IndexNodeServer -} - -func initSession(ctx context.Context) *sessionutil.Session { - baseTable := paramtable.GetBaseTable() - rootPath, err := baseTable.Load("etcd.rootPath") - if err != nil { - panic(err) - } - subPath, err := baseTable.Load("etcd.metaSubPath") - if err != nil { - panic(err) - } - metaRootPath := rootPath + "/" + subPath - - endpoints := baseTable.GetWithDefault("etcd.endpoints", paramtable.DefaultEtcdEndpoints) - etcdEndpoints := strings.Split(endpoints, ",") - - log.Ctx(context.TODO()).Debug("metaRootPath", zap.Any("metaRootPath", metaRootPath)) - log.Ctx(context.TODO()).Debug("etcdPoints", zap.Any("etcdPoints", etcdEndpoints)) - - etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) - if err != nil { - panic(err) - } - session := sessionutil.NewSessionWithEtcd(ctx, metaRootPath, etcdCli) - return session -} diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 48970f2999..f1f18486ab 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -57,8 +57,17 @@ const ( LabelStandalone = "STANDALONE" MilvusNodeIDForTesting = "MILVUS_NODE_ID_FOR_TESTING" exitCodeSessionLeaseExpired = 1 + + serverVersionKey = "version" ) +var errSessionVersionCheckFailure = errors.New("session version check failure") + +// isNotSessionVersionCheckFailure checks if the error is not a session version check failure. +func isNotSessionVersionCheckFailure(err error) bool { + return !errors.Is(err, errSessionVersionCheckFailure) +} + // EnableEmbededQueryNodeLabel set server labels for embedded query node. func EnableEmbededQueryNodeLabel() { os.Setenv(SupportedLabelPrefix+LabelStreamingNodeEmbeddedQueryNode, "1") @@ -169,6 +178,7 @@ type Session struct { isStandby atomic.Value enableActiveStandBy bool activeKey string + versionKey string sessionTTL int64 sessionRetryTimes int64 @@ -300,6 +310,7 @@ func (s *Session) Init(serverName, address string, exclusive bool, triggerKill b } s.ServerID = serverID s.ServerLabels = GetServerLabelsFromEnv(serverName) + s.versionKey = path.Join(s.metaRoot, DefaultServiceRoot, serverVersionKey) s.SetLogger(log.With( log.FieldComponent("service-registration"), @@ -325,6 +336,35 @@ func (s *Session) Register() { s.startKeepAliveLoop() } +// isCoordinator checks if the session needs to check the version. +func (s *Session) isCoordinator() bool { + return s.ServerName == typeutil.MixCoordRole || + s.ServerName == typeutil.QueryCoordRole || + s.ServerName == typeutil.DataCoordRole || + s.ServerName == typeutil.RootCoordRole || + s.ServerName == typeutil.IndexCoordRole +} + +// checkVersion checks the version of the session and returns the error if the version is not found or expired. +func (s *Session) checkVersionForCoordinator() (*mvccpb.KeyValue, error) { + resp, err := s.etcdCli.Get(s.ctx, s.versionKey) + if err != nil { + return nil, err + } + if resp.Count <= 0 { + // no version key found. + return nil, nil + } + version, err := semver.Parse(string(resp.Kvs[0].Value)) + if err != nil { + return nil, err + } + if common.Version.Major < version.Major || (common.Version.Major == version.Major && common.Version.Minor < version.Minor) { + return nil, errors.Wrapf(errSessionVersionCheckFailure, "current version(%s), session version(%s)", common.Version.String(), version.String()) + } + return resp.Kvs[0], nil +} + var serverIDMu sync.Mutex func (s *Session) getServerID() (int64, error) { @@ -462,54 +502,69 @@ func (s *Session) registerService() error { return err } - txnResp, err := s.etcdCli.Txn(s.ctx).If( - clientv3.Compare( - clientv3.Version(completeKey), - "=", - 0)). - Then(clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() + compareOps := []clientv3.Cmp{ + clientv3.Compare(clientv3.Version(completeKey), "=", 0), + } + ops := []clientv3.Op{ + clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID)), + } + + // if enable active-standby, we don't need to check the version now, + // only check the version when the standby is activated. + if s.isCoordinator() && !s.enableActiveStandBy { + if ops, compareOps, err = s.getOpsForCoordinator(ops, compareOps, sessionJSON); err != nil { + return err + } + } + + txnResp, err := s.etcdCli.Txn(s.ctx).If(compareOps...).Then(ops...).Commit() if err != nil { s.Logger().Warn("register on etcd error, check the availability of etcd", zap.Error(err)) return err } if txnResp != nil && !txnResp.Succeeded { - s.handleRestart(completeKey) return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ServerName) } s.Logger().Info("put session key into etcd, service registered successfully", zap.String("key", completeKey), zap.String("value", string(sessionJSON))) return nil } - return retry.Do(s.ctx, registerFn, retry.Attempts(uint(s.sessionRetryTimes))) + return retry.Do(s.ctx, registerFn, retry.Attempts(uint(s.sessionRetryTimes)), retry.RetryErr(isNotSessionVersionCheckFailure)) } -// Handle restart is fast path to handle node restart. -// This should be only a fast path for coordinator -// If we find previous session have same address as current , simply purge the old one so the recovery can be much faster -func (s *Session) handleRestart(key string) { - resp, err := s.etcdCli.Get(s.ctx, key) - log := log.With(zap.String("key", key)) +// getOpsForCoordinator gets the ops and compare ops for coordinator. +func (s *Session) getOpsForCoordinator(ops []clientv3.Op, compareOps []clientv3.Cmp, sessionJSON []byte) ([]clientv3.Op, []clientv3.Cmp, error) { + previousVersion, err := s.checkVersionForCoordinator() if err != nil { - log.Warn("failed to read old session from etcd, ignore", zap.Error(err)) - return + return nil, nil, err } - for _, kv := range resp.Kvs { - session := &Session{} - err = json.Unmarshal(kv.Value, session) + expectedVersion := int64(0) + if previousVersion != nil { + expectedVersion = previousVersion.Version + } + legacyCoord := []string{ + typeutil.QueryCoordRole, + typeutil.DataCoordRole, + typeutil.RootCoordRole, + } + for _, role := range legacyCoord { + key := path.Join(s.metaRoot, DefaultServiceRoot, role) + var newSession SessionRaw + if err := json.Unmarshal(sessionJSON, &newSession); err != nil { + return nil, nil, err + } + newSession.ServerName = role + newSessionJSON, err := json.Marshal(newSession) if err != nil { - log.Warn("failed to unmarshal old session from etcd, ignore", zap.Error(err)) - return - } - - if session.Address == s.Address && session.ServerID < s.ServerID { - log.Warn("find old session is same as current node, assume it as restart, purge old session", zap.String("key", key), - zap.String("address", session.Address)) - _, err := s.etcdCli.Delete(s.ctx, key) - if err != nil { - log.Warn("failed to unmarshal old session from etcd, ignore", zap.Error(err)) - return - } + return nil, nil, err } + ops = append(ops, clientv3.OpPut(key, string(newSessionJSON), clientv3.WithLease(*s.LeaseID))) + compareOps = append(compareOps, clientv3.Compare(clientv3.Version(key), "=", 0)) } + // promise the legacy coordinator version not available. + compareOps = append(compareOps, clientv3.Compare(clientv3.Version(s.versionKey), "=", expectedVersion)) + // setup the version key if is a coordinator. + ops = append(ops, clientv3.OpPut(s.versionKey, common.Version.String())) + return ops, compareOps, nil } // processKeepAliveResponse processes the response of etcd keepAlive interface @@ -980,12 +1035,21 @@ func (s *Session) ProcessActiveStandBy(activateFunc func() error) error { log.Error("json marshal error", zap.Error(err)) return false, -1, err } - txnResp, err := s.etcdCli.Txn(s.ctx).If( - clientv3.Compare( - clientv3.Version(s.activeKey), - "=", - 0)). - Then(clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID))).Commit() + + compareOps := []clientv3.Cmp{ + clientv3.Compare(clientv3.Version(s.activeKey), "=", 0), + } + ops := []clientv3.Op{ + clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID)), + } + + if s.isCoordinator() { + if ops, compareOps, err = s.getOpsForCoordinator(ops, compareOps, sessionJSON); err != nil { + return false, -1, err + } + } + + txnResp, err := s.etcdCli.Txn(s.ctx).If(compareOps...).Then(ops...).Commit() if err != nil { log.Error("register active key to etcd failed", zap.Error(err)) return false, -1, err diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index 27a0856548..1965559f74 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -771,6 +771,67 @@ func (s *SessionSuite) TestGetSessions() { assert.Equal(s.T(), "value2", ret["key2"]) } +func (s *SessionSuite) TestVersionKey() { + ctx := context.Background() + session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session.Init(typeutil.MixCoordRole, "normal", false, false) + + session.Register() + + resp, err := s.client.Get(ctx, session.versionKey) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + s.Equal(common.Version.String(), string(resp.Kvs[0].Value)) + + common.Version = semver.MustParse("2.5.6") + + s.Panics(func() { + session2 := NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session2.Init(typeutil.MixCoordRole, "normal", false, false) + session2.Register() + + resp, err = s.client.Get(ctx, session2.versionKey) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + s.Equal(common.Version.String(), string(resp.Kvs[0].Value)) + }) + + session.Stop() + + common.Version = semver.MustParse("2.6.4") + session = NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session.Init(typeutil.MixCoordRole, "normal", false, false) + session.Register() + + resp, err = s.client.Get(ctx, session.versionKey) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + s.Equal(common.Version.String(), string(resp.Kvs[0].Value)) + + session.Stop() + + common.Version = semver.MustParse("2.6.7") + session = NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session.Init(typeutil.MixCoordRole, "normal", false, false) + session.Register() + + resp, err = s.client.Get(ctx, session.versionKey) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + s.Equal(common.Version.String(), string(resp.Kvs[0].Value)) + session.Stop() + + common.Version = semver.MustParse("3.0.0") + session = NewSessionWithEtcd(ctx, s.metaRoot, s.client) + session.Init(typeutil.MixCoordRole, "normal", false, false) + session.Register() + + resp, err = s.client.Get(ctx, session.versionKey) + s.Require().NoError(err) + s.Equal(1, len(resp.Kvs)) + s.Equal(common.Version.String(), string(resp.Kvs[0].Value)) +} + func (s *SessionSuite) TestSessionLifetime() { ctx := context.Background() session := NewSessionWithEtcd(ctx, s.metaRoot, s.client) diff --git a/tests/go_client/testcases/insert_test.go b/tests/go_client/testcases/insert_test.go index 0afb093437..fe2ae02c34 100644 --- a/tests/go_client/testcases/insert_test.go +++ b/tests/go_client/testcases/insert_test.go @@ -911,9 +911,12 @@ func TestFlushRate(t *testing.T) { } wg.Wait() + errCnt := 0 for _, err := range errs { if err != nil { common.CheckErr(t, err, false, "request is rejected by grpc RateLimiter middleware, please retry later: rate limit exceeded") + errCnt++ } } + require.NotZero(t, errCnt) }