fix: Fix data race in global scheduler test using atomic counters (#42454)

issue: #42457

Replace unsafe ExpectedCalls modification with atomic.Int32 state
tracking to avoid race conditions in concurrent test execution. Changes
include:
- Use atomic counters instead of direct mock ExpectedCalls manipulation
- Add RunAndReturn with atomic state transitions for thread safety
- Remove github.com/samber/lo dependency

This prevents data race when mock framework and test goroutines access
ExpectedCalls concurrently.

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-06-03 14:18:30 +08:00 committed by GitHub
parent c827f4b948
commit 5a355d1e57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,10 +18,10 @@ package task
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
mock "github.com/stretchr/testify/mock"
@ -137,16 +137,14 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) {
NodeID: 2,
AvailableSlots: 100,
},
})
}).Maybe()
newTask := func() *MockTask {
task := NewMockTask(t)
task.EXPECT().GetTaskID().Return(1)
task.EXPECT().GetTaskType().Return(taskcommon.Compaction)
task.EXPECT().GetTaskState().Return(taskcommon.Init)
task.EXPECT().GetTaskType().Return(taskcommon.Compaction)
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return()
task.EXPECT().GetTaskSlot().Return(1)
task.EXPECT().GetTaskID().Return(1).Maybe()
task.EXPECT().GetTaskType().Return(taskcommon.Compaction).Maybe()
task.EXPECT().SetTaskTime(mock.Anything, mock.Anything).Return().Maybe()
task.EXPECT().GetTaskSlot().Return(1).Maybe()
return task
}
@ -156,12 +154,20 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) {
defer scheduler.Stop()
task := newTask()
var stateCounter atomic.Int32
// Set initial state
task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State {
counter := stateCounter.Load()
if counter == 0 {
return taskcommon.Init
}
return taskcommon.Retry
}).Maybe()
task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) {
task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "GetTaskState"
})
task.EXPECT().GetTaskState().Return(taskcommon.Retry)
})
stateCounter.Store(1) // Mark that CreateTaskOnWorker was called
}).Maybe()
scheduler.Enqueue(task)
assert.Eventually(t, func() bool {
@ -179,18 +185,27 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) {
defer scheduler.Stop()
task := newTask()
var stateCounter atomic.Int32
task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State {
counter := stateCounter.Load()
switch counter {
case 0:
return taskcommon.Init
case 1:
return taskcommon.InProgress
default:
return taskcommon.Retry
}
}).Maybe()
task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) {
task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "GetTaskState"
})
task.EXPECT().GetTaskState().Return(taskcommon.InProgress)
})
stateCounter.Store(1) // CreateTaskOnWorker called
}).Maybe()
task.EXPECT().QueryTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) {
task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "GetTaskState"
})
task.EXPECT().GetTaskState().Return(taskcommon.Retry)
})
stateCounter.Store(2) // QueryTaskOnWorker called
}).Maybe()
scheduler.Enqueue(task)
assert.Eventually(t, func() bool {
@ -208,19 +223,31 @@ func TestGlobalScheduler_TestSchedule(t *testing.T) {
defer scheduler.Stop()
task := newTask()
var stateCounter atomic.Int32
task.EXPECT().GetTaskState().RunAndReturn(func() taskcommon.State {
counter := stateCounter.Load()
switch counter {
case 0:
return taskcommon.Init
case 1:
return taskcommon.InProgress
default:
return taskcommon.Finished
}
}).Maybe()
task.EXPECT().CreateTaskOnWorker(mock.Anything, mock.Anything).Run(func(nodeID int64, cluster session.Cluster) {
task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "GetTaskState"
})
task.EXPECT().GetTaskState().Return(taskcommon.InProgress)
})
stateCounter.Store(1) // CreateTaskOnWorker called
}).Maybe()
task.EXPECT().QueryTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) {
task.ExpectedCalls = lo.Filter(task.ExpectedCalls, func(call *mock.Call, _ int) bool {
return call.Method != "GetTaskState"
})
task.EXPECT().GetTaskState().Return(taskcommon.Finished)
})
task.EXPECT().DropTaskOnWorker(mock.Anything).Return()
stateCounter.Store(2) // QueryTaskOnWorker called
}).Maybe()
task.EXPECT().DropTaskOnWorker(mock.Anything).Run(func(cluster session.Cluster) {
stateCounter.Store(3) // DropTaskOnWorker called
}).Maybe()
scheduler.Enqueue(task)
assert.Eventually(t, func() bool {