milvus/pkg/mq/msgdispatcher/client_test.go
yihao.dai 004a1875dc
enhance: Introduce batch subscription in msgdispatcher (#39863)
Introduce a batch subscription mechanism in msgdispatcher: the
msgdispatcher now includes a vchannel watch task queue, where all
vchannels in the queue will subscribe to the MQ only once and pull
messages from the oldest vchannel checkpoint to the latest.

issue: https://github.com/milvus-io/milvus/issues/39862

---------

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
2025-03-05 14:38:02 +08:00

497 lines
16 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"math/rand"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestClient(t *testing.T) {
factory := newMockFactory()
client := NewClient(factory, typeutil.ProxyRole, 1)
assert.NotNil(t, client)
defer client.Close()
pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
_, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
client.Deregister(fmt.Sprintf("%s_v1", pchannel))
client.Deregister(fmt.Sprintf("%s_v2", pchannel))
}
func TestClient_Concurrency(t *testing.T) {
factory := newMockFactory()
client1 := NewClient(factory, typeutil.ProxyRole, 1)
assert.NotNil(t, client1)
defer client1.Close()
paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "65536")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key)
const (
vchannelNumPerPchannel = 10
pchannelNum = 16
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pchannels := make([]string, pchannelNum)
for i := 0; i < pchannelNum; i++ {
pchannel := fmt.Sprintf("by-dev-rootcoord-dml-%d_%d", rand.Int63(), i)
pchannels[i] = pchannel
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
t.Logf("start to produce time tick to pchannel %s\n", pchannel)
}
wg := &sync.WaitGroup{}
deregisterCount := atomic.NewInt32(0)
for i := 0; i < vchannelNumPerPchannel; i++ {
for j := 0; j < pchannelNum; j++ {
vchannel := fmt.Sprintf("%s_%dv%d", pchannels[i], rand.Int(), i)
wg.Add(1)
go func() {
defer wg.Done()
_, err := client1.Register(ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
for j := 0; j < rand.Intn(2); j++ {
client1.Deregister(vchannel)
deregisterCount.Inc()
}
}()
}
}
wg.Wait()
c := client1.(*client)
expected := int(vchannelNumPerPchannel*pchannelNum - deregisterCount.Load())
// Verify registered targets number.
actual := 0
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
actual += manager.NumTarget()
return true
})
assert.Equal(t, expected, actual)
// Verify active targets number.
assert.Eventually(t, func() bool {
actual = 0
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
m := manager.(*dispatcherManager)
m.mu.RLock()
defer m.mu.RUnlock()
if m.mainDispatcher != nil {
actual += m.mainDispatcher.targets.Len()
}
for _, d := range m.deputyDispatchers {
actual += d.targets.Len()
}
return true
})
t.Logf("expect = %d, actual = %d\n", expected, actual)
return expected == actual
}, 15*time.Second, 100*time.Millisecond)
}
type SimulationSuite struct {
suite.Suite
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
client Client
factory msgstream.Factory
pchannel2Producer map[string]msgstream.MsgStream
pchannel2Vchannels map[string]map[string]*vchannelHelper
}
func (suite *SimulationSuite) SetupSuite() {
suite.factory = newMockFactory()
}
func (suite *SimulationSuite) SetupTest() {
const (
pchannelNum = 16
vchannelNumPerPchannel = 10
)
suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Minute*3)
suite.wg = &sync.WaitGroup{}
suite.client = NewClient(suite.factory, "test-client", 1)
// Init pchannel and producers.
suite.pchannel2Producer = make(map[string]msgstream.MsgStream)
suite.pchannel2Vchannels = make(map[string]map[string]*vchannelHelper)
for i := 0; i < pchannelNum; i++ {
pchannel := fmt.Sprintf("by-dev-rootcoord-dispatcher-dml-%d_%d", time.Now().UnixNano(), i)
producer, err := newMockProducer(suite.factory, pchannel)
suite.NoError(err)
suite.pchannel2Producer[pchannel] = producer
suite.pchannel2Vchannels[pchannel] = make(map[string]*vchannelHelper)
}
// Init vchannels.
for pchannel := range suite.pchannel2Producer {
for i := 0; i < vchannelNumPerPchannel; i++ {
collectionID := time.Now().UnixNano()
vchannel := fmt.Sprintf("%s_%dv0", pchannel, collectionID)
suite.pchannel2Vchannels[pchannel][vchannel] = &vchannelHelper{}
}
}
}
func (suite *SimulationSuite) TestDispatchToVchannels() {
// Register vchannels.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
output, err := suite.client.Register(suite.ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
suite.NoError(err)
helper.output = output
}
}
// Produce and dispatch messages to vchannel targets.
produceCtx, produceCancel := context.WithTimeout(suite.ctx, time.Second*3)
defer produceCancel()
for pchannel, vchannels := range suite.pchannel2Vchannels {
suite.wg.Add(1)
go produceMsgs(suite.T(), produceCtx, suite.wg, suite.pchannel2Producer[pchannel], vchannels)
}
// Mock pipelines consume messages.
consumeCtx, consumeCancel := context.WithTimeout(suite.ctx, 10*time.Second)
defer consumeCancel()
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.wg.Add(1)
go consumeMsgsFromTargets(suite.T(), consumeCtx, suite.wg, vchannel, helper)
}
}
suite.wg.Wait()
// Verify pub-sub messages number.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.Equal(helper.pubInsMsgNum.Load(), helper.subInsMsgNum.Load(), vchannel)
suite.Equal(helper.pubDelMsgNum.Load(), helper.subDelMsgNum.Load(), vchannel)
suite.Equal(helper.pubDDLMsgNum.Load(), helper.subDDLMsgNum.Load(), vchannel)
suite.Equal(helper.pubPackNum.Load(), helper.subPackNum.Load(), vchannel)
}
}
}
func (suite *SimulationSuite) TestMerge() {
// Produce msgs.
produceCtx, produceCancel := context.WithCancel(suite.ctx)
for pchannel, producer := range suite.pchannel2Producer {
suite.wg.Add(1)
go produceMsgs(suite.T(), produceCtx, suite.wg, producer, suite.pchannel2Vchannels[pchannel])
}
// Get random msg positions to seek for each vchannel.
for pchannel, vchannels := range suite.pchannel2Vchannels {
getRandomSeekPositions(suite.T(), suite.ctx, suite.factory, pchannel, vchannels)
}
// Register vchannels.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
pos := helper.seekPos
assert.NotNil(suite.T(), pos)
suite.T().Logf("seekTs = %d, vchannel = %s, msgID=%v\n", pos.GetTimestamp(), vchannel, pos.GetMsgID())
output, err := suite.client.Register(suite.ctx, NewStreamConfig(
vchannel, pos,
common.SubscriptionPositionUnknown,
))
suite.NoError(err)
helper.output = output
}
}
// Mock pipelines consume messages.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.wg.Add(1)
go consumeMsgsFromTargets(suite.T(), suite.ctx, suite.wg, vchannel, helper)
}
}
// Verify dispatchers merged.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
// Stop produce and verify pub-sub messages number.
produceCancel()
suite.Eventually(func() bool {
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
logFn := func(pubNum, skipNum, subNum int32, name string) {
suite.T().Logf("pub%sNum[%d]-skipped%sNum[%d] = %d, sub%sNum = %d, vchannel = %s\n",
name, pubNum, name, skipNum, pubNum-skipNum, name, subNum, vchannel)
}
if helper.pubInsMsgNum.Load()-helper.skippedInsMsgNum != helper.subInsMsgNum.Load() {
logFn(helper.pubInsMsgNum.Load(), helper.skippedInsMsgNum, helper.subInsMsgNum.Load(), "InsMsg")
return false
}
if helper.pubDelMsgNum.Load()-helper.skippedDelMsgNum != helper.subDelMsgNum.Load() {
logFn(helper.pubDelMsgNum.Load(), helper.skippedDelMsgNum, helper.subDelMsgNum.Load(), "DelMsg")
return false
}
if helper.pubDDLMsgNum.Load()-helper.skippedDDLMsgNum != helper.subDDLMsgNum.Load() {
logFn(helper.pubDDLMsgNum.Load(), helper.skippedDDLMsgNum, helper.subDDLMsgNum.Load(), "DDLMsg")
return false
}
if helper.pubPackNum.Load()-helper.skippedPackNum != helper.subPackNum.Load() {
logFn(helper.pubPackNum.Load(), helper.skippedPackNum, helper.subPackNum.Load(), "Pack")
return false
}
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
}
func (suite *SimulationSuite) TestSplit() {
// Modify the parameters to make triggering split easier.
paramtable.Get().Save(paramtable.Get().MQCfg.MaxTolerantLag.Key, "0.5")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.MaxTolerantLag.Key)
paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "512")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key)
// Produce msgs.
produceCtx, produceCancel := context.WithCancel(suite.ctx)
for pchannel, producer := range suite.pchannel2Producer {
suite.wg.Add(1)
go produceMsgs(suite.T(), produceCtx, suite.wg, producer, suite.pchannel2Vchannels[pchannel])
}
// Register vchannels.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
output, err := suite.client.Register(suite.ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
suite.NoError(err)
helper.output = output
}
}
// Verify dispatchers merged.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
getTargetChan := func(pchannel, vchannel string) chan *MsgPack {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.True(ok)
t, ok := manager.(*dispatcherManager).registeredTargets.Get(vchannel)
suite.True(ok)
return t.ch
}
// Inject additional messages into targets to trigger lag and split.
injectCtx, injectCancel := context.WithCancel(context.Background())
const splitNumPerPchannel = 3
for pchannel, vchannels := range suite.pchannel2Vchannels {
cnt := 0
for vchannel := range vchannels {
suite.wg.Add(1)
targetCh := getTargetChan(pchannel, vchannel)
go func() {
defer suite.wg.Done()
for {
select {
case targetCh <- &MsgPack{}:
case <-injectCtx.Done():
return
}
}
}()
cnt++
if cnt == splitNumPerPchannel {
break
}
}
}
// Verify split.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.True(ok)
suite.T().Logf("verifing split, dispatcherNum = %d, splitNum+1 = %d, pchannel = %s\n",
manager.NumConsumer(), splitNumPerPchannel+1, pchannel)
if manager.NumConsumer() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers
return false
}
}
return true
}, 20*time.Second, 100*time.Millisecond)
injectCancel()
// Mock pipelines consume messages to trigger merged again.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.wg.Add(1)
go consumeMsgsFromTargets(suite.T(), suite.ctx, suite.wg, vchannel, helper)
}
}
// Verify dispatchers merged.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
// Stop produce and verify pub-sub messages number.
produceCancel()
suite.Eventually(func() bool {
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
if helper.pubInsMsgNum.Load() != helper.subInsMsgNum.Load() {
suite.T().Logf("pubInsMsgNum = %d, subInsMsgNum = %d, vchannel = %s\n",
helper.pubInsMsgNum.Load(), helper.subInsMsgNum.Load(), vchannel)
return false
}
if helper.pubDelMsgNum.Load() != helper.subDelMsgNum.Load() {
suite.T().Logf("pubDelMsgNum = %d, subDelMsgNum = %d, vchannel = %s\n",
helper.pubDelMsgNum.Load(), helper.subDelMsgNum.Load(), vchannel)
return false
}
if helper.pubDDLMsgNum.Load() != helper.subDDLMsgNum.Load() {
suite.T().Logf("pubDDLMsgNum = %d, subDDLMsgNum = %d, vchannel = %s\n",
helper.pubDDLMsgNum.Load(), helper.subDDLMsgNum.Load(), vchannel)
return false
}
if helper.pubPackNum.Load() != helper.subPackNum.Load() {
suite.T().Logf("pubPackNum = %d, subPackNum = %d, vchannel = %s\n",
helper.pubPackNum.Load(), helper.subPackNum.Load(), vchannel)
return false
}
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
}
func (suite *SimulationSuite) TearDownTest() {
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel := range vchannels {
suite.client.Deregister(vchannel)
}
}
suite.client.Close()
suite.cancel()
suite.wg.Wait()
}
func (suite *SimulationSuite) TearDownSuite() {
}
func TestSimulation(t *testing.T) {
suite.Run(t, new(SimulationSuite))
}
func TestClientMainDispatcherLeak(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
assert.NotNil(t, client)
pchannel := "mock_vchannel_0"
vchannel1 := fmt.Sprintf("%s_abc_v0", pchannel) //"mock_vchannel_0_abc_v0"
vchannel2 := fmt.Sprintf("%s_abc_v1", pchannel) //"mock_vchannel_0_abc_v0"
_, err := client.Register(context.Background(), NewStreamConfig(vchannel1, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(context.Background(), NewStreamConfig(vchannel2, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
client.Deregister(vchannel2)
client.Deregister(vchannel1)
assert.NotPanics(
t, func() {
_, err = client.Register(context.Background(), NewStreamConfig(vchannel1, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(context.Background(), NewStreamConfig(vchannel2, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
},
)
client.Deregister(vchannel1)
client.Deregister(vchannel2)
assert.NotPanics(
t, func() {
_, err = client.Register(context.Background(), NewStreamConfig(vchannel1, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(context.Background(), NewStreamConfig(vchannel2, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
},
)
}