diff --git a/internal/datacoord/cluster.go b/internal/datacoord/cluster.go index 554f197992..8c632feadc 100644 --- a/internal/datacoord/cluster.go +++ b/internal/datacoord/cluster.go @@ -71,6 +71,22 @@ func NewClusterImpl(sessionManager session.DataNodeManager, channelManager Chann // Startup inits the cluster with the given data nodes. func (c *ClusterImpl) Startup(ctx context.Context, nodes []*session.NodeInfo) error { + oldNodes := c.sessionManager.GetSessions() + newNodesMap := lo.SliceToMap(nodes, func(info *session.NodeInfo) (int64, *session.NodeInfo) { + return info.NodeID, info + }) + + // clean offline nodes + for _, node := range oldNodes { + if _, ok := newNodesMap[node.NodeID()]; !ok { + c.sessionManager.DeleteSession(&session.NodeInfo{ + NodeID: node.NodeID(), + Address: node.Address(), + }) + } + } + + // add new nodes for _, node := range nodes { c.sessionManager.AddSession(node) } diff --git a/internal/datacoord/cluster_test.go b/internal/datacoord/cluster_test.go index a91e5d21c8..8b04c41308 100644 --- a/internal/datacoord/cluster_test.go +++ b/internal/datacoord/cluster_test.go @@ -20,8 +20,9 @@ import ( "context" "testing" + "github.com/bytedance/mockey" "github.com/cockroachdb/errors" - "github.com/samber/lo" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -62,23 +63,67 @@ func (suite *ClusterSuite) SetupTest() { func (suite *ClusterSuite) TearDownTest() {} -func (suite *ClusterSuite) TestStartup() { +func TestClusterImpl_Startup_NewNodes(t *testing.T) { nodes := []*session.NodeInfo{ {NodeID: 1, Address: "addr1"}, {NodeID: 2, Address: "addr2"}, {NodeID: 3, Address: "addr3"}, {NodeID: 4, Address: "addr4"}, } - suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Times(len(nodes)) - suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(ctx context.Context, legacys []int64, nodeIDs []int64) error { - suite.ElementsMatch(lo.Map(nodes, func(info *session.NodeInfo, _ int) int64 { return info.NodeID }), nodeIDs) - return nil - }).Once() - cluster := NewClusterImpl(suite.mockSession, suite.mockChManager) + // Mock the static functions called by ClusterImpl.Startup + mockGetSessions := mockey.Mock((*session.DataNodeManagerImpl).GetSessions).Return([]*session.Session{}).Build() + defer mockGetSessions.UnPatch() + + newAddedNodes := make([]int64, 0, len(nodes)) + mockAddSession := mockey.Mock((*session.DataNodeManagerImpl).AddSession).To(func(node *session.NodeInfo) { + newAddedNodes = append(newAddedNodes, node.NodeID) + }).Build() + defer mockAddSession.UnPatch() + + mockChannelStartup := mockey.Mock((*ChannelManagerImpl).Startup).Return(nil).Build() + defer mockChannelStartup.UnPatch() + + cluster := NewClusterImpl(&session.DataNodeManagerImpl{}, &ChannelManagerImpl{}) + err := cluster.Startup(context.Background(), nodes) - suite.NoError(err) + assert.NoError(t, err) + assert.ElementsMatch(t, newAddedNodes, []int64{1, 2, 3, 4}) +} + +func TestClusterImpl_Startup_RemoveOldNodes(t *testing.T) { + // Create real session objects for testing + existingSession1 := session.NewSession(&session.NodeInfo{NodeID: 1, Address: "old-addr1"}, nil) + existingSession2 := session.NewSession(&session.NodeInfo{NodeID: 2, Address: "addr2"}, nil) + existingSessions := []*session.Session{existingSession1, existingSession2} + + // New nodes to be added + newNodes := []*session.NodeInfo{ + {NodeID: 2, Address: "addr2"}, // existing node (should not be removed) + {NodeID: 3, Address: "addr3"}, // new node + } + + // Mock expectations + mockGetSessions := mockey.Mock((*session.DataNodeManagerImpl).GetSessions).Return(existingSessions).Build() + defer mockGetSessions.UnPatch() + + removeNodes := make([]int64, 0, len(existingSessions)) + mockDeleteSession := mockey.Mock((*session.DataNodeManagerImpl).DeleteSession).To(func(node *session.NodeInfo) { + removeNodes = append(removeNodes, node.NodeID) + }).Build() + defer mockDeleteSession.UnPatch() + + mockAddSession := mockey.Mock((*session.DataNodeManagerImpl).AddSession).Return().Build() + defer mockAddSession.UnPatch() + + mockChannelStartup := mockey.Mock((*ChannelManagerImpl).Startup).Return(nil).Build() + defer mockChannelStartup.UnPatch() + + cluster := NewClusterImpl(&session.DataNodeManagerImpl{}, &ChannelManagerImpl{}) + + err := cluster.Startup(context.Background(), newNodes) + assert.NoError(t, err) + assert.ElementsMatch(t, removeNodes, []int64{1}) } func (suite *ClusterSuite) TestRegister() { diff --git a/internal/datacoord/index_engine_version_manager.go b/internal/datacoord/index_engine_version_manager.go index bfcd378574..57bb89471c 100644 --- a/internal/datacoord/index_engine_version_manager.go +++ b/internal/datacoord/index_engine_version_manager.go @@ -3,6 +3,7 @@ package datacoord import ( "math" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -43,6 +44,18 @@ func (m *versionManagerImpl) Startup(sessions map[string]*sessionutil.Session) { m.mu.Lock() defer m.mu.Unlock() + sessionMap := lo.MapKeys(sessions, func(session *sessionutil.Session, _ string) int64 { + return session.ServerID + }) + + // clean offline nodes + for sessionID := range m.versions { + if _, ok := sessionMap[sessionID]; !ok { + m.removeNodeByID(sessionID) + } + } + + // deal with new online nodes for _, session := range sessions { m.addOrUpdate(session) } @@ -59,9 +72,13 @@ func (m *versionManagerImpl) RemoveNode(session *sessionutil.Session) { m.mu.Lock() defer m.mu.Unlock() - delete(m.versions, session.ServerID) - delete(m.scalarIndexVersions, session.ServerID) - delete(m.indexNonEncoding, session.ServerID) + m.removeNodeByID(session.ServerID) +} + +func (m *versionManagerImpl) removeNodeByID(sessionID int64) { + delete(m.versions, sessionID) + delete(m.scalarIndexVersions, sessionID) + delete(m.indexNonEncoding, sessionID) } func (m *versionManagerImpl) Update(session *sessionutil.Session) { diff --git a/internal/datacoord/index_engine_version_manager_test.go b/internal/datacoord/index_engine_version_manager_test.go index 82c290a045..9b67c095ec 100644 --- a/internal/datacoord/index_engine_version_manager_test.go +++ b/internal/datacoord/index_engine_version_manager_test.go @@ -155,3 +155,164 @@ func Test_IndexEngineVersionManager_GetIndexNoneEncoding(t *testing.T) { // after removing server1, then global none encoding should be true assert.True(t, m.GetIndexNonEncoding()) } + +func Test_IndexEngineVersionManager_StartupWithOfflineNodeCleanup(t *testing.T) { + m := newIndexEngineVersionManager() + + // First startup with initial nodes + m.Startup(map[string]*sessionutil.Session{ + "1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10}, + }, + }, + "2": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5}, + }, + }, + }) + + // Verify both nodes are present + assert.Equal(t, int32(15), m.GetCurrentIndexEngineVersion()) // min of 20 and 15 + assert.Equal(t, int32(10), m.GetMinimalIndexEngineVersion()) // max of 10 and 5 + + // Second startup with only one node online (node 2 is offline) + m.Startup(map[string]*sessionutil.Session{ + "1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 25, MinimalIndexVersion: 12}, + }, + }, + }) + + // Verify offline node 2 is cleaned up and only node 1 remains + assert.Equal(t, int32(25), m.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(12), m.GetMinimalIndexEngineVersion()) + + // Verify that node 2's data is actually removed from internal maps + vm := m.(*versionManagerImpl) + _, exists := vm.versions[2] + assert.False(t, exists, "offline node should be removed from versions map") + _, exists = vm.scalarIndexVersions[2] + assert.False(t, exists, "offline node should be removed from scalarIndexVersions map") + _, exists = vm.indexNonEncoding[2] + assert.False(t, exists, "offline node should be removed from indexNonEncoding map") +} + +func Test_IndexEngineVersionManager_StartupWithNewAndOfflineNodes(t *testing.T) { + m := newIndexEngineVersionManager() + + // First startup + m.Startup(map[string]*sessionutil.Session{ + "1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10}, + }, + }, + "2": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5}, + }, + }, + }) + + // Second startup: node 2 offline, node 3 comes online, node 1 still online + m.Startup(map[string]*sessionutil.Session{ + "1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 22, MinimalIndexVersion: 11}, + }, + }, + "3": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 3, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 18, MinimalIndexVersion: 8}, + }, + }, + }) + + // Verify node 2 is cleaned up and node 3 is added + assert.Equal(t, int32(18), m.GetCurrentIndexEngineVersion()) // min of 22 and 18 + assert.Equal(t, int32(11), m.GetMinimalIndexEngineVersion()) // max of 11 and 8 + + vm := m.(*versionManagerImpl) + // Node 1 should still exist + _, exists := vm.versions[1] + assert.True(t, exists, "online node 1 should remain") + // Node 2 should be removed + _, exists = vm.versions[2] + assert.False(t, exists, "offline node 2 should be removed") + // Node 3 should be added + _, exists = vm.versions[3] + assert.True(t, exists, "new online node 3 should be added") +} + +func Test_IndexEngineVersionManager_StartupWithEmptySession(t *testing.T) { + m := newIndexEngineVersionManager() + + // First startup with nodes + m.Startup(map[string]*sessionutil.Session{ + "1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10}, + }, + }, + }) + + assert.Equal(t, int32(20), m.GetCurrentIndexEngineVersion()) + + // Second startup with no nodes (all offline) + m.Startup(map[string]*sessionutil.Session{}) + + // Should return default values when no nodes are online + assert.Equal(t, int32(0), m.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(0), m.GetMinimalIndexEngineVersion()) + + vm := m.(*versionManagerImpl) + assert.Empty(t, vm.versions, "all nodes should be cleaned up") + assert.Empty(t, vm.scalarIndexVersions, "all nodes should be cleaned up") + assert.Empty(t, vm.indexNonEncoding, "all nodes should be cleaned up") +} + +func Test_IndexEngineVersionManager_removeNodeByID(t *testing.T) { + m := newIndexEngineVersionManager() + + // Add some nodes first + m.AddNode(&sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10}, + ScalarIndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5}, + IndexNonEncoding: true, + }, + }) + + vm := m.(*versionManagerImpl) + + // Verify node is added + _, exists := vm.versions[1] + assert.True(t, exists) + _, exists = vm.scalarIndexVersions[1] + assert.True(t, exists) + _, exists = vm.indexNonEncoding[1] + assert.True(t, exists) + + // Remove node by ID + vm.removeNodeByID(1) + + // Verify node is completely removed + _, exists = vm.versions[1] + assert.False(t, exists, "node should be removed from versions map") + _, exists = vm.scalarIndexVersions[1] + assert.False(t, exists, "node should be removed from scalarIndexVersions map") + _, exists = vm.indexNonEncoding[1] + assert.False(t, exists, "node should be removed from indexNonEncoding map") +} diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 0966f0b8ed..1093525b42 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -561,57 +561,24 @@ func (s *Server) initServiceDiscovery() error { return err } log.Info("DataCoord success to get DataNode sessions", zap.Any("sessions", sessions)) - - datanodes := make([]*session.NodeInfo, 0, len(sessions)) - legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue()) + err = s.rewatchDataNodes(sessions) if err != nil { - log.Warn("DataCoord failed to init service discovery", zap.Error(err)) - } - - for _, s := range sessions { - info := &session.NodeInfo{ - NodeID: s.ServerID, - Address: s.Address, - } - - if s.Version.LTE(legacyVersion) { - info.IsLegacy = true - } - - datanodes = append(datanodes, info) - } - - log.Info("DataCoord Cluster Manager start up") - if err := s.cluster.Startup(s.ctx, datanodes); err != nil { - log.Warn("DataCoord Cluster Manager failed to start up", zap.Error(err)) + log.Warn("DataCoord failed to rewatch datanode", zap.Error(err)) return err } - log.Info("DataCoord Cluster Manager start up successfully") - - // TODO implement rewatch logic - s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, nil) + s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes) inSessions, inRevision, err := s.session.GetSessions(typeutil.IndexNodeRole) if err != nil { log.Warn("DataCoord get QueryCoord session failed", zap.Error(err)) return err } - if Params.DataCoordCfg.BindIndexNodeMode.GetAsBool() { - if err = s.indexNodeManager.AddNode(Params.DataCoordCfg.IndexNodeID.GetAsInt64(), Params.DataCoordCfg.IndexNodeAddress.GetValue()); err != nil { - log.Error("add indexNode fail", zap.Int64("ServerID", Params.DataCoordCfg.IndexNodeID.GetAsInt64()), - zap.String("address", Params.DataCoordCfg.IndexNodeAddress.GetValue()), zap.Error(err)) - return err - } - log.Info("add indexNode success", zap.String("IndexNode address", Params.DataCoordCfg.IndexNodeAddress.GetValue()), - zap.Int64("nodeID", Params.DataCoordCfg.IndexNodeID.GetAsInt64())) - } else { - for _, session := range inSessions { - if err := s.indexNodeManager.AddNode(session.ServerID, session.Address); err != nil { - return err - } - } + err = s.rewatchIndexNodes(inSessions) + if err != nil { + log.Warn("DataCoord failed to rewatch indexnode", zap.Error(err)) + return err } - s.inEventCh = s.session.WatchServices(typeutil.IndexNodeRole, inRevision+1, nil) + s.inEventCh = s.session.WatchServices(typeutil.IndexNodeRole, inRevision+1, s.rewatchIndexNodes) s.indexEngineVersionManager = newIndexEngineVersionManager() qnSessions, qnRevision, err := s.session.GetSessions(typeutil.QueryNodeRole) @@ -619,12 +586,83 @@ func (s *Server) initServiceDiscovery() error { log.Warn("DataCoord get QueryNode sessions failed", zap.Error(err)) return err } - s.indexEngineVersionManager.Startup(qnSessions) - s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, nil) + s.rewatchQueryNodes(qnSessions) + s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes) return nil } +// rewatchQueryNodes is used to rewatch query nodes when datacoord is started or reconnected to etcd +// Note: may apply same node multiple times, so rewatchQueryNodes must be idempotent +func (s *Server) rewatchQueryNodes(sessions map[string]*sessionutil.Session) error { + s.indexEngineVersionManager.Startup(sessions) + return nil +} + +func (s *Server) rewatchIndexNodes(sessions map[string]*sessionutil.Session) error { + if Params.DataCoordCfg.BindIndexNodeMode.GetAsBool() { + nodes := make([]*session.NodeInfo, 0, 1) + nodes = append(nodes, &session.NodeInfo{ + NodeID: Params.DataCoordCfg.IndexNodeID.GetAsInt64(), + Address: Params.DataCoordCfg.IndexNodeAddress.GetValue(), + }) + if err := s.indexNodeManager.Startup(nodes); err != nil { + log.Error("add indexNode fail", zap.Int64("ServerID", Params.DataCoordCfg.IndexNodeID.GetAsInt64()), + zap.String("address", Params.DataCoordCfg.IndexNodeAddress.GetValue()), zap.Error(err)) + return err + } + log.Info("add indexNode success", zap.String("IndexNode address", Params.DataCoordCfg.IndexNodeAddress.GetValue()), + zap.Int64("nodeID", Params.DataCoordCfg.IndexNodeID.GetAsInt64())) + } else { + nodes := make([]*session.NodeInfo, 0, len(sessions)) + for _, s := range sessions { + nodes = append(nodes, &session.NodeInfo{ + NodeID: s.ServerID, + Address: s.Address, + }) + } + s.indexNodeManager.Startup(nodes) + } + return nil +} + +// rewatchDataNodes is used to rewatch data nodes when datacoord is started or reconnected to etcd +// Note: may apply same node multiple times, so rewatchDataNodes must be idempotent +func (s *Server) rewatchDataNodes(sessions map[string]*sessionutil.Session) error { + legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue()) + if err != nil { + log.Warn("DataCoord failed to init service discovery", zap.Error(err)) + return err + } + + datanodes := make([]*session.NodeInfo, 0, len(sessions)) + for _, ss := range sessions { + info := &session.NodeInfo{ + NodeID: ss.ServerID, + Address: ss.Address, + } + + if ss.Version.LTE(legacyVersion) { + info.IsLegacy = true + } + + datanodes = append(datanodes, info) + } + + // if err := s.nodeManager.Startup(s.ctx, datanodes); err != nil { + // log.Warn("DataCoord failed to add datanode", zap.Error(err)) + // return err + // } + + log.Info("DataCoord Cluster Manager start up") + if err := s.cluster.Startup(s.ctx, datanodes); err != nil { + log.Warn("DataCoord Cluster Manager failed to start up", zap.Error(err)) + return err + } + log.Info("DataCoord Cluster Manager start up successfully") + return nil +} + func (s *Server) initSegmentManager() error { if s.segmentManager == nil { manager, err := newSegmentManager(s.meta, s.allocator) diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 3ac40b68e3..9f5a8ec497 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -24,12 +24,13 @@ import ( "os/signal" "path" "strconv" - "strings" "sync" "syscall" "testing" "time" + "github.com/blang/semver/v4" + "github.com/bytedance/mockey" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -68,18 +69,7 @@ import ( ) func TestMain(m *testing.M) { - // init embed etcd - embedetcdServer, tempDir, err := etcd.StartTestEmbedEtcdServer() - if err != nil { - log.Fatal("failed to start embed etcd server", zap.Error(err)) - } - defer os.RemoveAll(tempDir) - defer embedetcdServer.Close() - - addrs := etcd.GetEmbedEtcdEndpoints(embedetcdServer) - paramtable.Init() - paramtable.Get().Save(Params.EtcdCfg.Endpoints.Key, strings.Join(addrs, ",")) rand.Seed(time.Now().UnixNano()) parameters := []string{"tikv", "etcd"} @@ -2434,6 +2424,141 @@ func closeTestServer(t *testing.T, svr *Server) { paramtable.Get().Reset(Params.CommonCfg.DataCoordTimeTick.Key) } +func TestServer_rewatchQueryNodes(t *testing.T) { + server := &Server{ + indexEngineVersionManager: newIndexEngineVersionManager(), + } + + // Test with empty sessions + err := server.rewatchQueryNodes(map[string]*sessionutil.Session{}) + assert.NoError(t, err) + + // Test with valid sessions + sessions := map[string]*sessionutil.Session{ + "session1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10}, + }, + }, + "session2": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5}, + }, + }, + } + + err = server.rewatchQueryNodes(sessions) + assert.NoError(t, err) + + // Verify the IndexEngineVersionManager received the sessions + assert.Equal(t, int32(15), server.indexEngineVersionManager.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(10), server.indexEngineVersionManager.GetMinimalIndexEngineVersion()) + + // Test idempotent behavior - calling again with same sessions should not cause issues + err = server.rewatchQueryNodes(sessions) + assert.NoError(t, err) + + // Verify values remain the same + assert.Equal(t, int32(15), server.indexEngineVersionManager.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(10), server.indexEngineVersionManager.GetMinimalIndexEngineVersion()) +} + +func TestServer_rewatchDataNodes_Success(t *testing.T) { + // Mock semver.Parse to avoid dependency on paramtable + mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build() + defer mockSemverParse.UnPatch() + + sessions := map[string]*sessionutil.Session{ + "session1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + Address: "localhost:9001", + Version: "2.3.0", + }, + }, + "session2": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + Address: "localhost:9002", + Version: "2.2.0", // legacy version + }, + }, + } + + server := &Server{ + ctx: context.Background(), + } + + // Create actual implementations + cluster := NewClusterImpl(nil, nil) + + server.cluster = cluster + + // Mock Cluster.Startup to succeed + mockClusterStartup := mockey.Mock((*ClusterImpl).Startup).Return(nil).Build() + defer mockClusterStartup.UnPatch() + + err := server.rewatchDataNodes(sessions) + assert.NoError(t, err) +} + +func TestServer_rewatchDataNodes_EmptySession(t *testing.T) { + // Mock semver.Parse to avoid dependency on paramtable + mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build() + defer mockSemverParse.UnPatch() + + server := &Server{ + ctx: context.Background(), + } + + // Create actual implementations + cluster := NewClusterImpl(nil, nil) + + server.cluster = cluster + + // Mock Cluster.Startup for empty nodes + mockStartup := mockey.Mock((*ClusterImpl).Startup).Return(nil).Build() + defer mockStartup.UnPatch() + + err := server.rewatchDataNodes(map[string]*sessionutil.Session{}) + assert.NoError(t, err) +} + +func TestServer_rewatchDataNodes_ClusterStartupFails(t *testing.T) { + // Mock semver.Parse to avoid dependency on paramtable + mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build() + defer mockSemverParse.UnPatch() + + sessions := map[string]*sessionutil.Session{ + "session1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + Address: "localhost:9001", + Version: "2.3.0", + }, + }, + } + + server := &Server{ + ctx: context.Background(), + } + + // Create actual implementations + cluster := NewClusterImpl(nil, nil) + + server.cluster = cluster + + // Mock Cluster.Startup to fail + mockStartup := mockey.Mock((*ClusterImpl).Startup).Return(errors.New("cluster startup failed")).Build() + defer mockStartup.UnPatch() + + err := server.rewatchDataNodes(sessions) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cluster startup failed") +} + func Test_CheckHealth(t *testing.T) { getSessionManager := func(isHealthy bool) *session.DataNodeManagerImpl { var client *mockDataNodeClient diff --git a/internal/datacoord/session/indexnode_manager.go b/internal/datacoord/session/indexnode_manager.go index cd6bb3fe00..865b6f1a78 100644 --- a/internal/datacoord/session/indexnode_manager.go +++ b/internal/datacoord/session/indexnode_manager.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/lock" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil" + "github.com/samber/lo" ) func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { @@ -46,6 +47,8 @@ type WorkerManager interface { QuerySlots() map[typeutil.UniqueID]*WorkerSlots GetAllClients() map[typeutil.UniqueID]types.IndexNodeClient GetClientByID(nodeID typeutil.UniqueID) (types.IndexNodeClient, bool) + + Startup(nodes []*NodeInfo) error } type WorkerSlots struct { @@ -265,3 +268,25 @@ func (nm *IndexNodeManager) getMetrics(ctx context.Context, req *milvuspb.GetMet } return ret } + +func (nm *IndexNodeManager) Startup(nodes []*NodeInfo) error { + // remove node which not exist in sessions + sessionMap := lo.SliceToMap(nodes, func(node *NodeInfo) (int64, *NodeInfo) { + return node.NodeID, node + }) + + // remove old nodes + for nodeID := range nm.nodeClients { + if _, ok := sessionMap[nodeID]; !ok { + nm.RemoveNode(nodeID) + } + } + + // add new nodes + for _, node := range nodes { + if err := nm.AddNode(node.NodeID, node.Address); err != nil { + return err + } + } + return nil +} diff --git a/internal/datacoord/session/indexnode_manager_test.go b/internal/datacoord/session/indexnode_manager_test.go index 9d9152632b..09a93060ef 100644 --- a/internal/datacoord/session/indexnode_manager_test.go +++ b/internal/datacoord/session/indexnode_manager_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/v2/proto/workerpb" @@ -114,3 +115,118 @@ func TestNodeManager_StoppingNode(t *testing.T) { assert.Equal(t, 0, len(nm.GetAllClients())) assert.Equal(t, 0, len(nm.stoppingNodes)) } + +func TestNodeManager_Startup_NewNodes(t *testing.T) { + nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { + return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool()) + } + + ctx := context.Background() + nm := NewNodeManager(ctx, nodeCreator) + + // Define test nodes + nodes := []*NodeInfo{ + {NodeID: 1, Address: "localhost:8080"}, + {NodeID: 2, Address: "localhost:8081"}, + } + + err := nm.Startup(nodes) + assert.NoError(t, err) + + // Verify nodes were added + ids := nm.GetAllClients() + assert.Len(t, ids, 2) + assert.Contains(t, ids, int64(1)) + assert.Contains(t, ids, int64(2)) + + // Verify clients are accessible + _, ok := nm.GetClientByID(1) + assert.True(t, ok) + + _, ok = nm.GetClientByID(2) + assert.True(t, ok) +} + +func TestNodeManager_Startup_RemoveOldNodes(t *testing.T) { + nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { + return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool()) + } + + ctx := context.Background() + nm := NewNodeManager(ctx, nodeCreator) + + // Add initial nodes + err := nm.AddNode(1, "localhost:8080") + assert.NoError(t, err) + err = nm.AddNode(2, "localhost:8081") + assert.NoError(t, err) + + // Startup with new set of nodes (removes node 1, keeps node 2, adds node 3) + newNodes := []*NodeInfo{ + {NodeID: 2, Address: "localhost:8081"}, // existing node + {NodeID: 3, Address: "localhost:8082"}, // new node + } + + err = nm.Startup(newNodes) + assert.NoError(t, err) + + // Verify final state + ids := nm.GetAllClients() + assert.Len(t, ids, 2) + assert.Contains(t, ids, int64(2)) + assert.Contains(t, ids, int64(3)) + assert.NotContains(t, ids, int64(1)) + + // Verify node 1 is removed + _, ok := nm.GetClientByID(1) + assert.False(t, ok) + + // Verify nodes 2 and 3 are accessible + _, ok = nm.GetClientByID(2) + assert.True(t, ok) + + _, ok = nm.GetClientByID(3) + assert.True(t, ok) +} + +func TestNodeManager_Startup_EmptyNodes(t *testing.T) { + nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { + return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool()) + } + + ctx := context.Background() + nm := NewNodeManager(ctx, nodeCreator) + + // Add initial node + err := nm.AddNode(1, "localhost:8080") + assert.NoError(t, err) + + // Startup with empty nodes (should remove all existing nodes) + err = nm.Startup(nil) + assert.NoError(t, err) + + // Verify all nodes are removed + ids := nm.GetAllClients() + assert.Empty(t, ids) +} + +func TestNodeManager_Startup_AddNodeError(t *testing.T) { + nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { + if nodeID == 1 { + return nil, assert.AnError + } + return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool()) + } + + ctx := context.Background() + nm := NewNodeManager(ctx, nodeCreator) + + nodes := []*NodeInfo{ + {NodeID: 1, Address: "localhost:8080"}, // This will fail + {NodeID: 2, Address: "localhost:8081"}, + } + + err := nm.Startup(nodes) + assert.Error(t, err) + assert.Contains(t, err.Error(), "assert.AnError") +} diff --git a/internal/datacoord/session/mock_worker_manager.go b/internal/datacoord/session/mock_worker_manager.go index 4190219e92..b003a69d0a 100644 --- a/internal/datacoord/session/mock_worker_manager.go +++ b/internal/datacoord/session/mock_worker_manager.go @@ -309,6 +309,52 @@ func (_c *MockWorkerManager_RemoveNode_Call) RunAndReturn(run func(int64)) *Mock return _c } +// Startup provides a mock function with given fields: nodes +func (_m *MockWorkerManager) Startup(nodes []*NodeInfo) error { + ret := _m.Called(nodes) + + if len(ret) == 0 { + panic("no return value specified for Startup") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]*NodeInfo) error); ok { + r0 = rf(nodes) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockWorkerManager_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' +type MockWorkerManager_Startup_Call struct { + *mock.Call +} + +// Startup is a helper method to define mock.On call +// - nodes []*NodeInfo +func (_e *MockWorkerManager_Expecter) Startup(nodes interface{}) *MockWorkerManager_Startup_Call { + return &MockWorkerManager_Startup_Call{Call: _e.mock.On("Startup", nodes)} +} + +func (_c *MockWorkerManager_Startup_Call) Run(run func(nodes []*NodeInfo)) *MockWorkerManager_Startup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]*NodeInfo)) + }) + return _c +} + +func (_c *MockWorkerManager_Startup_Call) Return(_a0 error) *MockWorkerManager_Startup_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWorkerManager_Startup_Call) RunAndReturn(run func([]*NodeInfo) error) *MockWorkerManager_Startup_Call { + _c.Call.Return(run) + return _c +} + // StoppingNode provides a mock function with given fields: nodeID func (_m *MockWorkerManager) StoppingNode(nodeID int64) { _m.Called(nodeID) diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index 95514056d4..0cef153428 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" @@ -125,7 +126,9 @@ func (dh *distHandler) handleDistResp(ctx context.Context, resp *querypb.GetData log.Warn("node last heart beat time lag too behind", zap.Time("now", time.Now()), zap.Time("lastHeartBeatTime", node.LastHeartbeat()), zap.Int64("nodeID", node.ID())) } - node.SetLastHeartbeat(time.Now()) + now := time.Now() + node.SetLastHeartbeat(now) + metrics.QueryCoordLastHeartbeatTimeStamp.WithLabelValues(fmt.Sprint(resp.GetNodeID())).Set(float64(now.UnixNano())) // skip update dist if no distribution change happens in query node if resp.GetLastModifyTs() != 0 && resp.GetLastModifyTs() <= dh.lastUpdateTs { diff --git a/internal/querycoordv2/dist/dist_handler_test.go b/internal/querycoordv2/dist/dist_handler_test.go index c34da497ff..1966ec8ab1 100644 --- a/internal/querycoordv2/dist/dist_handler_test.go +++ b/internal/querycoordv2/dist/dist_handler_test.go @@ -18,16 +18,21 @@ package dist import ( "context" + "fmt" "testing" "time" + "github.com/bytedance/mockey" "github.com/cockroachdb/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -191,6 +196,82 @@ func (suite *DistHandlerSuite) TestForcePullDist() { time.Sleep(300 * time.Millisecond) } +// TestHeartbeatMetricsRecording tests that heartbeat metrics are properly recorded +func TestHeartbeatMetricsRecording(t *testing.T) { + // Arrange: Create test response with a unique nodeID to avoid test interference + nodeID := time.Now().UnixNano() % 1000000 // Use timestamp-based unique ID + resp := &querypb.GetDataDistributionResponse{ + Status: merr.Success(), + NodeID: nodeID, + LastModifyTs: 1, + } + + // Create mock node + nodeManager := session.NewNodeManager() + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID, + Address: "localhost:19530", + Hostname: "localhost", + }) + nodeManager.Add(nodeInfo) + + // Mock time.Now() to get predictable timestamp + expectedTimestamp := time.Unix(1640995200, 0) // 2022-01-01 00:00:00 UTC + mockTimeNow := mockey.Mock(time.Now).Return(expectedTimestamp).Build() + defer mockTimeNow.UnPatch() + + // Record the initial state of the metric for our specific nodeID + initialMetricValue := getMetricValueForNode(fmt.Sprint(nodeID)) + + // Create dist handler + ctx := context.Background() + handler := &distHandler{ + nodeID: nodeID, + nodeManager: nodeManager, + dist: meta.NewDistributionManager(), + target: meta.NewTargetManager(nil, nil), + scheduler: task.NewScheduler(ctx, nil, nil, nil, nil, nil, nil), + } + + // Act: Handle distribution response + handler.handleDistResp(ctx, resp, false) + + // Assert: Verify our specific metric was recorded with the expected value + finalMetricValue := getMetricValueForNode(fmt.Sprint(nodeID)) + + // Check that the metric value changed and matches our expected timestamp + assert.NotEqual(t, initialMetricValue, finalMetricValue, "Metric value should have changed") + assert.Equal(t, float64(expectedTimestamp.UnixNano()), finalMetricValue, "Metric should record the expected timestamp") + + // Clean up: Remove the test metric to avoid affecting other tests + metrics.QueryCoordLastHeartbeatTimeStamp.DeleteLabelValues(fmt.Sprint(nodeID)) +} + +// Helper function to get the current metric value for a specific nodeID +func getMetricValueForNode(nodeID string) float64 { + // Create a temporary registry to capture the current state + registry := prometheus.NewRegistry() + registry.MustRegister(metrics.QueryCoordLastHeartbeatTimeStamp) + + metricFamilies, err := registry.Gather() + if err != nil { + return -1 // Return -1 if we can't gather metrics + } + + for _, mf := range metricFamilies { + if mf.GetName() == "milvus_querycoord_last_heartbeat_timestamp" { + for _, metric := range mf.GetMetric() { + for _, label := range metric.GetLabel() { + if label.GetName() == "node_id" && label.GetValue() == nodeID { + return metric.GetGauge().GetValue() + } + } + } + } + } + return 0 // Return 0 if metric not found (default value) +} + func TestDistHandlerSuite(t *testing.T) { suite.Run(t, new(DistHandlerSuite)) } diff --git a/internal/querycoordv2/meta/resource_manager.go b/internal/querycoordv2/meta/resource_manager.go index aaa30d0bd6..268599c258 100644 --- a/internal/querycoordv2/meta/resource_manager.go +++ b/internal/querycoordv2/meta/resource_manager.go @@ -465,6 +465,13 @@ func (rm *ResourceManager) HandleNodeUp(ctx context.Context, node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() + rm.handleNodeUp(ctx, node) +} + +func (rm *ResourceManager) handleNodeUp(ctx context.Context, node int64) { + if nodeInfo := rm.nodeMgr.Get(node); nodeInfo == nil || nodeInfo.IsEmbeddedQueryNodeInStreamingNode() { + return + } rm.incomingNode.Insert(node) // Trigger assign incoming node right away. // error can be ignored here, because `AssignPendingIncomingNode`` will retry assign node. @@ -480,7 +487,10 @@ func (rm *ResourceManager) HandleNodeUp(ctx context.Context, node int64) { func (rm *ResourceManager) HandleNodeDown(ctx context.Context, node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() + rm.handleNodeDown(ctx, node) +} +func (rm *ResourceManager) handleNodeDown(ctx context.Context, node int64) { rm.incomingNode.Remove(node) // for stopping query node becomes offline, node change won't be triggered, @@ -500,7 +510,10 @@ func (rm *ResourceManager) HandleNodeDown(ctx context.Context, node int64) { func (rm *ResourceManager) HandleNodeStopping(ctx context.Context, node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() + rm.handleNodeStopping(ctx, node) +} +func (rm *ResourceManager) handleNodeStopping(ctx context.Context, node int64) { rm.incomingNode.Remove(node) rgName, err := rm.unassignNode(ctx, node) log.Info("HandleNodeStopping: remove node from resource group", @@ -994,3 +1007,33 @@ func (rm *ResourceManager) GetResourceGroupsJSON(ctx context.Context) string { return string(ret) } + +func (rm *ResourceManager) CheckNodesInResourceGroup(ctx context.Context) { + rm.rwmutex.RLock() + defer rm.rwmutex.RUnlock() + + // clean stopping/offline nodes + assignedNodes := typeutil.NewUniqueSet() + for _, rg := range rm.groups { + for _, node := range rg.GetNodes() { + assignedNodes.Insert(node) + info := rm.nodeMgr.Get(node) + if info == nil { + rm.handleNodeDown(ctx, node) + } else if info.GetState() == session.NodeStateStopping { + log.Warn("node is stopping", zap.Int64("node", node)) + rm.handleNodeStopping(ctx, node) + } else if info.IsEmbeddedQueryNodeInStreamingNode() { + log.Warn("unreachable code, but just for dirty meta clean up", zap.Int64("node", node)) + rm.handleNodeStopping(ctx, node) + } + } + } + + // add new nodes + for _, node := range rm.nodeMgr.GetAll() { + if !assignedNodes.Contain(node.ID()) { + rm.handleNodeUp(context.Background(), node.ID()) + } + } +} diff --git a/internal/querycoordv2/meta/resource_manager_test.go b/internal/querycoordv2/meta/resource_manager_test.go index 507cb441dc..131d8ec187 100644 --- a/internal/querycoordv2/meta/resource_manager_test.go +++ b/internal/querycoordv2/meta/resource_manager_test.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/v2/kv" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/util/etcd" @@ -221,6 +222,16 @@ func (suite *ResourceManagerSuite) TestManipulateResourceGroup() { // RemoveResourceGroup will remove all nodes from the resource group. err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.NoError(err) + + suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 10, + Address: "localhost", + Hostname: "localhost", + Labels: map[string]string{ + sessionutil.LabelStreamingNodeEmbeddedQueryNode: "1", + }, + })) + suite.manager.HandleNodeUp(ctx, 10) } func (suite *ResourceManagerSuite) TestNodeUpAndDown() { @@ -1005,3 +1016,208 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } } + +// createTestResourceManager creates a ResourceManager for testing +func createTestResourceManager(t *testing.T) *ResourceManager { + // Create a mock catalog + mockCatalog := &mocks.MetaKv{} + mockCatalog.On("MultiSave", mock.Anything, mock.Anything).Return(nil) + + // Create a mock node manager + nodeMgr := session.NewNodeManager() + + // Create resource manager + store := querycoord.NewCatalog(mockCatalog) + manager := NewResourceManager(store, nodeMgr) + + return manager +} + +// TestResourceManager_handleNodeUp tests the private handleNodeUp method +func TestResourceManager_handleNodeUp(t *testing.T) { + // Arrange + manager := createTestResourceManager(t) + ctx := context.Background() + nodeID := int64(1001) + + // Add node to node manager + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID, + Address: "localhost", + Hostname: "localhost", + }) + manager.nodeMgr.Add(nodeInfo) + + // Act + manager.handleNodeUp(ctx, nodeID) + + // Assert + // After successful assignment, node should be removed from incomingNode + assert.False(t, manager.incomingNode.Contain(nodeID)) + + // Verify node was assigned to default resource group + nodes, err := manager.GetNodes(ctx, DefaultResourceGroupName) + assert.NoError(t, err) + assert.Contains(t, nodes, nodeID) +} + +// TestResourceManager_handleNodeDown tests the private handleNodeDown method +func TestResourceManager_handleNodeDown(t *testing.T) { + // Arrange + manager := createTestResourceManager(t) + ctx := context.Background() + nodeID := int64(1002) + + // Add node to node manager + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID, + Address: "localhost", + Hostname: "localhost", + }) + manager.nodeMgr.Add(nodeInfo) + + // Add node to incoming set and assign it to a resource group first + manager.handleNodeUp(ctx, nodeID) + nodes, err := manager.GetNodes(ctx, DefaultResourceGroupName) + assert.NoError(t, err) + assert.Contains(t, nodes, nodeID) + + // Act + manager.handleNodeDown(ctx, nodeID) + + // Assert + assert.False(t, manager.incomingNode.Contain(nodeID)) + + // Verify node was removed from resource group + nodes, err = manager.GetNodes(ctx, DefaultResourceGroupName) + assert.NoError(t, err) + assert.NotContains(t, nodes, nodeID) + + // Verify node is no longer in nodeIDMap + _, exists := manager.nodeIDMap[nodeID] + assert.False(t, exists) +} + +// TestResourceManager_handleNodeStopping tests the private handleNodeStopping method +func TestResourceManager_handleNodeStopping(t *testing.T) { + // Arrange + manager := createTestResourceManager(t) + ctx := context.Background() + nodeID := int64(1003) + + // Add node to node manager + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID, + Address: "localhost", + Hostname: "localhost", + }) + manager.nodeMgr.Add(nodeInfo) + + // Add node to incoming set and assign it to a resource group first + manager.handleNodeUp(ctx, nodeID) + nodes, err := manager.GetNodes(ctx, DefaultResourceGroupName) + assert.NoError(t, err) + assert.Contains(t, nodes, nodeID) + + // Act + manager.handleNodeStopping(ctx, nodeID) + + // Assert + assert.False(t, manager.incomingNode.Contain(nodeID)) + + // Verify node was removed from resource group + nodes, err = manager.GetNodes(ctx, DefaultResourceGroupName) + assert.NoError(t, err) + assert.NotContains(t, nodes, nodeID) + + // Verify node is no longer in nodeIDMap + _, exists := manager.nodeIDMap[nodeID] + assert.False(t, exists) +} + +// TestResourceManager_CheckNodesInResourceGroup tests the CheckNodesInResourceGroup method +func TestResourceManager_CheckNodesInResourceGroup(t *testing.T) { + // Arrange + manager := createTestResourceManager(t) + + // Add some nodes to node manager + nodeInfo1 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1001, + Address: "localhost:1001", + Hostname: "localhost", + }) + nodeInfo2 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1002, + Address: "localhost:1002", + Hostname: "localhost", + }) + nodeInfo3 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1003, + Address: "localhost:1003", + Hostname: "localhost", + }) + manager.nodeMgr.Add(nodeInfo1) + manager.nodeMgr.Add(nodeInfo2) + manager.nodeMgr.Add(nodeInfo3) + + // Set node 1002 as stopping + nodeInfo2.SetState(session.NodeStateStopping) + + // Add nodes to default resource group + ctx := context.Background() + manager.handleNodeUp(ctx, 1001) + manager.handleNodeUp(ctx, 1002) + manager.handleNodeUp(ctx, 1004) + + // Act + manager.CheckNodesInResourceGroup(ctx) + + // Verify final state: offline node (1004) should be removed + finalNodes, err := manager.GetNodes(context.Background(), DefaultResourceGroupName) + assert.NoError(t, err) + assert.NotContains(t, finalNodes, int64(1004), "Offline node should be removed") + + // Verify stopping node (1002) should be removed + assert.NotContains(t, finalNodes, int64(1002), "Stopping node should be removed") + + // Verify healthy node (1001) should remain + assert.Contains(t, finalNodes, int64(1001), "Healthy node should remain") + + // Verify new node (1003) should be added + assert.Contains(t, finalNodes, int64(1003), "New node should be added") +} + +// TestResourceManager_CheckNodesInResourceGroup_AllNodesHealthy tests CheckNodesInResourceGroup with all healthy nodes +func TestResourceManager_CheckNodesInResourceGroup_AllNodesHealthy(t *testing.T) { + // Arrange + manager := createTestResourceManager(t) + + // Add some healthy nodes to node manager + nodeInfo1 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1001, + Address: "localhost:1001", + Hostname: "localhost", + }) + nodeInfo2 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1002, + Address: "localhost:1002", + Hostname: "localhost", + }) + manager.nodeMgr.Add(nodeInfo1) + manager.nodeMgr.Add(nodeInfo2) + + // Add nodes to default resource group + ctx := context.Background() + manager.handleNodeUp(ctx, 1001) + manager.handleNodeUp(ctx, 1002) + + // Act + manager.CheckNodesInResourceGroup(ctx) + + // Verify that healthy nodes remain unchanged + finalNodes, err := manager.GetNodes(ctx, DefaultResourceGroupName) + assert.NoError(t, err) + assert.Contains(t, finalNodes, int64(1001), "Healthy node should remain") + assert.Contains(t, finalNodes, int64(1002), "Healthy node should remain") + assert.Equal(t, 2, len(finalNodes), "Should have exactly 2 nodes") +} diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 195f861faa..4fb0492146 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -128,9 +128,6 @@ type Server struct { enableActiveStandBy bool activateFunc func() error - nodeUpEventChan chan int64 - notifyNodeUp chan struct{} - // proxy client manager proxyCreator proxyutil.ProxyCreator proxyWatcher proxyutil.ProxyWatcherInterface @@ -142,12 +139,10 @@ type Server struct { func NewQueryCoord(ctx context.Context) (*Server, error) { ctx, cancel := context.WithCancel(ctx) server := &Server{ - ctx: ctx, - cancel: cancel, - nodeUpEventChan: make(chan int64, 10240), - notifyNodeUp: make(chan struct{}), - balancerMap: make(map[string]balance.Balance), - metricsRequest: metricsinfo.NewMetricsRequest(), + ctx: ctx, + cancel: cancel, + balancerMap: make(map[string]balance.Balance), + metricsRequest: metricsinfo.NewMetricsRequest(), } server.UpdateStateCode(commonpb.StateCode_Abnormal) server.queryNodeCreator = session.DefaultQueryNodeCreator @@ -534,27 +529,14 @@ func (s *Server) startQueryCoord() error { if err != nil { return err } - for _, node := range sessions { - s.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: node.ServerID, - Address: node.Address, - Hostname: node.HostName, - Version: node.Version, - Labels: node.GetServerLabel(), - })) - s.taskScheduler.AddExecutor(node.ServerID) - if node.Stopping { - s.nodeMgr.Stopping(node.ServerID) - } - } - s.checkNodeStateInRG() - for _, node := range sessions { - s.handleNodeUp(node.ServerID) + log.Info("rewatch nodes", zap.Any("sessions", sessions)) + err = s.rewatchNodes(sessions) + if err != nil { + return err } - s.wg.Add(2) - go s.handleNodeUpLoop() + s.wg.Add(1) go s.watchNodes(revision) // check whether old node exist, if yes suspend auto balance until all old nodes down @@ -751,7 +733,7 @@ func (s *Server) watchNodes(revision int64) { log := log.Ctx(s.ctx) defer s.wg.Done() - eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, nil) + eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes) for { select { case <-s.ctx.Done(): @@ -771,14 +753,15 @@ func (s *Server) watchNodes(revision int64) { return } + nodeID := event.Session.ServerID + addr := event.Session.Address + log := log.With( + zap.Int64("nodeID", nodeID), + zap.String("nodeAddr", addr), + ) + switch event.EventType { case sessionutil.SessionAddEvent: - nodeID := event.Session.ServerID - addr := event.Session.Address - log.Info("add node to NodeManager", - zap.Int64("nodeID", nodeID), - zap.String("nodeAddr", addr), - ) s.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: nodeID, Address: addr, @@ -786,91 +769,90 @@ func (s *Server) watchNodes(revision int64) { Version: event.Session.Version, Labels: event.Session.GetServerLabel(), })) - s.nodeUpEventChan <- nodeID - select { - case s.notifyNodeUp <- struct{}{}: - default: - } + s.handleNodeUp(nodeID) case sessionutil.SessionUpdateEvent: - nodeID := event.Session.ServerID - addr := event.Session.Address - log.Info("stopping the node", - zap.Int64("nodeID", nodeID), - zap.String("nodeAddr", addr), - ) + log.Info("stopping the node") s.nodeMgr.Stopping(nodeID) - s.checkerController.Check() - s.meta.ResourceManager.HandleNodeStopping(context.Background(), nodeID) + s.handleNodeStopping(nodeID) case sessionutil.SessionDelEvent: - nodeID := event.Session.ServerID - log.Info("a node down, remove it", zap.Int64("nodeID", nodeID)) + log.Info("a node down, remove it") s.nodeMgr.Remove(nodeID) s.handleNodeDown(nodeID) - s.metricsCacheManager.InvalidateSystemInfoMetrics() } } } } -func (s *Server) handleNodeUpLoop() { - log := log.Ctx(s.ctx) - defer s.wg.Done() - ticker := time.NewTicker(Params.QueryCoordCfg.CheckHealthInterval.GetAsDuration(time.Millisecond)) - defer ticker.Stop() - for { - select { - case <-s.ctx.Done(): - log.Info("handle node up loop exit due to context done") - return - case <-s.notifyNodeUp: - s.tryHandleNodeUp() - case <-ticker.C: - s.tryHandleNodeUp() - } - } -} +// rewatchNodes is used to re-watch nodes when querycoord restart or reconnect to etcd +// Note: may apply same node multiple times, so rewatchNodes must be idempotent +func (s *Server) rewatchNodes(sessions map[string]*sessionutil.Session) error { + sessionMap := lo.MapKeys(sessions, func(s *sessionutil.Session, _ string) int64 { + return s.ServerID + }) -func (s *Server) tryHandleNodeUp() { - log := log.Ctx(s.ctx).WithRateGroup("qcv2.Server", 1, 60) - ctx, cancel := context.WithTimeout(s.ctx, Params.QueryCoordCfg.CheckHealthRPCTimeout.GetAsDuration(time.Millisecond)) - defer cancel() - reasons, err := s.checkNodeHealth(ctx) - if err != nil { - log.RatedWarn(10, "unhealthy node exist, node up will be delayed", - zap.Int("delayedNodeUpEvents", len(s.nodeUpEventChan)), - zap.Int("unhealthyNodeNum", len(reasons)), - zap.Strings("unhealthyReason", reasons)) - return - } - for len(s.nodeUpEventChan) > 0 { - nodeID := <-s.nodeUpEventChan - if s.nodeMgr.Get(nodeID) != nil { - // only if all nodes are healthy, node up event will be handled - s.handleNodeUp(nodeID) - s.metricsCacheManager.InvalidateSystemInfoMetrics() - s.checkerController.Check() - } else { - log.Warn("node already down", - zap.Int64("nodeID", nodeID)) + // first remove all offline nodes + for _, node := range s.nodeMgr.GetAll() { + nodeSession, ok := sessionMap[node.ID()] + if !ok { + // node in node manager but session not exist, means it's offline + s.nodeMgr.Remove(node.ID()) + s.handleNodeDown(node.ID()) + } else if nodeSession.Stopping && !node.IsStoppingState() { + // node in node manager but session is stopping, means it's stopping + s.nodeMgr.Stopping(node.ID()) + s.handleNodeStopping(node.ID()) } } + + // then add all on new online nodes + for _, nodeSession := range sessionMap { + nodeInfo := s.nodeMgr.Get(nodeSession.ServerID) + if nodeInfo == nil { + s.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeSession.GetServerID(), + Address: nodeSession.GetAddress(), + Hostname: nodeSession.HostName, + Version: nodeSession.Version, + Labels: nodeSession.GetServerLabel(), + })) + + if nodeSession.Stopping { + s.nodeMgr.Stopping(nodeSession.ServerID) + s.handleNodeStopping(nodeSession.ServerID) + } else { + s.handleNodeUp(nodeSession.GetServerID()) + } + } + } + + // Note: Node manager doesn't persist node list, so after query coord restart, we cannot + // update all node statuses in resource manager based on session and node manager's node list. + // Therefore, manual status checking of all nodes in resource manager is needed. + s.meta.ResourceManager.CheckNodesInResourceGroup(s.ctx) + + return nil } func (s *Server) handleNodeUp(node int64) { nodeInfo := s.nodeMgr.Get(node) if nodeInfo == nil { + log.Ctx(s.ctx).Warn("node already down", zap.Int64("nodeID", node)) return } + + // add executor to task scheduler s.taskScheduler.AddExecutor(node) + + // start dist handler s.distController.StartDistInstance(s.ctx, node) - if nodeInfo.IsEmbeddedQueryNodeInStreamingNode() { - // The querynode embedded in the streaming node can not work with streaming node. - return - } + // need assign to new rg and replica s.meta.ResourceManager.HandleNodeUp(s.ctx, node) + + s.metricsCacheManager.InvalidateSystemInfoMetrics() + s.checkerController.Check() } func (s *Server) handleNodeDown(node int64) { @@ -886,20 +868,21 @@ func (s *Server) handleNodeDown(node int64) { s.taskScheduler.RemoveByNode(node) s.meta.ResourceManager.HandleNodeDown(context.Background(), node) + + // clean node's metrics + metrics.QueryCoordLastHeartbeatTimeStamp.DeleteLabelValues(fmt.Sprint(node)) + s.metricsCacheManager.InvalidateSystemInfoMetrics() } -func (s *Server) checkNodeStateInRG() { - for _, rgName := range s.meta.ListResourceGroups(s.ctx) { - rg := s.meta.ResourceManager.GetResourceGroup(s.ctx, rgName) - for _, node := range rg.GetNodes() { - info := s.nodeMgr.Get(node) - if info == nil { - s.meta.ResourceManager.HandleNodeDown(context.Background(), node) - } else if info.IsStoppingState() { - s.meta.ResourceManager.HandleNodeStopping(context.Background(), node) - } - } - } +func (s *Server) handleNodeStopping(node int64) { + // mark node as stopping in node manager + s.nodeMgr.Stopping(node) + + // mark node as stopping in resource manager + s.meta.ResourceManager.HandleNodeStopping(context.Background(), node) + + // trigger checker to check stopping node + s.checkerController.Check() } func (s *Server) updateBalanceConfigLoop(ctx context.Context) { diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index bd3c843a40..e78f7f2fc0 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -18,6 +18,7 @@ package querycoordv2 import ( "context" + "fmt" "math/rand" "os" "sync" @@ -25,6 +26,7 @@ import ( "time" "github.com/bytedance/mockey" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -45,12 +47,14 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/tikv" ) @@ -186,7 +190,6 @@ func (suite *ServerSuite) TestNodeUp() { suite.NoError(err) defer node1.Stop() - suite.server.notifyNodeUp <- struct{}{} suite.Eventually(func() bool { node := suite.server.nodeMgr.Get(node1.ID) if node == nil { @@ -200,54 +203,6 @@ func (suite *ServerSuite) TestNodeUp() { } return true }, 5*time.Second, time.Second) - - // mock unhealthy node - suite.server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: 1001, - Address: "localhost", - Hostname: "localhost", - })) - - node2 := mocks.NewMockQueryNode(suite.T(), suite.server.etcdCli, 101) - node2.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{Status: merr.Success()}, nil).Maybe() - err = node2.Start() - suite.NoError(err) - defer node2.Stop() - - // expect node2 won't be add to qc, due to unhealthy nodes exist - suite.server.notifyNodeUp <- struct{}{} - suite.Eventually(func() bool { - node := suite.server.nodeMgr.Get(node2.ID) - if node == nil { - return false - } - for _, collection := range suite.collections { - replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID) - if replica == nil { - return true - } - } - return false - }, 5*time.Second, time.Second) - - // mock unhealthy node down, so no unhealthy nodes exist - suite.server.nodeMgr.Remove(1001) - suite.server.notifyNodeUp <- struct{}{} - - // expect node2 will be add to qc - suite.Eventually(func() bool { - node := suite.server.nodeMgr.Get(node2.ID) - if node == nil { - return false - } - for _, collection := range suite.collections { - replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID) - if replica == nil { - return false - } - } - return true - }, 5*time.Second, time.Second) } func (suite *ServerSuite) TestNodeUpdate() { @@ -749,6 +704,244 @@ func (suite *ServerSuite) newQueryCoord() (*Server, error) { return server, err } +// TestRewatchNodes tests the rewatchNodes function behavior +func TestRewatchNodes(t *testing.T) { + // Arrange: Create simple server instance + server := createSimpleTestServer() + + // Create test sessions + sessions := map[string]*sessionutil.Session{ + "querynode-1001": createTestSession(1001, "localhost:19530", false), + "querynode-1002": createTestSession(1002, "localhost:19531", false), + "querynode-1003": createTestSession(1003, "localhost:19532", true), // stopping + } + + // Pre-add some nodes to node manager to test removal logic + server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1001, + Address: "localhost:19530", + Hostname: "localhost", + })) + server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1004, // This node will be removed as it's not in sessions + Address: "localhost:19533", + Hostname: "localhost", + })) + + // Mock external calls + mockHandleNodeUp := mockey.Mock((*Server).handleNodeUp).Return().Build() + defer mockHandleNodeUp.UnPatch() + + mockHandleNodeDown := mockey.Mock((*Server).handleNodeDown).Return().Build() + defer mockHandleNodeDown.UnPatch() + + mockHandleNodeStopping := mockey.Mock((*Server).handleNodeStopping).Return().Build() + defer mockHandleNodeStopping.UnPatch() + + server.meta = &meta.Meta{ + ResourceManager: meta.NewResourceManager(nil, nil), + } + mockCheckNodesInResourceGroup := mockey.Mock((*meta.ResourceManager).CheckNodesInResourceGroup).Return().Build() + defer mockCheckNodesInResourceGroup.UnPatch() + + // Act: Call rewatchNodes + err := server.rewatchNodes(sessions) + + // Assert: Verify no error occurred + assert.NoError(t, err) + + // Verify node 1004 was removed + assert.Nil(t, server.nodeMgr.Get(1004), "Offline node should be removed") + + // Verify nodes 1001, 1002 exist + assert.NotNil(t, server.nodeMgr.Get(1001), "Online node should exist") + assert.NotNil(t, server.nodeMgr.Get(1002), "Online node should exist") + assert.NotNil(t, server.nodeMgr.Get(1003), "Stopping node should exist") +} + +// TestRewatchNodesWithEmptySessions tests rewatchNodes with empty sessions +func TestRewatchNodesWithEmptySessions(t *testing.T) { + // Arrange: Create server with existing nodes + server := createSimpleTestServer() + + // Add some existing nodes + server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1001, + Address: "localhost:19530", + Hostname: "localhost", + })) + server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1002, + Address: "localhost:19531", + Hostname: "localhost", + })) + + // Mock external calls + mockHandleNodeDown := mockey.Mock((*Server).handleNodeDown).Return().Build() + defer mockHandleNodeDown.UnPatch() + + server.meta = &meta.Meta{ + ResourceManager: meta.NewResourceManager(nil, nil), + } + mockCheckNodesInResourceGroup := mockey.Mock((*meta.ResourceManager).CheckNodesInResourceGroup).Return().Build() + defer mockCheckNodesInResourceGroup.UnPatch() + + // Act: Call rewatchNodes with empty sessions + err := server.rewatchNodes(nil) + + // Assert: All nodes should be removed + assert.NoError(t, err) + assert.Nil(t, server.nodeMgr.Get(1001), "All nodes should be removed when no sessions exist") + assert.Nil(t, server.nodeMgr.Get(1002), "All nodes should be removed when no sessions exist") +} + +// TestHandleNodeUpWithMissingNode tests handleNodeUp when node doesn't exist +func TestHandleNodeUpWithMissingNode(t *testing.T) { + // Arrange: Create server without adding the node + server := createSimpleTestServer() + + nodeID := int64(1001) + + // Act: Handle node up for non-existent node + server.handleNodeUp(nodeID) + + // Assert: Should handle gracefully (no panic, early return) + // The function should return early when node is not found +} + +// TestHandleNodeDownMetricsCleanup tests that handleNodeDown cleans up metrics properly +func TestHandleNodeDownMetricsCleanup(t *testing.T) { + // Arrange: Set up metrics with test value + nodeID := int64(1001) + + // Setup metrics with test value + registry := prometheus.NewRegistry() + metrics.RegisterQueryCoord(registry) + + // Set a test metric value + metrics.QueryCoordLastHeartbeatTimeStamp.WithLabelValues(fmt.Sprint(nodeID)).Set(1640995200.0) + + // Verify metric exists before deletion + metricFamilies, err := registry.Gather() + assert.NoError(t, err) + + found := false + for _, mf := range metricFamilies { + if mf.GetName() == "milvus_querycoord_last_heartbeat_timestamp" { + for _, metric := range mf.GetMetric() { + for _, label := range metric.GetLabel() { + if label.GetName() == "node_id" && label.GetValue() == fmt.Sprint(nodeID) { + found = true + break + } + } + } + } + } + assert.True(t, found, "Metric should exist before cleanup") + + // Create a minimal server + ctx := context.Background() + server := &Server{ + ctx: ctx, + taskScheduler: task.NewScheduler(ctx, nil, nil, nil, nil, nil, nil), + dist: meta.NewDistributionManager(), + distController: dist.NewDistController(nil, nil, nil, nil, nil, nil), + metricsCacheManager: metricsinfo.NewMetricsCacheManager(), + meta: &meta.Meta{ + ResourceManager: meta.NewResourceManager(nil, nil), + }, + } + + mockRemoveExecutor := mockey.Mock((task.Scheduler).RemoveExecutor).Return().Build() + defer mockRemoveExecutor.UnPatch() + mockRemoveByNode := mockey.Mock((task.Scheduler).RemoveByNode).Return().Build() + defer mockRemoveByNode.UnPatch() + mockDistControllerRemove := mockey.Mock((*dist.ControllerImpl).Remove).Return().Build() + defer mockDistControllerRemove.UnPatch() + mockRemoveFromManager := mockey.Mock(server.dist.ChannelDistManager.Update).Return().Build() + defer mockRemoveFromManager.UnPatch() + mockRemoveFromManager = mockey.Mock(server.dist.SegmentDistManager.Update).Return().Build() + defer mockRemoveFromManager.UnPatch() + mockInvalidateSystemInfoMetrics := mockey.Mock((*metricsinfo.MetricsCacheManager).InvalidateSystemInfoMetrics).Return().Build() + defer mockInvalidateSystemInfoMetrics.UnPatch() + + mockResourceManagerHandleNodeDown := mockey.Mock((*meta.ResourceManager).HandleNodeDown).Return().Build() + defer mockResourceManagerHandleNodeDown.UnPatch() + + // Act: Call handleNodeDown which should clean up metrics + server.handleNodeDown(nodeID) + + metricFamilies, err = registry.Gather() + assert.NoError(t, err) + + // Check that the heartbeat metric for this node was deleted + found = false + for _, mf := range metricFamilies { + if mf.GetName() == "milvus_querycoord_last_heartbeat_timestamp" { + for _, metric := range mf.GetMetric() { + for _, label := range metric.GetLabel() { + if label.GetName() == "node_id" && label.GetValue() == fmt.Sprint(nodeID) { + found = true + break + } + } + } + } + } + assert.False(t, found, "Metric should be cleaned up after handleNodeDown") +} + +// TestNodeManagerStopping tests the node manager stopping functionality +func TestNodeManagerStopping(t *testing.T) { + // Arrange: Create node manager and add a node + nodeID := int64(1001) + nodeMgr := session.NewNodeManager() + + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID, + Address: "localhost:19530", + Hostname: "localhost", + }) + nodeMgr.Add(nodeInfo) + + // Verify node exists and is not stopping initially + node := nodeMgr.Get(nodeID) + assert.NotNil(t, node) + assert.False(t, node.IsStoppingState(), "Node should not be stopping initially") + + // Act: Mark node as stopping + nodeMgr.Stopping(nodeID) + + // Assert: Node should be in stopping state + node = nodeMgr.Get(nodeID) + assert.NotNil(t, node) + assert.True(t, node.IsStoppingState(), "Node should be in stopping state after calling Stopping()") +} + +// Helper function to create a simple test server +func createSimpleTestServer() *Server { + ctx := context.Background() + server := &Server{ + ctx: ctx, + nodeMgr: session.NewNodeManager(), + } + return server +} + +// Helper function to create a test session +func createTestSession(nodeID int64, address string, stopping bool) *sessionutil.Session { + session := &sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: nodeID, + Address: address, + Stopping: stopping, + HostName: "localhost", + }, + } + return session +} + func TestServer(t *testing.T) { parameters := []string{"tikv", "etcd"} for _, v := range parameters { diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 592e297237..021c634328 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -221,6 +221,17 @@ func (suite *ServiceSuite) SetupTest() { proxyClientManager: suite.proxyManager, } + // Initialize checkerController to prevent nil pointer dereference in handleNodeUp + suite.server.checkerController = checkers.NewCheckerController( + suite.meta, + suite.dist, + suite.targetMgr, + suite.nodeMgr, + suite.taskScheduler, + suite.broker, + suite.server.getBalancerFunc, + ) + suite.server.registerMetricsRequest() suite.server.UpdateStateCode(commonpb.StateCode_Healthy) diff --git a/pkg/metrics/querycoord_metrics.go b/pkg/metrics/querycoord_metrics.go index c15bebb3b2..cc9b652e95 100644 --- a/pkg/metrics/querycoord_metrics.go +++ b/pkg/metrics/querycoord_metrics.go @@ -157,6 +157,14 @@ var ( Name: "replica_ro_node_total", Help: "total read only node number of replica", }) + + QueryCoordLastHeartbeatTimeStamp = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.QueryCoordRole, + Name: "last_heartbeat_timestamp", + Help: "heartbeat timestamp of query node", + }, []string{nodeIDLabelName}) ) // RegisterQueryCoord registers QueryCoord metrics @@ -174,6 +182,7 @@ func RegisterQueryCoord(registry *prometheus.Registry) { registry.MustRegister(QueryCoordResourceGroupInfo) registry.MustRegister(QueryCoordResourceGroupReplicaTotal) registry.MustRegister(QueryCoordReplicaRONodeTotal) + registry.MustRegister(QueryCoordLastHeartbeatTimeStamp) } func CleanQueryCoordMetricsWithCollectionID(collectionID int64) {