diff --git a/internal/querycoordv2/assign/assign_policy_rowcount.go b/internal/querycoordv2/assign/assign_policy_rowcount.go index 9e69ff01dd..a018f7c9bf 100644 --- a/internal/querycoordv2/assign/assign_policy_rowcount.go +++ b/internal/querycoordv2/assign/assign_policy_rowcount.go @@ -19,6 +19,7 @@ package assign import ( "context" "sort" + "sync" "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -34,6 +35,50 @@ type RowCountBasedAssignPolicy struct { nodeManager *session.NodeManager scheduler task.Scheduler dist *meta.DistributionManager + + mu sync.Mutex + status *rowcountWorkloadStatus + version int64 +} + +type rowcountWorkloadStatus struct { + nodeGlobalRowCount map[int64]int + nodeGlobalChannelRowCount map[int64]int + nodeGlobalChannelCount map[int64]int +} + +// getWorkloadStatus refreshes and returns the workload status if the underlying distribution version has changed. +func (p *RowCountBasedAssignPolicy) getWorkloadStatus() *rowcountWorkloadStatus { + p.mu.Lock() + defer p.mu.Unlock() + + currVer := p.dist.SegmentDistManager.GetVersion() + p.dist.ChannelDistManager.GetVersion() + if currVer == p.version && p.status != nil { + return p.status + } + + status := &rowcountWorkloadStatus{ + nodeGlobalRowCount: make(map[int64]int), + nodeGlobalChannelRowCount: make(map[int64]int), + nodeGlobalChannelCount: make(map[int64]int), + } + + allSegments := p.dist.SegmentDistManager.GetByFilter() + for _, s := range allSegments { + status.nodeGlobalRowCount[s.Node] += int(s.GetNumOfRows()) + } + + allChannels := p.dist.ChannelDistManager.GetByFilter() + for _, ch := range allChannels { + status.nodeGlobalChannelCount[ch.Node]++ + if ch.View != nil { + status.nodeGlobalChannelRowCount[ch.Node] += int(ch.View.NumOfGrowingRows) + } + } + + p.status = status + p.version = currVer + return p.status } // newRowCountBasedAssignPolicy creates a new RowCountBasedAssignPolicy @@ -47,6 +92,7 @@ func newRowCountBasedAssignPolicy( nodeManager: nodeManager, scheduler: scheduler, dist: dist, + version: -1, } } @@ -169,20 +215,12 @@ func (p *RowCountBasedAssignPolicy) AssignChannel( // convertToNodeItemsBySegment creates node items with row count scores func (p *RowCountBasedAssignPolicy) convertToNodeItemsBySegment(nodeIDs []int64) map[int64]*NodeItem { + status := p.getWorkloadStatus() + ret := make(map[int64]*NodeItem, len(nodeIDs)) for _, node := range nodeIDs { - // Calculate sealed segment row count on node - segments := p.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(node)) - rowcnt := 0 - for _, s := range segments { - rowcnt += int(s.GetNumOfRows()) - } - - // Calculate growing segment row count on node - channels := p.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node)) - for _, channel := range channels { - rowcnt += int(channel.View.NumOfGrowingRows) - } + // Get pre-aggregated global row counts from status + rowcnt := status.nodeGlobalRowCount[node] + status.nodeGlobalChannelRowCount[node] // Calculate executing task cost in scheduler rowcnt += p.scheduler.GetSegmentTaskDelta(node, -1) @@ -196,11 +234,13 @@ func (p *RowCountBasedAssignPolicy) convertToNodeItemsBySegment(nodeIDs []int64) // convertToNodeItemsByChannel creates node items with channel count scores func (p *RowCountBasedAssignPolicy) convertToNodeItemsByChannel(nodeIDs []int64) map[int64]*NodeItem { + status := p.getWorkloadStatus() + ret := make(map[int64]*NodeItem, len(nodeIDs)) for _, node := range nodeIDs { - channels := p.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node)) + // Get pre-aggregated channel count from status + channelCount := status.nodeGlobalChannelCount[node] - channelCount := len(channels) // Calculate executing task cost in scheduler channelCount += p.scheduler.GetChannelTaskDelta(node, -1) diff --git a/internal/querycoordv2/assign/assign_policy_rowcount_test.go b/internal/querycoordv2/assign/assign_policy_rowcount_test.go index 7af7d23411..202c771732 100644 --- a/internal/querycoordv2/assign/assign_policy_rowcount_test.go +++ b/internal/querycoordv2/assign/assign_policy_rowcount_test.go @@ -410,3 +410,45 @@ func TestRowCountBasedAssignPolicy_AssignChannel_EmptyChannels(t *testing.T) { assert.NotNil(t, plans) assert.Equal(t, 0, len(plans)) } + +// TestRowCountBasedAssignPolicy_WorkloadStatusOnDemandUpdate tests the on-demand workload status update mechanism +func TestRowCountBasedAssignPolicy_WorkloadStatusOnDemandUpdate(t *testing.T) { + nodeManager := session.NewNodeManager() + mockScheduler := task.NewMockScheduler(t) + mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + dist := meta.NewDistributionManager(nodeManager) + + policy := newRowCountBasedAssignPolicy(nodeManager, mockScheduler, dist) + + // 1. Init status + firstStatus := policy.getWorkloadStatus() + firstVersion := policy.version + assert.NotNil(t, firstStatus) + + // 2. Update without meta change + secondStatus := policy.getWorkloadStatus() + // Should be identical pointer + assert.Equal(t, firstStatus, secondStatus, "Status pointer should be identical when version hasn't changed") + assert.Equal(t, firstVersion, policy.version) + + // 3. Update with segment meta change + dist.SegmentDistManager.Update(1, &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 100}, Node: 1}) + + // 4. Update again + thirdStatus := policy.getWorkloadStatus() + // Should be new pointer + assert.NotEqual(t, firstStatus, thirdStatus, "Status should be refreshed when segment version changed") + assert.Greater(t, policy.version, firstVersion) + + secondVersion := policy.version + + // 5. Update with channel meta change + dist.ChannelDistManager.Update(1, &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{ChannelName: "v1"}, Node: 1, View: &meta.LeaderView{ID: 1}}) + + // 6. Update again + fourthStatus := policy.getWorkloadStatus() + // Should be new pointer + assert.NotEqual(t, thirdStatus, fourthStatus, "Status should be refreshed when channel version changed") + assert.Greater(t, policy.version, secondVersion) +} diff --git a/internal/querycoordv2/assign/assign_policy_score.go b/internal/querycoordv2/assign/assign_policy_score.go index 9dc3e98951..b5b1776d79 100644 --- a/internal/querycoordv2/assign/assign_policy_score.go +++ b/internal/querycoordv2/assign/assign_policy_score.go @@ -20,6 +20,7 @@ import ( "context" "math" "sort" + "sync" "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -38,6 +39,50 @@ type ScoreBasedAssignPolicy struct { scheduler task.Scheduler dist *meta.DistributionManager meta *meta.Meta + + mu sync.Mutex + status *workloadStatus + version int64 +} + +type workloadStatus struct { + nodeGlobalRowCount map[int64]int + nodeGlobalChannelRowCount map[int64]int + nodeGlobalChannels map[int64][]*meta.DmChannel +} + +// getWorkloadStatus refreshes and returns the workload status if the underlying distribution version has changed. +func (p *ScoreBasedAssignPolicy) getWorkloadStatus() *workloadStatus { + p.mu.Lock() + defer p.mu.Unlock() + + currVer := p.dist.SegmentDistManager.GetVersion() + p.dist.ChannelDistManager.GetVersion() + if currVer == p.version && p.status != nil { + return p.status + } + + status := &workloadStatus{ + nodeGlobalRowCount: make(map[int64]int), + nodeGlobalChannelRowCount: make(map[int64]int), + nodeGlobalChannels: make(map[int64][]*meta.DmChannel), + } + + allSegments := p.dist.SegmentDistManager.GetByFilter() + for _, s := range allSegments { + status.nodeGlobalRowCount[s.Node] += int(s.GetNumOfRows()) + } + + allChannels := p.dist.ChannelDistManager.GetByFilter() + for _, ch := range allChannels { + status.nodeGlobalChannels[ch.Node] = append(status.nodeGlobalChannels[ch.Node], ch) + if ch.View != nil { + status.nodeGlobalChannelRowCount[ch.Node] += int(ch.View.NumOfGrowingRows) + } + } + + p.status = status + p.version = currVer + return p.status } // newScoreBasedAssignPolicy creates a new ScoreBasedAssignPolicy @@ -53,6 +98,7 @@ func newScoreBasedAssignPolicy( scheduler: scheduler, dist: dist, meta: meta, + version: -1, } } @@ -156,6 +202,8 @@ func (p *ScoreBasedAssignPolicy) AssignSegment( // ConvertToNodeItemsBySegment creates node items with comprehensive scores func (p *ScoreBasedAssignPolicy) ConvertToNodeItemsBySegment(collectionID int64, nodeIDs []int64) map[int64]*NodeItem { + status := p.getWorkloadStatus() + totalScore := 0 nodeScoreMap := make(map[int64]*NodeItem) nodeMemMap := make(map[int64]float64) @@ -163,7 +211,7 @@ func (p *ScoreBasedAssignPolicy) ConvertToNodeItemsBySegment(collectionID int64, allNodeHasMemInfo := true for _, node := range nodeIDs { - score := p.calculateScoreBySegment(collectionID, node) + score := p.calculateScoreBySegment(collectionID, node, status) NodeItem := NewNodeItem(score, node) nodeScoreMap[node] = &NodeItem totalScore += score @@ -215,19 +263,8 @@ func (p *ScoreBasedAssignPolicy) ConvertToNodeItemsBySegment(collectionID int64, } // calculateScoreBySegment calculates comprehensive score for a node -func (p *ScoreBasedAssignPolicy) calculateScoreBySegment(collectionID, nodeID int64) int { - // Calculate global sealed segment row count - globalSegments := p.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(nodeID)) - nodeRowCount := 0 - for _, s := range globalSegments { - nodeRowCount += int(s.GetNumOfRows()) - } - - // Calculate global growing segment row count - delegatorList := p.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(nodeID)) - for _, d := range delegatorList { - nodeRowCount += int(float64(d.View.NumOfGrowingRows)) - } +func (p *ScoreBasedAssignPolicy) calculateScoreBySegment(collectionID, nodeID int64, status *workloadStatus) int { + nodeRowCount := status.nodeGlobalRowCount[nodeID] + status.nodeGlobalChannelRowCount[nodeID] // Calculate executing task cost in scheduler nodeRowCount += p.scheduler.GetSegmentTaskDelta(nodeID, -1) @@ -248,7 +285,7 @@ func (p *ScoreBasedAssignPolicy) calculateScoreBySegment(collectionID, nodeID in meta.WithNodeID2Channel(nodeID), ) for _, d := range collDelegatorList { - collectionRowCount += int(float64(d.View.NumOfGrowingRows)) + collectionRowCount += int(d.View.NumOfGrowingRows) } // Calculate executing task cost for collection @@ -365,6 +402,8 @@ func (p *ScoreBasedAssignPolicy) AssignChannel( // ConvertToNodeItemsByChannel creates node items with channel scores func (p *ScoreBasedAssignPolicy) ConvertToNodeItemsByChannel(collectionID int64, nodeIDs []int64) map[int64]*NodeItem { + status := p.getWorkloadStatus() + totalScore := 0 nodeScoreMap := make(map[int64]*NodeItem) nodeMemMap := make(map[int64]float64) @@ -372,7 +411,7 @@ func (p *ScoreBasedAssignPolicy) ConvertToNodeItemsByChannel(collectionID int64, allNodeHasMemInfo := true for _, node := range nodeIDs { - score := p.calculateScoreByChannel(collectionID, node) + score := p.calculateScoreByChannel(collectionID, node, status) NodeItem := NewNodeItem(score, node) nodeScoreMap[node] = &NodeItem totalScore += score @@ -408,8 +447,8 @@ func (p *ScoreBasedAssignPolicy) ConvertToNodeItemsByChannel(collectionID int64, } // calculateScoreByChannel calculates score based on channel count -func (p *ScoreBasedAssignPolicy) calculateScoreByChannel(collectionID, nodeID int64) int { - channels := p.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(nodeID)) +func (p *ScoreBasedAssignPolicy) calculateScoreByChannel(collectionID, nodeID int64, status *workloadStatus) int { + channels := status.nodeGlobalChannels[nodeID] totalScore := 0.0 for _, ch := range channels { @@ -430,5 +469,5 @@ func (p *ScoreBasedAssignPolicy) CalculateChannelScore(ch *meta.DmChannel, curre channelWeight := paramtable.Get().QueryCoordCfg.CollectionChannelCountFactor.GetAsFloat() return math.Max(1.0, channelWeight) } - return 1 + return 1.0 } diff --git a/internal/querycoordv2/assign/assign_policy_score_test.go b/internal/querycoordv2/assign/assign_policy_score_test.go index dceafd4d3f..eefd62f733 100644 --- a/internal/querycoordv2/assign/assign_policy_score_test.go +++ b/internal/querycoordv2/assign/assign_policy_score_test.go @@ -532,3 +532,46 @@ func TestScoreBasedAssignPolicy_AssignChannel_EmptyNodes(t *testing.T) { assert.Nil(t, plans) } + +// TestScoreBasedAssignPolicy_WorkloadStatusOnDemandUpdate tests the on-demand workload status update mechanism +func TestScoreBasedAssignPolicy_WorkloadStatusOnDemandUpdate(t *testing.T) { + nodeManager := session.NewNodeManager() + mockScheduler := task.NewMockScheduler(t) + mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + dist := meta.NewDistributionManager(nodeManager) + metaMgr := meta.NewMeta(nil, nil, nodeManager) + + policy := newScoreBasedAssignPolicy(nodeManager, mockScheduler, dist, metaMgr) + + // 1. Init status + firstStatus := policy.getWorkloadStatus() + firstVersion := policy.version + assert.NotNil(t, firstStatus) + + // 2. Update without meta change + secondStatus := policy.getWorkloadStatus() + // Should be identical pointer + assert.Equal(t, firstStatus, secondStatus, "Status pointer should be identical when version hasn't changed") + assert.Equal(t, firstVersion, policy.version) + + // 3. Update with segment meta change + dist.SegmentDistManager.Update(1, &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 100}, Node: 1}) + + // 4. Update again + thirdStatus := policy.getWorkloadStatus() + // Should be new pointer + assert.NotEqual(t, firstStatus, thirdStatus, "Status should be refreshed when segment version changed") + assert.Greater(t, policy.version, firstVersion) + + secondVersion := policy.version + + // 5. Update with channel meta change + dist.ChannelDistManager.Update(1, &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{ChannelName: "v1"}, Node: 1, View: &meta.LeaderView{ID: 1}}) + + // 6. Update again + fourthStatus := policy.getWorkloadStatus() + // Should be new pointer + assert.NotEqual(t, thirdStatus, fourthStatus, "Status should be refreshed when channel version changed") + assert.Greater(t, policy.version, secondVersion) +} diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index cab8279eae..078088538c 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -134,8 +134,14 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task { func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica) []task.Task { ret := make([]task.Task, 0) + replicaSegmentDist := c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) + delegatorList := c.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithReplica2Channel(replica)) + ch2DelegatorList := lo.GroupBy(delegatorList, func(d *meta.DmChannel) string { + return d.View.Channel + }) + // compare with targets to find the lack and redundancy of segments - lacks, loadPriorities, redundancies, toUpdate := c.getSealedSegmentDiff(ctx, replica.GetCollectionID(), replica.GetID()) + lacks, loadPriorities, redundancies, toUpdate := c.getSealedSegmentDiff(ctx, replica.GetCollectionID(), replica, replicaSegmentDist) tasks := c.createSegmentLoadTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), lacks, loadPriorities, replica) task.SetReason("lacks of segment", tasks...) task.SetPriority(task.TaskPriorityNormal, tasks...) @@ -146,15 +152,15 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica task.SetPriority(task.TaskPriorityNormal, tasks...) ret = append(ret, tasks...) - redundancies = c.filterOutSegmentInUse(ctx, replica, redundancies) + redundancies = c.filterOutSegmentInUse(ctx, replica, redundancies, ch2DelegatorList) tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Historical) task.SetReason("segment not exists in target", tasks...) task.SetPriority(task.TaskPriorityNormal, tasks...) ret = append(ret, tasks...) // compare inner dists to find repeated loaded segments - redundancies = c.findRepeatedSealedSegments(ctx, replica.GetID()) - redundancies = c.filterOutExistedOnLeader(replica, redundancies) + redundancies = c.findRepeatedSealedSegments(ctx, replica, replicaSegmentDist) + redundancies = c.filterOutExistedOnLeader(replica, redundancies, ch2DelegatorList) tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Historical) task.SetReason("redundancies of segment", tasks...) // set deduplicate task priority to low, to avoid deduplicate task cancel balance task @@ -162,7 +168,7 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica ret = append(ret, tasks...) // compare with target to find the lack and redundancy of segments - _, redundancies = c.getGrowingSegmentDiff(ctx, replica.GetCollectionID(), replica.GetID()) + _, redundancies = c.getGrowingSegmentDiff(ctx, replica.GetCollectionID(), replica, delegatorList) tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Streaming) task.SetReason("streaming segment not exists in target", tasks...) task.SetPriority(task.TaskPriorityNormal, tasks...) @@ -173,19 +179,13 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica // GetGrowingSegmentDiff get streaming segment diff between leader view and target func (c *SegmentChecker) getGrowingSegmentDiff(ctx context.Context, collectionID int64, - replicaID int64, + replica *meta.Replica, + delegatorList []*meta.DmChannel, ) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { - replica := c.meta.Get(ctx, replicaID) - if replica == nil { - log.Info("replica does not exist, skip it") - return - } - log := log.Ctx(context.TODO()).WithRateGroup("qcv2.SegmentChecker", 1, 60).With( zap.Int64("collectionID", collectionID), zap.Int64("replicaID", replica.GetID())) - delegatorList := c.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica)) for _, d := range delegatorList { view := d.View targetVersion := c.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.CurrentTarget) @@ -236,14 +236,9 @@ func (c *SegmentChecker) getGrowingSegmentDiff(ctx context.Context, collectionID func (c *SegmentChecker) getSealedSegmentDiff( ctx context.Context, collectionID int64, - replicaID int64, + replica *meta.Replica, + dist []*meta.Segment, ) (toLoad []*datapb.SegmentInfo, loadPriorities []commonpb.LoadPriority, toRelease []*meta.Segment, toUpdate []*meta.Segment) { - replica := c.meta.Get(ctx, replicaID) - if replica == nil { - log.Info("replica does not exist, skip it") - return - } - dist := c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) sort.Slice(dist, func(i, j int) bool { return dist[i].Version < dist[j].Version }) @@ -294,14 +289,8 @@ func (c *SegmentChecker) getSealedSegmentDiff( return } -func (c *SegmentChecker) findRepeatedSealedSegments(ctx context.Context, replicaID int64) []*meta.Segment { +func (c *SegmentChecker) findRepeatedSealedSegments(ctx context.Context, replica *meta.Replica, dist []*meta.Segment) []*meta.Segment { segments := make([]*meta.Segment, 0) - replica := c.meta.Get(ctx, replicaID) - if replica == nil { - log.Info("replica does not exist, skip it") - return segments - } - dist := c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) versions := make(map[int64]*meta.Segment) for _, s := range dist { maxVer, ok := versions[s.GetID()] @@ -321,12 +310,8 @@ func (c *SegmentChecker) findRepeatedSealedSegments(ctx context.Context, replica } // for duplicated segment, we should release the one which is not serving on leader -func (c *SegmentChecker) filterOutExistedOnLeader(replica *meta.Replica, segments []*meta.Segment) []*meta.Segment { +func (c *SegmentChecker) filterOutExistedOnLeader(replica *meta.Replica, segments []*meta.Segment, ch2DelegatorList map[string][]*meta.DmChannel) []*meta.Segment { notServing := make([]*meta.Segment, 0, len(segments)) - delegatorList := c.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica)) - ch2DelegatorList := lo.GroupBy(delegatorList, func(d *meta.DmChannel) string { - return d.View.Channel - }) for _, s := range segments { delegatorList := ch2DelegatorList[s.GetInsertChannel()] if len(delegatorList) == 0 { @@ -350,12 +335,8 @@ func (c *SegmentChecker) filterOutExistedOnLeader(replica *meta.Replica, segment } // for sealed segment which doesn't exist in target, we should release it after delegator has updated to latest readable version -func (c *SegmentChecker) filterOutSegmentInUse(ctx context.Context, replica *meta.Replica, segments []*meta.Segment) []*meta.Segment { +func (c *SegmentChecker) filterOutSegmentInUse(ctx context.Context, replica *meta.Replica, segments []*meta.Segment, ch2DelegatorList map[string][]*meta.DmChannel) []*meta.Segment { notUsed := make([]*meta.Segment, 0, len(segments)) - delegatorList := c.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica)) - ch2DelegatorList := lo.GroupBy(delegatorList, func(d *meta.DmChannel) string { - return d.View.Channel - }) for _, s := range segments { currentTargetVersion := c.targetMgr.GetCollectionTargetVersion(ctx, s.CollectionID, meta.CurrentTarget) partition := c.meta.CollectionManager.GetPartition(ctx, s.PartitionID) diff --git a/internal/querycoordv2/checkers/segment_checker_test.go b/internal/querycoordv2/checkers/segment_checker_test.go index db9d6c9927..0e61d63662 100644 --- a/internal/querycoordv2/checkers/segment_checker_test.go +++ b/internal/querycoordv2/checkers/segment_checker_test.go @@ -21,6 +21,7 @@ import ( "sort" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -684,7 +685,9 @@ func (suite *SegmentCheckerTestSuite) TestLoadPriority() { suite.checker.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) // test getSealedSegmentDiff - toLoad, loadPriorities, toRelease, toUpdate := suite.checker.getSealedSegmentDiff(ctx, collectionID, replicaID) + // Pre-fetch segment distribution for the test + dist := suite.checker.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) + toLoad, loadPriorities, toRelease, toUpdate := suite.checker.getSealedSegmentDiff(ctx, collectionID, replica, dist) // verify results suite.Equal(2, len(toLoad)) @@ -706,7 +709,8 @@ func (suite *SegmentCheckerTestSuite) TestLoadPriority() { // update current target to include segment2 suite.checker.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) // test again - toLoad, loadPriorities, toRelease, toUpdate = suite.checker.getSealedSegmentDiff(ctx, collectionID, replicaID) + dist = suite.checker.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) + toLoad, loadPriorities, toRelease, toUpdate = suite.checker.getSealedSegmentDiff(ctx, collectionID, replica, dist) // verify results suite.Equal(0, len(toLoad)) suite.Equal(0, len(loadPriorities)) @@ -737,8 +741,17 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutExistedOnLeader() { utils.CreateTestSegment(collectionID, partitionID, segmentID3, nodeID1, 1, channel), } + // Helper to get ch2DelegatorList + getCh2DelegatorList := func() map[string][]*meta.DmChannel { + delegatorList := checker.dist.ChannelDistManager.GetByCollectionAndFilter(collectionID, meta.WithReplica2Channel(replica)) + return lo.GroupBy(delegatorList, func(d *meta.DmChannel) string { + return d.View.Channel + }) + } + // Test case 1: No leader views - should skip releasing segments - result := checker.filterOutExistedOnLeader(replica, segments) + ch2DelegatorList := getCh2DelegatorList() + result := checker.filterOutExistedOnLeader(replica, segments, ch2DelegatorList) suite.Equal(0, len(result), "Should return all segments when no leader views") // Test case 2: Segment serving on leader - should be filtered out @@ -753,7 +766,8 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutExistedOnLeader() { View: leaderView1, }) - result = checker.filterOutExistedOnLeader(replica, segments) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutExistedOnLeader(replica, segments, ch2DelegatorList) suite.Len(result, 2, "Should filter out segment serving on leader") // Check that segmentID1 is filtered out @@ -773,7 +787,8 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutExistedOnLeader() { View: leaderView2, }) - result = checker.filterOutExistedOnLeader(replica, segments) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutExistedOnLeader(replica, segments, ch2DelegatorList) suite.Len(result, 1, "Should filter out segments serving on their respective leaders") suite.Equal(segmentID3, result[0].GetID(), "Only non-serving segment should remain") @@ -789,7 +804,8 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutExistedOnLeader() { View: leaderView3, }) - result = checker.filterOutExistedOnLeader(replica, []*meta.Segment{segments[2]}) // Only test segmentID3 + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutExistedOnLeader(replica, []*meta.Segment{segments[2]}, ch2DelegatorList) // Only test segmentID3 suite.Len(result, 1, "Segment not serving on its actual node should not be filtered") } @@ -833,8 +849,17 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutSegmentInUse() { checker.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) currentTargetVersion := checker.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.CurrentTarget) + // Helper to get ch2DelegatorList + getCh2DelegatorList := func() map[string][]*meta.DmChannel { + delegatorList := checker.dist.ChannelDistManager.GetByCollectionAndFilter(collectionID, meta.WithReplica2Channel(replica)) + return lo.GroupBy(delegatorList, func(d *meta.DmChannel) string { + return d.View.Channel + }) + } + // Test case 1: No leader views - should skip releasing segments - result := checker.filterOutSegmentInUse(ctx, replica, segments) + ch2DelegatorList := getCh2DelegatorList() + result := checker.filterOutSegmentInUse(ctx, replica, segments, ch2DelegatorList) suite.Equal(0, len(result), "Should return all segments when no leader views") // Test case 2: Leader view with outdated target version - segment should be filtered (still in use) @@ -850,7 +875,8 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutSegmentInUse() { View: leaderView1, }) - result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[0]}) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[0]}, ch2DelegatorList) suite.Len(result, 0, "Segment should be filtered out when delegator hasn't updated to latest version") // Test case 3: Leader view with current target version - segment should not be filtered @@ -866,7 +892,8 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutSegmentInUse() { View: leaderView2, }) - result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[0]}) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[0]}, ch2DelegatorList) suite.Len(result, 1, "Segment should not be filtered when delegator has updated to latest version") // Test case 4: Leader view with initial target version - segment should not be filtered @@ -882,7 +909,8 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutSegmentInUse() { View: leaderView3, }) - result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[1]}) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[1]}, ch2DelegatorList) suite.Len(result, 1, "Segment should not be filtered when leader has initial target version") // Test case 5: Multiple leader views with mixed versions - segment should be filtered (still in use) @@ -915,12 +943,14 @@ func (suite *SegmentCheckerTestSuite) TestFilterOutSegmentInUse() { utils.CreateTestSegment(collectionID, partitionID, segmentID2, nodeID2, 1, channel), } - result = checker.filterOutSegmentInUse(ctx, replica, testSegments) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutSegmentInUse(ctx, replica, testSegments, ch2DelegatorList) suite.Len(result, 0, "Should release all segments when any delegator hasn't updated") // Test case 6: Partition is nil - should release all segments (no partition info) checker.meta.CollectionManager.RemovePartition(ctx, partitionID) - result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[0]}) + ch2DelegatorList = getCh2DelegatorList() + result = checker.filterOutSegmentInUse(ctx, replica, []*meta.Segment{segments[0]}, ch2DelegatorList) suite.Len(result, 0, "Should release all segments when partition is nil") } diff --git a/internal/querycoordv2/meta/channel_dist_manager.go b/internal/querycoordv2/meta/channel_dist_manager.go index 67fd5df080..71d0d7f397 100644 --- a/internal/querycoordv2/meta/channel_dist_manager.go +++ b/internal/querycoordv2/meta/channel_dist_manager.go @@ -234,6 +234,7 @@ type ChannelDistManagerInterface interface { GetShardLeader(channelName string, replica *Replica) *DmChannel GetChannelDist(collectionID int64) []*metricsinfo.DmChannel GetLeaderView(collectionID int64) []*metricsinfo.LeaderView + GetVersion() int64 } type ChannelDistManager struct { @@ -246,6 +247,13 @@ type ChannelDistManager struct { collectionIndex map[int64][]*DmChannel nodeManager *session.NodeManager + version int64 +} + +func (m *ChannelDistManager) GetVersion() int64 { + m.rwmutex.RLock() + defer m.rwmutex.RUnlock() + return m.version } func NewChannelDistManager(nodeManager *session.NodeManager) *ChannelDistManager { @@ -323,6 +331,7 @@ func (m *ChannelDistManager) Update(nodeID typeutil.UniqueID, channels ...*DmCha m.channels[nodeID] = composeNodeChannels(channels...) m.updateCollectionIndex() + m.version++ return newServiceableChannels } diff --git a/internal/querycoordv2/meta/channel_dist_manager_test.go b/internal/querycoordv2/meta/channel_dist_manager_test.go index 0381589da9..471c749e28 100644 --- a/internal/querycoordv2/meta/channel_dist_manager_test.go +++ b/internal/querycoordv2/meta/channel_dist_manager_test.go @@ -93,6 +93,18 @@ func (suite *ChannelDistManagerSuite) SetupTest() { suite.dist.Update(suite.nodes[2], suite.channels["dmc1"].Clone()) } +func (suite *ChannelDistManagerSuite) TestVersion() { + dist := suite.dist + v1 := dist.GetVersion() + + // Update with some new data + newChannel := suite.channels["dmc0"].Clone() + newChannel.Version = 2 + dist.Update(suite.nodes[0], newChannel) + v2 := dist.GetVersion() + suite.Greater(v2, v1) +} + func (suite *ChannelDistManagerSuite) TestGetBy() { dist := suite.dist diff --git a/internal/querycoordv2/meta/segment_dist_manager.go b/internal/querycoordv2/meta/segment_dist_manager.go index 5d92d2843f..88f56ad126 100644 --- a/internal/querycoordv2/meta/segment_dist_manager.go +++ b/internal/querycoordv2/meta/segment_dist_manager.go @@ -163,6 +163,7 @@ type SegmentDistManagerInterface interface { Update(nodeID typeutil.UniqueID, segments ...*Segment) GetByFilter(filters ...SegmentDistFilter) []*Segment GetSegmentDist(collectionID int64) []*metricsinfo.Segment + GetVersion() int64 } type SegmentDistManager struct { @@ -170,6 +171,13 @@ type SegmentDistManager struct { // nodeID -> []*Segment segments map[typeutil.UniqueID]nodeSegments + version int64 +} + +func (m *SegmentDistManager) GetVersion() int64 { + m.rwmutex.RLock() + defer m.rwmutex.RUnlock() + return m.version } type nodeSegments struct { @@ -218,6 +226,7 @@ func (m *SegmentDistManager) Update(nodeID typeutil.UniqueID, segments ...*Segme segment.Node = nodeID } m.segments[nodeID] = composeNodeSegments(segments) + m.version++ } // GetByFilter return segment list which match all given filters diff --git a/internal/querycoordv2/meta/segment_dist_manager_test.go b/internal/querycoordv2/meta/segment_dist_manager_test.go index f53c649745..7d7931247e 100644 --- a/internal/querycoordv2/meta/segment_dist_manager_test.go +++ b/internal/querycoordv2/meta/segment_dist_manager_test.go @@ -89,6 +89,16 @@ func (suite *SegmentDistManagerSuite) SetupTest() { suite.dist.Update(suite.nodes[2], suite.segments[3].Clone(), suite.segments[4].Clone()) } +func (suite *SegmentDistManagerSuite) TestVersion() { + dist := suite.dist + v1 := dist.GetVersion() + + // Update with some new data + dist.Update(suite.nodes[0], suite.segments[1].Clone(), suite.segments[2].Clone(), suite.segments[3].Clone()) + v2 := dist.GetVersion() + suite.Greater(v2, v1) +} + func (suite *SegmentDistManagerSuite) TestGetBy() { dist := suite.dist // Test GetByNode diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index ebe80f9683..c0d7614418 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -159,16 +159,18 @@ func (queue *taskQueue) Range(fn func(task Task) bool) { } type ExecutingTaskDelta struct { - data map[int64]map[int64]int // nodeID -> collectionID -> taskDelta - mu sync.RWMutex // Mutex to protect the map + data map[int64]map[int64]int // nodeID -> collectionID -> taskDelta + nodeTotalDelta map[int64]int // nodeID -> totalTaskDelta + mu sync.RWMutex // Mutex to protect the map taskIDRecords UniqueSet } func NewExecutingTaskDelta() *ExecutingTaskDelta { return &ExecutingTaskDelta{ - data: make(map[int64]map[int64]int), - taskIDRecords: NewUniqueSet(), + data: make(map[int64]map[int64]int), + nodeTotalDelta: make(map[int64]int), + taskIDRecords: NewUniqueSet(), } } @@ -194,6 +196,7 @@ func (etd *ExecutingTaskDelta) Add(task Task) { etd.data[nodeID] = make(map[int64]int) } etd.data[nodeID][collectionID] += delta + etd.nodeTotalDelta[nodeID] += delta } } @@ -220,6 +223,7 @@ func (etd *ExecutingTaskDelta) Sub(task Task) { } etd.data[nodeID][collectionID] -= delta + etd.nodeTotalDelta[nodeID] -= delta } } @@ -229,22 +233,29 @@ func (etd *ExecutingTaskDelta) Get(nodeID, collectionID int64) int { etd.mu.RLock() defer etd.mu.RUnlock() - var sum int - - for nID, collections := range etd.data { - if nodeID != -1 && nID != nodeID { - continue - } - - for cID, delta := range collections { - if collectionID != -1 && cID != collectionID { - continue - } - - sum += delta + if nodeID != -1 && collectionID != -1 { + nodeData, ok := etd.data[nodeID] + if !ok { + return 0 } + return nodeData[collectionID] } + if nodeID != -1 { + return etd.nodeTotalDelta[nodeID] + } + + var sum int + if collectionID != -1 { + for _, collections := range etd.data { + sum += collections[collectionID] + } + return sum + } + + for _, delta := range etd.nodeTotalDelta { + sum += delta + } return sum } @@ -253,7 +264,11 @@ func (etd *ExecutingTaskDelta) printDetailInfos() { defer etd.mu.RUnlock() if etd.taskIDRecords.Len() > 0 { - log.Info("task delta cache info", zap.Any("taskIDRecords", etd.taskIDRecords.Collect()), zap.Any("data", etd.data)) + log.Info("task delta cache info", + zap.Any("taskIDRecords", etd.taskIDRecords.Collect()), + zap.Any("data", etd.data), + zap.Any("nodeTotalDelta", etd.nodeTotalDelta), + ) } } @@ -261,6 +276,7 @@ func (etd *ExecutingTaskDelta) Clear() { etd.mu.Lock() defer etd.mu.Unlock() etd.data = make(map[int64]map[int64]int) + etd.nodeTotalDelta = make(map[int64]int) etd.taskIDRecords.Clear() } diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 4a93bd876e..86bd7c663a 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -1791,7 +1791,7 @@ func (suite *TaskSuite) TestTaskDeltaCache() { context.TODO(), 10*time.Second, WrapIDSource(0), - 1, + collectionID, suite.replica, NewChannelAction(nodeID, ActionTypeGrow, "channel"), ) @@ -1809,6 +1809,8 @@ func (suite *TaskSuite) TestTaskDeltaCache() { etd.Sub(tasks[i]) } suite.Equal(0, etd.Get(nodeID, collectionID)) + suite.Equal(0, etd.Get(nodeID, -1)) + suite.Equal(0, etd.Get(-1, -1)) } func (suite *TaskSuite) TestRemoveTaskWithError() {