diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 6103b48166..fa43ae0963 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -436,6 +436,11 @@ queryCoord: collectionObserverInterval: 200 # the interval of collection observer checkExecutedFlagInterval: 100 # the interval of check executed flag to force to pull dist updateCollectionLoadStatusInterval: 5 # 5m, max interval of updating collection loaded status for check health + # Duration (in seconds) that a query node remains marked as resource exhausted after reaching resource limits. + # During this period, the node won't receive new tasks to loading resource. + # Set to 0 to disable the penalty period. + resourceExhaustionPenaltyDuration: 30 + resourceExhaustionCleanupInterval: 10 # Interval (in seconds) for cleaning up expired resource exhaustion marks on query nodes. cleanExcludeSegmentInterval: 60 # the time duration of clean pipeline exclude segment which used for filter invalid data, in seconds ip: # TCP/IP address of queryCoord. If not specified, use the first unicastable address port: 19531 # TCP port of queryCoord diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index 015f5a3533..3cb5ebe8b0 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -83,6 +83,13 @@ func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int }) } + // Filter out query nodes that are currently marked as resource exhausted. + // These nodes have recently reported OOM or disk full errors and are under + // a penalty period during which they won't receive new loading tasks. + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + return !b.nodeManager.IsResourceExhausted(node) + }) + nodesInfo := b.getNodes(nodes) if len(nodesInfo) == 0 { return nil @@ -103,7 +110,7 @@ func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int To: nodesInfo[i%len(nodesInfo)].ID(), } ret = append(ret, plan) - if len(ret) > balanceBatchSize { + if len(ret) >= balanceBatchSize { break } } @@ -123,6 +130,14 @@ func (b *RoundRobinBalancer) AssignChannel(ctx context.Context, collectionID int return info != nil && info.GetState() == session.NodeStateNormal && versionRangeFilter(info.Version()) }) } + + // Filter out query nodes that are currently marked as resource exhausted. + // These nodes have recently reported OOM or disk full errors and are under + // a penalty period during which they won't receive new loading tasks. + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + return !b.nodeManager.IsResourceExhausted(node) + }) + nodesInfo := b.getNodes(nodes) if len(nodesInfo) == 0 { return nil diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index bd4b686292..6d3b007ec6 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -53,6 +53,13 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID }) } + // Filter out query nodes that are currently marked as resource exhausted. + // These nodes have recently reported OOM or disk full errors and are under + // a penalty period during which they won't receive new loading tasks. + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + return !b.nodeManager.IsResourceExhausted(node) + }) + nodeItems := b.convertToNodeItemsBySegment(nodes) if len(nodeItems) == 0 { return nil @@ -77,7 +84,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID Segment: s, } plans = append(plans, plan) - if len(plans) > balanceBatchSize { + if len(plans) >= balanceBatchSize { break } // change node's score and push back @@ -103,6 +110,13 @@ func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, collectionID }) } + // Filter out query nodes that are currently marked as resource exhausted. + // These nodes have recently reported OOM or disk full errors and are under + // a penalty period during which they won't receive new loading tasks. + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + return !b.nodeManager.IsResourceExhausted(node) + }) + nodeItems := b.convertToNodeItemsByChannel(nodes) if len(nodeItems) == 0 { return nil diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 9dcf10bf70..a35b23d545 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -73,6 +73,13 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() } + // Filter out query nodes that are currently marked as resource exhausted. + // These nodes have recently reported OOM or disk full errors and are under + // a penalty period during which they won't receive new loading tasks. + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + return !b.nodeManager.IsResourceExhausted(node) + }) + // calculate each node's score nodeItemsMap := b.convertToNodeItemsBySegment(br, collectionID, nodes) if len(nodeItemsMap) == 0 { @@ -139,7 +146,7 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64 targetNode.AddCurrentScoreDelta(scoreChanges) }(s) - if len(plans) > balanceBatchSize { + if len(plans) >= balanceBatchSize { break } } @@ -167,6 +174,13 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt() } + // Filter out query nodes that are currently marked as resource exhausted. + // These nodes have recently reported OOM or disk full errors and are under + // a penalty period during which they won't receive new loading tasks. + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + return !b.nodeManager.IsResourceExhausted(node) + }) + // calculate each node's score nodeItemsMap := b.convertToNodeItemsByChannel(br, collectionID, nodes) if len(nodeItemsMap) == 0 { @@ -238,7 +252,7 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 targetNode.AddCurrentScoreDelta(scoreChanges) }(ch) - if len(plans) > balanceBatchSize { + if len(plans) >= balanceBatchSize { break } } @@ -640,7 +654,7 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo return nil } - log.Ctx(ctx).WithRateGroup(fmt.Sprintf("genSegmentPlan-%d-%d", replica.GetCollectionID(), replica.GetID()), 1, 60). + log.Ctx(ctx).WithRateGroup(fmt.Sprintf("genChannelPlan-%d-%d", replica.GetCollectionID(), replica.GetID()), 1, 60). RatedInfo(30, "node channel workload status", zap.Int64("collectionID", replica.GetCollectionID()), zap.Int64("replicaID", replica.GetID()), @@ -651,7 +665,6 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo channelDist[node] = b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node)) } - balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt() // find the segment from the node which has more score than the average channelsToMove := make([]*meta.DmChannel, 0) for node, channels := range channelDist { @@ -667,10 +680,6 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo channelScore := b.calculateChannelScore(ch, replica.GetCollectionID()) br.AddRecord(StrRecordf("pick channel %s with score %f from node %d", ch.GetChannelName(), channelScore, node)) channelsToMove = append(channelsToMove, ch) - if len(channelsToMove) >= balanceBatchSize { - br.AddRecord(StrRecordf("stop add channel candidate since current plan is equal to batch max(%d)", balanceBatchSize)) - break - } currentScore -= channelScore if currentScore <= assignedScore { diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index d96e82a2f8..470603d67b 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/samber/lo" "github.com/stretchr/testify/mock" @@ -82,6 +83,8 @@ func (suite *ScoreBasedBalancerTestSuite) SetupTest() { suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key, "5") } func (suite *ScoreBasedBalancerTestSuite) TearDownTest() { @@ -120,8 +123,6 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { unstableAssignment: true, expectPlans: [][]SegmentAssignPlan{ { - // as assign segments is used while loading collection, - // all assignPlan should have weight equal to 1(HIGH PRIORITY) {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ ID: 3, NumOfRows: 15, CollectionID: 1, @@ -137,6 +138,34 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { }, }, }, + { + name: "test assigning segments with resource exhausted nodes", + comment: "this case verifies that segments won't be assigned to resource exhausted nodes", + distributions: map[int64][]*meta.Segment{}, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10, CollectionID: 1}}, + }, + }, + nodes: []int64{1, 2}, + collectionIDs: []int64{0}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + unstableAssignment: false, + expectPlans: [][]SegmentAssignPlan{ + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 2, NumOfRows: 10, + CollectionID: 1, + }}, From: -1, To: 2}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 1, NumOfRows: 5, + CollectionID: 1, + }}, From: -1, To: 2}, + }, + }, + }, { name: "test non-empty cluster assigning one collection", comment: "this case will verify the effect of global row for loading segments process, although node1" + @@ -244,6 +273,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) } + + // Mock resource exhausted node for the specific test case + if c.name == "test assigning segments with resource exhausted nodes" { + suite.balancer.nodeManager.MarkResourceExhaustion(c.nodes[0], time.Hour) + } + for i := range c.collectionIDs { plans := balancer.AssignSegment(ctx, c.collectionIDs[i], c.assignments[i], c.nodes, false) if c.unstableAssignment { @@ -1622,3 +1657,124 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnStoppingNode() { suite.Equal(node2Counter.Load(), int32(5)) suite.Equal(node3Counter.Load(), int32(5)) } + +func (suite *ScoreBasedBalancerTestSuite) TestAssignChannel() { + ctx := context.Background() + cases := []struct { + name string + nodes []int64 + collectionID int64 + replicaID int64 + channels []*datapb.VchannelInfo + states []session.State + distributions map[int64][]*meta.DmChannel + expectPlans []ChannelAssignPlan + unstableAssignment bool + }{ + { + name: "test empty cluster assigning channels", + nodes: []int64{1, 2, 3}, + collectionID: 1, + replicaID: 1, + channels: []*datapb.VchannelInfo{ + {CollectionID: 1, ChannelName: "channel1"}, + {CollectionID: 1, ChannelName: "channel2"}, + {CollectionID: 1, ChannelName: "channel3"}, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.DmChannel{}, + unstableAssignment: true, + expectPlans: []ChannelAssignPlan{ + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel1"}}, From: -1, To: 1}, + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel2"}}, From: -1, To: 2}, + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel3"}}, From: -1, To: 3}, + }, + }, + { + name: "test assigning channels with resource exhausted nodes", + nodes: []int64{1, 2, 3}, + collectionID: 1, + replicaID: 1, + channels: []*datapb.VchannelInfo{ + {CollectionID: 1, ChannelName: "channel1"}, + {CollectionID: 1, ChannelName: "channel2"}, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.DmChannel{}, + expectPlans: []ChannelAssignPlan{ + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel1"}}, From: -1, To: 2}, + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel2"}}, From: -1, To: 3}, + }, + }, + { + name: "test non-empty cluster assigning channels", + nodes: []int64{1, 2, 3}, + collectionID: 1, + replicaID: 1, + channels: []*datapb.VchannelInfo{ + {CollectionID: 1, ChannelName: "channel4"}, + {CollectionID: 1, ChannelName: "channel5"}, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.DmChannel{ + 1: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel1"}, Node: 1}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel2"}, Node: 1}, + }, + 2: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel3"}, Node: 2}, + }, + }, + expectPlans: []ChannelAssignPlan{ + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel4"}}, From: -1, To: 3}, + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel5"}}, From: -1, To: 2}, + }, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + // Set up nodes + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "127.0.0.1:0", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributions[c.nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } + + // Mark node 1 as resource exhausted for the specific test case + if c.name == "test assigning channels with resource exhausted nodes" { + suite.balancer.nodeManager.MarkResourceExhaustion(c.nodes[0], time.Hour) + } + + // Set up channel distributions + for node, channels := range c.distributions { + balancer.dist.ChannelDistManager.Update(node, channels...) + } + + // Convert VchannelInfo to DmChannel + dmChannels := make([]*meta.DmChannel, 0, len(c.channels)) + for _, ch := range c.channels { + dmChannels = append(dmChannels, &meta.DmChannel{ + VchannelInfo: ch, + }) + } + + // Test channel assignment + plans := balancer.AssignChannel(ctx, c.collectionID, dmChannels, c.nodes, true) + if c.unstableAssignment { + suite.Len(plans, len(c.expectPlans)) + } else { + assertChannelAssignPlanElementMatch(&suite.Suite, c.expectPlans, plans) + } + }) + } +} diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index e6aea2bec5..288382d26f 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -293,6 +293,7 @@ func (s *Server) initQueryCoord() error { // Init meta s.nodeMgr = session.NewNodeManager() + s.nodeMgr.Start(s.ctx) err = s.initMeta() if err != nil { return err diff --git a/internal/querycoordv2/session/node_manager.go b/internal/querycoordv2/session/node_manager.go index ff8ca753a7..07e6b45725 100644 --- a/internal/querycoordv2/session/node_manager.go +++ b/internal/querycoordv2/session/node_manager.go @@ -17,6 +17,7 @@ package session import ( + "context" "fmt" "sync" "time" @@ -25,7 +26,9 @@ import ( "go.uber.org/atomic" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) type Manager interface { @@ -35,13 +38,16 @@ type Manager interface { Get(nodeID int64) *NodeInfo GetAll() []*NodeInfo - Suspend(nodeID int64) error - Resume(nodeID int64) error + MarkResourceExhaustion(nodeID int64, duration time.Duration) + IsResourceExhausted(nodeID int64) bool + ClearExpiredResourceExhaustion() + Start(ctx context.Context) } type NodeManager struct { - mu sync.RWMutex - nodes map[int64]*NodeInfo + mu sync.RWMutex + nodes map[int64]*NodeInfo + startOnce sync.Once } func (m *NodeManager) Add(node *NodeInfo) { @@ -135,6 +141,12 @@ type NodeInfo struct { immutableInfo ImmutableNodeInfo state State lastHeartbeat *atomic.Int64 + + // resourceExhaustionExpireAt is the timestamp when the resource exhaustion penalty expires. + // When a query node reports resource exhaustion (OOM, disk full, etc.), it gets marked + // with a penalty duration during which it won't receive new loading tasks. + // Zero value means no active penalty. + resourceExhaustionExpireAt time.Time } func (n *NodeInfo) ID() int64 { @@ -258,3 +270,97 @@ func WithCPUNum(num int64) StatsOption { n.setCPUNum(num) } } + +// MarkResourceExhaustion marks a query node as resource exhausted for the specified duration. +// During this period, the node won't receive new segment/channel loading tasks. +// If duration is 0 or negative, the resource exhaustion mark is cleared immediately. +// This is typically called when a query node reports resource exhaustion errors (OOM, disk full, etc.). +func (m *NodeManager) MarkResourceExhaustion(nodeID int64, duration time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + + if node, ok := m.nodes[nodeID]; ok { + node.mu.Lock() + if duration > 0 { + node.resourceExhaustionExpireAt = time.Now().Add(duration) + } else { + node.resourceExhaustionExpireAt = time.Time{} + } + node.mu.Unlock() + } +} + +// IsResourceExhausted checks if a query node is currently marked as resource exhausted. +// Returns true if the node has an active (non-expired) resource exhaustion mark. +// This is a pure read-only operation with no side effects - expired marks are not +// automatically cleared here. Use ClearExpiredResourceExhaustion for cleanup. +func (m *NodeManager) IsResourceExhausted(nodeID int64) bool { + m.mu.RLock() + node := m.nodes[nodeID] + m.mu.RUnlock() + + if node == nil { + return false + } + + node.mu.RLock() + defer node.mu.RUnlock() + + return !node.resourceExhaustionExpireAt.IsZero() && + time.Now().Before(node.resourceExhaustionExpireAt) +} + +// ClearExpiredResourceExhaustion iterates through all nodes and clears any expired +// resource exhaustion marks. This is called periodically by the cleanup loop started +// via Start(). It only clears marks that have already expired; active marks are preserved. +func (m *NodeManager) ClearExpiredResourceExhaustion() { + m.mu.RLock() + nodes := make([]*NodeInfo, 0, len(m.nodes)) + for _, node := range m.nodes { + nodes = append(nodes, node) + } + m.mu.RUnlock() + + now := time.Now() + for _, node := range nodes { + node.mu.Lock() + if !node.resourceExhaustionExpireAt.IsZero() && !now.Before(node.resourceExhaustionExpireAt) { + node.resourceExhaustionExpireAt = time.Time{} + } + node.mu.Unlock() + } +} + +// Start begins the background cleanup loop for expired resource exhaustion marks. +// The cleanup interval is controlled by queryCoord.resourceExhaustionCleanupInterval config. +// The loop will stop when the provided context is canceled. +// This method is idempotent - multiple calls will only start one cleanup loop. +func (m *NodeManager) Start(ctx context.Context) { + m.startOnce.Do(func() { + go m.cleanupLoop(ctx) + }) +} + +// cleanupLoop is the internal goroutine that periodically clears expired resource +// exhaustion marks from all nodes. It supports dynamic interval refresh. +func (m *NodeManager) cleanupLoop(ctx context.Context) { + interval := paramtable.Get().QueryCoordCfg.ResourceExhaustionCleanupInterval.GetAsDuration(time.Second) + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info("cleanupLoop stopped") + return + case <-ticker.C: + m.ClearExpiredResourceExhaustion() + // Support dynamic interval refresh + newInterval := paramtable.Get().QueryCoordCfg.ResourceExhaustionCleanupInterval.GetAsDuration(time.Second) + if newInterval != interval { + interval = newInterval + ticker.Reset(interval) + } + } + } +} diff --git a/internal/querycoordv2/session/node_manager_test.go b/internal/querycoordv2/session/node_manager_test.go index 4a4be94ce5..29b83ce212 100644 --- a/internal/querycoordv2/session/node_manager_test.go +++ b/internal/querycoordv2/session/node_manager_test.go @@ -66,6 +66,96 @@ func (s *NodeManagerSuite) TestNodeOperation() { s.False(s.nodeManager.IsStoppingNode(2)) } +func (s *NodeManagerSuite) TestResourceExhaustion() { + nodeID := int64(1) + s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{NodeID: nodeID})) + + s.Run("mark_exhausted", func() { + s.nodeManager.MarkResourceExhaustion(nodeID, 10*time.Minute) + s.True(s.nodeManager.IsResourceExhausted(nodeID)) + }) + + s.Run("expired_without_cleanup", func() { + // IsResourceExhausted is pure read-only, does not auto-clear + s.nodeManager.MarkResourceExhaustion(nodeID, 1*time.Millisecond) + time.Sleep(2 * time.Millisecond) + // After expiry, IsResourceExhausted returns false (pure check) + s.False(s.nodeManager.IsResourceExhausted(nodeID)) + }) + + s.Run("clear_expired", func() { + // Set expired mark + s.nodeManager.MarkResourceExhaustion(nodeID, 1*time.Millisecond) + time.Sleep(2 * time.Millisecond) + // ClearExpiredResourceExhaustion should clear expired marks + s.nodeManager.ClearExpiredResourceExhaustion() + s.False(s.nodeManager.IsResourceExhausted(nodeID)) + }) + + s.Run("clear_does_not_affect_active", func() { + // Set active mark + s.nodeManager.MarkResourceExhaustion(nodeID, 10*time.Minute) + // ClearExpiredResourceExhaustion should not clear active marks + s.nodeManager.ClearExpiredResourceExhaustion() + s.True(s.nodeManager.IsResourceExhausted(nodeID)) + }) + + s.Run("invalid_node", func() { + s.False(s.nodeManager.IsResourceExhausted(999)) + }) + + s.Run("mark_non_existent_node", func() { + // MarkResourceExhaustion on non-existent node should not panic + s.nodeManager.MarkResourceExhaustion(999, 10*time.Minute) + s.False(s.nodeManager.IsResourceExhausted(999)) + }) + + s.Run("clear_mark_with_zero_duration", func() { + // Mark the node as exhausted + s.nodeManager.MarkResourceExhaustion(nodeID, 10*time.Minute) + s.True(s.nodeManager.IsResourceExhausted(nodeID)) + // Clear the mark by setting duration to 0 + s.nodeManager.MarkResourceExhaustion(nodeID, 0) + s.False(s.nodeManager.IsResourceExhausted(nodeID)) + }) + + s.Run("clear_mark_with_negative_duration", func() { + // Mark the node as exhausted + s.nodeManager.MarkResourceExhaustion(nodeID, 10*time.Minute) + s.True(s.nodeManager.IsResourceExhausted(nodeID)) + // Clear the mark by setting negative duration + s.nodeManager.MarkResourceExhaustion(nodeID, -1*time.Second) + s.False(s.nodeManager.IsResourceExhausted(nodeID)) + }) + + s.Run("multiple_nodes_cleanup", func() { + // Add more nodes + nodeID2 := int64(2) + nodeID3 := int64(3) + s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{NodeID: nodeID2})) + s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{NodeID: nodeID3})) + + // Mark all nodes as exhausted with different durations + s.nodeManager.MarkResourceExhaustion(nodeID, 1*time.Millisecond) // will expire + s.nodeManager.MarkResourceExhaustion(nodeID2, 10*time.Minute) // won't expire + s.nodeManager.MarkResourceExhaustion(nodeID3, 1*time.Millisecond) // will expire + + time.Sleep(2 * time.Millisecond) + + // Before cleanup, check status + s.False(s.nodeManager.IsResourceExhausted(nodeID)) // expired + s.True(s.nodeManager.IsResourceExhausted(nodeID2)) // still active + s.False(s.nodeManager.IsResourceExhausted(nodeID3)) // expired + + // Cleanup should clear expired marks only + s.nodeManager.ClearExpiredResourceExhaustion() + + s.False(s.nodeManager.IsResourceExhausted(nodeID)) + s.True(s.nodeManager.IsResourceExhausted(nodeID2)) // still active + s.False(s.nodeManager.IsResourceExhausted(nodeID3)) + }) +} + func (s *NodeManagerSuite) TestNodeInfo() { node := NewNodeInfo(ImmutableNodeInfo{ NodeID: 1, diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 2c2a890d3e..fcff28fef2 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/hardware" "github.com/milvus-io/milvus/pkg/v2/util/lock" "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/timerecord" . "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -1000,6 +1001,21 @@ func (scheduler *taskScheduler) remove(task Task) { scheduler.targetMgr.UpdateCollectionNextTarget(scheduler.ctx, task.CollectionID()) } + // If task failed due to resource exhaustion (OOM, disk full, GPU OOM, etc.), + // mark the node as resource exhausted for a penalty period. + // During this period, the balancer will skip this node when assigning new segments/channels. + // This prevents continuous failures on the same node and allows it time to recover. + if errors.Is(task.Err(), merr.ErrSegmentRequestResourceFailed) { + for _, action := range task.Actions() { + if action.Type() == ActionTypeGrow { + nodeID := action.Node() + duration := paramtable.Get().QueryCoordCfg.ResourceExhaustionPenaltyDuration.GetAsDuration(time.Second) + scheduler.nodeMgr.MarkResourceExhaustion(nodeID, duration) + log.Info("mark resource exhaustion for node", zap.Int64("nodeID", nodeID), zap.Duration("duration", duration), zap.Error(task.Err())) + } + } + } + task.Cancel(nil) _, ok := scheduler.tasks.GetAndRemove(task.ID()) scheduler.waitQueue.Remove(task) diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 48df05c18e..b7b1b7abd5 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -2014,6 +2014,28 @@ func (suite *TaskSuite) TestRemoveTaskWithError() { // when try to remove task with ErrSegmentNotFound, should trigger UpdateNextTarget scheduler.remove(task1) mockTarget.AssertExpectations(suite.T()) + + // test remove task with ErrSegmentRequestResourceFailed + task2, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + commonpb.LoadPriority_LOW, + NewSegmentActionWithScope(nodeID, ActionTypeGrow, "", 1, querypb.DataScope_Historical, 100), + ) + suite.NoError(err) + err = scheduler.Add(task2) + suite.NoError(err) + + task2.Fail(merr.ErrSegmentRequestResourceFailed) + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.ResourceExhaustionPenaltyDuration.Key, "3") + scheduler.remove(task2) + suite.True(suite.nodeMgr.IsResourceExhausted(nodeID)) + // expect the penalty duration is expired + time.Sleep(3 * time.Second) + suite.False(suite.nodeMgr.IsResourceExhausted(nodeID)) } func TestTask(t *testing.T) { diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 0c3b64d25f..487f01c7aa 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -1573,19 +1573,22 @@ func (loader *segmentLoader) checkLogicalSegmentSize(ctx context.Context, segmen logicalDiskUsageLimit := uint64(float64(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.GetAsInt64()) * paramtable.Get().QueryNodeCfg.MaxDiskUsagePercentage.GetAsFloat()) if predictLogicalMemUsage > logicalMemUsageLimit { - return 0, 0, fmt.Errorf("Logical memory usage checking for segment loading failed, predictLogicalMemUsage = %v MB, LogicalMemUsageLimit = %v MB, decrease the evictableMemoryCacheRatio (current: %v) if you want to load more segments", - logutil.ToMB(float64(predictLogicalMemUsage)), - logutil.ToMB(float64(logicalMemUsageLimit)), - paramtable.Get().QueryNodeCfg.TieredEvictableMemoryCacheRatio.GetAsFloat(), + log.Warn("logical memory usage checking for segment loading failed", + zap.String("resourceType", "Memory"), + zap.Float64("predictLogicalMemUsageMB", logutil.ToMB(float64(predictLogicalMemUsage))), + zap.Float64("logicalMemUsageLimitMB", logutil.ToMB(float64(logicalMemUsageLimit))), + zap.Float64("evictableMemoryCacheRatio", paramtable.Get().QueryNodeCfg.TieredEvictableMemoryCacheRatio.GetAsFloat()), ) + return 0, 0, merr.WrapErrSegmentRequestResourceFailed("Memory") } if predictLogicalDiskUsage > logicalDiskUsageLimit { - return 0, 0, fmt.Errorf("Logical disk usage checking for segment loading failed, predictLogicalDiskUsage = %v MB, LogicalDiskUsageLimit = %v MB, decrease the evictableDiskCacheRatio (current: %v) if you want to load more segments", + log.Warn(fmt.Sprintf("Logical disk usage checking for segment loading failed, predictLogicalDiskUsage = %v MB, LogicalDiskUsageLimit = %v MB, decrease the evictableDiskCacheRatio (current: %v) if you want to load more segments", logutil.ToMB(float64(predictLogicalDiskUsage)), logutil.ToMB(float64(logicalDiskUsageLimit)), paramtable.Get().QueryNodeCfg.TieredEvictableDiskCacheRatio.GetAsFloat(), - ) + )) + return 0, 0, merr.WrapErrSegmentRequestResourceFailed("Disk") } return predictLogicalMemUsage - logicalMemUsage, predictLogicalDiskUsage - logicalDiskUsage, nil @@ -1681,20 +1684,26 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn } else { // fallback to original segment loading logic if predictMemUsage > uint64(float64(totalMem)*paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) { - return 0, 0, fmt.Errorf("load segment failed, OOM if load, maxSegmentSize = %v MB, memUsage = %v MB, predictMemUsage = %v MB, totalMem = %v MB thresholdFactor = %f", - logutil.ToMB(float64(maxSegmentSize)), - logutil.ToMB(float64(memUsage)), - logutil.ToMB(float64(predictMemUsage)), - logutil.ToMB(float64(totalMem)), - paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) + log.Warn("load segment failed, OOM if load", + zap.String("resourceType", "Memory"), + zap.Float64("maxSegmentSizeMB", logutil.ToMB(float64(maxSegmentSize))), + zap.Float64("memUsageMB", logutil.ToMB(float64(memUsage))), + zap.Float64("predictMemUsageMB", logutil.ToMB(float64(predictMemUsage))), + zap.Float64("totalMemMB", logutil.ToMB(float64(totalMem))), + zap.Float64("thresholdFactor", paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()), + ) + return 0, 0, merr.WrapErrSegmentRequestResourceFailed("Memory") } if predictDiskUsage > uint64(float64(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.GetAsInt64())*paramtable.Get().QueryNodeCfg.MaxDiskUsagePercentage.GetAsFloat()) { - return 0, 0, merr.WrapErrServiceDiskLimitExceeded(float32(predictDiskUsage), float32(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.GetAsInt64()), fmt.Sprintf("load segment failed, disk space is not enough, diskUsage = %v MB, predictDiskUsage = %v MB, totalDisk = %v MB, thresholdFactor = %f", - logutil.ToMB(float64(diskUsage)), - logutil.ToMB(float64(predictDiskUsage)), - logutil.ToMB(float64(uint64(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.GetAsInt64()))), - paramtable.Get().QueryNodeCfg.MaxDiskUsagePercentage.GetAsFloat())) + log.Warn("load segment failed, disk space is not enough", + zap.String("resourceType", "Disk"), + zap.Float64("diskUsageMB", logutil.ToMB(float64(diskUsage))), + zap.Float64("predictDiskUsageMB", logutil.ToMB(float64(predictDiskUsage))), + zap.Float64("totalDiskMB", logutil.ToMB(float64(uint64(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.GetAsInt64())))), + zap.Float64("thresholdFactor", paramtable.Get().QueryNodeCfg.MaxDiskUsagePercentage.GetAsFloat()), + ) + return 0, 0, merr.WrapErrSegmentRequestResourceFailed("Disk") } } @@ -2304,10 +2313,13 @@ func checkSegmentGpuMemSize(fieldGpuMemSizeList []uint64, OverloadedMemoryThresh } } if minId == -1 { - return fmt.Errorf("load segment failed, GPU OOM if loaded, GpuMemUsage(bytes) = %v, usedGpuMem(bytes) = %v, maxGPUMem(bytes) = %v", - fieldGpuMem, - usedGpuMem, - maxGpuMemSize) + log.Warn("load segment failed, GPU OOM if loaded", + zap.String("resourceType", "GPU"), + zap.Uint64("gpuMemUsageBytes", fieldGpuMem), + zap.Any("usedGpuMemBytes", usedGpuMem), + zap.Any("maxGpuMemBytes", maxGpuMemSize), + ) + return merr.WrapErrSegmentRequestResourceFailed("GPU") } currentGpuMem[minId] += minGpuMem } diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index eada18dd7d..f512944bf2 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -917,6 +917,100 @@ func (suite *SegmentLoaderDetailSuite) TestRequestResource() { }) } +func (suite *SegmentLoaderDetailSuite) TestCheckSegmentSizeWithDiskLimit() { + ctx := context.Background() + + // Save original value and restore after test + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.Key, "1") // 1MB + defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.DiskCapacityLimit.Key) + + // Set disk usage threshold + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.MaxDiskUsagePercentage.Key, "0.8") // 80% threshold + defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.MaxDiskUsagePercentage.Key) + + // set mmap, trigger dist cost + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.MmapScalarField.Key, "true") + defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.MmapScalarField.Key) + + // Create a test segment that would exceed the disk limit + loadInfo := &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 1000, + BinlogPaths: []*datapb.FieldBinlog{ + { + FieldID: 1, + Binlogs: []*datapb.Binlog{ + { + LogPath: "test_path", + LogSize: 1024 * 1024 * 1024 * 2, // 2GB + MemorySize: 1024 * 1024 * 1024 * 4, + }, + }, + }, + { + FieldID: 105, + Binlogs: []*datapb.Binlog{ + { + LogPath: "test_path", + LogSize: 1024 * 1024 * 1024 * 2, // 2GB + MemorySize: 1024 * 1024 * 1024 * 4, + }, + }, + }, + }, + } + + // Mock collection manager to return a valid collection + collection, err := NewCollection(suite.collectionID, suite.schema, nil, nil) + suite.NoError(err) + suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) + + memUsage := uint64(100 * 1024 * 1024) // 100MB + totalMem := uint64(1024 * 1024 * 1024) // 1GB + localDiskUsage := int64(100 * 1024) // 100KB + + _, _, err = suite.loader.checkSegmentSize(ctx, []*querypb.SegmentLoadInfo{loadInfo}, memUsage, totalMem, localDiskUsage) + suite.Error(err) + suite.True(errors.Is(err, merr.ErrSegmentRequestResourceFailed)) +} + +func (suite *SegmentLoaderDetailSuite) TestCheckSegmentSizeWithMemoryLimit() { + ctx := context.Background() + + // Create a test segment that would exceed the memory limit + loadInfo := &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 1000, + BinlogPaths: []*datapb.FieldBinlog{ + { + FieldID: 1, + Binlogs: []*datapb.Binlog{ + { + LogPath: "test_path", + LogSize: 1024 * 1024, // 1MB + MemorySize: 1024 * 1024 * 900, // 900MB + }, + }, + }, + }, + } + + memUsage := uint64(100 * 1024 * 1024) // 100MB + totalMem := uint64(1024 * 1024 * 1024) // 1GB + localDiskUsage := int64(100 * 1024) // 100KB + + // Set memory threshold to 80% + paramtable.Get().Save("queryNode.overloadedMemoryThresholdPercentage", "0.8") + + _, _, err := suite.loader.checkSegmentSize(ctx, []*querypb.SegmentLoadInfo{loadInfo}, memUsage, totalMem, localDiskUsage) + suite.Error(err) + suite.True(errors.Is(err, merr.ErrSegmentRequestResourceFailed)) +} + func TestSegmentLoader(t *testing.T) { suite.Run(t, &SegmentLoaderSuite{}) suite.Run(t, &SegmentLoaderDetailSuite{}) diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index d5e9569663..1c1b5a4d30 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -106,6 +106,11 @@ var ( ErrSegmentLack = newMilvusError("segment lacks", 602, false) ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false) ErrSegmentLoadFailed = newMilvusError("segment load failed", 604, false) + // ErrSegmentRequestResourceFailed indicates the query node cannot load the segment + // due to resource exhaustion (Memory, Disk, or GPU). When this error is returned, + // the query coordinator will mark the node as resource exhausted and apply a + // penalty period during which the node won't receive new loading tasks. + ErrSegmentRequestResourceFailed = newMilvusError("segment request resource failed", 605, false) // Index related ErrIndexNotFound = newMilvusError("index not found", 700, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index 643f0f040b..2836a4446f 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -116,6 +116,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrSegmentNotLoaded(1, "failed to query"), ErrSegmentNotLoaded) s.ErrorIs(WrapErrSegmentLack(1, "lack of segment"), ErrSegmentLack) s.ErrorIs(WrapErrSegmentReduplicate(1, "redundancy of segment"), ErrSegmentReduplicate) + s.ErrorIs(WrapErrSegmentRequestResourceFailed("Memory"), ErrSegmentRequestResourceFailed) // Index related s.ErrorIs(WrapErrIndexNotFound("failed to get Index"), ErrIndexNotFound) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 9cbfad5915..ccc8825053 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -759,6 +759,24 @@ func WrapErrSegmentLoadFailed(id int64, msg ...string) error { return err } +// WrapErrSegmentRequestResourceFailed creates a resource exhaustion error for segment loading. +// resourceType should be one of: "Memory", "Disk", "GPU". +// This error triggers the query coordinator to mark the node as resource exhausted, +// applying a penalty period controlled by queryCoord.resourceExhaustionPenaltyDuration. +func WrapErrSegmentRequestResourceFailed( + resourceType string, + msg ...string, +) error { + err := wrapFields(ErrSegmentRequestResourceFailed, + value("resourceType", resourceType), + ) + + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "->")) + } + return err +} + func WrapErrSegmentNotLoaded(id int64, msg ...string) error { err := wrapFields(ErrSegmentNotLoaded, value("segment", id)) if len(msg) > 0 { diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 29fe4aa88f..b0fad0da66 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2491,7 +2491,9 @@ type queryCoordConfig struct { // query node task parallelism factor QueryNodeTaskParallelismFactor ParamItem `refreshable:"true"` - BalanceCheckCollectionMaxCount ParamItem `refreshable:"true"` + BalanceCheckCollectionMaxCount ParamItem `refreshable:"true"` + ResourceExhaustionPenaltyDuration ParamItem `refreshable:"true"` + ResourceExhaustionCleanupInterval ParamItem `refreshable:"true"` } func (p *queryCoordConfig) init(base *BaseTable) { @@ -3133,6 +3135,25 @@ If this parameter is set false, Milvus simply searches the growing segments with Export: false, } p.BalanceCheckCollectionMaxCount.Init(base.mgr) + p.ResourceExhaustionPenaltyDuration = ParamItem{ + Key: "queryCoord.resourceExhaustionPenaltyDuration", + Version: "2.6.7", + DefaultValue: "30", + Doc: `Duration (in seconds) that a query node remains marked as resource exhausted after reaching resource limits. +During this period, the node won't receive new tasks to loading resource. +Set to 0 to disable the penalty period.`, + Export: true, + } + p.ResourceExhaustionPenaltyDuration.Init(base.mgr) + + p.ResourceExhaustionCleanupInterval = ParamItem{ + Key: "queryCoord.resourceExhaustionCleanupInterval", + Version: "2.6.7", + DefaultValue: "10", + Doc: "Interval (in seconds) for cleaning up expired resource exhaustion marks on query nodes.", + Export: true, + } + p.ResourceExhaustionCleanupInterval.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index c63de9ef66..53bbc5964f 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -414,6 +414,8 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 2, Params.QueryNodeTaskParallelismFactor.GetAsInt()) assert.Equal(t, 100, Params.BalanceCheckCollectionMaxCount.GetAsInt()) + assert.Equal(t, 30, Params.ResourceExhaustionPenaltyDuration.GetAsInt()) + assert.Equal(t, 10, Params.ResourceExhaustionCleanupInterval.GetAsInt()) }) t.Run("test queryNodeConfig", func(t *testing.T) {