diff --git a/internal/datacoord/task/global_scheduler_test.go b/internal/datacoord/task/global_scheduler_test.go index f1eabe63ec..0e3c24a2c5 100644 --- a/internal/datacoord/task/global_scheduler_test.go +++ b/internal/datacoord/task/global_scheduler_test.go @@ -18,10 +18,10 @@ package task import ( "context" + "sync/atomic" "testing" "time" - "github.com/samber/lo" "github.com/stretchr/testify/assert" mock "github.com/stretchr/testify/mock" @@ -137,16 +137,14 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) { NodeID: 2, AvailableSlots: 100, }, - }) + }).Maybe() newTask := func() *MockTask { task := NewMockTask(t) - task.EXPECT().GetTaskID().Return(1) - task.EXPECT().GetTaskType().Return(taskcommon.Compaction) - task.EXPECT().GetTaskState().Return(taskcommon.Init) - task.EXPECT().GetTaskType().Return(taskcommon.Compaction) - task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return() - task.EXPECT().GetTaskSlot().Return(1) + task.EXPECT().GetTaskID().Return(1).Maybe() + task.EXPECT().GetTaskType().Return(taskcommon.Compaction).Maybe() + task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return().Maybe() + task.EXPECT().GetTaskSlot().Return(1).Maybe() return task } @@ -156,12 +154,20 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) { defer scheduler.Stop() task := newTask() + var stateCounter atomic.Int32 + + // Set initial state + task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State { + counter := stateCounter.Load() + if counter == 0 { + return taskcommon.Init + } + return taskcommon.Retry + }).Maybe() + task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) { - task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "GetTaskState" - }) - task.EXPECT().GetTaskState().Return(taskcommon.Retry) - }) + stateCounter.Store(1) // Mark that CreateTaskOnWorker was called + }).Maybe() scheduler.Enqueue(task) assert.Eventually(t, func() bool { @@ -179,18 +185,27 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) { defer scheduler.Stop() task := newTask() + var stateCounter atomic.Int32 + + task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State { + counter := stateCounter.Load() + switch counter { + case 0: + return taskcommon.Init + case 1: + return taskcommon.InProgress + default: + return taskcommon.Retry + } + }).Maybe() + task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) { - task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "GetTaskState" - }) - task.EXPECT().GetTaskState().Return(taskcommon.InProgress) - }) + stateCounter.Store(1) // CreateTaskOnWorker called + }).Maybe() + task.EXPECT().QueryTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) { - task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "GetTaskState" - }) - task.EXPECT().GetTaskState().Return(taskcommon.Retry) - }) + stateCounter.Store(2) // QueryTaskOnWorker called + }).Maybe() scheduler.Enqueue(task) assert.Eventually(t, func() bool { @@ -208,19 +223,31 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) { defer scheduler.Stop() task := newTask() + var stateCounter atomic.Int32 + + task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State { + counter := stateCounter.Load() + switch counter { + case 0: + return taskcommon.Init + case 1: + return taskcommon.InProgress + default: + return taskcommon.Finished + } + }).Maybe() + task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) { - task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "GetTaskState" - }) - task.EXPECT().GetTaskState().Return(taskcommon.InProgress) - }) + stateCounter.Store(1) // CreateTaskOnWorker called + }).Maybe() + task.EXPECT().QueryTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) { - task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "GetTaskState" - }) - task.EXPECT().GetTaskState().Return(taskcommon.Finished) - }) - task.EXPECT().DropTaskOnWorker(mock.Anything).Return() + stateCounter.Store(2) // QueryTaskOnWorker called + }).Maybe() + + task.EXPECT().DropTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) { + stateCounter.Store(3) // DropTaskOnWorker called + }).Maybe() scheduler.Enqueue(task) assert.Eventually(t, func() bool {