diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 404a0897bd..6f5c9d1d6e 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -146,6 +146,69 @@ func (queue *taskQueue) Range(fn func(task Task) bool) { } } +type ExecutingTaskDelta struct { + data map[int64]map[int64]int // nodeID -> collectionID -> taskDelta + mu sync.RWMutex // Mutex to protect the map +} + +func NewExecutingTaskDelta() *ExecutingTaskDelta { + return &ExecutingTaskDelta{ + data: make(map[int64]map[int64]int), + } +} + +// Add updates the taskDelta for the given nodeID and collectionID +func (etd *ExecutingTaskDelta) Add(nodeID int64, collectionID int64, delta int) { + etd.mu.Lock() + defer etd.mu.Unlock() + + if _, exists := etd.data[nodeID]; !exists { + etd.data[nodeID] = make(map[int64]int) + } + etd.data[nodeID][collectionID] += delta +} + +// Sub updates the taskDelta for the given nodeID and collectionID by subtracting delta +func (etd *ExecutingTaskDelta) Sub(nodeID int64, collectionID int64, delta int) { + etd.mu.Lock() + defer etd.mu.Unlock() + + if _, exists := etd.data[nodeID]; exists { + etd.data[nodeID][collectionID] -= delta + if etd.data[nodeID][collectionID] <= 0 { + delete(etd.data[nodeID], collectionID) + } + if len(etd.data[nodeID]) == 0 { + delete(etd.data, nodeID) + } + } +} + +// Get retrieves the sum of taskDelta for the given nodeID and collectionID +// If nodeID or collectionID is -1, it matches all +func (etd *ExecutingTaskDelta) Get(nodeID, collectionID int64) int { + etd.mu.RLock() + defer etd.mu.RUnlock() + + var sum int + + for nID, collections := range etd.data { + if nodeID != -1 && nID != nodeID { + continue + } + + for cID, delta := range collections { + if collectionID != -1 && cID != collectionID { + continue + } + + sum += delta + } + } + + return sum +} + type Scheduler interface { Start() Stop() @@ -183,6 +246,10 @@ type taskScheduler struct { processQueue *taskQueue waitQueue *taskQueue taskStats *expirable.LRU[UniqueID, Task] + + // nodeID -> collectionID -> taskDelta + segmentTaskDelta *ExecutingTaskDelta + channelTaskDelta *ExecutingTaskDelta } func NewScheduler(ctx context.Context, @@ -209,13 +276,15 @@ func NewScheduler(ctx context.Context, cluster: cluster, nodeMgr: nodeMgr, - collKeyLock: lock.NewKeyLock[int64](), - tasks: NewConcurrentMap[UniqueID, struct{}](), - segmentTasks: NewConcurrentMap[replicaSegmentIndex, Task](), - channelTasks: NewConcurrentMap[replicaChannelIndex, Task](), - processQueue: newTaskQueue(), - waitQueue: newTaskQueue(), - taskStats: expirable.NewLRU[UniqueID, Task](64, nil, time.Minute*15), + collKeyLock: lock.NewKeyLock[int64](), + tasks: NewConcurrentMap[UniqueID, struct{}](), + segmentTasks: NewConcurrentMap[replicaSegmentIndex, Task](), + channelTasks: NewConcurrentMap[replicaChannelIndex, Task](), + processQueue: newTaskQueue(), + waitQueue: newTaskQueue(), + taskStats: expirable.NewLRU[UniqueID, Task](64, nil, time.Minute*15), + segmentTaskDelta: NewExecutingTaskDelta(), + channelTaskDelta: NewExecutingTaskDelta(), } } @@ -272,6 +341,7 @@ func (scheduler *taskScheduler) Add(task Task) error { task.SetID(scheduler.idAllocator()) scheduler.waitQueue.Add(task) scheduler.tasks.Insert(task.ID(), struct{}{}) + scheduler.incExecutingTaskDelta(task) switch task := task.(type) { case *SegmentTask: index := NewReplicaSegmentIndex(task) @@ -511,76 +581,59 @@ func (scheduler *taskScheduler) Dispatch(node int64) { } func (scheduler *taskScheduler) GetSegmentTaskDelta(nodeID, collectionID int64) int { - targetActions := make(map[int64][]Action) - scheduler.segmentTasks.Range(func(_ replicaSegmentIndex, task Task) bool { - taskCollID := task.CollectionID() - if collectionID != -1 && collectionID != taskCollID { - return true - } - actions := filterActions(task.Actions(), nodeID) - if len(actions) > 0 { - targetActions[taskCollID] = append(targetActions[taskCollID], actions...) - } - return true - }) - - return scheduler.calculateTaskDelta(targetActions) + return scheduler.segmentTaskDelta.Get(nodeID, collectionID) } func (scheduler *taskScheduler) GetChannelTaskDelta(nodeID, collectionID int64) int { - targetActions := make(map[int64][]Action) - scheduler.channelTasks.Range(func(_ replicaChannelIndex, task Task) bool { - taskCollID := task.CollectionID() - if collectionID != -1 && collectionID != taskCollID { - return true - } - actions := filterActions(task.Actions(), nodeID) - if len(actions) > 0 { - targetActions[taskCollID] = append(targetActions[taskCollID], actions...) - } - return true - }) - - return scheduler.calculateTaskDelta(targetActions) + return scheduler.channelTaskDelta.Get(nodeID, collectionID) } -// filter actions by nodeID -func filterActions(actions []Action, nodeID int64) []Action { - filtered := make([]Action, 0, len(actions)) - for _, action := range actions { - if nodeID == -1 || action.Node() == nodeID { - filtered = append(filtered, action) +func (scheduler *taskScheduler) incExecutingTaskDelta(task Task) { + for _, action := range task.Actions() { + delta := scheduler.computeActionDelta(task.CollectionID(), action) + switch action.(type) { + case *SegmentAction: + scheduler.segmentTaskDelta.Add(action.Node(), task.CollectionID(), delta) + case *ChannelAction: + scheduler.channelTaskDelta.Add(action.Node(), task.CollectionID(), delta) } } - return filtered } -func (scheduler *taskScheduler) calculateTaskDelta(targetActions map[int64][]Action) int { - sum := 0 - for collectionID, actions := range targetActions { - for _, action := range actions { - delta := 0 - if action.Type() == ActionTypeGrow { - delta = 1 - } else if action.Type() == ActionTypeReduce { - delta = -1 - } - - switch action := action.(type) { - case *SegmentAction: - // skip growing segment's count, cause doesn't know realtime row number of growing segment - if action.Scope == querypb.DataScope_Historical { - segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) - if segment != nil { - sum += int(segment.GetNumOfRows()) * delta - } - } - case *ChannelAction: - sum += delta - } +func (scheduler *taskScheduler) decExecutingTaskDelta(task Task) { + for _, action := range task.Actions() { + delta := scheduler.computeActionDelta(task.CollectionID(), action) + switch action.(type) { + case *SegmentAction: + scheduler.segmentTaskDelta.Sub(action.Node(), task.CollectionID(), delta) + case *ChannelAction: + scheduler.channelTaskDelta.Sub(action.Node(), task.CollectionID(), delta) } } - return sum +} + +func (scheduler *taskScheduler) computeActionDelta(collectionID int64, action Action) int { + delta := 0 + if action.Type() == ActionTypeGrow { + delta = 1 + } else if action.Type() == ActionTypeReduce { + delta = -1 + } + + switch action := action.(type) { + case *SegmentAction: + // skip growing segment's count, cause doesn't know realtime row number of growing segment + if action.Scope == querypb.DataScope_Historical { + segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) + if segment != nil { + return int(segment.GetNumOfRows()) * delta + } + } + case *ChannelAction: + return delta + } + + return 0 } func (scheduler *taskScheduler) GetExecutedFlag(nodeID int64) <-chan struct{} { @@ -891,9 +944,12 @@ func (scheduler *taskScheduler) remove(task Task) { } task.Cancel(nil) - scheduler.tasks.Remove(task.ID()) + _, ok := scheduler.tasks.GetAndRemove(task.ID()) scheduler.waitQueue.Remove(task) scheduler.processQueue.Remove(task) + if ok { + scheduler.decExecutingTaskDelta(task) + } switch task := task.(type) { case *SegmentTask: diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index c6c933dfb3..be5fc021bc 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -1903,6 +1903,15 @@ func (suite *TaskSuite) TestCalculateTaskDelta() { suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) + + scheduler.remove(task1) + scheduler.remove(task2) + scheduler.remove(task3) + scheduler.remove(task4) + suite.Equal(0, scheduler.GetSegmentTaskDelta(nodeID, coll)) + suite.Equal(0, scheduler.GetChannelTaskDelta(nodeID, coll)) + suite.Equal(0, scheduler.GetSegmentTaskDelta(nodeID2, coll2)) + suite.Equal(0, scheduler.GetChannelTaskDelta(nodeID2, coll2)) } func TestTask(t *testing.T) {