diff --git a/internal/datanode/channel_manager.go b/internal/datanode/channel_manager.go index 97ae15e714..1fb3e4d4a0 100644 --- a/internal/datanode/channel_manager.go +++ b/internal/datanode/channel_manager.go @@ -32,7 +32,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type releaseFunc func(channel string) +type ( + releaseFunc func(channel string) + watchFunc func(ctx context.Context, dn *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) +) type ChannelManager interface { Submit(info *datapb.ChannelWatchInfo) error @@ -206,7 +209,7 @@ func (m *ChannelManagerImpl) handleOpState(opState *opState) { } func (m *ChannelManagerImpl) getOrCreateRunner(channel string) *opRunner { - runner, loaded := m.opRunners.GetOrInsert(channel, NewOpRunner(channel, m.dn, m.releaseFunc, m.communicateCh)) + runner, loaded := m.opRunners.GetOrInsert(channel, NewOpRunner(channel, m.dn, m.releaseFunc, executeWatch, m.communicateCh)) if !loaded { runner.Start() } @@ -228,6 +231,7 @@ type opRunner struct { channel string dn *DataNode releaseFunc releaseFunc + watchFunc watchFunc guard sync.RWMutex allOps map[UniqueID]*opInfo // opID -> tickler @@ -238,11 +242,12 @@ type opRunner struct { closeWg sync.WaitGroup } -func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opState) *opRunner { +func NewOpRunner(channel string, dn *DataNode, releaseF releaseFunc, watchF watchFunc, resultCh chan *opState) *opRunner { return &opRunner{ channel: channel, dn: dn, - releaseFunc: f, + releaseFunc: releaseF, + watchFunc: watchF, opsInQueue: make(chan *datapb.ChannelWatchInfo, 10), allOps: make(map[UniqueID]*opInfo), resultCh: resultCh, @@ -333,16 +338,16 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { opInfo.tickler = tickler var ( - successSig = make(chan struct{}, 1) - waiter sync.WaitGroup + successSig = make(chan struct{}, 1) + finishWaiter sync.WaitGroup ) watchTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) ctx, cancel := context.WithTimeout(context.Background(), watchTimeout) defer cancel() - startTimer := func(wg *sync.WaitGroup) { - defer wg.Done() + startTimer := func(finishWg *sync.WaitGroup) { + defer finishWg.Done() timer := time.NewTimer(watchTimeout) defer timer.Stop() @@ -377,11 +382,12 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { } } - waiter.Add(2) - go startTimer(&waiter) + finishWaiter.Add(2) + go startTimer(&finishWaiter) + go func() { - defer waiter.Done() - fg, err := executeWatch(ctx, r.dn, info, tickler) + defer finishWaiter.Done() + fg, err := r.watchFunc(ctx, r.dn, info, tickler) if err != nil { opState.state = datapb.ChannelWatchState_WatchFailure } else { @@ -391,7 +397,7 @@ func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { } }() - waiter.Wait() + finishWaiter.Wait() return opState } @@ -402,13 +408,14 @@ func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opI opID: opID, } var ( - successSig = make(chan struct{}, 1) - waiter sync.WaitGroup + successSig = make(chan struct{}, 1) + finishWaiter sync.WaitGroup ) log := log.With(zap.Int64("opID", opID), zap.String("channel", channel)) - startTimer := func(wg *sync.WaitGroup) { - defer wg.Done() + startTimer := func(finishWaiter *sync.WaitGroup) { + defer finishWaiter.Done() + releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) timer := time.NewTimer(releaseTimeout) defer timer.Stop() @@ -435,8 +442,8 @@ func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opI } } - waiter.Add(1) - go startTimer(&waiter) + finishWaiter.Add(1) + go startTimer(&finishWaiter) go func() { // TODO: failure should panic this DN, but we're not sure how // to recover when releaseFunc stuck. @@ -450,7 +457,7 @@ func (r *opRunner) releaseWithTimer(releaseFunc releaseFunc, channel string, opI successSig <- struct{}{} }() - waiter.Wait() + finishWaiter.Wait() return opState } diff --git a/internal/datanode/channel_manager_test.go b/internal/datanode/channel_manager_test.go index 85c13d7fe9..0dad91c14c 100644 --- a/internal/datanode/channel_manager_test.go +++ b/internal/datanode/channel_manager_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -56,7 +57,7 @@ func (s *OpRunnerSuite) TestWatchWithTimer() { mockReleaseFunc := func(channel string) { log.Info("mock release func") } - runner := NewOpRunner(channel, s.node, mockReleaseFunc, commuCh) + runner := NewOpRunner(channel, s.node, mockReleaseFunc, executeWatch, commuCh) err := runner.Enqueue(info) s.Require().NoError(err) @@ -67,6 +68,35 @@ func (s *OpRunnerSuite) TestWatchWithTimer() { runner.FinishOp(100) } +func (s *OpRunnerSuite) TestWatchTimeout() { + channel := "by-dev-rootcoord-dml-1000" + paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001") + defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key) + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + + sig := make(chan struct{}) + commuCh := make(chan *opState) + + mockReleaseFunc := func(channel string) { log.Info("mock release func") } + mockWatchFunc := func(ctx context.Context, dn *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) { + <-ctx.Done() + sig <- struct{}{} + return nil, errors.New("timeout") + } + + runner := NewOpRunner(channel, s.node, mockReleaseFunc, mockWatchFunc, commuCh) + runner.Start() + defer runner.Close() + err := runner.Enqueue(info) + s.Require().NoError(err) + + <-sig + opState := <-commuCh + s.Require().NotNil(opState) + s.Equal(info.GetOpID(), opState.opID) + s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state) +} + type OpRunnerSuite struct { suite.Suite node *DataNode @@ -126,26 +156,6 @@ func (s *ChannelManagerSuite) TearDownTest() { } } -func (s *ChannelManagerSuite) TestWatchFail() { - channel := "by-dev-rootcoord-dml-2" - paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001") - defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key) - info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) - s.Require().Equal(0, s.manager.opRunners.Len()) - err := s.manager.Submit(info) - s.Require().NoError(err) - - opState := <-s.manager.communicateCh - s.Require().NotNil(opState) - s.Equal(info.GetOpID(), opState.opID) - s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state) - - s.manager.handleOpState(opState) - - resp := s.manager.GetProgress(info) - s.Equal(datapb.ChannelWatchState_WatchFailure, resp.GetState()) -} - func (s *ChannelManagerSuite) TestReleaseStuck() { var ( channel = "by-dev-rootcoord-dml-2"