diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index 206e3592a7..bac3effc2e 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -197,18 +197,7 @@ func (m *ChannelManagerImpl) AddNode(nodeID UniqueID) error { log.Info("register node", zap.Int64("registered node", nodeID)) m.store.AddNode(nodeID) - updates := m.assignPolicy(m.store.GetNodesChannels(), m.store.GetBufferChannelInfo(), m.legacyNodes.Collect()) - - if updates == nil { - log.Info("register node with no reassignment", zap.Int64("registered node", nodeID)) - return nil - } - - err := m.execute(updates) - if err != nil { - log.Warn("fail to update channel operation updates into meta", zap.Error(err)) - } - return err + return nil } // Release writes ToRelease channel watch states for a channel diff --git a/internal/datacoord/channel_manager_test.go b/internal/datacoord/channel_manager_test.go index 566682daec..71cf1332e5 100644 --- a/internal/datacoord/channel_manager_test.go +++ b/internal/datacoord/channel_manager_test.go @@ -150,9 +150,10 @@ func (s *ChannelManagerSuite) TestAddNode() { err = m.AddNode(testNodeID) s.NoError(err) - lo.ForEach(testChannels, func(ch string, _ int) { - s.checkAssignment(m, testNodeID, ch, ToWatch) - }) + info := m.store.GetNode(testNodeID) + s.NotNil(info) + s.Empty(info.Channels) + s.Equal(info.NodeID, testNodeID) }) s.Run("AddNode with channels evenly in other node", func() { var ( @@ -747,9 +748,10 @@ func (s *ChannelManagerSuite) TestStartup() { err = m.AddNode(2) s.NoError(err) - s.checkAssignment(m, 2, "ch1", ToWatch) - s.checkAssignment(m, 2, "ch2", ToWatch) - s.checkAssignment(m, 2, "ch3", ToWatch) + info := m.store.GetNode(2) + s.NotNil(info) + s.Empty(info.Channels) + s.Equal(info.NodeID, int64(2)) } func (s *ChannelManagerSuite) TestStartupNilSchema() { @@ -807,23 +809,6 @@ func (s *ChannelManagerSuite) TestStartupNilSchema() { s.NotNil(channel.GetWatchInfo().Schema) log.Info("Recovered non-nil schema channel", zap.Any("channel", channel)) } - - err = m.AddNode(7) - s.Require().NoError(err) - s.checkAssignment(m, 7, "ch1", ToWatch) - s.checkAssignment(m, 7, "ch2", ToWatch) - s.checkAssignment(m, 7, "ch3", ToWatch) - - for ch := range chNodes { - channel, got := m.GetChannel(7, ch) - s.Require().True(got) - s.NotNil(channel.GetSchema()) - s.Equal(ch, channel.GetName()) - - s.NotNil(channel.GetWatchInfo()) - s.NotNil(channel.GetWatchInfo().Schema) - log.Info("non-nil schema channel", zap.Any("channel", channel)) - } } func (s *ChannelManagerSuite) TestStartupRootCoordFailed() { @@ -842,9 +827,6 @@ func (s *ChannelManagerSuite) TestStartupRootCoordFailed() { err = m.Startup(context.TODO(), nil, []int64{2}) s.Error(err) - - err = m.Startup(context.TODO(), nil, []int64{1, 2}) - s.Error(err) } func (s *ChannelManagerSuite) TestCheckLoop() {} diff --git a/internal/datacoord/policy.go b/internal/datacoord/policy.go index c458097b86..b29b9299c4 100644 --- a/internal/datacoord/policy.go +++ b/internal/datacoord/policy.go @@ -149,130 +149,211 @@ func EmptyAssignPolicy(currentCluster Assignments, toAssign *NodeChannelInfo, ex return nil } -func AvgAssignByCountPolicy(currentCluster Assignments, toAssign *NodeChannelInfo, execlusiveNodes []int64) *ChannelOpSet { +// AvgAssignByCountPolicy balances channel distribution across nodes based on count +func AvgAssignByCountPolicy(currentCluster Assignments, toAssign *NodeChannelInfo, exclusiveNodes []int64) *ChannelOpSet { var ( - toCluster Assignments - fromCluster Assignments - channelNum int = 0 + availableNodes Assignments // Nodes that can receive channels + sourceNodes Assignments // Nodes that can provide channels + totalChannelCount int // Total number of channels in the cluster ) - nodeToAvg := typeutil.NewUniqueSet() - lo.ForEach(currentCluster, func(info *NodeChannelInfo, _ int) { - // Get fromCluster - if toAssign == nil && len(info.Channels) > 0 { - fromCluster = append(fromCluster, info) - channelNum += len(info.Channels) - nodeToAvg.Insert(info.NodeID) + // Create a set to track unique node IDs for average calculation + uniqueNodeIDs := typeutil.NewUniqueSet() + + // Iterate through each node in the current cluster + lo.ForEach(currentCluster, func(nodeInfo *NodeChannelInfo, _ int) { + // If we're balancing existing channels (not assigning new ones) and this node has channels + if toAssign == nil && len(nodeInfo.Channels) > 0 { + sourceNodes = append(sourceNodes, nodeInfo) // Add to source nodes + totalChannelCount += len(nodeInfo.Channels) // Count its channels + uniqueNodeIDs.Insert(nodeInfo.NodeID) // Track this node for average calculation return } - // Get toCluster by filtering out execlusive nodes - if lo.Contains(execlusiveNodes, info.NodeID) || (toAssign != nil && info.NodeID == toAssign.NodeID) { + // Skip nodes that are in the exclusive list or the node we're reassigning from + if lo.Contains(exclusiveNodes, nodeInfo.NodeID) || (toAssign != nil && nodeInfo.NodeID == toAssign.NodeID) { return } - toCluster = append(toCluster, info) - channelNum += len(info.Channels) - nodeToAvg.Insert(info.NodeID) + // This node can receive channels + availableNodes = append(availableNodes, nodeInfo) // Add to target nodes + totalChannelCount += len(nodeInfo.Channels) // Count its channels + uniqueNodeIDs.Insert(nodeInfo.NodeID) // Track this node for average calculation }) - // If no datanode alive, do nothing - if len(toCluster) == 0 { + // If no nodes are available to receive channels, do nothing + if len(availableNodes) == 0 { + log.Info("No available nodes to receive channels") return nil } - // 1. assign unassigned channels first + // CASE 1: Assign unassigned channels to nodes if toAssign != nil && len(toAssign.Channels) > 0 { - chPerNode := (len(toAssign.Channels) + channelNum) / nodeToAvg.Len() + return assignNewChannels(availableNodes, toAssign, uniqueNodeIDs.Len(), totalChannelCount, exclusiveNodes) + } - // sort by assigned channels count ascsending - sort.Slice(toCluster, func(i, j int) bool { - return len(toCluster[i].Channels) <= len(toCluster[j].Channels) + // Check if auto-balancing is enabled + if !Params.DataCoordCfg.AutoBalance.GetAsBool() { + log.Info("Auto balance disabled") + return nil + } + + // CASE 2: Balance existing channels across nodes + if len(sourceNodes) == 0 { + log.Info("No source nodes to rebalance from") + return nil + } + + return balanceExistingChannels(currentCluster, sourceNodes, uniqueNodeIDs.Len(), totalChannelCount, exclusiveNodes) +} + +// assignNewChannels handles assigning new channels to available nodes +func assignNewChannels(availableNodes Assignments, toAssign *NodeChannelInfo, nodeCount int, totalChannelCount int, exclusiveNodes []int64) *ChannelOpSet { + // Calculate total channels after assignment + totalChannelsAfterAssignment := totalChannelCount + len(toAssign.Channels) + + // Calculate ideal distribution (channels per node) + baseChannelsPerNode := totalChannelsAfterAssignment / nodeCount + extraChannels := totalChannelsAfterAssignment % nodeCount + + // Create a map to track target channel count for each node + targetChannelCounts := make(map[int64]int) + for _, nodeInfo := range availableNodes { + targetChannelCounts[nodeInfo.NodeID] = baseChannelsPerNode + if extraChannels > 0 { + targetChannelCounts[nodeInfo.NodeID]++ // Distribute remainder one by one + extraChannels-- + } + } + + // Track which channels will be assigned to which nodes + nodeAssignments := make(map[int64][]RWChannel) + + // Create a working copy of available nodes that we can sort + sortedNodes := make([]*NodeChannelInfo, len(availableNodes)) + copy(sortedNodes, availableNodes) + + // Assign channels to nodes, prioritizing nodes with fewer channels + for _, channel := range toAssign.GetChannels() { + // Sort nodes by their current load (existing + newly assigned channels) + sort.Slice(sortedNodes, func(i, j int) bool { + // Compare total channels (existing + newly assigned) + iTotal := len(sortedNodes[i].Channels) + len(nodeAssignments[sortedNodes[i].NodeID]) + jTotal := len(sortedNodes[j].Channels) + len(nodeAssignments[sortedNodes[j].NodeID]) + return iTotal < jTotal }) - nodesLackOfChannels := Assignments(lo.Filter(toCluster, func(info *NodeChannelInfo, _ int) bool { - return len(info.Channels) < chPerNode - })) + // Find the best node to assign to (the one with fewest channels) + bestNode := sortedNodes[0] - if len(nodesLackOfChannels) == 0 { - nodesLackOfChannels = toCluster + // Try to find a node that's below its target count + for _, node := range sortedNodes { + currentTotal := len(node.Channels) + len(nodeAssignments[node.NodeID]) + if currentTotal < targetChannelCounts[node.NodeID] { + bestNode = node + break + } } - updates := make(map[int64][]RWChannel) - for i, newChannel := range toAssign.GetChannels() { - n := nodesLackOfChannels[i%len(nodesLackOfChannels)].NodeID - updates[n] = append(updates[n], newChannel) - } - - opSet := NewChannelOpSet() - for id, chs := range updates { - opSet.Append(id, Watch, chs...) - opSet.Delete(toAssign.NodeID, chs...) - } - - log.Info("Assign channels to nodes by channel count", - zap.Int("toAssign channel count", len(toAssign.Channels)), - zap.Any("original nodeID", toAssign.NodeID), - zap.Int64s("exclusive nodes", execlusiveNodes), - zap.Any("operations", opSet), - zap.Int64s("nodesLackOfChannels", lo.Map(nodesLackOfChannels, func(info *NodeChannelInfo, _ int) int64 { - return info.NodeID - })), - ) - return opSet + // Assign the channel to the selected node + nodeAssignments[bestNode.NodeID] = append(nodeAssignments[bestNode.NodeID], channel) } - if !Params.DataCoordCfg.AutoBalance.GetAsBool() { - log.Info("auto balance disabled") + // Create operations to watch channels on new nodes and delete from original node + operations := NewChannelOpSet() + for nodeID, channels := range nodeAssignments { + operations.Append(nodeID, Watch, channels...) // New node watches channels + operations.Delete(toAssign.NodeID, channels...) // Remove channels from original node + } + + // Log the assignment operations + log.Info("Assign channels to nodes by channel count", + zap.Int("toAssign channel count", len(toAssign.Channels)), + zap.Any("original nodeID", toAssign.NodeID), + zap.Int64s("exclusive nodes", exclusiveNodes), + zap.Any("operations", operations), + zap.Any("target distribution", targetChannelCounts), + ) + + return operations +} + +// balanceExistingChannels handles rebalancing existing channels across nodes +func balanceExistingChannels(currentCluster Assignments, sourceNodes Assignments, nodeCount int, totalChannelCount int, exclusiveNodes []int64) *ChannelOpSet { + // Calculate ideal distribution + baseChannelsPerNode := totalChannelCount / nodeCount + extraChannels := totalChannelCount % nodeCount + + // If there are too few channels to distribute, do nothing + if baseChannelsPerNode == 0 { + log.Info("Too few channels to distribute meaningfully") return nil } - // 2. balance fromCluster to toCluster if no unassignedChannels - if len(fromCluster) == 0 { - return nil - } - chPerNode := channelNum / nodeToAvg.Len() - if chPerNode == 0 { - return nil - } - - // sort in descending order and reallocate - sort.Slice(fromCluster, func(i, j int) bool { - return len(fromCluster[i].Channels) > len(fromCluster[j].Channels) - }) - - releases := make(map[int64][]RWChannel) - for _, info := range fromCluster { - if len(info.Channels) > chPerNode { - cnt := 0 - for _, ch := range info.Channels { - cnt++ - if cnt > chPerNode { - releases[info.NodeID] = append(releases[info.NodeID], ch) - } + // Create a map to track target channel count for each node + targetChannelCounts := make(map[int64]int) + for _, nodeInfo := range currentCluster { + if !lo.Contains(exclusiveNodes, nodeInfo.NodeID) { + targetChannelCounts[nodeInfo.NodeID] = baseChannelsPerNode + if extraChannels > 0 { + targetChannelCounts[nodeInfo.NodeID]++ // Distribute remainder one by one + extraChannels-- } } } - // Channels in `releases` are reassigned eventually by channel manager. - opSet := NewChannelOpSet() - for k, v := range releases { - if lo.Contains(execlusiveNodes, k) { - opSet.Append(k, Delete, v...) - opSet.Append(bufferID, Watch, v...) - } else { - opSet.Append(k, Release, v...) + // Sort nodes by channel count (descending) to take from nodes with most channels + sort.Slice(sourceNodes, func(i, j int) bool { + return len(sourceNodes[i].Channels) > len(sourceNodes[j].Channels) + }) + + // Track which channels will be released from which nodes + channelsToRelease := make(map[int64][]RWChannel) + + // First handle exclusive nodes - we need to remove all channels from them + for _, nodeInfo := range sourceNodes { + if lo.Contains(exclusiveNodes, nodeInfo.NodeID) { + channelsToRelease[nodeInfo.NodeID] = lo.Values(nodeInfo.Channels) + continue + } + + // For regular nodes, only release if they have more than their target + targetCount := targetChannelCounts[nodeInfo.NodeID] + currentCount := len(nodeInfo.Channels) + + if currentCount > targetCount { + // Calculate how many channels to release + excessCount := currentCount - targetCount + + // Get the channels to release (we'll take the last ones) + channels := lo.Values(nodeInfo.Channels) + channelsToRelease[nodeInfo.NodeID] = channels[len(channels)-excessCount:] } } - log.Info("Assign channels to nodes by channel count", - zap.Int64s("exclusive nodes", execlusiveNodes), - zap.Int("channel count", channelNum), - zap.Int("channel per node", chPerNode), - zap.Any("operations", opSet), - zap.Array("fromCluster", fromCluster), - zap.Array("toCluster", toCluster), + // Create operations to release channels from overloaded nodes + operations := NewChannelOpSet() + for nodeID, channels := range channelsToRelease { + if len(channels) == 0 { + continue + } + + if lo.Contains(exclusiveNodes, nodeID) { + operations.Append(nodeID, Delete, channels...) // Delete channels from exclusive nodes + operations.Append(bufferID, Watch, channels...) // Move to buffer temporarily + } else { + operations.Append(nodeID, Release, channels...) // Release channels from regular nodes + } + } + + // Log the balancing operations + log.Info("Balance channels across nodes", + zap.Int64s("exclusive nodes", exclusiveNodes), + zap.Int("total channel count", totalChannelCount), + zap.Int("target channels per node", baseChannelsPerNode), + zap.Any("target distribution", targetChannelCounts), + zap.Any("operations", operations), ) - return opSet + return operations } diff --git a/internal/datacoord/policy_test.go b/internal/datacoord/policy_test.go index 87b8777284..225b53ddd3 100644 --- a/internal/datacoord/policy_test.go +++ b/internal/datacoord/policy_test.go @@ -139,7 +139,10 @@ func (s *AssignByCountPolicySuite) TestWithoutUnassignedChannels() { opSet := AvgAssignByCountPolicy(s.curCluster, nil, execlusiveNodes) s.NotNil(opSet) - s.Equal(2, opSet.GetChannelNumber()) + for _, op := range opSet.Collect() { + s.T().Logf("opType=%s, opNodeID=%d, numOpChannel=%d", ChannelOpTypeNames[op.Type], op.NodeID, len(op.Channels)) + } + s.Equal(6, opSet.GetChannelNumber()) for _, op := range opSet.Collect() { if op.NodeID == bufferID { s.Equal(Watch, op.Type) @@ -253,16 +256,17 @@ func (s *AssignByCountPolicySuite) TestWithUnassignedChannels() { s.Equal(67, opSet.GetChannelNumber()) for _, op := range opSet.Collect() { + s.T().Logf("opType=%s, opNodeID=%d, numOpChannel=%d", ChannelOpTypeNames[op.Type], op.NodeID, len(op.Channels)) if op.NodeID == bufferID { s.Equal(Delete, op.Type) } } - s.Equal(4, opSet.Len()) + s.Equal(6, opSet.Len()) nodeIDs := lo.FilterMap(opSet.Collect(), func(op *ChannelOp, _ int) (int64, bool) { return op.NodeID, op.NodeID != bufferID }) - s.ElementsMatch([]int64{3, 2}, nodeIDs) + s.ElementsMatch([]int64{3, 2, 1}, nodeIDs) }) s.Run("toAssign from nodeID = 1", func() { diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 47977c0236..5d2d6df58e 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -115,10 +115,6 @@ func (c *client) Deregister(vchannel string) { if manager, ok := c.managers.Get(pchannel); ok { manager.Remove(vchannel) - if manager.NumTarget() == 0 && manager.NumConsumer() == 0 { - manager.Close() - c.managers.Remove(pchannel) - } log.Info("deregister done", zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel), zap.Duration("dur", time.Since(start))) } diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 2ec9bff3b0..8b1ad693ff 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -110,7 +110,7 @@ func TestClient_Concurrency(t *testing.T) { // Verify registered targets number. actual := 0 c.managers.Range(func(pchannel string, manager DispatcherManager) bool { - actual += manager.NumTarget() + actual += manager.(*dispatcherManager).registeredTargets.Len() return true }) assert.Equal(t, expected, actual) @@ -120,14 +120,7 @@ func TestClient_Concurrency(t *testing.T) { actual = 0 c.managers.Range(func(pchannel string, manager DispatcherManager) bool { m := manager.(*dispatcherManager) - m.mu.RLock() - defer m.mu.RUnlock() - if m.mainDispatcher != nil { - actual += m.mainDispatcher.targets.Len() - } - for _, d := range m.deputyDispatchers { - actual += d.targets.Len() - } + actual += int(m.numActiveTarget.Load()) return true }) t.Logf("expect = %d, actual = %d\n", expected, actual) @@ -263,9 +256,9 @@ func (suite *SimulationSuite) TestMerge() { suite.Eventually(func() bool { for pchannel := range suite.pchannel2Producer { manager, ok := suite.client.(*client).managers.Get(pchannel) - suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel) + suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.(*dispatcherManager).numConsumer.Load(), pchannel) suite.True(ok) - if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist + if manager.(*dispatcherManager).numConsumer.Load() != 1 { // expected all merged, only mainDispatcher exist return false } } @@ -330,9 +323,9 @@ func (suite *SimulationSuite) TestSplit() { suite.Eventually(func() bool { for pchannel := range suite.pchannel2Producer { manager, ok := suite.client.(*client).managers.Get(pchannel) - suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel) + suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.(*dispatcherManager).numConsumer.Load(), pchannel) suite.True(ok) - if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist + if manager.(*dispatcherManager).numConsumer.Load() != 1 { // expected all merged, only mainDispatcher exist return false } } @@ -378,8 +371,8 @@ func (suite *SimulationSuite) TestSplit() { manager, ok := suite.client.(*client).managers.Get(pchannel) suite.True(ok) suite.T().Logf("verifing split, dispatcherNum = %d, splitNum+1 = %d, pchannel = %s\n", - manager.NumConsumer(), splitNumPerPchannel+1, pchannel) - if manager.NumConsumer() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers + manager.(*dispatcherManager).numConsumer.Load(), splitNumPerPchannel+1, pchannel) + if manager.(*dispatcherManager).numConsumer.Load() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers return false } } @@ -400,9 +393,9 @@ func (suite *SimulationSuite) TestSplit() { suite.Eventually(func() bool { for pchannel := range suite.pchannel2Producer { manager, ok := suite.client.(*client).managers.Get(pchannel) - suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel) + suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.(*dispatcherManager).numConsumer.Load(), pchannel) suite.True(ok) - if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist + if manager.(*dispatcherManager).numConsumer.Load() != 1 { // expected all merged, only mainDispatcher exist return false } } diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index d9db84fc06..10dacde853 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -40,8 +40,6 @@ import ( type DispatcherManager interface { Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) Remove(vchannel string) - NumTarget() int - NumConsumer() int Run() Close() } @@ -55,7 +53,9 @@ type dispatcherManager struct { registeredTargets *typeutil.ConcurrentMap[string, *target] - mu sync.RWMutex + numConsumer atomic.Int64 + numActiveTarget atomic.Int64 + mainDispatcher *Dispatcher deputyDispatchers map[int64]*Dispatcher // ID -> *Dispatcher @@ -99,22 +99,6 @@ func (c *dispatcherManager) Remove(vchannel string) { t.close() } -func (c *dispatcherManager) NumTarget() int { - return c.registeredTargets.Len() -} - -func (c *dispatcherManager) NumConsumer() int { - c.mu.RLock() - defer c.mu.RUnlock() - - numConsumer := 0 - if c.mainDispatcher != nil { - numConsumer++ - } - numConsumer += len(c.deputyDispatchers) - return numConsumer -} - func (c *dispatcherManager) Close() { c.closeOnce.Do(func() { c.closeChan <- struct{}{} @@ -139,14 +123,30 @@ func (c *dispatcherManager) Run() { c.tryRemoveUnregisteredTargets() c.tryBuildDispatcher() c.tryMerge() + c.updateNumInfo() } } } +func (c *dispatcherManager) updateNumInfo() { + numConsumer := 0 + numActiveTarget := 0 + if c.mainDispatcher != nil { + numConsumer++ + numActiveTarget += c.mainDispatcher.TargetNum() + } + numConsumer += len(c.deputyDispatchers) + c.numConsumer.Store(int64(numConsumer)) + + for _, d := range c.deputyDispatchers { + numActiveTarget += d.TargetNum() + } + c.numActiveTarget.Store(int64(numActiveTarget)) +} + func (c *dispatcherManager) tryRemoveUnregisteredTargets() { log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) unregisteredTargets := make([]*target, 0) - c.mu.RLock() for _, dispatcher := range c.deputyDispatchers { for _, t := range dispatcher.GetTargets() { if !c.registeredTargets.Contain(t.vchannel) { @@ -161,10 +161,6 @@ func (c *dispatcherManager) tryRemoveUnregisteredTargets() { } } } - c.mu.RUnlock() - - c.mu.Lock() - defer c.mu.Unlock() for _, dispatcher := range c.deputyDispatchers { for _, t := range unregisteredTargets { if dispatcher.HasTarget(t.vchannel) { @@ -206,7 +202,6 @@ func (c *dispatcherManager) tryBuildDispatcher() { // get lack targets to perform subscription lackTargets := make([]*target, 0, len(allTargets)) - c.mu.RLock() OUTER: for _, t := range allTargets { if c.mainDispatcher != nil && c.mainDispatcher.HasTarget(t.vchannel) { @@ -219,7 +214,6 @@ OUTER: } lackTargets = append(lackTargets, t) } - c.mu.RUnlock() if len(lackTargets) == 0 { return @@ -287,8 +281,6 @@ OUTER: zap.Strings("vchannels", vchannels), ) - c.mu.Lock() - defer c.mu.Unlock() if c.mainDispatcher == nil { c.mainDispatcher = d log.Info("add main dispatcher", zap.Int64("id", d.ID())) @@ -299,9 +291,6 @@ OUTER: } func (c *dispatcherManager) tryMerge() { - c.mu.Lock() - defer c.mu.Unlock() - start := time.Now() log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) @@ -363,9 +352,6 @@ func (c *dispatcherManager) deleteMetric(channel string) { } func (c *dispatcherManager) uploadMetric() { - c.mu.RLock() - defer c.mu.RUnlock() - nodeIDStr := fmt.Sprintf("%d", c.nodeID) fn := func(gauge *prometheus.GaugeVec) { if c.mainDispatcher == nil { diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index c5e660b90c..7541e92f9e 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -50,8 +50,8 @@ func TestManager(t *testing.T) { assert.NotNil(t, c) go c.Run() defer c.Close() - assert.Equal(t, 0, c.NumConsumer()) - assert.Equal(t, 0, c.NumTarget()) + assert.Equal(t, int64(0), c.(*dispatcherManager).numConsumer.Load()) + assert.Equal(t, 0, c.(*dispatcherManager).registeredTargets.Len()) var offset int for i := 0; i < 30; i++ { @@ -64,8 +64,8 @@ func TestManager(t *testing.T) { assert.NoError(t, err) } assert.Eventually(t, func() bool { - t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget()) - return c.NumTarget() == offset + t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.(*dispatcherManager).numConsumer.Load(), c.(*dispatcherManager).registeredTargets.Len()) + return c.(*dispatcherManager).registeredTargets.Len() == offset }, 3*time.Second, 10*time.Millisecond) for j := 0; j < rand.Intn(r); j++ { vchannel := fmt.Sprintf("%s_vchannelv%d", pchannel, offset) @@ -74,8 +74,8 @@ func TestManager(t *testing.T) { offset-- } assert.Eventually(t, func() bool { - t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget()) - return c.NumTarget() == offset + t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.(*dispatcherManager).numConsumer.Load(), c.(*dispatcherManager).registeredTargets.Len()) + return c.(*dispatcherManager).registeredTargets.Len() == offset }, 3*time.Second, 10*time.Millisecond) } }) @@ -108,7 +108,7 @@ func TestManager(t *testing.T) { assert.NoError(t, err) o2, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, 3, c.NumTarget()) + assert.Equal(t, 3, c.(*dispatcherManager).registeredTargets.Len()) consumeFn := func(output <-chan *MsgPack, done <-chan struct{}, wg *sync.WaitGroup) { defer wg.Done() @@ -130,14 +130,14 @@ func TestManager(t *testing.T) { go consumeFn(o2, d2, wg) assert.Eventually(t, func() bool { - return c.NumConsumer() == 1 // expected merge + return c.(*dispatcherManager).numConsumer.Load() == 1 // expected merge }, 20*time.Second, 10*time.Millisecond) // stop consume vchannel_2 to trigger split d2 <- struct{}{} assert.Eventually(t, func() bool { - t.Logf("c.NumConsumer=%d", c.NumConsumer()) - return c.NumConsumer() == 2 // expected split + t.Logf("c.NumConsumer=%d", c.(*dispatcherManager).numConsumer.Load()) + return c.(*dispatcherManager).numConsumer.Load() == 2 // expected split }, 20*time.Second, 10*time.Millisecond) // stop all @@ -169,9 +169,9 @@ func TestManager(t *testing.T) { assert.NoError(t, err) _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, 3, c.NumTarget()) + assert.Equal(t, 3, c.(*dispatcherManager).registeredTargets.Len()) assert.Eventually(t, func() bool { - return c.NumConsumer() >= 1 + return c.(*dispatcherManager).numConsumer.Load() >= 1 }, 3*time.Second, 10*time.Millisecond) c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) for _, d := range c.(*dispatcherManager).deputyDispatchers { @@ -183,9 +183,9 @@ func TestManager(t *testing.T) { defer paramtable.Get().Reset(checkIntervalK) assert.Eventually(t, func() bool { - return c.NumConsumer() == 1 // expected merged + return c.(*dispatcherManager).numConsumer.Load() == 1 // expected merged }, 3*time.Second, 10*time.Millisecond) - assert.Equal(t, 3, c.NumTarget()) + assert.Equal(t, 3, c.(*dispatcherManager).registeredTargets.Len()) }) t.Run("test_repeated_vchannel", func(t *testing.T) { @@ -220,7 +220,7 @@ func TestManager(t *testing.T) { assert.Error(t, err) assert.Eventually(t, func() bool { - return c.NumConsumer() >= 1 + return c.(*dispatcherManager).numConsumer.Load() >= 1 }, 3*time.Second, 10*time.Millisecond) }) }