enhance: Implement rewatch mechanism for etcd failure scenarios (#43829) (#43920)

issue: #43828
pr: #43829 #43909
Implement robust rewatch mechanism to handle etcd connection failures
and node reconnection scenarios in DataCoord and QueryCoord, along with
heartbeat lag monitoring capabilities.

Changes include:
- Implement rewatchDataNodes/rewatchQueryNodes callbacks for etcd
reconnection scenarios
- Add idempotent rewatchNodes method to handle etcd session recovery
gracefully
- Add QueryCoordLastHeartbeatTimeStamp metric for monitoring node
heartbeat lag
- Clean up heartbeat metrics when nodes go down to prevent metric leaks

---------

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
Co-authored-by: Zhen Ye <chyezh@outlook.com>
This commit is contained in:
wei liu 2025-10-15 14:12:01 +08:00 committed by GitHub
parent 1f94b1f5f6
commit 47949fd883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1351 additions and 223 deletions

View File

@ -71,6 +71,22 @@ func NewClusterImpl(sessionManager session.DataNodeManager, channelManager Chann
// Startup inits the cluster with the given data nodes.
func (c *ClusterImpl) Startup(ctx context.Context, nodes []*session.NodeInfo) error {
oldNodes := c.sessionManager.GetSessions()
newNodesMap := lo.SliceToMap(nodes, func(info *session.NodeInfo) (int64, *session.NodeInfo) {
return info.NodeID, info
})
// clean offline nodes
for _, node := range oldNodes {
if _, ok := newNodesMap[node.NodeID()]; !ok {
c.sessionManager.DeleteSession(&session.NodeInfo{
NodeID: node.NodeID(),
Address: node.Address(),
})
}
}
// add new nodes
for _, node := range nodes {
c.sessionManager.AddSession(node)
}

View File

@ -20,8 +20,9 @@ import (
"context"
"testing"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@ -62,23 +63,67 @@ func (suite *ClusterSuite) SetupTest() {
func (suite *ClusterSuite) TearDownTest() {}
func (suite *ClusterSuite) TestStartup() {
func TestClusterImpl_Startup_NewNodes(t *testing.T) {
nodes := []*session.NodeInfo{
{NodeID: 1, Address: "addr1"},
{NodeID: 2, Address: "addr2"},
{NodeID: 3, Address: "addr3"},
{NodeID: 4, Address: "addr4"},
}
suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Times(len(nodes))
suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, legacys []int64, nodeIDs []int64) error {
suite.ElementsMatch(lo.Map(nodes, func(info *session.NodeInfo, _ int) int64 { return info.NodeID }), nodeIDs)
return nil
}).Once()
cluster := NewClusterImpl(suite.mockSession, suite.mockChManager)
// Mock the static functions called by ClusterImpl.Startup
mockGetSessions := mockey.Mock((*session.DataNodeManagerImpl).GetSessions).Return([]*session.Session{}).Build()
defer mockGetSessions.UnPatch()
newAddedNodes := make([]int64, 0, len(nodes))
mockAddSession := mockey.Mock((*session.DataNodeManagerImpl).AddSession).To(func(node *session.NodeInfo) {
newAddedNodes = append(newAddedNodes, node.NodeID)
}).Build()
defer mockAddSession.UnPatch()
mockChannelStartup := mockey.Mock((*ChannelManagerImpl).Startup).Return(nil).Build()
defer mockChannelStartup.UnPatch()
cluster := NewClusterImpl(&session.DataNodeManagerImpl{}, &ChannelManagerImpl{})
err := cluster.Startup(context.Background(), nodes)
suite.NoError(err)
assert.NoError(t, err)
assert.ElementsMatch(t, newAddedNodes, []int64{1, 2, 3, 4})
}
func TestClusterImpl_Startup_RemoveOldNodes(t *testing.T) {
// Create real session objects for testing
existingSession1 := session.NewSession(&session.NodeInfo{NodeID: 1, Address: "old-addr1"}, nil)
existingSession2 := session.NewSession(&session.NodeInfo{NodeID: 2, Address: "addr2"}, nil)
existingSessions := []*session.Session{existingSession1, existingSession2}
// New nodes to be added
newNodes := []*session.NodeInfo{
{NodeID: 2, Address: "addr2"}, // existing node (should not be removed)
{NodeID: 3, Address: "addr3"}, // new node
}
// Mock expectations
mockGetSessions := mockey.Mock((*session.DataNodeManagerImpl).GetSessions).Return(existingSessions).Build()
defer mockGetSessions.UnPatch()
removeNodes := make([]int64, 0, len(existingSessions))
mockDeleteSession := mockey.Mock((*session.DataNodeManagerImpl).DeleteSession).To(func(node *session.NodeInfo) {
removeNodes = append(removeNodes, node.NodeID)
}).Build()
defer mockDeleteSession.UnPatch()
mockAddSession := mockey.Mock((*session.DataNodeManagerImpl).AddSession).Return().Build()
defer mockAddSession.UnPatch()
mockChannelStartup := mockey.Mock((*ChannelManagerImpl).Startup).Return(nil).Build()
defer mockChannelStartup.UnPatch()
cluster := NewClusterImpl(&session.DataNodeManagerImpl{}, &ChannelManagerImpl{})
err := cluster.Startup(context.Background(), newNodes)
assert.NoError(t, err)
assert.ElementsMatch(t, removeNodes, []int64{1})
}
func (suite *ClusterSuite) TestRegister() {

View File

@ -3,6 +3,7 @@ package datacoord
import (
"math"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -43,6 +44,18 @@ func (m *versionManagerImpl) Startup(sessions map[string]*sessionutil.Session) {
m.mu.Lock()
defer m.mu.Unlock()
sessionMap := lo.MapKeys(sessions, func(session *sessionutil.Session, _ string) int64 {
return session.ServerID
})
// clean offline nodes
for sessionID := range m.versions {
if _, ok := sessionMap[sessionID]; !ok {
m.removeNodeByID(sessionID)
}
}
// deal with new online nodes
for _, session := range sessions {
m.addOrUpdate(session)
}
@ -59,9 +72,13 @@ func (m *versionManagerImpl) RemoveNode(session *sessionutil.Session) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.versions, session.ServerID)
delete(m.scalarIndexVersions, session.ServerID)
delete(m.indexNonEncoding, session.ServerID)
m.removeNodeByID(session.ServerID)
}
func (m *versionManagerImpl) removeNodeByID(sessionID int64) {
delete(m.versions, sessionID)
delete(m.scalarIndexVersions, sessionID)
delete(m.indexNonEncoding, sessionID)
}
func (m *versionManagerImpl) Update(session *sessionutil.Session) {

View File

@ -155,3 +155,164 @@ func Test_IndexEngineVersionManager_GetIndexNoneEncoding(t *testing.T) {
// after removing server1, then global none encoding should be true
assert.True(t, m.GetIndexNonEncoding())
}
func Test_IndexEngineVersionManager_StartupWithOfflineNodeCleanup(t *testing.T) {
m := newIndexEngineVersionManager()
// First startup with initial nodes
m.Startup(map[string]*sessionutil.Session{
"1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10},
},
},
"2": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 2,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5},
},
},
})
// Verify both nodes are present
assert.Equal(t, int32(15), m.GetCurrentIndexEngineVersion()) // min of 20 and 15
assert.Equal(t, int32(10), m.GetMinimalIndexEngineVersion()) // max of 10 and 5
// Second startup with only one node online (node 2 is offline)
m.Startup(map[string]*sessionutil.Session{
"1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 25, MinimalIndexVersion: 12},
},
},
})
// Verify offline node 2 is cleaned up and only node 1 remains
assert.Equal(t, int32(25), m.GetCurrentIndexEngineVersion())
assert.Equal(t, int32(12), m.GetMinimalIndexEngineVersion())
// Verify that node 2's data is actually removed from internal maps
vm := m.(*versionManagerImpl)
_, exists := vm.versions[2]
assert.False(t, exists, "offline node should be removed from versions map")
_, exists = vm.scalarIndexVersions[2]
assert.False(t, exists, "offline node should be removed from scalarIndexVersions map")
_, exists = vm.indexNonEncoding[2]
assert.False(t, exists, "offline node should be removed from indexNonEncoding map")
}
func Test_IndexEngineVersionManager_StartupWithNewAndOfflineNodes(t *testing.T) {
m := newIndexEngineVersionManager()
// First startup
m.Startup(map[string]*sessionutil.Session{
"1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10},
},
},
"2": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 2,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5},
},
},
})
// Second startup: node 2 offline, node 3 comes online, node 1 still online
m.Startup(map[string]*sessionutil.Session{
"1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 22, MinimalIndexVersion: 11},
},
},
"3": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 3,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 18, MinimalIndexVersion: 8},
},
},
})
// Verify node 2 is cleaned up and node 3 is added
assert.Equal(t, int32(18), m.GetCurrentIndexEngineVersion()) // min of 22 and 18
assert.Equal(t, int32(11), m.GetMinimalIndexEngineVersion()) // max of 11 and 8
vm := m.(*versionManagerImpl)
// Node 1 should still exist
_, exists := vm.versions[1]
assert.True(t, exists, "online node 1 should remain")
// Node 2 should be removed
_, exists = vm.versions[2]
assert.False(t, exists, "offline node 2 should be removed")
// Node 3 should be added
_, exists = vm.versions[3]
assert.True(t, exists, "new online node 3 should be added")
}
func Test_IndexEngineVersionManager_StartupWithEmptySession(t *testing.T) {
m := newIndexEngineVersionManager()
// First startup with nodes
m.Startup(map[string]*sessionutil.Session{
"1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10},
},
},
})
assert.Equal(t, int32(20), m.GetCurrentIndexEngineVersion())
// Second startup with no nodes (all offline)
m.Startup(map[string]*sessionutil.Session{})
// Should return default values when no nodes are online
assert.Equal(t, int32(0), m.GetCurrentIndexEngineVersion())
assert.Equal(t, int32(0), m.GetMinimalIndexEngineVersion())
vm := m.(*versionManagerImpl)
assert.Empty(t, vm.versions, "all nodes should be cleaned up")
assert.Empty(t, vm.scalarIndexVersions, "all nodes should be cleaned up")
assert.Empty(t, vm.indexNonEncoding, "all nodes should be cleaned up")
}
func Test_IndexEngineVersionManager_removeNodeByID(t *testing.T) {
m := newIndexEngineVersionManager()
// Add some nodes first
m.AddNode(&sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10},
ScalarIndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5},
IndexNonEncoding: true,
},
})
vm := m.(*versionManagerImpl)
// Verify node is added
_, exists := vm.versions[1]
assert.True(t, exists)
_, exists = vm.scalarIndexVersions[1]
assert.True(t, exists)
_, exists = vm.indexNonEncoding[1]
assert.True(t, exists)
// Remove node by ID
vm.removeNodeByID(1)
// Verify node is completely removed
_, exists = vm.versions[1]
assert.False(t, exists, "node should be removed from versions map")
_, exists = vm.scalarIndexVersions[1]
assert.False(t, exists, "node should be removed from scalarIndexVersions map")
_, exists = vm.indexNonEncoding[1]
assert.False(t, exists, "node should be removed from indexNonEncoding map")
}

View File

@ -561,57 +561,24 @@ func (s *Server) initServiceDiscovery() error {
return err
}
log.Info("DataCoord success to get DataNode sessions", zap.Any("sessions", sessions))
datanodes := make([]*session.NodeInfo, 0, len(sessions))
legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue())
err = s.rewatchDataNodes(sessions)
if err != nil {
log.Warn("DataCoord failed to init service discovery", zap.Error(err))
}
for _, s := range sessions {
info := &session.NodeInfo{
NodeID: s.ServerID,
Address: s.Address,
}
if s.Version.LTE(legacyVersion) {
info.IsLegacy = true
}
datanodes = append(datanodes, info)
}
log.Info("DataCoord Cluster Manager start up")
if err := s.cluster.Startup(s.ctx, datanodes); err != nil {
log.Warn("DataCoord Cluster Manager failed to start up", zap.Error(err))
log.Warn("DataCoord failed to rewatch datanode", zap.Error(err))
return err
}
log.Info("DataCoord Cluster Manager start up successfully")
// TODO implement rewatch logic
s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, nil)
s.dnEventCh = s.session.WatchServicesWithVersionRange(typeutil.DataNodeRole, r, rev+1, s.rewatchDataNodes)
inSessions, inRevision, err := s.session.GetSessions(typeutil.IndexNodeRole)
if err != nil {
log.Warn("DataCoord get QueryCoord session failed", zap.Error(err))
return err
}
if Params.DataCoordCfg.BindIndexNodeMode.GetAsBool() {
if err = s.indexNodeManager.AddNode(Params.DataCoordCfg.IndexNodeID.GetAsInt64(), Params.DataCoordCfg.IndexNodeAddress.GetValue()); err != nil {
log.Error("add indexNode fail", zap.Int64("ServerID", Params.DataCoordCfg.IndexNodeID.GetAsInt64()),
zap.String("address", Params.DataCoordCfg.IndexNodeAddress.GetValue()), zap.Error(err))
return err
}
log.Info("add indexNode success", zap.String("IndexNode address", Params.DataCoordCfg.IndexNodeAddress.GetValue()),
zap.Int64("nodeID", Params.DataCoordCfg.IndexNodeID.GetAsInt64()))
} else {
for _, session := range inSessions {
if err := s.indexNodeManager.AddNode(session.ServerID, session.Address); err != nil {
return err
}
}
err = s.rewatchIndexNodes(inSessions)
if err != nil {
log.Warn("DataCoord failed to rewatch indexnode", zap.Error(err))
return err
}
s.inEventCh = s.session.WatchServices(typeutil.IndexNodeRole, inRevision+1, nil)
s.inEventCh = s.session.WatchServices(typeutil.IndexNodeRole, inRevision+1, s.rewatchIndexNodes)
s.indexEngineVersionManager = newIndexEngineVersionManager()
qnSessions, qnRevision, err := s.session.GetSessions(typeutil.QueryNodeRole)
@ -619,12 +586,83 @@ func (s *Server) initServiceDiscovery() error {
log.Warn("DataCoord get QueryNode sessions failed", zap.Error(err))
return err
}
s.indexEngineVersionManager.Startup(qnSessions)
s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, nil)
s.rewatchQueryNodes(qnSessions)
s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, s.rewatchQueryNodes)
return nil
}
// rewatchQueryNodes is used to rewatch query nodes when datacoord is started or reconnected to etcd
// Note: may apply same node multiple times, so rewatchQueryNodes must be idempotent
func (s *Server) rewatchQueryNodes(sessions map[string]*sessionutil.Session) error {
s.indexEngineVersionManager.Startup(sessions)
return nil
}
func (s *Server) rewatchIndexNodes(sessions map[string]*sessionutil.Session) error {
if Params.DataCoordCfg.BindIndexNodeMode.GetAsBool() {
nodes := make([]*session.NodeInfo, 0, 1)
nodes = append(nodes, &session.NodeInfo{
NodeID: Params.DataCoordCfg.IndexNodeID.GetAsInt64(),
Address: Params.DataCoordCfg.IndexNodeAddress.GetValue(),
})
if err := s.indexNodeManager.Startup(nodes); err != nil {
log.Error("add indexNode fail", zap.Int64("ServerID", Params.DataCoordCfg.IndexNodeID.GetAsInt64()),
zap.String("address", Params.DataCoordCfg.IndexNodeAddress.GetValue()), zap.Error(err))
return err
}
log.Info("add indexNode success", zap.String("IndexNode address", Params.DataCoordCfg.IndexNodeAddress.GetValue()),
zap.Int64("nodeID", Params.DataCoordCfg.IndexNodeID.GetAsInt64()))
} else {
nodes := make([]*session.NodeInfo, 0, len(sessions))
for _, s := range sessions {
nodes = append(nodes, &session.NodeInfo{
NodeID: s.ServerID,
Address: s.Address,
})
}
s.indexNodeManager.Startup(nodes)
}
return nil
}
// rewatchDataNodes is used to rewatch data nodes when datacoord is started or reconnected to etcd
// Note: may apply same node multiple times, so rewatchDataNodes must be idempotent
func (s *Server) rewatchDataNodes(sessions map[string]*sessionutil.Session) error {
legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue())
if err != nil {
log.Warn("DataCoord failed to init service discovery", zap.Error(err))
return err
}
datanodes := make([]*session.NodeInfo, 0, len(sessions))
for _, ss := range sessions {
info := &session.NodeInfo{
NodeID: ss.ServerID,
Address: ss.Address,
}
if ss.Version.LTE(legacyVersion) {
info.IsLegacy = true
}
datanodes = append(datanodes, info)
}
// if err := s.nodeManager.Startup(s.ctx, datanodes); err != nil {
// log.Warn("DataCoord failed to add datanode", zap.Error(err))
// return err
// }
log.Info("DataCoord Cluster Manager start up")
if err := s.cluster.Startup(s.ctx, datanodes); err != nil {
log.Warn("DataCoord Cluster Manager failed to start up", zap.Error(err))
return err
}
log.Info("DataCoord Cluster Manager start up successfully")
return nil
}
func (s *Server) initSegmentManager() error {
if s.segmentManager == nil {
manager, err := newSegmentManager(s.meta, s.allocator)

View File

@ -24,12 +24,13 @@ import (
"os/signal"
"path"
"strconv"
"strings"
"sync"
"syscall"
"testing"
"time"
"github.com/blang/semver/v4"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -68,18 +69,7 @@ import (
)
func TestMain(m *testing.M) {
// init embed etcd
embedetcdServer, tempDir, err := etcd.StartTestEmbedEtcdServer()
if err != nil {
log.Fatal("failed to start embed etcd server", zap.Error(err))
}
defer os.RemoveAll(tempDir)
defer embedetcdServer.Close()
addrs := etcd.GetEmbedEtcdEndpoints(embedetcdServer)
paramtable.Init()
paramtable.Get().Save(Params.EtcdCfg.Endpoints.Key, strings.Join(addrs, ","))
rand.Seed(time.Now().UnixNano())
parameters := []string{"tikv", "etcd"}
@ -2434,6 +2424,141 @@ func closeTestServer(t *testing.T, svr *Server) {
paramtable.Get().Reset(Params.CommonCfg.DataCoordTimeTick.Key)
}
func TestServer_rewatchQueryNodes(t *testing.T) {
server := &Server{
indexEngineVersionManager: newIndexEngineVersionManager(),
}
// Test with empty sessions
err := server.rewatchQueryNodes(map[string]*sessionutil.Session{})
assert.NoError(t, err)
// Test with valid sessions
sessions := map[string]*sessionutil.Session{
"session1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10},
},
},
"session2": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 2,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5},
},
},
}
err = server.rewatchQueryNodes(sessions)
assert.NoError(t, err)
// Verify the IndexEngineVersionManager received the sessions
assert.Equal(t, int32(15), server.indexEngineVersionManager.GetCurrentIndexEngineVersion())
assert.Equal(t, int32(10), server.indexEngineVersionManager.GetMinimalIndexEngineVersion())
// Test idempotent behavior - calling again with same sessions should not cause issues
err = server.rewatchQueryNodes(sessions)
assert.NoError(t, err)
// Verify values remain the same
assert.Equal(t, int32(15), server.indexEngineVersionManager.GetCurrentIndexEngineVersion())
assert.Equal(t, int32(10), server.indexEngineVersionManager.GetMinimalIndexEngineVersion())
}
func TestServer_rewatchDataNodes_Success(t *testing.T) {
// Mock semver.Parse to avoid dependency on paramtable
mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build()
defer mockSemverParse.UnPatch()
sessions := map[string]*sessionutil.Session{
"session1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
Address: "localhost:9001",
Version: "2.3.0",
},
},
"session2": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 2,
Address: "localhost:9002",
Version: "2.2.0", // legacy version
},
},
}
server := &Server{
ctx: context.Background(),
}
// Create actual implementations
cluster := NewClusterImpl(nil, nil)
server.cluster = cluster
// Mock Cluster.Startup to succeed
mockClusterStartup := mockey.Mock((*ClusterImpl).Startup).Return(nil).Build()
defer mockClusterStartup.UnPatch()
err := server.rewatchDataNodes(sessions)
assert.NoError(t, err)
}
func TestServer_rewatchDataNodes_EmptySession(t *testing.T) {
// Mock semver.Parse to avoid dependency on paramtable
mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build()
defer mockSemverParse.UnPatch()
server := &Server{
ctx: context.Background(),
}
// Create actual implementations
cluster := NewClusterImpl(nil, nil)
server.cluster = cluster
// Mock Cluster.Startup for empty nodes
mockStartup := mockey.Mock((*ClusterImpl).Startup).Return(nil).Build()
defer mockStartup.UnPatch()
err := server.rewatchDataNodes(map[string]*sessionutil.Session{})
assert.NoError(t, err)
}
func TestServer_rewatchDataNodes_ClusterStartupFails(t *testing.T) {
// Mock semver.Parse to avoid dependency on paramtable
mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build()
defer mockSemverParse.UnPatch()
sessions := map[string]*sessionutil.Session{
"session1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
Address: "localhost:9001",
Version: "2.3.0",
},
},
}
server := &Server{
ctx: context.Background(),
}
// Create actual implementations
cluster := NewClusterImpl(nil, nil)
server.cluster = cluster
// Mock Cluster.Startup to fail
mockStartup := mockey.Mock((*ClusterImpl).Startup).Return(errors.New("cluster startup failed")).Build()
defer mockStartup.UnPatch()
err := server.rewatchDataNodes(sessions)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cluster startup failed")
}
func Test_CheckHealth(t *testing.T) {
getSessionManager := func(isHealthy bool) *session.DataNodeManagerImpl {
var client *mockDataNodeClient

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/lock"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
"github.com/samber/lo"
)
func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) {
@ -46,6 +47,8 @@ type WorkerManager interface {
QuerySlots() map[typeutil.UniqueID]*WorkerSlots
GetAllClients() map[typeutil.UniqueID]types.IndexNodeClient
GetClientByID(nodeID typeutil.UniqueID) (types.IndexNodeClient, bool)
Startup(nodes []*NodeInfo) error
}
type WorkerSlots struct {
@ -265,3 +268,25 @@ func (nm *IndexNodeManager) getMetrics(ctx context.Context, req *milvuspb.GetMet
}
return ret
}
func (nm *IndexNodeManager) Startup(nodes []*NodeInfo) error {
// remove node which not exist in sessions
sessionMap := lo.SliceToMap(nodes, func(node *NodeInfo) (int64, *NodeInfo) {
return node.NodeID, node
})
// remove old nodes
for nodeID := range nm.nodeClients {
if _, ok := sessionMap[nodeID]; !ok {
nm.RemoveNode(nodeID)
}
}
// add new nodes
for _, node := range nodes {
if err := nm.AddNode(node.NodeID, node.Address); err != nil {
return err
}
}
return nil
}

View File

@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
@ -114,3 +115,118 @@ func TestNodeManager_StoppingNode(t *testing.T) {
assert.Equal(t, 0, len(nm.GetAllClients()))
assert.Equal(t, 0, len(nm.stoppingNodes))
}
func TestNodeManager_Startup_NewNodes(t *testing.T) {
nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) {
return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool())
}
ctx := context.Background()
nm := NewNodeManager(ctx, nodeCreator)
// Define test nodes
nodes := []*NodeInfo{
{NodeID: 1, Address: "localhost:8080"},
{NodeID: 2, Address: "localhost:8081"},
}
err := nm.Startup(nodes)
assert.NoError(t, err)
// Verify nodes were added
ids := nm.GetAllClients()
assert.Len(t, ids, 2)
assert.Contains(t, ids, int64(1))
assert.Contains(t, ids, int64(2))
// Verify clients are accessible
_, ok := nm.GetClientByID(1)
assert.True(t, ok)
_, ok = nm.GetClientByID(2)
assert.True(t, ok)
}
func TestNodeManager_Startup_RemoveOldNodes(t *testing.T) {
nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) {
return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool())
}
ctx := context.Background()
nm := NewNodeManager(ctx, nodeCreator)
// Add initial nodes
err := nm.AddNode(1, "localhost:8080")
assert.NoError(t, err)
err = nm.AddNode(2, "localhost:8081")
assert.NoError(t, err)
// Startup with new set of nodes (removes node 1, keeps node 2, adds node 3)
newNodes := []*NodeInfo{
{NodeID: 2, Address: "localhost:8081"}, // existing node
{NodeID: 3, Address: "localhost:8082"}, // new node
}
err = nm.Startup(newNodes)
assert.NoError(t, err)
// Verify final state
ids := nm.GetAllClients()
assert.Len(t, ids, 2)
assert.Contains(t, ids, int64(2))
assert.Contains(t, ids, int64(3))
assert.NotContains(t, ids, int64(1))
// Verify node 1 is removed
_, ok := nm.GetClientByID(1)
assert.False(t, ok)
// Verify nodes 2 and 3 are accessible
_, ok = nm.GetClientByID(2)
assert.True(t, ok)
_, ok = nm.GetClientByID(3)
assert.True(t, ok)
}
func TestNodeManager_Startup_EmptyNodes(t *testing.T) {
nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) {
return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool())
}
ctx := context.Background()
nm := NewNodeManager(ctx, nodeCreator)
// Add initial node
err := nm.AddNode(1, "localhost:8080")
assert.NoError(t, err)
// Startup with empty nodes (should remove all existing nodes)
err = nm.Startup(nil)
assert.NoError(t, err)
// Verify all nodes are removed
ids := nm.GetAllClients()
assert.Empty(t, ids)
}
func TestNodeManager_Startup_AddNodeError(t *testing.T) {
nodeCreator := func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) {
if nodeID == 1 {
return nil, assert.AnError
}
return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool())
}
ctx := context.Background()
nm := NewNodeManager(ctx, nodeCreator)
nodes := []*NodeInfo{
{NodeID: 1, Address: "localhost:8080"}, // This will fail
{NodeID: 2, Address: "localhost:8081"},
}
err := nm.Startup(nodes)
assert.Error(t, err)
assert.Contains(t, err.Error(), "assert.AnError")
}

View File

@ -309,6 +309,52 @@ func (_c *MockWorkerManager_RemoveNode_Call) RunAndReturn(run func(int64)) *Mock
return _c
}
// Startup provides a mock function with given fields: nodes
func (_m *MockWorkerManager) Startup(nodes []*NodeInfo) error {
ret := _m.Called(nodes)
if len(ret) == 0 {
panic("no return value specified for Startup")
}
var r0 error
if rf, ok := ret.Get(0).(func([]*NodeInfo) error); ok {
r0 = rf(nodes)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockWorkerManager_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup'
type MockWorkerManager_Startup_Call struct {
*mock.Call
}
// Startup is a helper method to define mock.On call
// - nodes []*NodeInfo
func (_e *MockWorkerManager_Expecter) Startup(nodes interface{}) *MockWorkerManager_Startup_Call {
return &MockWorkerManager_Startup_Call{Call: _e.mock.On("Startup", nodes)}
}
func (_c *MockWorkerManager_Startup_Call) Run(run func(nodes []*NodeInfo)) *MockWorkerManager_Startup_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*NodeInfo))
})
return _c
}
func (_c *MockWorkerManager_Startup_Call) Return(_a0 error) *MockWorkerManager_Startup_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWorkerManager_Startup_Call) RunAndReturn(run func([]*NodeInfo) error) *MockWorkerManager_Startup_Call {
_c.Call.Return(run)
return _c
}
// StoppingNode provides a mock function with given fields: nodeID
func (_m *MockWorkerManager) StoppingNode(nodeID int64) {
_m.Called(nodeID)

View File

@ -33,6 +33,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -125,7 +126,9 @@ func (dh *distHandler) handleDistResp(ctx context.Context, resp *querypb.GetData
log.Warn("node last heart beat time lag too behind", zap.Time("now", time.Now()),
zap.Time("lastHeartBeatTime", node.LastHeartbeat()), zap.Int64("nodeID", node.ID()))
}
node.SetLastHeartbeat(time.Now())
now := time.Now()
node.SetLastHeartbeat(now)
metrics.QueryCoordLastHeartbeatTimeStamp.WithLabelValues(fmt.Sprint(resp.GetNodeID())).Set(float64(now.UnixNano()))
// skip update dist if no distribution change happens in query node
if resp.GetLastModifyTs() != 0 && resp.GetLastModifyTs() <= dh.lastUpdateTs {

View File

@ -18,16 +18,21 @@ package dist
import (
"context"
"fmt"
"testing"
"time"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -191,6 +196,82 @@ func (suite *DistHandlerSuite) TestForcePullDist() {
time.Sleep(300 * time.Millisecond)
}
// TestHeartbeatMetricsRecording tests that heartbeat metrics are properly recorded
func TestHeartbeatMetricsRecording(t *testing.T) {
// Arrange: Create test response with a unique nodeID to avoid test interference
nodeID := time.Now().UnixNano() % 1000000 // Use timestamp-based unique ID
resp := &querypb.GetDataDistributionResponse{
Status: merr.Success(),
NodeID: nodeID,
LastModifyTs: 1,
}
// Create mock node
nodeManager := session.NewNodeManager()
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID,
Address: "localhost:19530",
Hostname: "localhost",
})
nodeManager.Add(nodeInfo)
// Mock time.Now() to get predictable timestamp
expectedTimestamp := time.Unix(1640995200, 0) // 2022-01-01 00:00:00 UTC
mockTimeNow := mockey.Mock(time.Now).Return(expectedTimestamp).Build()
defer mockTimeNow.UnPatch()
// Record the initial state of the metric for our specific nodeID
initialMetricValue := getMetricValueForNode(fmt.Sprint(nodeID))
// Create dist handler
ctx := context.Background()
handler := &distHandler{
nodeID: nodeID,
nodeManager: nodeManager,
dist: meta.NewDistributionManager(),
target: meta.NewTargetManager(nil, nil),
scheduler: task.NewScheduler(ctx, nil, nil, nil, nil, nil, nil),
}
// Act: Handle distribution response
handler.handleDistResp(ctx, resp, false)
// Assert: Verify our specific metric was recorded with the expected value
finalMetricValue := getMetricValueForNode(fmt.Sprint(nodeID))
// Check that the metric value changed and matches our expected timestamp
assert.NotEqual(t, initialMetricValue, finalMetricValue, "Metric value should have changed")
assert.Equal(t, float64(expectedTimestamp.UnixNano()), finalMetricValue, "Metric should record the expected timestamp")
// Clean up: Remove the test metric to avoid affecting other tests
metrics.QueryCoordLastHeartbeatTimeStamp.DeleteLabelValues(fmt.Sprint(nodeID))
}
// Helper function to get the current metric value for a specific nodeID
func getMetricValueForNode(nodeID string) float64 {
// Create a temporary registry to capture the current state
registry := prometheus.NewRegistry()
registry.MustRegister(metrics.QueryCoordLastHeartbeatTimeStamp)
metricFamilies, err := registry.Gather()
if err != nil {
return -1 // Return -1 if we can't gather metrics
}
for _, mf := range metricFamilies {
if mf.GetName() == "milvus_querycoord_last_heartbeat_timestamp" {
for _, metric := range mf.GetMetric() {
for _, label := range metric.GetLabel() {
if label.GetName() == "node_id" && label.GetValue() == nodeID {
return metric.GetGauge().GetValue()
}
}
}
}
}
return 0 // Return 0 if metric not found (default value)
}
func TestDistHandlerSuite(t *testing.T) {
suite.Run(t, new(DistHandlerSuite))
}

View File

@ -465,6 +465,13 @@ func (rm *ResourceManager) HandleNodeUp(ctx context.Context, node int64) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
rm.handleNodeUp(ctx, node)
}
func (rm *ResourceManager) handleNodeUp(ctx context.Context, node int64) {
if nodeInfo := rm.nodeMgr.Get(node); nodeInfo == nil || nodeInfo.IsEmbeddedQueryNodeInStreamingNode() {
return
}
rm.incomingNode.Insert(node)
// Trigger assign incoming node right away.
// error can be ignored here, because `AssignPendingIncomingNode`` will retry assign node.
@ -480,7 +487,10 @@ func (rm *ResourceManager) HandleNodeUp(ctx context.Context, node int64) {
func (rm *ResourceManager) HandleNodeDown(ctx context.Context, node int64) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
rm.handleNodeDown(ctx, node)
}
func (rm *ResourceManager) handleNodeDown(ctx context.Context, node int64) {
rm.incomingNode.Remove(node)
// for stopping query node becomes offline, node change won't be triggered,
@ -500,7 +510,10 @@ func (rm *ResourceManager) HandleNodeDown(ctx context.Context, node int64) {
func (rm *ResourceManager) HandleNodeStopping(ctx context.Context, node int64) {
rm.rwmutex.Lock()
defer rm.rwmutex.Unlock()
rm.handleNodeStopping(ctx, node)
}
func (rm *ResourceManager) handleNodeStopping(ctx context.Context, node int64) {
rm.incomingNode.Remove(node)
rgName, err := rm.unassignNode(ctx, node)
log.Info("HandleNodeStopping: remove node from resource group",
@ -994,3 +1007,33 @@ func (rm *ResourceManager) GetResourceGroupsJSON(ctx context.Context) string {
return string(ret)
}
func (rm *ResourceManager) CheckNodesInResourceGroup(ctx context.Context) {
rm.rwmutex.RLock()
defer rm.rwmutex.RUnlock()
// clean stopping/offline nodes
assignedNodes := typeutil.NewUniqueSet()
for _, rg := range rm.groups {
for _, node := range rg.GetNodes() {
assignedNodes.Insert(node)
info := rm.nodeMgr.Get(node)
if info == nil {
rm.handleNodeDown(ctx, node)
} else if info.GetState() == session.NodeStateStopping {
log.Warn("node is stopping", zap.Int64("node", node))
rm.handleNodeStopping(ctx, node)
} else if info.IsEmbeddedQueryNodeInStreamingNode() {
log.Warn("unreachable code, but just for dirty meta clean up", zap.Int64("node", node))
rm.handleNodeStopping(ctx, node)
}
}
}
// add new nodes
for _, node := range rm.nodeMgr.GetAll() {
if !assignedNodes.Contain(node.ID()) {
rm.handleNodeUp(context.Background(), node.ID())
}
}
}

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/kv"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
@ -221,6 +222,16 @@ func (suite *ResourceManagerSuite) TestManipulateResourceGroup() {
// RemoveResourceGroup will remove all nodes from the resource group.
err = suite.manager.RemoveResourceGroup(ctx, "rg2")
suite.NoError(err)
suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 10,
Address: "localhost",
Hostname: "localhost",
Labels: map[string]string{
sessionutil.LabelStreamingNodeEmbeddedQueryNode: "1",
},
}))
suite.manager.HandleNodeUp(ctx, 10)
}
func (suite *ResourceManagerSuite) TestNodeUpAndDown() {
@ -1005,3 +1016,208 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() {
suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"])
}
}
// createTestResourceManager creates a ResourceManager for testing
func createTestResourceManager(t *testing.T) *ResourceManager {
// Create a mock catalog
mockCatalog := &mocks.MetaKv{}
mockCatalog.On("MultiSave", mock.Anything, mock.Anything).Return(nil)
// Create a mock node manager
nodeMgr := session.NewNodeManager()
// Create resource manager
store := querycoord.NewCatalog(mockCatalog)
manager := NewResourceManager(store, nodeMgr)
return manager
}
// TestResourceManager_handleNodeUp tests the private handleNodeUp method
func TestResourceManager_handleNodeUp(t *testing.T) {
// Arrange
manager := createTestResourceManager(t)
ctx := context.Background()
nodeID := int64(1001)
// Add node to node manager
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID,
Address: "localhost",
Hostname: "localhost",
})
manager.nodeMgr.Add(nodeInfo)
// Act
manager.handleNodeUp(ctx, nodeID)
// Assert
// After successful assignment, node should be removed from incomingNode
assert.False(t, manager.incomingNode.Contain(nodeID))
// Verify node was assigned to default resource group
nodes, err := manager.GetNodes(ctx, DefaultResourceGroupName)
assert.NoError(t, err)
assert.Contains(t, nodes, nodeID)
}
// TestResourceManager_handleNodeDown tests the private handleNodeDown method
func TestResourceManager_handleNodeDown(t *testing.T) {
// Arrange
manager := createTestResourceManager(t)
ctx := context.Background()
nodeID := int64(1002)
// Add node to node manager
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID,
Address: "localhost",
Hostname: "localhost",
})
manager.nodeMgr.Add(nodeInfo)
// Add node to incoming set and assign it to a resource group first
manager.handleNodeUp(ctx, nodeID)
nodes, err := manager.GetNodes(ctx, DefaultResourceGroupName)
assert.NoError(t, err)
assert.Contains(t, nodes, nodeID)
// Act
manager.handleNodeDown(ctx, nodeID)
// Assert
assert.False(t, manager.incomingNode.Contain(nodeID))
// Verify node was removed from resource group
nodes, err = manager.GetNodes(ctx, DefaultResourceGroupName)
assert.NoError(t, err)
assert.NotContains(t, nodes, nodeID)
// Verify node is no longer in nodeIDMap
_, exists := manager.nodeIDMap[nodeID]
assert.False(t, exists)
}
// TestResourceManager_handleNodeStopping tests the private handleNodeStopping method
func TestResourceManager_handleNodeStopping(t *testing.T) {
// Arrange
manager := createTestResourceManager(t)
ctx := context.Background()
nodeID := int64(1003)
// Add node to node manager
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID,
Address: "localhost",
Hostname: "localhost",
})
manager.nodeMgr.Add(nodeInfo)
// Add node to incoming set and assign it to a resource group first
manager.handleNodeUp(ctx, nodeID)
nodes, err := manager.GetNodes(ctx, DefaultResourceGroupName)
assert.NoError(t, err)
assert.Contains(t, nodes, nodeID)
// Act
manager.handleNodeStopping(ctx, nodeID)
// Assert
assert.False(t, manager.incomingNode.Contain(nodeID))
// Verify node was removed from resource group
nodes, err = manager.GetNodes(ctx, DefaultResourceGroupName)
assert.NoError(t, err)
assert.NotContains(t, nodes, nodeID)
// Verify node is no longer in nodeIDMap
_, exists := manager.nodeIDMap[nodeID]
assert.False(t, exists)
}
// TestResourceManager_CheckNodesInResourceGroup tests the CheckNodesInResourceGroup method
func TestResourceManager_CheckNodesInResourceGroup(t *testing.T) {
// Arrange
manager := createTestResourceManager(t)
// Add some nodes to node manager
nodeInfo1 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1001,
Address: "localhost:1001",
Hostname: "localhost",
})
nodeInfo2 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1002,
Address: "localhost:1002",
Hostname: "localhost",
})
nodeInfo3 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1003,
Address: "localhost:1003",
Hostname: "localhost",
})
manager.nodeMgr.Add(nodeInfo1)
manager.nodeMgr.Add(nodeInfo2)
manager.nodeMgr.Add(nodeInfo3)
// Set node 1002 as stopping
nodeInfo2.SetState(session.NodeStateStopping)
// Add nodes to default resource group
ctx := context.Background()
manager.handleNodeUp(ctx, 1001)
manager.handleNodeUp(ctx, 1002)
manager.handleNodeUp(ctx, 1004)
// Act
manager.CheckNodesInResourceGroup(ctx)
// Verify final state: offline node (1004) should be removed
finalNodes, err := manager.GetNodes(context.Background(), DefaultResourceGroupName)
assert.NoError(t, err)
assert.NotContains(t, finalNodes, int64(1004), "Offline node should be removed")
// Verify stopping node (1002) should be removed
assert.NotContains(t, finalNodes, int64(1002), "Stopping node should be removed")
// Verify healthy node (1001) should remain
assert.Contains(t, finalNodes, int64(1001), "Healthy node should remain")
// Verify new node (1003) should be added
assert.Contains(t, finalNodes, int64(1003), "New node should be added")
}
// TestResourceManager_CheckNodesInResourceGroup_AllNodesHealthy tests CheckNodesInResourceGroup with all healthy nodes
func TestResourceManager_CheckNodesInResourceGroup_AllNodesHealthy(t *testing.T) {
// Arrange
manager := createTestResourceManager(t)
// Add some healthy nodes to node manager
nodeInfo1 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1001,
Address: "localhost:1001",
Hostname: "localhost",
})
nodeInfo2 := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1002,
Address: "localhost:1002",
Hostname: "localhost",
})
manager.nodeMgr.Add(nodeInfo1)
manager.nodeMgr.Add(nodeInfo2)
// Add nodes to default resource group
ctx := context.Background()
manager.handleNodeUp(ctx, 1001)
manager.handleNodeUp(ctx, 1002)
// Act
manager.CheckNodesInResourceGroup(ctx)
// Verify that healthy nodes remain unchanged
finalNodes, err := manager.GetNodes(ctx, DefaultResourceGroupName)
assert.NoError(t, err)
assert.Contains(t, finalNodes, int64(1001), "Healthy node should remain")
assert.Contains(t, finalNodes, int64(1002), "Healthy node should remain")
assert.Equal(t, 2, len(finalNodes), "Should have exactly 2 nodes")
}

View File

@ -128,9 +128,6 @@ type Server struct {
enableActiveStandBy bool
activateFunc func() error
nodeUpEventChan chan int64
notifyNodeUp chan struct{}
// proxy client manager
proxyCreator proxyutil.ProxyCreator
proxyWatcher proxyutil.ProxyWatcherInterface
@ -142,12 +139,10 @@ type Server struct {
func NewQueryCoord(ctx context.Context) (*Server, error) {
ctx, cancel := context.WithCancel(ctx)
server := &Server{
ctx: ctx,
cancel: cancel,
nodeUpEventChan: make(chan int64, 10240),
notifyNodeUp: make(chan struct{}),
balancerMap: make(map[string]balance.Balance),
metricsRequest: metricsinfo.NewMetricsRequest(),
ctx: ctx,
cancel: cancel,
balancerMap: make(map[string]balance.Balance),
metricsRequest: metricsinfo.NewMetricsRequest(),
}
server.UpdateStateCode(commonpb.StateCode_Abnormal)
server.queryNodeCreator = session.DefaultQueryNodeCreator
@ -534,27 +529,14 @@ func (s *Server) startQueryCoord() error {
if err != nil {
return err
}
for _, node := range sessions {
s.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: node.ServerID,
Address: node.Address,
Hostname: node.HostName,
Version: node.Version,
Labels: node.GetServerLabel(),
}))
s.taskScheduler.AddExecutor(node.ServerID)
if node.Stopping {
s.nodeMgr.Stopping(node.ServerID)
}
}
s.checkNodeStateInRG()
for _, node := range sessions {
s.handleNodeUp(node.ServerID)
log.Info("rewatch nodes", zap.Any("sessions", sessions))
err = s.rewatchNodes(sessions)
if err != nil {
return err
}
s.wg.Add(2)
go s.handleNodeUpLoop()
s.wg.Add(1)
go s.watchNodes(revision)
// check whether old node exist, if yes suspend auto balance until all old nodes down
@ -751,7 +733,7 @@ func (s *Server) watchNodes(revision int64) {
log := log.Ctx(s.ctx)
defer s.wg.Done()
eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, nil)
eventChan := s.session.WatchServices(typeutil.QueryNodeRole, revision+1, s.rewatchNodes)
for {
select {
case <-s.ctx.Done():
@ -771,14 +753,15 @@ func (s *Server) watchNodes(revision int64) {
return
}
nodeID := event.Session.ServerID
addr := event.Session.Address
log := log.With(
zap.Int64("nodeID", nodeID),
zap.String("nodeAddr", addr),
)
switch event.EventType {
case sessionutil.SessionAddEvent:
nodeID := event.Session.ServerID
addr := event.Session.Address
log.Info("add node to NodeManager",
zap.Int64("nodeID", nodeID),
zap.String("nodeAddr", addr),
)
s.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID,
Address: addr,
@ -786,91 +769,90 @@ func (s *Server) watchNodes(revision int64) {
Version: event.Session.Version,
Labels: event.Session.GetServerLabel(),
}))
s.nodeUpEventChan <- nodeID
select {
case s.notifyNodeUp <- struct{}{}:
default:
}
s.handleNodeUp(nodeID)
case sessionutil.SessionUpdateEvent:
nodeID := event.Session.ServerID
addr := event.Session.Address
log.Info("stopping the node",
zap.Int64("nodeID", nodeID),
zap.String("nodeAddr", addr),
)
log.Info("stopping the node")
s.nodeMgr.Stopping(nodeID)
s.checkerController.Check()
s.meta.ResourceManager.HandleNodeStopping(context.Background(), nodeID)
s.handleNodeStopping(nodeID)
case sessionutil.SessionDelEvent:
nodeID := event.Session.ServerID
log.Info("a node down, remove it", zap.Int64("nodeID", nodeID))
log.Info("a node down, remove it")
s.nodeMgr.Remove(nodeID)
s.handleNodeDown(nodeID)
s.metricsCacheManager.InvalidateSystemInfoMetrics()
}
}
}
}
func (s *Server) handleNodeUpLoop() {
log := log.Ctx(s.ctx)
defer s.wg.Done()
ticker := time.NewTicker(Params.QueryCoordCfg.CheckHealthInterval.GetAsDuration(time.Millisecond))
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
log.Info("handle node up loop exit due to context done")
return
case <-s.notifyNodeUp:
s.tryHandleNodeUp()
case <-ticker.C:
s.tryHandleNodeUp()
}
}
}
// rewatchNodes is used to re-watch nodes when querycoord restart or reconnect to etcd
// Note: may apply same node multiple times, so rewatchNodes must be idempotent
func (s *Server) rewatchNodes(sessions map[string]*sessionutil.Session) error {
sessionMap := lo.MapKeys(sessions, func(s *sessionutil.Session, _ string) int64 {
return s.ServerID
})
func (s *Server) tryHandleNodeUp() {
log := log.Ctx(s.ctx).WithRateGroup("qcv2.Server", 1, 60)
ctx, cancel := context.WithTimeout(s.ctx, Params.QueryCoordCfg.CheckHealthRPCTimeout.GetAsDuration(time.Millisecond))
defer cancel()
reasons, err := s.checkNodeHealth(ctx)
if err != nil {
log.RatedWarn(10, "unhealthy node exist, node up will be delayed",
zap.Int("delayedNodeUpEvents", len(s.nodeUpEventChan)),
zap.Int("unhealthyNodeNum", len(reasons)),
zap.Strings("unhealthyReason", reasons))
return
}
for len(s.nodeUpEventChan) > 0 {
nodeID := <-s.nodeUpEventChan
if s.nodeMgr.Get(nodeID) != nil {
// only if all nodes are healthy, node up event will be handled
s.handleNodeUp(nodeID)
s.metricsCacheManager.InvalidateSystemInfoMetrics()
s.checkerController.Check()
} else {
log.Warn("node already down",
zap.Int64("nodeID", nodeID))
// first remove all offline nodes
for _, node := range s.nodeMgr.GetAll() {
nodeSession, ok := sessionMap[node.ID()]
if !ok {
// node in node manager but session not exist, means it's offline
s.nodeMgr.Remove(node.ID())
s.handleNodeDown(node.ID())
} else if nodeSession.Stopping && !node.IsStoppingState() {
// node in node manager but session is stopping, means it's stopping
s.nodeMgr.Stopping(node.ID())
s.handleNodeStopping(node.ID())
}
}
// then add all on new online nodes
for _, nodeSession := range sessionMap {
nodeInfo := s.nodeMgr.Get(nodeSession.ServerID)
if nodeInfo == nil {
s.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeSession.GetServerID(),
Address: nodeSession.GetAddress(),
Hostname: nodeSession.HostName,
Version: nodeSession.Version,
Labels: nodeSession.GetServerLabel(),
}))
if nodeSession.Stopping {
s.nodeMgr.Stopping(nodeSession.ServerID)
s.handleNodeStopping(nodeSession.ServerID)
} else {
s.handleNodeUp(nodeSession.GetServerID())
}
}
}
// Note: Node manager doesn't persist node list, so after query coord restart, we cannot
// update all node statuses in resource manager based on session and node manager's node list.
// Therefore, manual status checking of all nodes in resource manager is needed.
s.meta.ResourceManager.CheckNodesInResourceGroup(s.ctx)
return nil
}
func (s *Server) handleNodeUp(node int64) {
nodeInfo := s.nodeMgr.Get(node)
if nodeInfo == nil {
log.Ctx(s.ctx).Warn("node already down", zap.Int64("nodeID", node))
return
}
// add executor to task scheduler
s.taskScheduler.AddExecutor(node)
// start dist handler
s.distController.StartDistInstance(s.ctx, node)
if nodeInfo.IsEmbeddedQueryNodeInStreamingNode() {
// The querynode embedded in the streaming node can not work with streaming node.
return
}
// need assign to new rg and replica
s.meta.ResourceManager.HandleNodeUp(s.ctx, node)
s.metricsCacheManager.InvalidateSystemInfoMetrics()
s.checkerController.Check()
}
func (s *Server) handleNodeDown(node int64) {
@ -886,20 +868,21 @@ func (s *Server) handleNodeDown(node int64) {
s.taskScheduler.RemoveByNode(node)
s.meta.ResourceManager.HandleNodeDown(context.Background(), node)
// clean node's metrics
metrics.QueryCoordLastHeartbeatTimeStamp.DeleteLabelValues(fmt.Sprint(node))
s.metricsCacheManager.InvalidateSystemInfoMetrics()
}
func (s *Server) checkNodeStateInRG() {
for _, rgName := range s.meta.ListResourceGroups(s.ctx) {
rg := s.meta.ResourceManager.GetResourceGroup(s.ctx, rgName)
for _, node := range rg.GetNodes() {
info := s.nodeMgr.Get(node)
if info == nil {
s.meta.ResourceManager.HandleNodeDown(context.Background(), node)
} else if info.IsStoppingState() {
s.meta.ResourceManager.HandleNodeStopping(context.Background(), node)
}
}
}
func (s *Server) handleNodeStopping(node int64) {
// mark node as stopping in node manager
s.nodeMgr.Stopping(node)
// mark node as stopping in resource manager
s.meta.ResourceManager.HandleNodeStopping(context.Background(), node)
// trigger checker to check stopping node
s.checkerController.Check()
}
func (s *Server) updateBalanceConfigLoop(ctx context.Context) {

View File

@ -18,6 +18,7 @@ package querycoordv2
import (
"context"
"fmt"
"math/rand"
"os"
"sync"
@ -25,6 +26,7 @@ import (
"time"
"github.com/bytedance/mockey"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
@ -45,12 +47,14 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/tikv"
)
@ -186,7 +190,6 @@ func (suite *ServerSuite) TestNodeUp() {
suite.NoError(err)
defer node1.Stop()
suite.server.notifyNodeUp <- struct{}{}
suite.Eventually(func() bool {
node := suite.server.nodeMgr.Get(node1.ID)
if node == nil {
@ -200,54 +203,6 @@ func (suite *ServerSuite) TestNodeUp() {
}
return true
}, 5*time.Second, time.Second)
// mock unhealthy node
suite.server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1001,
Address: "localhost",
Hostname: "localhost",
}))
node2 := mocks.NewMockQueryNode(suite.T(), suite.server.etcdCli, 101)
node2.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{Status: merr.Success()}, nil).Maybe()
err = node2.Start()
suite.NoError(err)
defer node2.Stop()
// expect node2 won't be add to qc, due to unhealthy nodes exist
suite.server.notifyNodeUp <- struct{}{}
suite.Eventually(func() bool {
node := suite.server.nodeMgr.Get(node2.ID)
if node == nil {
return false
}
for _, collection := range suite.collections {
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID)
if replica == nil {
return true
}
}
return false
}, 5*time.Second, time.Second)
// mock unhealthy node down, so no unhealthy nodes exist
suite.server.nodeMgr.Remove(1001)
suite.server.notifyNodeUp <- struct{}{}
// expect node2 will be add to qc
suite.Eventually(func() bool {
node := suite.server.nodeMgr.Get(node2.ID)
if node == nil {
return false
}
for _, collection := range suite.collections {
replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID)
if replica == nil {
return false
}
}
return true
}, 5*time.Second, time.Second)
}
func (suite *ServerSuite) TestNodeUpdate() {
@ -749,6 +704,244 @@ func (suite *ServerSuite) newQueryCoord() (*Server, error) {
return server, err
}
// TestRewatchNodes tests the rewatchNodes function behavior
func TestRewatchNodes(t *testing.T) {
// Arrange: Create simple server instance
server := createSimpleTestServer()
// Create test sessions
sessions := map[string]*sessionutil.Session{
"querynode-1001": createTestSession(1001, "localhost:19530", false),
"querynode-1002": createTestSession(1002, "localhost:19531", false),
"querynode-1003": createTestSession(1003, "localhost:19532", true), // stopping
}
// Pre-add some nodes to node manager to test removal logic
server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1001,
Address: "localhost:19530",
Hostname: "localhost",
}))
server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1004, // This node will be removed as it's not in sessions
Address: "localhost:19533",
Hostname: "localhost",
}))
// Mock external calls
mockHandleNodeUp := mockey.Mock((*Server).handleNodeUp).Return().Build()
defer mockHandleNodeUp.UnPatch()
mockHandleNodeDown := mockey.Mock((*Server).handleNodeDown).Return().Build()
defer mockHandleNodeDown.UnPatch()
mockHandleNodeStopping := mockey.Mock((*Server).handleNodeStopping).Return().Build()
defer mockHandleNodeStopping.UnPatch()
server.meta = &meta.Meta{
ResourceManager: meta.NewResourceManager(nil, nil),
}
mockCheckNodesInResourceGroup := mockey.Mock((*meta.ResourceManager).CheckNodesInResourceGroup).Return().Build()
defer mockCheckNodesInResourceGroup.UnPatch()
// Act: Call rewatchNodes
err := server.rewatchNodes(sessions)
// Assert: Verify no error occurred
assert.NoError(t, err)
// Verify node 1004 was removed
assert.Nil(t, server.nodeMgr.Get(1004), "Offline node should be removed")
// Verify nodes 1001, 1002 exist
assert.NotNil(t, server.nodeMgr.Get(1001), "Online node should exist")
assert.NotNil(t, server.nodeMgr.Get(1002), "Online node should exist")
assert.NotNil(t, server.nodeMgr.Get(1003), "Stopping node should exist")
}
// TestRewatchNodesWithEmptySessions tests rewatchNodes with empty sessions
func TestRewatchNodesWithEmptySessions(t *testing.T) {
// Arrange: Create server with existing nodes
server := createSimpleTestServer()
// Add some existing nodes
server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1001,
Address: "localhost:19530",
Hostname: "localhost",
}))
server.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1002,
Address: "localhost:19531",
Hostname: "localhost",
}))
// Mock external calls
mockHandleNodeDown := mockey.Mock((*Server).handleNodeDown).Return().Build()
defer mockHandleNodeDown.UnPatch()
server.meta = &meta.Meta{
ResourceManager: meta.NewResourceManager(nil, nil),
}
mockCheckNodesInResourceGroup := mockey.Mock((*meta.ResourceManager).CheckNodesInResourceGroup).Return().Build()
defer mockCheckNodesInResourceGroup.UnPatch()
// Act: Call rewatchNodes with empty sessions
err := server.rewatchNodes(nil)
// Assert: All nodes should be removed
assert.NoError(t, err)
assert.Nil(t, server.nodeMgr.Get(1001), "All nodes should be removed when no sessions exist")
assert.Nil(t, server.nodeMgr.Get(1002), "All nodes should be removed when no sessions exist")
}
// TestHandleNodeUpWithMissingNode tests handleNodeUp when node doesn't exist
func TestHandleNodeUpWithMissingNode(t *testing.T) {
// Arrange: Create server without adding the node
server := createSimpleTestServer()
nodeID := int64(1001)
// Act: Handle node up for non-existent node
server.handleNodeUp(nodeID)
// Assert: Should handle gracefully (no panic, early return)
// The function should return early when node is not found
}
// TestHandleNodeDownMetricsCleanup tests that handleNodeDown cleans up metrics properly
func TestHandleNodeDownMetricsCleanup(t *testing.T) {
// Arrange: Set up metrics with test value
nodeID := int64(1001)
// Setup metrics with test value
registry := prometheus.NewRegistry()
metrics.RegisterQueryCoord(registry)
// Set a test metric value
metrics.QueryCoordLastHeartbeatTimeStamp.WithLabelValues(fmt.Sprint(nodeID)).Set(1640995200.0)
// Verify metric exists before deletion
metricFamilies, err := registry.Gather()
assert.NoError(t, err)
found := false
for _, mf := range metricFamilies {
if mf.GetName() == "milvus_querycoord_last_heartbeat_timestamp" {
for _, metric := range mf.GetMetric() {
for _, label := range metric.GetLabel() {
if label.GetName() == "node_id" && label.GetValue() == fmt.Sprint(nodeID) {
found = true
break
}
}
}
}
}
assert.True(t, found, "Metric should exist before cleanup")
// Create a minimal server
ctx := context.Background()
server := &Server{
ctx: ctx,
taskScheduler: task.NewScheduler(ctx, nil, nil, nil, nil, nil, nil),
dist: meta.NewDistributionManager(),
distController: dist.NewDistController(nil, nil, nil, nil, nil, nil),
metricsCacheManager: metricsinfo.NewMetricsCacheManager(),
meta: &meta.Meta{
ResourceManager: meta.NewResourceManager(nil, nil),
},
}
mockRemoveExecutor := mockey.Mock((task.Scheduler).RemoveExecutor).Return().Build()
defer mockRemoveExecutor.UnPatch()
mockRemoveByNode := mockey.Mock((task.Scheduler).RemoveByNode).Return().Build()
defer mockRemoveByNode.UnPatch()
mockDistControllerRemove := mockey.Mock((*dist.ControllerImpl).Remove).Return().Build()
defer mockDistControllerRemove.UnPatch()
mockRemoveFromManager := mockey.Mock(server.dist.ChannelDistManager.Update).Return().Build()
defer mockRemoveFromManager.UnPatch()
mockRemoveFromManager = mockey.Mock(server.dist.SegmentDistManager.Update).Return().Build()
defer mockRemoveFromManager.UnPatch()
mockInvalidateSystemInfoMetrics := mockey.Mock((*metricsinfo.MetricsCacheManager).InvalidateSystemInfoMetrics).Return().Build()
defer mockInvalidateSystemInfoMetrics.UnPatch()
mockResourceManagerHandleNodeDown := mockey.Mock((*meta.ResourceManager).HandleNodeDown).Return().Build()
defer mockResourceManagerHandleNodeDown.UnPatch()
// Act: Call handleNodeDown which should clean up metrics
server.handleNodeDown(nodeID)
metricFamilies, err = registry.Gather()
assert.NoError(t, err)
// Check that the heartbeat metric for this node was deleted
found = false
for _, mf := range metricFamilies {
if mf.GetName() == "milvus_querycoord_last_heartbeat_timestamp" {
for _, metric := range mf.GetMetric() {
for _, label := range metric.GetLabel() {
if label.GetName() == "node_id" && label.GetValue() == fmt.Sprint(nodeID) {
found = true
break
}
}
}
}
}
assert.False(t, found, "Metric should be cleaned up after handleNodeDown")
}
// TestNodeManagerStopping tests the node manager stopping functionality
func TestNodeManagerStopping(t *testing.T) {
// Arrange: Create node manager and add a node
nodeID := int64(1001)
nodeMgr := session.NewNodeManager()
nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: nodeID,
Address: "localhost:19530",
Hostname: "localhost",
})
nodeMgr.Add(nodeInfo)
// Verify node exists and is not stopping initially
node := nodeMgr.Get(nodeID)
assert.NotNil(t, node)
assert.False(t, node.IsStoppingState(), "Node should not be stopping initially")
// Act: Mark node as stopping
nodeMgr.Stopping(nodeID)
// Assert: Node should be in stopping state
node = nodeMgr.Get(nodeID)
assert.NotNil(t, node)
assert.True(t, node.IsStoppingState(), "Node should be in stopping state after calling Stopping()")
}
// Helper function to create a simple test server
func createSimpleTestServer() *Server {
ctx := context.Background()
server := &Server{
ctx: ctx,
nodeMgr: session.NewNodeManager(),
}
return server
}
// Helper function to create a test session
func createTestSession(nodeID int64, address string, stopping bool) *sessionutil.Session {
session := &sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{
ServerID: nodeID,
Address: address,
Stopping: stopping,
HostName: "localhost",
},
}
return session
}
func TestServer(t *testing.T) {
parameters := []string{"tikv", "etcd"}
for _, v := range parameters {

View File

@ -221,6 +221,17 @@ func (suite *ServiceSuite) SetupTest() {
proxyClientManager: suite.proxyManager,
}
// Initialize checkerController to prevent nil pointer dereference in handleNodeUp
suite.server.checkerController = checkers.NewCheckerController(
suite.meta,
suite.dist,
suite.targetMgr,
suite.nodeMgr,
suite.taskScheduler,
suite.broker,
suite.server.getBalancerFunc,
)
suite.server.registerMetricsRequest()
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)

View File

@ -157,6 +157,14 @@ var (
Name: "replica_ro_node_total",
Help: "total read only node number of replica",
})
QueryCoordLastHeartbeatTimeStamp = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: milvusNamespace,
Subsystem: typeutil.QueryCoordRole,
Name: "last_heartbeat_timestamp",
Help: "heartbeat timestamp of query node",
}, []string{nodeIDLabelName})
)
// RegisterQueryCoord registers QueryCoord metrics
@ -174,6 +182,7 @@ func RegisterQueryCoord(registry *prometheus.Registry) {
registry.MustRegister(QueryCoordResourceGroupInfo)
registry.MustRegister(QueryCoordResourceGroupReplicaTotal)
registry.MustRegister(QueryCoordReplicaRONodeTotal)
registry.MustRegister(QueryCoordLastHeartbeatTimeStamp)
}
func CleanQueryCoordMetricsWithCollectionID(collectionID int64) {