diff --git a/internal/querynode/mock_tsafe_replica_test.go b/internal/querynode/mock_tsafe_replica_test.go new file mode 100644 index 0000000000..64c2e70737 --- /dev/null +++ b/internal/querynode/mock_tsafe_replica_test.go @@ -0,0 +1,248 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package querynode + +import mock "github.com/stretchr/testify/mock" + +// MockTSafeReplicaInterface is an autogenerated mock type for the TSafeReplicaInterface type +type MockTSafeReplicaInterface struct { + mock.Mock +} + +type MockTSafeReplicaInterface_Expecter struct { + mock *mock.Mock +} + +func (_m *MockTSafeReplicaInterface) EXPECT() *MockTSafeReplicaInterface_Expecter { + return &MockTSafeReplicaInterface_Expecter{mock: &_m.Mock} +} + +// Watch provides a mock function with given fields: +func (_m *MockTSafeReplicaInterface) Watch() Listener { + ret := _m.Called() + + var r0 Listener + if rf, ok := ret.Get(0).(func() Listener); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(Listener) + } + } + + return r0 +} + +// MockTSafeReplicaInterface_Watch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Watch' +type MockTSafeReplicaInterface_Watch_Call struct { + *mock.Call +} + +// Watch is a helper method to define mock.On call +func (_e *MockTSafeReplicaInterface_Expecter) Watch() *MockTSafeReplicaInterface_Watch_Call { + return &MockTSafeReplicaInterface_Watch_Call{Call: _e.mock.On("Watch")} +} + +func (_c *MockTSafeReplicaInterface_Watch_Call) Run(run func()) *MockTSafeReplicaInterface_Watch_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockTSafeReplicaInterface_Watch_Call) Return(_a0 Listener) *MockTSafeReplicaInterface_Watch_Call { + _c.Call.Return(_a0) + return _c +} + +// WatchChannel provides a mock function with given fields: channel +func (_m *MockTSafeReplicaInterface) WatchChannel(channel string) Listener { + ret := _m.Called(channel) + + var r0 Listener + if rf, ok := ret.Get(0).(func(string) Listener); ok { + r0 = rf(channel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(Listener) + } + } + + return r0 +} + +// MockTSafeReplicaInterface_WatchChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchChannel' +type MockTSafeReplicaInterface_WatchChannel_Call struct { + *mock.Call +} + +// WatchChannel is a helper method to define mock.On call +// - channel string +func (_e *MockTSafeReplicaInterface_Expecter) WatchChannel(channel interface{}) *MockTSafeReplicaInterface_WatchChannel_Call { + return &MockTSafeReplicaInterface_WatchChannel_Call{Call: _e.mock.On("WatchChannel", channel)} +} + +func (_c *MockTSafeReplicaInterface_WatchChannel_Call) Run(run func(channel string)) *MockTSafeReplicaInterface_WatchChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockTSafeReplicaInterface_WatchChannel_Call) Return(_a0 Listener) *MockTSafeReplicaInterface_WatchChannel_Call { + _c.Call.Return(_a0) + return _c +} + +// addTSafe provides a mock function with given fields: vChannel +func (_m *MockTSafeReplicaInterface) addTSafe(vChannel string) { + _m.Called(vChannel) +} + +// MockTSafeReplicaInterface_addTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'addTSafe' +type MockTSafeReplicaInterface_addTSafe_Call struct { + *mock.Call +} + +// addTSafe is a helper method to define mock.On call +// - vChannel string +func (_e *MockTSafeReplicaInterface_Expecter) addTSafe(vChannel interface{}) *MockTSafeReplicaInterface_addTSafe_Call { + return &MockTSafeReplicaInterface_addTSafe_Call{Call: _e.mock.On("addTSafe", vChannel)} +} + +func (_c *MockTSafeReplicaInterface_addTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_addTSafe_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockTSafeReplicaInterface_addTSafe_Call) Return() *MockTSafeReplicaInterface_addTSafe_Call { + _c.Call.Return() + return _c +} + +// getTSafe provides a mock function with given fields: vChannel +func (_m *MockTSafeReplicaInterface) getTSafe(vChannel string) (uint64, error) { + ret := _m.Called(vChannel) + + var r0 uint64 + if rf, ok := ret.Get(0).(func(string) uint64); ok { + r0 = rf(vChannel) + } else { + r0 = ret.Get(0).(uint64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(vChannel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTSafeReplicaInterface_getTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getTSafe' +type MockTSafeReplicaInterface_getTSafe_Call struct { + *mock.Call +} + +// getTSafe is a helper method to define mock.On call +// - vChannel string +func (_e *MockTSafeReplicaInterface_Expecter) getTSafe(vChannel interface{}) *MockTSafeReplicaInterface_getTSafe_Call { + return &MockTSafeReplicaInterface_getTSafe_Call{Call: _e.mock.On("getTSafe", vChannel)} +} + +func (_c *MockTSafeReplicaInterface_getTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_getTSafe_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockTSafeReplicaInterface_getTSafe_Call) Return(_a0 uint64, _a1 error) *MockTSafeReplicaInterface_getTSafe_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// removeTSafe provides a mock function with given fields: vChannel +func (_m *MockTSafeReplicaInterface) removeTSafe(vChannel string) { + _m.Called(vChannel) +} + +// MockTSafeReplicaInterface_removeTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeTSafe' +type MockTSafeReplicaInterface_removeTSafe_Call struct { + *mock.Call +} + +// removeTSafe is a helper method to define mock.On call +// - vChannel string +func (_e *MockTSafeReplicaInterface_Expecter) removeTSafe(vChannel interface{}) *MockTSafeReplicaInterface_removeTSafe_Call { + return &MockTSafeReplicaInterface_removeTSafe_Call{Call: _e.mock.On("removeTSafe", vChannel)} +} + +func (_c *MockTSafeReplicaInterface_removeTSafe_Call) Run(run func(vChannel string)) *MockTSafeReplicaInterface_removeTSafe_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockTSafeReplicaInterface_removeTSafe_Call) Return() *MockTSafeReplicaInterface_removeTSafe_Call { + _c.Call.Return() + return _c +} + +// setTSafe provides a mock function with given fields: vChannel, timestamp +func (_m *MockTSafeReplicaInterface) setTSafe(vChannel string, timestamp uint64) error { + ret := _m.Called(vChannel, timestamp) + + var r0 error + if rf, ok := ret.Get(0).(func(string, uint64) error); ok { + r0 = rf(vChannel, timestamp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTSafeReplicaInterface_setTSafe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'setTSafe' +type MockTSafeReplicaInterface_setTSafe_Call struct { + *mock.Call +} + +// setTSafe is a helper method to define mock.On call +// - vChannel string +// - timestamp uint64 +func (_e *MockTSafeReplicaInterface_Expecter) setTSafe(vChannel interface{}, timestamp interface{}) *MockTSafeReplicaInterface_setTSafe_Call { + return &MockTSafeReplicaInterface_setTSafe_Call{Call: _e.mock.On("setTSafe", vChannel, timestamp)} +} + +func (_c *MockTSafeReplicaInterface_setTSafe_Call) Run(run func(vChannel string, timestamp uint64)) *MockTSafeReplicaInterface_setTSafe_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(uint64)) + }) + return _c +} + +func (_c *MockTSafeReplicaInterface_setTSafe_Call) Return(_a0 error) *MockTSafeReplicaInterface_setTSafe_Call { + _c.Call.Return(_a0) + return _c +} + +type mockConstructorTestingTNewMockTSafeReplicaInterface interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockTSafeReplicaInterface creates a new instance of MockTSafeReplicaInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockTSafeReplicaInterface(t mockConstructorTestingTNewMockTSafeReplicaInterface) *MockTSafeReplicaInterface { + mock := &MockTSafeReplicaInterface{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index 140070121a..c1490a765f 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -1049,11 +1049,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, resultMut.Lock() defer resultMut.Unlock() if streamErr != nil { - cancel() - // not set cancel error - if !errors.Is(streamErr, context.Canceled) { + if err == nil { err = fmt.Errorf("stream operation failed: %w", streamErr) } + cancel() } }() @@ -1077,11 +1076,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, resultMut.Lock() defer resultMut.Unlock() if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - cancel() - // not set cancel error - if !errors.Is(nodeErr, context.Canceled) { + if err == nil { err = fmt.Errorf("Search %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr) } + cancel() return } results = append(results, partialResult) @@ -1128,11 +1126,10 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi streamErr := withStreaming(reqCtx) if streamErr != nil { - cancel() - // not set cancel error - if !errors.Is(streamErr, context.Canceled) { + if err == nil { err = fmt.Errorf("stream operation failed: %w", streamErr) } + cancel() } }() @@ -1156,8 +1153,8 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi resultMut.Lock() defer resultMut.Unlock() if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - cancel() err = fmt.Errorf("Query %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr) + cancel() return } results = append(results, partialResult) diff --git a/internal/querynode/task_read.go b/internal/querynode/task_read.go index 23ca034f60..d485ee8e12 100644 --- a/internal/querynode/task_read.go +++ b/internal/querynode/task_read.go @@ -41,6 +41,7 @@ type readTask interface { CanMergeWith(readTask) bool CPUUsage() int32 Timeout() bool + TimeoutError() error SetMaxCPUUsage(int32) SetStep(step TaskStep) @@ -133,12 +134,16 @@ func (b *baseReadTask) Timeout() bool { return !funcutil.CheckCtxValid(b.Ctx()) } +func (b *baseReadTask) TimeoutError() error { + return b.ctx.Err() +} + func (b *baseReadTask) Ready() (bool, error) { if b.waitTSafeTr == nil { b.waitTSafeTr = timerecord.NewTimeRecorder("waitTSafeTimeRecorder") } if b.Timeout() { - return false, fmt.Errorf("deadline exceed") + return false, b.TimeoutError() } var channel Channel if b.DataScope == querypb.DataScope_Streaming { diff --git a/internal/querynode/task_read_test.go b/internal/querynode/task_read_test.go new file mode 100644 index 0000000000..72ade4777b --- /dev/null +++ b/internal/querynode/task_read_test.go @@ -0,0 +1,115 @@ +package querynode + +import ( + "context" + "testing" + "time" + + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/util/timerecord" + "github.com/stretchr/testify/suite" +) + +type baseReadTaskSuite struct { + suite.Suite + + qs *queryShard + tsafe *MockTSafeReplicaInterface + + task *baseReadTask +} + +func (s *baseReadTaskSuite) SetupSuite() { + meta := newMockReplicaInterface() + meta.getCollectionByIDFunc = func(collectionID UniqueID) (*Collection, error) { + return &Collection{ + id: defaultCollectionID, + }, nil + } + rcm := &mocks.ChunkManager{} + lcm := &mocks.ChunkManager{} + + tsafe := &MockTSafeReplicaInterface{} + + qs, err := newQueryShard(context.Background(), defaultCollectionID, defaultDMLChannel, defaultReplicaID, nil, meta, tsafe, lcm, rcm, false) + s.Require().NoError(err) + + s.qs = qs +} + +func (s *baseReadTaskSuite) TearDownSuite() { + s.qs.Close() +} + +func (s *baseReadTaskSuite) SetupTest() { + s.task = &baseReadTask{QS: s.qs, tr: timerecord.NewTimeRecorder("baseReadTaskTest")} +} + +func (s *baseReadTaskSuite) TearDownTest() { + s.task = nil +} + +func (s *baseReadTaskSuite) TestPreExecute() { + ctx := context.Background() + err := s.task.PreExecute(ctx) + s.Assert().NoError(err) + s.Assert().Equal(TaskStepPreExecute, s.task.step) +} + +func (s *baseReadTaskSuite) TestExecute() { + ctx := context.Background() + err := s.task.Execute(ctx) + s.Assert().NoError(err) + s.Assert().Equal(TaskStepExecute, s.task.step) +} + +func (s *baseReadTaskSuite) TestTimeout() { + s.Run("background ctx", func() { + s.task.ctx = context.Background() + s.Assert().False(s.task.Timeout()) + }) + + s.Run("context canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + s.task.ctx = ctx + + s.Assert().True(s.task.Timeout()) + }) + + s.Run("deadline exceeded", func() { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Minute)) + defer cancel() + s.task.ctx = ctx + + s.Assert().True(s.task.Timeout()) + }) +} + +func (s *baseReadTaskSuite) TestTimeoutError() { + s.Run("background ctx", func() { + s.task.ctx = context.Background() + s.Assert().Nil(s.task.TimeoutError()) + }) + + s.Run("context canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + s.task.ctx = ctx + + s.Assert().ErrorIs(s.task.TimeoutError(), context.Canceled) + }) + + s.Run("deadline exceeded", func() { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Minute)) + defer cancel() + s.task.ctx = ctx + + s.Assert().ErrorIs(s.task.TimeoutError(), context.DeadlineExceeded) + }) + +} + +func TestBaseReadTask(t *testing.T) { + suite.Run(t, new(baseReadTaskSuite)) +} diff --git a/internal/querynode/task_scheduler.go b/internal/querynode/task_scheduler.go index 7aafe60600..96a2036532 100644 --- a/internal/querynode/task_scheduler.go +++ b/internal/querynode/task_scheduler.go @@ -146,7 +146,6 @@ func (s *taskScheduler) tryEvictUnsolvedReadTask(headCount int) { if diff <= 0 { return } - timeoutErr := fmt.Errorf("deadline exceed") var next *list.Element for e := s.unsolvedReadTasks.Front(); e != nil; e = next { next = e.Next() @@ -160,7 +159,7 @@ func (s *taskScheduler) tryEvictUnsolvedReadTask(headCount int) { if t.Timeout() { s.unsolvedReadTasks.Remove(e) rateCol.rtCounter.sub(t, unsolvedQueueType) - t.Notify(timeoutErr) + t.Notify(t.TimeoutError()) diff-- } } @@ -188,7 +187,7 @@ func (s *taskScheduler) scheduleReadTasks() { for { select { case <-s.ctx.Done(): - log.Warn("QueryNode sop schedulerReadTasks") + log.Warn("QueryNode stop schedulerReadTasks") return case <-s.notifyChan: @@ -273,12 +272,11 @@ func (s *taskScheduler) executeReadTasks() { defer s.wg.Done() var taskWg sync.WaitGroup defer taskWg.Wait() - timeoutErr := fmt.Errorf("deadline exceed") executeFunc := func(t readTask) { defer taskWg.Done() if t.Timeout() { - t.Notify(timeoutErr) + t.Notify(t.TimeoutError()) } else { s.processReadTask(t) } @@ -302,6 +300,7 @@ func (s *taskScheduler) executeReadTasks() { pendingTaskLen := len(s.executeReadTaskChan) taskWg.Add(1) atomic.AddInt32(&s.readConcurrency, int32(pendingTaskLen+1)) + log.Debug("begin to execute task") go executeFunc(t) for i := 0; i < pendingTaskLen; i++ { diff --git a/internal/querynode/task_scheduler_test.go b/internal/querynode/task_scheduler_test.go index 50ae6fdb7d..411219a41d 100644 --- a/internal/querynode/task_scheduler_test.go +++ b/internal/querynode/task_scheduler_test.go @@ -20,6 +20,8 @@ import ( "context" "errors" "testing" + + "github.com/stretchr/testify/assert" ) type mockTask struct { @@ -65,6 +67,7 @@ type mockReadTask struct { ready bool canMerge bool timeout bool + timeoutError error step TaskStep readyError error } @@ -89,6 +92,10 @@ func (m *mockReadTask) Timeout() bool { return m.timeout } +func (m *mockReadTask) TimeoutError() error { + return m.timeoutError +} + func (m *mockReadTask) SetMaxCPUUsage(cpu int32) { m.maxCPU = cpu } @@ -125,3 +132,80 @@ func TestTaskScheduler(t *testing.T) { ts.Close() } + +func TestTaskScheduler_tryEvictUnsolvedReadTask(t *testing.T) { + t.Run("evict canceled task", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + tSafe := newTSafeReplica() + + ts := newTaskScheduler(ctx, tSafe) + + taskCanceled := &mockReadTask{ + mockTask: mockTask{ + baseTask: baseTask{ + ctx: ctx, + done: make(chan error, 1024), + }, + }, + timeout: true, + timeoutError: context.Canceled, + } + taskNormal := &mockReadTask{ + mockTask: mockTask{ + baseTask: baseTask{ + ctx: context.Background(), + done: make(chan error, 1024), + }, + }, + } + + ts.unsolvedReadTasks.PushBack(taskNormal) + ts.unsolvedReadTasks.PushBack(taskCanceled) + + // set max len to 2 + tmp := Params.QueryNodeCfg.MaxUnsolvedQueueSize + Params.QueryNodeCfg.MaxUnsolvedQueueSize = 2 + ts.tryEvictUnsolvedReadTask(1) + Params.QueryNodeCfg.MaxUnsolvedQueueSize = tmp + + err := <-taskCanceled.done + assert.ErrorIs(t, err, context.Canceled) + + select { + case <-taskNormal.done: + t.Fail() + default: + } + + assert.Equal(t, 1, ts.unsolvedReadTasks.Len()) + }) +} + +func TestTaskScheduler_executeReadTasks(t *testing.T) { + t.Run("execute canceled task", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tSafe := newTSafeReplica() + + ts := newTaskScheduler(ctx, tSafe) + ts.Start() + defer ts.Close() + + taskCanceled := &mockReadTask{ + mockTask: mockTask{ + baseTask: baseTask{ + ctx: ctx, + done: make(chan error, 1024), + }, + }, + timeout: true, + timeoutError: context.Canceled, + } + + ts.executeReadTaskChan <- taskCanceled + + err := <-taskCanceled.done + assert.ErrorIs(t, err, context.Canceled) + }) +}