diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 79dcef0732..daa279928b 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -98,13 +98,13 @@ func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-ch } // Check if the consumer number limit has been reached. limit := paramtable.Get().MQCfg.MaxDispatcherNumPerPchannel.GetAsInt() - if manager.Num() >= limit { + if manager.NumConsumer() >= limit { return nil, merr.WrapErrTooManyConsumers(vchannel, fmt.Sprintf("limit=%d", limit)) } // Begin to register ch, err := manager.Add(ctx, streamConfig) if err != nil { - if manager.Num() == 0 { + if manager.NumTarget() == 0 { manager.Close() c.managers.Remove(pchannel) } @@ -122,7 +122,7 @@ func (c *client) Deregister(vchannel string) { defer c.managerMut.Unlock(pchannel) if manager, ok := c.managers.Get(pchannel); ok { manager.Remove(vchannel) - if manager.Num() == 0 { + if manager.NumTarget() == 0 { manager.Close() c.managers.Remove(pchannel) } diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index d046953b64..2a3c60824c 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -38,7 +38,8 @@ import ( type DispatcherManager interface { Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) Remove(vchannel string) - Num() int + NumTarget() int + NumConsumer() int Run() Close() } @@ -145,7 +146,17 @@ func (c *dispatcherManager) Remove(vchannel string) { c.lagTargets.GetAndRemove(vchannel) } -func (c *dispatcherManager) Num() int { +func (c *dispatcherManager) NumTarget() int { + c.mu.RLock() + defer c.mu.RUnlock() + var res int + if c.mainDispatcher != nil { + res += c.mainDispatcher.TargetNum() + } + return res + len(c.soloDispatchers) + c.lagTargets.Len() +} + +func (c *dispatcherManager) NumConsumer() int { c.mu.RLock() defer c.mu.RUnlock() var res int diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index b02ba95621..4edcfaf963 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -39,7 +39,8 @@ func TestManager(t *testing.T) { t.Run("test add and remove dispatcher", func(t *testing.T) { c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) assert.NotNil(t, c) - assert.Equal(t, 0, c.Num()) + assert.Equal(t, 0, c.NumConsumer()) + assert.Equal(t, 0, c.NumTarget()) var offset int for i := 0; i < 100; i++ { @@ -50,14 +51,16 @@ func TestManager(t *testing.T) { t.Logf("add vchannel, %s", vchannel) _, err := c.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, offset, c.Num()) + assert.Equal(t, offset, c.NumConsumer()) + assert.Equal(t, offset, c.NumTarget()) } for j := 0; j < rand.Intn(r); j++ { vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset) t.Logf("remove vchannel, %s", vchannel) c.Remove(vchannel) offset-- - assert.Equal(t, offset, c.Num()) + assert.Equal(t, offset, c.NumConsumer()) + assert.Equal(t, offset, c.NumTarget()) } } }) @@ -73,7 +76,8 @@ func TestManager(t *testing.T) { assert.NoError(t, err) _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, 3, c.Num()) + assert.Equal(t, 3, c.NumConsumer()) + assert.Equal(t, 3, c.NumTarget()) c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) c.(*dispatcherManager).mu.RLock() for _, d := range c.(*dispatcherManager).soloDispatchers { @@ -82,7 +86,8 @@ func TestManager(t *testing.T) { c.(*dispatcherManager).mu.RUnlock() c.(*dispatcherManager).tryMerge() - assert.Equal(t, 1, c.Num()) + assert.Equal(t, 1, c.NumConsumer()) + assert.Equal(t, 3, c.NumTarget()) info := &target{ vchannel: "mock_vchannel_2", @@ -90,7 +95,7 @@ func TestManager(t *testing.T) { ch: nil, } c.(*dispatcherManager).split(info) - assert.Equal(t, 2, c.Num()) + assert.Equal(t, 2, c.NumConsumer()) }) t.Run("test run and close", func(t *testing.T) { @@ -104,7 +109,8 @@ func TestManager(t *testing.T) { assert.NoError(t, err) _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, 3, c.Num()) + assert.Equal(t, 3, c.NumConsumer()) + assert.Equal(t, 3, c.NumTarget()) c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) c.(*dispatcherManager).mu.RLock() for _, d := range c.(*dispatcherManager).soloDispatchers { @@ -117,8 +123,9 @@ func TestManager(t *testing.T) { defer paramtable.Get().Reset(checkIntervalK) go c.Run() assert.Eventually(t, func() bool { - return c.Num() == 1 // expected merged + return c.NumConsumer() == 1 // expected merged }, 3*time.Second, 10*time.Millisecond) + assert.Equal(t, 3, c.NumTarget()) assert.NotPanics(t, func() { c.Close() @@ -140,7 +147,8 @@ func TestManager(t *testing.T) { assert.Error(t, err) _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - assert.Equal(t, 0, c.Num()) + assert.Equal(t, 0, c.NumConsumer()) + assert.Equal(t, 0, c.NumTarget()) assert.NotPanics(t, func() { c.Close() @@ -374,9 +382,10 @@ func (suite *SimulationSuite) TestMerge() { } suite.Eventually(func() bool { - suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.Num()) - return suite.manager.Num() == 1 // expected all merged, only mainDispatcher exist + suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.NumConsumer()) + return suite.manager.NumConsumer() == 1 // expected all merged, only mainDispatcher exist }, 15*time.Second, 100*time.Millisecond) + assert.Equal(suite.T(), vchannelNum, suite.manager.NumTarget()) cancel() wg.Wait() @@ -409,9 +418,10 @@ func (suite *SimulationSuite) TestSplit() { } suite.Eventually(func() bool { - suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.Num(), splitNum+1) - return suite.manager.Num() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers + suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.NumConsumer(), splitNum+1) + return suite.manager.NumConsumer() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers }, 10*time.Second, 100*time.Millisecond) + assert.Equal(suite.T(), vchannelNum, suite.manager.NumTarget()) cancel() }