diff --git a/internal/datacoord/channel_manager_v2.go b/internal/datacoord/channel_manager_v2.go index 402b09ebb8..6243761ce7 100644 --- a/internal/datacoord/channel_manager_v2.go +++ b/internal/datacoord/channel_manager_v2.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + "sync" "time" "github.com/cockroachdb/errors" @@ -59,9 +60,9 @@ type SubCluster interface { } type ChannelManagerImplV2 struct { - ctx context.Context cancel context.CancelFunc mu lock.RWMutex + wg sync.WaitGroup h Handler store RWChannelStore @@ -101,7 +102,6 @@ func NewChannelManagerV2( ) (*ChannelManagerImplV2, error) { m := &ChannelManagerImplV2{ h: h, - ctx: context.TODO(), // TODO factory: NewChannelPolicyFactoryV1(), store: NewChannelStoreV2(kv), subCluster: subCluster, @@ -122,7 +122,7 @@ func NewChannelManagerV2( } func (m *ChannelManagerImplV2) Startup(ctx context.Context, legacyNodes, allNodes []int64) error { - m.ctx, m.cancel = context.WithCancel(ctx) + ctx, m.cancel = context.WithCancel(ctx) m.legacyNodes = typeutil.NewUniqueSet(legacyNodes...) @@ -156,7 +156,11 @@ func (m *ChannelManagerImplV2) Startup(ctx context.Context, legacyNodes, allNode if m.balanceCheckLoop != nil { log.Info("starting channel balance loop") - go m.balanceCheckLoop(m.ctx) + m.wg.Add(1) + go func() { + defer m.wg.Done() + m.balanceCheckLoop(ctx) + }() } log.Info("cluster start up", @@ -171,6 +175,7 @@ func (m *ChannelManagerImplV2) Startup(ctx context.Context, legacyNodes, allNode func (m *ChannelManagerImplV2) Close() { if m.cancel != nil { m.cancel() + m.wg.Wait() } } @@ -439,12 +444,12 @@ func (m *ChannelManagerImplV2) CheckLoop(ctx context.Context) { m.Balance() } case <-checkTicker.C: - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) } } } -func (m *ChannelManagerImplV2) AdvanceChannelState() { +func (m *ChannelManagerImplV2) AdvanceChannelState(ctx context.Context) { m.mu.RLock() standbys := m.store.GetNodeChannelsBy(WithAllNodes(), WithChannelStates(Standby)) toNotifies := m.store.GetNodeChannelsBy(WithoutBufferNode(), WithChannelStates(ToWatch, ToRelease)) @@ -452,9 +457,9 @@ func (m *ChannelManagerImplV2) AdvanceChannelState() { m.mu.RUnlock() // Processing standby channels - updatedStandbys := m.advanceStandbys(standbys) - updatedToCheckes := m.advanceToChecks(toChecks) - updatedToNotifies := m.advanceToNotifies(toNotifies) + updatedStandbys := m.advanceStandbys(ctx, standbys) + updatedToCheckes := m.advanceToChecks(ctx, toChecks) + updatedToNotifies := m.advanceToNotifies(ctx, toNotifies) if updatedStandbys || updatedToCheckes || updatedToNotifies { m.lastActiveTimestamp = time.Now() @@ -477,7 +482,7 @@ func (m *ChannelManagerImplV2) finishRemoveChannel(nodeID int64, channels ...RWC } } -func (m *ChannelManagerImplV2) advanceStandbys(standbys []*NodeChannelInfo) bool { +func (m *ChannelManagerImplV2) advanceStandbys(_ context.Context, standbys []*NodeChannelInfo) bool { var advanced bool = false for _, nodeAssign := range standbys { validChannels := make(map[string]RWChannel) @@ -513,7 +518,7 @@ func (m *ChannelManagerImplV2) advanceStandbys(standbys []*NodeChannelInfo) bool return advanced } -func (m *ChannelManagerImplV2) advanceToNotifies(toNotifies []*NodeChannelInfo) bool { +func (m *ChannelManagerImplV2) advanceToNotifies(ctx context.Context, toNotifies []*NodeChannelInfo) bool { var advanced bool = false for _, nodeAssign := range toNotifies { channelCount := len(nodeAssign.Channels) @@ -537,7 +542,7 @@ func (m *ChannelManagerImplV2) advanceToNotifies(toNotifies []*NodeChannelInfo) innerCh := ch future := getOrCreateIOPool().Submit(func() (any, error) { - err := m.Notify(nodeAssign.NodeID, innerCh.GetWatchInfo()) + err := m.Notify(ctx, nodeAssign.NodeID, innerCh.GetWatchInfo()) return innerCh, err }) futures = append(futures, future) @@ -573,7 +578,7 @@ type poolResult struct { ch RWChannel } -func (m *ChannelManagerImplV2) advanceToChecks(toChecks []*NodeChannelInfo) bool { +func (m *ChannelManagerImplV2) advanceToChecks(ctx context.Context, toChecks []*NodeChannelInfo) bool { var advanced bool = false for _, nodeAssign := range toChecks { if len(nodeAssign.Channels) == 0 { @@ -592,7 +597,7 @@ func (m *ChannelManagerImplV2) advanceToChecks(toChecks []*NodeChannelInfo) bool innerCh := ch future := getOrCreateIOPool().Submit(func() (any, error) { - successful, got := m.Check(nodeAssign.NodeID, innerCh.GetWatchInfo()) + successful, got := m.Check(ctx, nodeAssign.NodeID, innerCh.GetWatchInfo()) if got { return poolResult{ successful: successful, @@ -624,14 +629,14 @@ func (m *ChannelManagerImplV2) advanceToChecks(toChecks []*NodeChannelInfo) bool return advanced } -func (m *ChannelManagerImplV2) Notify(nodeID int64, info *datapb.ChannelWatchInfo) error { +func (m *ChannelManagerImplV2) Notify(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) error { log := log.With( zap.String("channel", info.GetVchan().GetChannelName()), zap.Int64("assignment", nodeID), zap.String("operation", info.GetState().String()), ) log.Info("Notify channel operation") - err := m.subCluster.NotifyChannelOperation(m.ctx, nodeID, &datapb.ChannelOperationsRequest{Infos: []*datapb.ChannelWatchInfo{info}}) + err := m.subCluster.NotifyChannelOperation(ctx, nodeID, &datapb.ChannelOperationsRequest{Infos: []*datapb.ChannelWatchInfo{info}}) if err != nil { log.Warn("Fail to notify channel operations", zap.Error(err)) return err @@ -640,14 +645,14 @@ func (m *ChannelManagerImplV2) Notify(nodeID int64, info *datapb.ChannelWatchInf return nil } -func (m *ChannelManagerImplV2) Check(nodeID int64, info *datapb.ChannelWatchInfo) (successful bool, got bool) { +func (m *ChannelManagerImplV2) Check(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (successful bool, got bool) { log := log.With( zap.Int64("opID", info.GetOpID()), zap.Int64("nodeID", nodeID), zap.String("check operation", info.GetState().String()), zap.String("channel", info.GetVchan().GetChannelName()), ) - resp, err := m.subCluster.CheckChannelOperationProgress(m.ctx, nodeID, info) + resp, err := m.subCluster.CheckChannelOperationProgress(ctx, nodeID, info) if err != nil { log.Warn("Fail to check channel operation progress") return false, false diff --git a/internal/datacoord/channel_manager_v2_test.go b/internal/datacoord/channel_manager_v2_test.go index 87ac92dcb3..4bacd11399 100644 --- a/internal/datacoord/channel_manager_v2_test.go +++ b/internal/datacoord/channel_manager_v2_test.go @@ -371,6 +371,8 @@ func (s *ChannelManagerSuite) TestFindWatcher() { } func (s *ChannelManagerSuite) TestAdvanceChannelState() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() s.Run("advance statndby with no available nodes", func() { chNodes := map[string]int64{ "ch1": bufferID, @@ -383,7 +385,7 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, bufferID, "ch1", Standby) s.checkAssignment(m, bufferID, "ch2", Standby) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, bufferID, "ch1", Standby) s.checkAssignment(m, bufferID, "ch2", Standby) }) @@ -402,7 +404,7 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, bufferID, "ch2", Standby) s.checkAssignment(m, 1, "ch3", Watched) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) }) @@ -418,7 +420,7 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Watching) s.checkAssignment(m, 1, "ch2", Watching) }) @@ -434,13 +436,13 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Watching) s.checkAssignment(m, 1, "ch2", Watching) s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ToWatch}, nil).Twice() - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Watching) s.checkAssignment(m, 1, "ch2", Watching) }) @@ -456,13 +458,13 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Watching) s.checkAssignment(m, 1, "ch2", Watching) s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_WatchSuccess}, nil).Twice() - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Watched) s.checkAssignment(m, 1, "ch2", Watched) }) @@ -478,18 +480,18 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Watching) s.checkAssignment(m, 1, "ch2", Watching) s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_WatchFailure}, nil).Twice() - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Standby) s.checkAssignment(m, 1, "ch2", Standby) s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) }) @@ -505,13 +507,13 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToRelease) s.checkAssignment(m, 1, "ch2", ToRelease) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Releasing) s.checkAssignment(m, 1, "ch2", Releasing) s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ToRelease}, nil).Twice() - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Releasing) s.checkAssignment(m, 1, "ch2", Releasing) }) @@ -527,18 +529,18 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToRelease) s.checkAssignment(m, 1, "ch2", ToRelease) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Releasing) s.checkAssignment(m, 1, "ch2", Releasing) s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ReleaseSuccess}, nil).Twice() - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Standby) s.checkAssignment(m, 1, "ch2", Standby) s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) }) @@ -554,18 +556,18 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToRelease) s.checkAssignment(m, 1, "ch2", ToRelease) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Releasing) s.checkAssignment(m, 1, "ch2", Releasing) s.mockCluster.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). Return(&datapb.ChannelOperationProgressResponse{State: datapb.ChannelWatchState_ReleaseFailure}, nil).Twice() - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Standby) s.checkAssignment(m, 1, "ch2", Standby) s.mockHandler.EXPECT().CheckShouldDropChannel(mock.Anything).Return(false) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) // TODO, donot assign to abnormal nodes s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) @@ -583,7 +585,7 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", ToWatch) s.checkAssignment(m, 1, "ch2", ToWatch) }) @@ -599,7 +601,7 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToRelease) s.checkAssignment(m, 1, "ch2", ToRelease) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", Releasing) s.checkAssignment(m, 1, "ch2", Releasing) }) @@ -616,7 +618,7 @@ func (s *ChannelManagerSuite) TestAdvanceChannelState() { s.checkAssignment(m, 1, "ch1", ToRelease) s.checkAssignment(m, 1, "ch2", ToRelease) - m.AdvanceChannelState() + m.AdvanceChannelState(ctx) s.checkAssignment(m, 1, "ch1", ToRelease) s.checkAssignment(m, 1, "ch2", ToRelease) })