enhance: [2.5] Introduce batch subscription in msgdispatcher (#40596)

Introduce a batch subscription mechanism in msgdispatcher: the
msgdispatcher now includes a vchannel watch task queue, where all
vchannels in the queue will subscribe to the MQ only once and pull
messages from the oldest vchannel checkpoint to the latest.

issue: https://github.com/milvus-io/milvus/issues/39862

pr: https://github.com/milvus-io/milvus/pull/39863

---------

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2025-03-24 10:18:17 +08:00 committed by GitHub
parent d703d8dac8
commit b534c9d804
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1228 additions and 815 deletions

View File

@ -176,9 +176,6 @@ mq:
mergeCheckInterval: 1 # the interval time(in seconds) for dispatcher to check whether to merge
targetBufSize: 16 # the lenth of channel buffer for targe
maxTolerantLag: 3 # Default value: "3", the timeout(in seconds) that target sends msgPack
maxDispatcherNumPerPchannel: 5 # The maximum number of dispatchers per physical channel, primarily to limit the number of consumers and prevent performance issues(e.g., during recovery when a large number of channels are watched).
retrySleep: 3 # register retry sleep time in seconds
retryTimeout: 60 # register retry timeout in seconds
# Related configuration of pulsar, used to manage Milvus logs of recent mutation operations, output streaming log, and provide log publish-subscribe services.
pulsar:

View File

@ -19,6 +19,7 @@ package datacoord
import (
"context"
"fmt"
"sort"
"sync"
"time"
@ -545,7 +546,15 @@ func (m *ChannelManagerImpl) advanceToNotifies(ctx context.Context, toNotifies [
zap.Int("total operation count", len(nodeAssign.Channels)),
zap.Strings("channel names", chNames),
)
for _, ch := range nodeAssign.Channels {
// Sort watch tasks by seek position to minimize lag between
// positions during batch subscription in the dispatcher.
channels := lo.Values(nodeAssign.Channels)
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetWatchInfo().GetVchan().GetSeekPosition().GetTimestamp() <
channels[j].GetWatchInfo().GetVchan().GetSeekPosition().GetTimestamp()
})
for _, ch := range channels {
innerCh := ch
tmpWatchInfo := typeutil.Clone(innerCh.GetWatchInfo())
tmpWatchInfo.Vchan = m.h.GetDataVChanPositions(innerCh, allPartitionID)

View File

@ -336,7 +336,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
s.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(s.ms, nil)
s.ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.ms.EXPECT().Chan().Return(s.msChan)
s.ms.EXPECT().Close().Return()
s.ms.EXPECT().Close().Return().Maybe()
s.pipelineParams = &util2.PipelineParams{
Ctx: context.TODO(),

View File

@ -21,7 +21,6 @@ import (
"fmt"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -34,9 +33,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -87,22 +84,15 @@ func createNewInputFromDispatcher(initCtx context.Context,
replicateConfig := msgstream.GetReplicateConfig(replicateID, schema.GetDbName(), schema.GetName())
if seekPos != nil && len(seekPos.MsgID) != 0 {
err := retry.Handle(initCtx, func() (bool, error) {
input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
VChannel: vchannel,
Pos: seekPos,
SubPos: common.SubscriptionPositionUnknown,
ReplicateConfig: replicateConfig,
})
if err != nil {
log.Warn("datanode consume failed", zap.Error(err))
return errors.Is(err, merr.ErrTooManyConsumers), err
}
return false, nil
}, retry.Sleep(paramtable.Get().MQCfg.RetrySleep.GetAsDuration(time.Second)), // 5 seconds
retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes
input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
VChannel: vchannel,
Pos: seekPos,
SubPos: common.SubscriptionPositionUnknown,
ReplicateConfig: replicateConfig,
})
if err != nil {
log.Warn("datanode consume failed after retried", zap.Error(err))
dispatcherClient.Deregister(vchannel)
return nil, err
}
@ -114,22 +104,15 @@ func createNewInputFromDispatcher(initCtx context.Context,
return input, err
}
err = retry.Handle(initCtx, func() (bool, error) {
input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
VChannel: vchannel,
Pos: nil,
SubPos: common.SubscriptionPositionEarliest,
ReplicateConfig: replicateConfig,
})
if err != nil {
log.Warn("datanode consume failed", zap.Error(err))
return errors.Is(err, merr.ErrTooManyConsumers), err
}
return false, nil
}, retry.Sleep(paramtable.Get().MQCfg.RetrySleep.GetAsDuration(time.Second)), // 5 seconds
retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes
input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
VChannel: vchannel,
Pos: nil,
SubPos: common.SubscriptionPositionEarliest,
ReplicateConfig: replicateConfig,
})
if err != nil {
log.Warn("datanode consume failed after retried", zap.Error(err))
dispatcherClient.Deregister(vchannel)
return nil, err
}

View File

@ -311,11 +311,11 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
}
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()
// watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req)
@ -363,11 +363,11 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
}
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()
// watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req)
@ -498,16 +498,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
suite.ErrorIs(merr.Error(status), merr.ErrChannelReduplicate)
suite.node.unsubscribingChannels.Remove(suite.vchannel)
// init msgstream failed
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Close().Return()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once()
status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
// load growing failed
badSegmentReq := typeutil.Clone(req)
for _, info := range badSegmentReq.SegmentInfos {

View File

@ -22,7 +22,6 @@ import (
"sync"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -36,9 +35,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/options"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
)
@ -127,22 +123,15 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M
}
start := time.Now()
err = retry.Handle(ctx, func() (bool, error) {
p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{
VChannel: p.vChannel,
Pos: position,
SubPos: common.SubscriptionPositionUnknown,
ReplicateConfig: p.replicateConfig,
})
if err != nil {
log.Warn("dispatcher register failed", zap.String("channel", position.ChannelName), zap.Error(err))
return errors.Is(err, merr.ErrTooManyConsumers), err
}
return false, nil
}, retry.Sleep(paramtable.Get().MQCfg.RetrySleep.GetAsDuration(time.Second)), // 5 seconds
retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes
p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{
VChannel: p.vChannel,
Pos: position,
SubPos: common.SubscriptionPositionUnknown,
ReplicateConfig: p.replicateConfig,
})
if err != nil {
log.Error("dispatcher register failed after retried", zap.String("channel", position.ChannelName), zap.Error(err))
p.dispatcher.Deregister(p.vChannel)
return WrapErrRegDispather(err)
}

View File

@ -18,7 +18,6 @@ package msgdispatcher
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
@ -29,8 +28,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/lock"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -82,13 +79,15 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
}
func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
vchannel := streamConfig.VChannel
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
pchannel := funcutil.ToPhysicalChannel(vchannel)
start := time.Now()
vchannel := streamConfig.VChannel
pchannel := funcutil.ToPhysicalChannel(vchannel)
log := log.Ctx(ctx).With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.managerMut.Lock(pchannel)
defer c.managerMut.Unlock(pchannel)
var manager DispatcherManager
manager, ok := c.managers.Get(pchannel)
if !ok {
@ -96,18 +95,10 @@ func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-ch
c.managers.Insert(pchannel, manager)
go manager.Run()
}
// Check if the consumer number limit has been reached.
limit := paramtable.Get().MQCfg.MaxDispatcherNumPerPchannel.GetAsInt()
if manager.NumConsumer() >= limit {
return nil, merr.WrapErrTooManyConsumers(vchannel, fmt.Sprintf("limit=%d", limit))
}
// Begin to register
ch, err := manager.Add(ctx, streamConfig)
if err != nil {
if manager.NumTarget() == 0 {
manager.Close()
c.managers.Remove(pchannel)
}
log.Error("register failed", zap.Error(err))
return nil, err
}
@ -116,13 +107,15 @@ func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-ch
}
func (c *client) Deregister(vchannel string) {
pchannel := funcutil.ToPhysicalChannel(vchannel)
start := time.Now()
pchannel := funcutil.ToPhysicalChannel(vchannel)
c.managerMut.Lock(pchannel)
defer c.managerMut.Unlock(pchannel)
if manager, ok := c.managers.Get(pchannel); ok {
manager.Remove(vchannel)
if manager.NumTarget() == 0 {
if manager.NumTarget() == 0 && manager.NumConsumer() == 0 {
manager.Close()
c.managers.Remove(pchannel)
}
@ -132,12 +125,12 @@ func (c *client) Deregister(vchannel string) {
}
func (c *client) Close() {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID))
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID))
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
c.managerMut.Lock(pchannel)
defer c.managerMut.Unlock(pchannel)
log.Info("close manager", zap.String("channel", pchannel))
c.managers.Remove(pchannel)
manager.Close()

View File

@ -25,62 +25,437 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestClient(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
factory := newMockFactory()
client := NewClient(factory, typeutil.ProxyRole, 1)
assert.NotNil(t, client)
_, err := client.Register(context.Background(), NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(context.Background(), NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.NotPanics(t, func() {
client.Deregister("mock_vchannel_0")
client.Close()
})
defer client.Close()
t.Run("with timeout ctx", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond)
defer cancel()
<-time.After(2 * time.Millisecond)
pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())
client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1)
defer client.Close()
assert.NotNil(t, client)
_, err := client.Register(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
_, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
client.Deregister(fmt.Sprintf("%s_v1", pchannel))
client.Deregister(fmt.Sprintf("%s_v2", pchannel))
}
func TestClient_Concurrency(t *testing.T) {
client1 := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
factory := newMockFactory()
client1 := NewClient(factory, typeutil.ProxyRole, 1)
assert.NotNil(t, client1)
defer client1.Close()
paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "65536")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key)
const (
vchannelNumPerPchannel = 10
pchannelNum = 16
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pchannels := make([]string, pchannelNum)
for i := 0; i < pchannelNum; i++ {
pchannel := fmt.Sprintf("by-dev-rootcoord-dml-%d_%d", rand.Int63(), i)
pchannels[i] = pchannel
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
t.Logf("start to produce time tick to pchannel %s\n", pchannel)
}
wg := &sync.WaitGroup{}
const total = 100
deregisterCount := atomic.NewInt32(0)
for i := 0; i < total; i++ {
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
wg.Add(1)
go func() {
_, err := client1.Register(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
for j := 0; j < rand.Intn(2); j++ {
client1.Deregister(vchannel)
deregisterCount.Inc()
}
wg.Done()
}()
for i := 0; i < vchannelNumPerPchannel; i++ {
for j := 0; j < pchannelNum; j++ {
vchannel := fmt.Sprintf("%s_%dv%d", pchannels[i], rand.Int(), i)
wg.Add(1)
go func() {
defer wg.Done()
_, err := client1.Register(ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
for j := 0; j < rand.Intn(2); j++ {
client1.Deregister(vchannel)
deregisterCount.Inc()
}
}()
}
}
wg.Wait()
expected := int(total - deregisterCount.Load())
c := client1.(*client)
n := c.managers.Len()
assert.Equal(t, expected, n)
expected := int(vchannelNumPerPchannel*pchannelNum - deregisterCount.Load())
// Verify registered targets number.
actual := 0
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
actual += manager.NumTarget()
return true
})
assert.Equal(t, expected, actual)
// Verify active targets number.
assert.Eventually(t, func() bool {
actual = 0
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
m := manager.(*dispatcherManager)
m.mu.RLock()
defer m.mu.RUnlock()
if m.mainDispatcher != nil {
actual += m.mainDispatcher.targets.Len()
}
for _, d := range m.deputyDispatchers {
actual += d.targets.Len()
}
return true
})
t.Logf("expect = %d, actual = %d\n", expected, actual)
return expected == actual
}, 15*time.Second, 100*time.Millisecond)
}
type SimulationSuite struct {
suite.Suite
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
client Client
factory msgstream.Factory
pchannel2Producer map[string]msgstream.MsgStream
pchannel2Vchannels map[string]map[string]*vchannelHelper
}
func (suite *SimulationSuite) SetupSuite() {
suite.factory = newMockFactory()
}
func (suite *SimulationSuite) SetupTest() {
const (
pchannelNum = 16
vchannelNumPerPchannel = 10
)
suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Minute*3)
suite.wg = &sync.WaitGroup{}
suite.client = NewClient(suite.factory, "test-client", 1)
// Init pchannel and producers.
suite.pchannel2Producer = make(map[string]msgstream.MsgStream)
suite.pchannel2Vchannels = make(map[string]map[string]*vchannelHelper)
for i := 0; i < pchannelNum; i++ {
pchannel := fmt.Sprintf("by-dev-rootcoord-dispatcher-dml-%d_%d", time.Now().UnixNano(), i)
producer, err := newMockProducer(suite.factory, pchannel)
suite.NoError(err)
suite.pchannel2Producer[pchannel] = producer
suite.pchannel2Vchannels[pchannel] = make(map[string]*vchannelHelper)
}
// Init vchannels.
for pchannel := range suite.pchannel2Producer {
for i := 0; i < vchannelNumPerPchannel; i++ {
collectionID := time.Now().UnixNano()
vchannel := fmt.Sprintf("%s_%dv0", pchannel, collectionID)
suite.pchannel2Vchannels[pchannel][vchannel] = &vchannelHelper{}
}
}
}
func (suite *SimulationSuite) TestDispatchToVchannels() {
// Register vchannels.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
output, err := suite.client.Register(suite.ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
suite.NoError(err)
helper.output = output
}
}
// Produce and dispatch messages to vchannel targets.
produceCtx, produceCancel := context.WithTimeout(suite.ctx, time.Second*3)
defer produceCancel()
for pchannel, vchannels := range suite.pchannel2Vchannels {
suite.wg.Add(1)
go produceMsgs(suite.T(), produceCtx, suite.wg, suite.pchannel2Producer[pchannel], vchannels)
}
// Mock pipelines consume messages.
consumeCtx, consumeCancel := context.WithTimeout(suite.ctx, 10*time.Second)
defer consumeCancel()
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.wg.Add(1)
go consumeMsgsFromTargets(suite.T(), consumeCtx, suite.wg, vchannel, helper)
}
}
suite.wg.Wait()
// Verify pub-sub messages number.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.Equal(helper.pubInsMsgNum.Load(), helper.subInsMsgNum.Load(), vchannel)
suite.Equal(helper.pubDelMsgNum.Load(), helper.subDelMsgNum.Load(), vchannel)
suite.Equal(helper.pubDDLMsgNum.Load(), helper.subDDLMsgNum.Load(), vchannel)
suite.Equal(helper.pubPackNum.Load(), helper.subPackNum.Load(), vchannel)
}
}
}
func (suite *SimulationSuite) TestMerge() {
// Produce msgs.
produceCtx, produceCancel := context.WithCancel(suite.ctx)
for pchannel, producer := range suite.pchannel2Producer {
suite.wg.Add(1)
go produceMsgs(suite.T(), produceCtx, suite.wg, producer, suite.pchannel2Vchannels[pchannel])
}
// Get random msg positions to seek for each vchannel.
for pchannel, vchannels := range suite.pchannel2Vchannels {
getRandomSeekPositions(suite.T(), suite.ctx, suite.factory, pchannel, vchannels)
}
// Register vchannels.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
pos := helper.seekPos
assert.NotNil(suite.T(), pos)
suite.T().Logf("seekTs = %d, vchannel = %s, msgID=%v\n", pos.GetTimestamp(), vchannel, pos.GetMsgID())
output, err := suite.client.Register(suite.ctx, NewStreamConfig(
vchannel, pos,
common.SubscriptionPositionUnknown,
))
suite.NoError(err)
helper.output = output
}
}
// Mock pipelines consume messages.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.wg.Add(1)
go consumeMsgsFromTargets(suite.T(), suite.ctx, suite.wg, vchannel, helper)
}
}
// Verify dispatchers merged.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
// Stop produce and verify pub-sub messages number.
produceCancel()
suite.Eventually(func() bool {
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
logFn := func(pubNum, skipNum, subNum int32, name string) {
suite.T().Logf("pub%sNum[%d]-skipped%sNum[%d] = %d, sub%sNum = %d, vchannel = %s\n",
name, pubNum, name, skipNum, pubNum-skipNum, name, subNum, vchannel)
}
if helper.pubInsMsgNum.Load()-helper.skippedInsMsgNum != helper.subInsMsgNum.Load() {
logFn(helper.pubInsMsgNum.Load(), helper.skippedInsMsgNum, helper.subInsMsgNum.Load(), "InsMsg")
return false
}
if helper.pubDelMsgNum.Load()-helper.skippedDelMsgNum != helper.subDelMsgNum.Load() {
logFn(helper.pubDelMsgNum.Load(), helper.skippedDelMsgNum, helper.subDelMsgNum.Load(), "DelMsg")
return false
}
if helper.pubDDLMsgNum.Load()-helper.skippedDDLMsgNum != helper.subDDLMsgNum.Load() {
logFn(helper.pubDDLMsgNum.Load(), helper.skippedDDLMsgNum, helper.subDDLMsgNum.Load(), "DDLMsg")
return false
}
if helper.pubPackNum.Load()-helper.skippedPackNum != helper.subPackNum.Load() {
logFn(helper.pubPackNum.Load(), helper.skippedPackNum, helper.subPackNum.Load(), "Pack")
return false
}
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
}
func (suite *SimulationSuite) TestSplit() {
// Modify the parameters to make triggering split easier.
paramtable.Get().Save(paramtable.Get().MQCfg.MaxTolerantLag.Key, "0.5")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.MaxTolerantLag.Key)
paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "512")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key)
// Produce msgs.
produceCtx, produceCancel := context.WithCancel(suite.ctx)
for pchannel, producer := range suite.pchannel2Producer {
suite.wg.Add(1)
go produceMsgs(suite.T(), produceCtx, suite.wg, producer, suite.pchannel2Vchannels[pchannel])
}
// Register vchannels.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
output, err := suite.client.Register(suite.ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
suite.NoError(err)
helper.output = output
}
}
// Verify dispatchers merged.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
getTargetChan := func(pchannel, vchannel string) chan *MsgPack {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.True(ok)
t, ok := manager.(*dispatcherManager).registeredTargets.Get(vchannel)
suite.True(ok)
return t.ch
}
// Inject additional messages into targets to trigger lag and split.
injectCtx, injectCancel := context.WithCancel(context.Background())
const splitNumPerPchannel = 3
for pchannel, vchannels := range suite.pchannel2Vchannels {
cnt := 0
for vchannel := range vchannels {
suite.wg.Add(1)
targetCh := getTargetChan(pchannel, vchannel)
go func() {
defer suite.wg.Done()
for {
select {
case targetCh <- &MsgPack{}:
case <-injectCtx.Done():
return
}
}
}()
cnt++
if cnt == splitNumPerPchannel {
break
}
}
}
// Verify split.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.True(ok)
suite.T().Logf("verifing split, dispatcherNum = %d, splitNum+1 = %d, pchannel = %s\n",
manager.NumConsumer(), splitNumPerPchannel+1, pchannel)
if manager.NumConsumer() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers
return false
}
}
return true
}, 20*time.Second, 100*time.Millisecond)
injectCancel()
// Mock pipelines consume messages to trigger merged again.
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
suite.wg.Add(1)
go consumeMsgsFromTargets(suite.T(), suite.ctx, suite.wg, vchannel, helper)
}
}
// Verify dispatchers merged.
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
// Stop produce and verify pub-sub messages number.
produceCancel()
suite.Eventually(func() bool {
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel, helper := range vchannels {
if helper.pubInsMsgNum.Load() != helper.subInsMsgNum.Load() {
suite.T().Logf("pubInsMsgNum = %d, subInsMsgNum = %d, vchannel = %s\n",
helper.pubInsMsgNum.Load(), helper.subInsMsgNum.Load(), vchannel)
return false
}
if helper.pubDelMsgNum.Load() != helper.subDelMsgNum.Load() {
suite.T().Logf("pubDelMsgNum = %d, subDelMsgNum = %d, vchannel = %s\n",
helper.pubDelMsgNum.Load(), helper.subDelMsgNum.Load(), vchannel)
return false
}
if helper.pubDDLMsgNum.Load() != helper.subDDLMsgNum.Load() {
suite.T().Logf("pubDDLMsgNum = %d, subDDLMsgNum = %d, vchannel = %s\n",
helper.pubDDLMsgNum.Load(), helper.subDDLMsgNum.Load(), vchannel)
return false
}
if helper.pubPackNum.Load() != helper.subPackNum.Load() {
suite.T().Logf("pubPackNum = %d, subPackNum = %d, vchannel = %s\n",
helper.pubPackNum.Load(), helper.subPackNum.Load(), vchannel)
return false
}
}
}
return true
}, 15*time.Second, 100*time.Millisecond)
}
func (suite *SimulationSuite) TearDownTest() {
for _, vchannels := range suite.pchannel2Vchannels {
for vchannel := range vchannels {
suite.client.Deregister(vchannel)
}
}
suite.client.Close()
suite.cancel()
suite.wg.Wait()
}
func (suite *SimulationSuite) TearDownSuite() {
}
func TestSimulation(t *testing.T) {
suite.Run(t, new(SimulationSuite))
}
func TestClientMainDispatcherLeak(t *testing.T) {

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -62,20 +63,20 @@ type Dispatcher struct {
ctx context.Context
cancel context.CancelFunc
id int64
pullbackEndTs typeutil.Timestamp
pullbackDone bool
pullbackDoneNotifier *syncutil.AsyncTaskNotifier[struct{}]
done chan struct{}
wg sync.WaitGroup
once sync.Once
isMain bool // indicates if it's a main dispatcher
pchannel string
curTs atomic.Uint64
lagNotifyChan chan struct{}
lagTargets *typeutil.ConcurrentMap[string, *target] // vchannel -> *target
// vchannel -> *target, lock free since we guarantee that
// it's modified only after dispatcher paused or terminated
targets map[string]*target
targets *typeutil.ConcurrentMap[string, *target]
stream msgstream.MsgStream
}
@ -83,18 +84,17 @@ type Dispatcher struct {
func NewDispatcher(
ctx context.Context,
factory msgstream.Factory,
isMain bool,
id int64,
pchannel string,
position *Pos,
subName string,
subPos SubPos,
lagNotifyChan chan struct{},
lagTargets *typeutil.ConcurrentMap[string, *target],
includeCurrentMsg bool,
pullbackEndTs typeutil.Timestamp,
) (*Dispatcher, error) {
log := log.With(zap.String("pchannel", pchannel),
zap.String("subName", subName), zap.Bool("isMain", isMain))
log.Info("creating dispatcher...")
subName := fmt.Sprintf("%s-%d-%d", pchannel, id, time.Now().UnixNano())
log := log.Ctx(ctx).With(zap.String("pchannel", pchannel),
zap.Int64("id", id), zap.String("subName", subName))
log.Info("creating dispatcher...", zap.Uint64("pullbackEndTs", pullbackEndTs))
var stream msgstream.MsgStream
var err error
@ -116,8 +116,8 @@ func NewDispatcher(
log.Error("asConsumer failed", zap.Error(err))
return nil, err
}
err = stream.Seek(ctx, []*Pos{position}, includeCurrentMsg)
log.Info("as consumer done", zap.Any("position", position))
err = stream.Seek(ctx, []*Pos{position}, false)
if err != nil {
log.Error("seek failed", zap.Error(err))
return nil, err
@ -135,59 +135,75 @@ func NewDispatcher(
}
d := &Dispatcher{
done: make(chan struct{}, 1),
isMain: isMain,
pchannel: pchannel,
lagNotifyChan: lagNotifyChan,
lagTargets: lagTargets,
targets: make(map[string]*target),
stream: stream,
id: id,
pullbackEndTs: pullbackEndTs,
pullbackDoneNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
done: make(chan struct{}, 1),
pchannel: pchannel,
targets: typeutil.NewConcurrentMap[string, *target](),
stream: stream,
}
metrics.NumConsumers.WithLabelValues(paramtable.GetRole(), fmt.Sprint(paramtable.GetNodeID())).Inc()
return d, nil
}
func (d *Dispatcher) ID() int64 {
return d.id
}
func (d *Dispatcher) CurTs() typeutil.Timestamp {
return d.curTs.Load()
}
func (d *Dispatcher) AddTarget(t *target) {
log := log.With(zap.String("vchannel", t.vchannel), zap.Bool("isMain", d.isMain))
if _, ok := d.targets[t.vchannel]; ok {
log := log.With(zap.String("vchannel", t.vchannel), zap.Int64("id", d.ID()), zap.Uint64("ts", t.pos.GetTimestamp()))
if _, ok := d.targets.GetOrInsert(t.vchannel, t); ok {
log.Warn("target exists")
return
}
d.targets[t.vchannel] = t
log.Info("add new target")
}
func (d *Dispatcher) GetTarget(vchannel string) (*target, error) {
if t, ok := d.targets[vchannel]; ok {
if t, ok := d.targets.Get(vchannel); ok {
return t, nil
}
return nil, fmt.Errorf("cannot find target, vchannel=%s, isMain=%t", vchannel, d.isMain)
return nil, fmt.Errorf("cannot find target, vchannel=%s", vchannel)
}
func (d *Dispatcher) CloseTarget(vchannel string) {
log := log.With(zap.String("vchannel", vchannel), zap.Bool("isMain", d.isMain))
if t, ok := d.targets[vchannel]; ok {
t.close()
delete(d.targets, vchannel)
log.Info("closed target")
func (d *Dispatcher) GetTargets() []*target {
return d.targets.Values()
}
func (d *Dispatcher) HasTarget(vchannel string) bool {
return d.targets.Contain(vchannel)
}
func (d *Dispatcher) RemoveTarget(vchannel string) {
log := log.With(zap.String("vchannel", vchannel), zap.Int64("id", d.ID()))
if _, ok := d.targets.GetAndRemove(vchannel); ok {
log.Info("target removed")
} else {
log.Warn("target not exist")
}
}
func (d *Dispatcher) TargetNum() int {
return len(d.targets)
return d.targets.Len()
}
func (d *Dispatcher) BlockUtilPullbackDone() {
select {
case <-d.ctx.Done():
case <-d.pullbackDoneNotifier.FinishChan():
}
}
func (d *Dispatcher) Handle(signal signal) {
log := log.With(zap.String("pchannel", d.pchannel),
zap.String("signal", signal.String()), zap.Bool("isMain", d.isMain))
log.Info("get signal")
log := log.With(zap.String("pchannel", d.pchannel), zap.Int64("id", d.ID()),
zap.String("signal", signal.String()))
log.Debug("get signal")
switch signal {
case start:
d.ctx, d.cancel = context.WithCancel(context.Background())
@ -214,7 +230,7 @@ func (d *Dispatcher) Handle(signal signal) {
}
func (d *Dispatcher) work() {
log := log.With(zap.String("pchannel", d.pchannel), zap.Bool("isMain", d.isMain))
log := log.With(zap.String("pchannel", d.pchannel), zap.Int64("id", d.ID()))
log.Info("begin to work")
defer d.wg.Done()
for {
@ -232,12 +248,36 @@ func (d *Dispatcher) work() {
targetPacks := d.groupingMsgs(pack)
for vchannel, p := range targetPacks {
var err error
t := d.targets[vchannel]
if d.isMain {
// for main dispatcher, split target if err occurs
t, _ := d.targets.Get(vchannel)
// The dispatcher seeks from the oldest target,
// so for each target, msg before the target position must be filtered out.
if p.EndTs <= t.pos.GetTimestamp() {
log.Info("skip msg",
zap.String("vchannel", vchannel),
zap.Int("msgCount", len(p.Msgs)),
zap.Uint64("packBeginTs", p.BeginTs),
zap.Uint64("packEndTs", p.EndTs),
zap.Uint64("posTs", t.pos.GetTimestamp()),
)
for _, msg := range p.Msgs {
log.Debug("skip msg info",
zap.String("vchannel", vchannel),
zap.String("msgType", msg.Type().String()),
zap.Int64("msgID", msg.ID()),
zap.Uint64("msgBeginTs", msg.BeginTs()),
zap.Uint64("msgEndTs", msg.EndTs()),
zap.Uint64("packBeginTs", p.BeginTs),
zap.Uint64("packEndTs", p.EndTs),
zap.Uint64("posTs", t.pos.GetTimestamp()),
)
}
continue
}
if d.targets.Len() > 1 {
// for dispatcher with multiple targets, split target if err occurs
err = t.send(p)
} else {
// for solo dispatcher, only 1 target exists, we should
// for dispatcher with only one target,
// keep retrying if err occurs, unless it paused or terminated.
for {
err = t.send(p)
@ -250,12 +290,19 @@ func (d *Dispatcher) work() {
t.pos = typeutil.Clone(pack.StartPositions[0])
// replace the pChannel with vChannel
t.pos.ChannelName = t.vchannel
d.lagTargets.Insert(t.vchannel, t)
d.nonBlockingNotify()
delete(d.targets, vchannel)
log.Warn("lag target notified", zap.Error(err))
d.targets.GetAndRemove(vchannel)
log.Warn("lag target", zap.Error(err))
}
}
if !d.pullbackDone && pack.EndPositions[0].GetTimestamp() >= d.pullbackEndTs {
d.pullbackDoneNotifier.Finish(struct{}{})
log.Info("dispatcher pullback done",
zap.Uint64("pullbackEndTs", d.pullbackEndTs),
zap.Time("pullbackTime", tsoutil.PhysicalTime(d.pullbackEndTs)),
)
d.pullbackDone = true
}
}
}
}
@ -265,7 +312,7 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
// but we still need to dispatch time ticks to the targets.
targetPacks := make(map[string]*MsgPack)
replicateConfigs := make(map[string]*msgstream.ReplicateConfig)
for vchannel, t := range d.targets {
d.targets.Range(func(vchannel string, t *target) bool {
targetPacks[vchannel] = &MsgPack{
BeginTs: pack.BeginTs,
EndTs: pack.EndTs,
@ -276,7 +323,8 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
if t.replicateConfig != nil {
replicateConfigs[vchannel] = t.replicateConfig
}
}
return true
})
// group messages by vchannel
for _, msg := range pack.Msgs {
var vchannel, collectionID string
@ -348,7 +396,7 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
d.resetMsgPackTS(targetPacks[vchannel], beginTs, endTs)
}
for vchannel := range replicateEndChannels {
if t, ok := d.targets[vchannel]; ok {
if t, ok := d.targets.Get(vchannel); ok {
t.replicateConfig = nil
log.Info("replicate end, set replicate config nil", zap.String("vchannel", vchannel))
}
@ -374,10 +422,3 @@ func (d *Dispatcher) resetMsgPackTS(pack *MsgPack, newBeginTs, newEndTs typeutil
pack.StartPositions = startPositions
pack.EndPositions = endPositions
}
func (d *Dispatcher) nonBlockingNotify() {
select {
case d.lagNotifyChan <- struct{}{}:
default:
}
}

View File

@ -17,8 +17,6 @@
package msgdispatcher
import (
"fmt"
"math/rand"
"sync"
"testing"
"time"
@ -37,7 +35,8 @@ import (
func TestDispatcher(t *testing.T) {
ctx := context.Background()
t.Run("test base", func(t *testing.T) {
d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false)
d, err := NewDispatcher(ctx, newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
assert.NoError(t, err)
assert.NotPanics(t, func() {
d.Handle(start)
@ -65,19 +64,24 @@ func TestDispatcher(t *testing.T) {
return ms, nil
},
}
d, err := NewDispatcher(ctx, factory, true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false)
d, err := NewDispatcher(ctx, factory, time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
assert.Error(t, err)
assert.Nil(t, d)
})
t.Run("test target", func(t *testing.T) {
d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false)
d, err := NewDispatcher(ctx, newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
assert.NoError(t, err)
output := make(chan *msgstream.MsgPack, 1024)
getTarget := func(vchannel string, pos *Pos, ch chan *msgstream.MsgPack) *target {
target := newTarget(vchannel, pos, nil)
target := newTarget(&StreamConfig{
VChannel: vchannel,
Pos: pos,
})
target.ch = ch
return target
}
@ -91,14 +95,7 @@ func TestDispatcher(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, cap(output), cap(target.ch))
d.CloseTarget("mock_vchannel_0")
select {
case <-time.After(1 * time.Second):
assert.Fail(t, "timeout, didn't receive close message")
case _, ok := <-target.ch:
assert.False(t, ok)
}
d.RemoveTarget("mock_vchannel_0")
num = d.TargetNum()
assert.Equal(t, 1, num)
@ -107,7 +104,7 @@ func TestDispatcher(t *testing.T) {
t.Run("test concurrent send and close", func(t *testing.T) {
for i := 0; i < 100; i++ {
output := make(chan *msgstream.MsgPack, 1024)
target := newTarget("mock_vchannel_0", nil, nil)
target := newTarget(&StreamConfig{VChannel: "mock_vchannel_0"})
target.ch = output
assert.Equal(t, cap(output), cap(target.ch))
wg := &sync.WaitGroup{}
@ -130,7 +127,8 @@ func TestDispatcher(t *testing.T) {
}
func BenchmarkDispatcher_handle(b *testing.B) {
d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false)
d, err := NewDispatcher(context.Background(), newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
@ -144,10 +142,14 @@ func BenchmarkDispatcher_handle(b *testing.B) {
}
func TestGroupMessage(t *testing.T) {
d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false)
d, err := NewDispatcher(context.Background(), newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
assert.NoError(t, err)
d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil))
d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo")))
d.AddTarget(newTarget(&StreamConfig{VChannel: "mock_pchannel_0_1v0"}))
d.AddTarget(newTarget(&StreamConfig{
VChannel: "mock_pchannel_0_2v0",
ReplicateConfig: msgstream.GetReplicateConfig("local-test", "foo", "coo"),
}))
{
// no replicate msg
packs := d.groupingMsgs(&MsgPack{
@ -286,7 +288,8 @@ func TestGroupMessage(t *testing.T) {
{
// replicate end
replicateTarget := d.targets["mock_pchannel_0_2v0"]
replicateTarget, ok := d.targets.Get("mock_pchannel_0_2v0")
assert.True(t, ok)
assert.NotNil(t, replicateTarget.replicateConfig)
packs := d.groupingMsgs(&MsgPack{
BeginTs: 1,

View File

@ -19,18 +19,20 @@ package msgdispatcher
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -51,121 +53,67 @@ type dispatcherManager struct {
nodeID int64
pchannel string
lagNotifyChan chan struct{}
lagTargets *typeutil.ConcurrentMap[string, *target] // vchannel -> *target
registeredTargets *typeutil.ConcurrentMap[string, *target]
mu sync.RWMutex // guards mainDispatcher and soloDispatchers
mainDispatcher *Dispatcher
soloDispatchers map[string]*Dispatcher // vchannel -> *Dispatcher
mu sync.RWMutex
mainDispatcher *Dispatcher
deputyDispatchers map[int64]*Dispatcher // ID -> *Dispatcher
factory msgstream.Factory
closeChan chan struct{}
closeOnce sync.Once
idAllocator atomic.Int64
factory msgstream.Factory
closeChan chan struct{}
closeOnce sync.Once
}
func NewDispatcherManager(pchannel string, role string, nodeID int64, factory msgstream.Factory) DispatcherManager {
log.Info("create new dispatcherManager", zap.String("role", role),
zap.Int64("nodeID", nodeID), zap.String("pchannel", pchannel))
c := &dispatcherManager{
role: role,
nodeID: nodeID,
pchannel: pchannel,
lagNotifyChan: make(chan struct{}, 1),
lagTargets: typeutil.NewConcurrentMap[string, *target](),
soloDispatchers: make(map[string]*Dispatcher),
factory: factory,
closeChan: make(chan struct{}),
role: role,
nodeID: nodeID,
pchannel: pchannel,
registeredTargets: typeutil.NewConcurrentMap[string, *target](),
deputyDispatchers: make(map[int64]*Dispatcher),
factory: factory,
closeChan: make(chan struct{}),
}
return c
}
func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) string {
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
}
func (c *dispatcherManager) Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
vchannel := streamConfig.VChannel
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.soloDispatchers[vchannel]; ok {
// current dispatcher didn't allow multiple subscriptions on same vchannel at same time
log.Warn("unreachable: solo vchannel dispatcher already exists")
return nil, fmt.Errorf("solo vchannel dispatcher already exists")
t := newTarget(streamConfig)
if _, ok := c.registeredTargets.GetOrInsert(t.vchannel, t); ok {
return nil, fmt.Errorf("vchannel %s already exists in the dispatcher", t.vchannel)
}
if c.mainDispatcher != nil {
if _, err := c.mainDispatcher.GetTarget(vchannel); err == nil {
// current dispatcher didn't allow multiple subscriptions on same vchannel at same time
log.Warn("unreachable: vchannel has been registered in main dispatcher, ")
return nil, fmt.Errorf("vchannel has been registered in main dispatcher")
}
}
isMain := c.mainDispatcher == nil
d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, streamConfig.Pos, c.constructSubName(vchannel, isMain), streamConfig.SubPos, c.lagNotifyChan, c.lagTargets, false)
if err != nil {
return nil, err
}
t := newTarget(vchannel, streamConfig.Pos, streamConfig.ReplicateConfig)
d.AddTarget(t)
if isMain {
c.mainDispatcher = d
log.Info("add main dispatcher")
} else {
c.soloDispatchers[vchannel] = d
log.Info("add solo dispatcher")
}
d.Handle(start)
log.Ctx(ctx).Info("target register done", zap.String("vchannel", t.vchannel))
return t.ch, nil
}
func (c *dispatcherManager) Remove(vchannel string) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.soloDispatchers[vchannel]; ok {
c.soloDispatchers[vchannel].Handle(terminate)
c.soloDispatchers[vchannel].CloseTarget(vchannel)
delete(c.soloDispatchers, vchannel)
c.deleteMetric(vchannel)
log.Info("remove soloDispatcher done")
t, ok := c.registeredTargets.GetAndRemove(vchannel)
if !ok {
log.Info("the target was not registered before", zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
return
}
if c.mainDispatcher != nil {
c.mainDispatcher.Handle(pause)
c.mainDispatcher.CloseTarget(vchannel)
if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 {
c.mainDispatcher.Handle(terminate)
c.mainDispatcher = nil
log.Info("remove mainDispatcher done")
} else {
c.mainDispatcher.Handle(resume)
}
}
c.lagTargets.GetAndRemove(vchannel)
c.removeTargetFromDispatcher(t)
t.close()
}
func (c *dispatcherManager) NumTarget() int {
c.mu.RLock()
defer c.mu.RUnlock()
var res int
if c.mainDispatcher != nil {
res += c.mainDispatcher.TargetNum()
}
return res + len(c.soloDispatchers) + c.lagTargets.Len()
return c.registeredTargets.Len()
}
func (c *dispatcherManager) NumConsumer() int {
c.mu.RLock()
defer c.mu.RUnlock()
var res int
numConsumer := 0
if c.mainDispatcher != nil {
res++
numConsumer++
}
return res + len(c.soloDispatchers)
numConsumer += len(c.deputyDispatchers)
return numConsumer
}
func (c *dispatcherManager) Close() {
@ -175,8 +123,7 @@ func (c *dispatcherManager) Close() {
}
func (c *dispatcherManager) Run() {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
log.Info("dispatcherManager is running...")
ticker1 := time.NewTicker(10 * time.Second)
ticker2 := time.NewTicker(paramtable.Get().MQCfg.MergeCheckInterval.GetAsDuration(time.Second))
@ -190,87 +137,232 @@ func (c *dispatcherManager) Run() {
case <-ticker1.C:
c.uploadMetric()
case <-ticker2.C:
c.tryRemoveUnregisteredTargets()
c.tryBuildDispatcher()
c.tryMerge()
case <-c.lagNotifyChan:
c.mu.Lock()
c.lagTargets.Range(func(vchannel string, t *target) bool {
c.split(t)
c.lagTargets.GetAndRemove(vchannel)
return true
})
c.mu.Unlock()
}
}
}
func (c *dispatcherManager) tryMerge() {
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID))
func (c *dispatcherManager) removeTargetFromDispatcher(t *target) {
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
c.mu.Lock()
defer c.mu.Unlock()
for _, dispatcher := range c.deputyDispatchers {
if dispatcher.HasTarget(t.vchannel) {
dispatcher.Handle(pause)
dispatcher.RemoveTarget(t.vchannel)
if dispatcher.TargetNum() == 0 {
dispatcher.Handle(terminate)
delete(c.deputyDispatchers, dispatcher.ID())
log.Info("remove deputy dispatcher done", zap.Int64("id", dispatcher.ID()))
} else {
dispatcher.Handle(resume)
}
t.close()
}
}
if c.mainDispatcher != nil {
if c.mainDispatcher.HasTarget(t.vchannel) {
c.mainDispatcher.Handle(pause)
c.mainDispatcher.RemoveTarget(t.vchannel)
if c.mainDispatcher.TargetNum() == 0 && len(c.deputyDispatchers) == 0 {
c.mainDispatcher.Handle(terminate)
c.mainDispatcher = nil
} else {
c.mainDispatcher.Handle(resume)
}
t.close()
}
}
}
func (c *dispatcherManager) tryRemoveUnregisteredTargets() {
unregisteredTargets := make([]*target, 0)
c.mu.RLock()
for _, dispatcher := range c.deputyDispatchers {
for _, t := range dispatcher.GetTargets() {
if !c.registeredTargets.Contain(t.vchannel) {
unregisteredTargets = append(unregisteredTargets, t)
}
}
}
if c.mainDispatcher != nil {
for _, t := range c.mainDispatcher.GetTargets() {
if !c.registeredTargets.Contain(t.vchannel) {
unregisteredTargets = append(unregisteredTargets, t)
}
}
}
c.mu.RUnlock()
for _, t := range unregisteredTargets {
c.removeTargetFromDispatcher(t)
}
}
func (c *dispatcherManager) tryBuildDispatcher() {
tr := timerecord.NewTimeRecorder("")
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
allTargets := c.registeredTargets.Values()
// get lack targets to perform subscription
lackTargets := make([]*target, 0, len(allTargets))
c.mu.RLock()
OUTER:
for _, t := range allTargets {
if c.mainDispatcher != nil && c.mainDispatcher.HasTarget(t.vchannel) {
continue
}
for _, dispatcher := range c.deputyDispatchers {
if dispatcher.HasTarget(t.vchannel) {
continue OUTER
}
}
lackTargets = append(lackTargets, t)
}
c.mu.RUnlock()
if len(lackTargets) == 0 {
return
}
sort.Slice(lackTargets, func(i, j int) bool {
return lackTargets[i].pos.GetTimestamp() < lackTargets[j].pos.GetTimestamp()
})
// To prevent the position gap between targets from becoming too large and causing excessive pull-back time,
// limit the position difference between targets to no more than 60 minutes.
earliestTarget := lackTargets[0]
candidateTargets := make([]*target, 0, len(lackTargets))
for _, t := range lackTargets {
if tsoutil.PhysicalTime(t.pos.GetTimestamp()).Sub(
tsoutil.PhysicalTime(earliestTarget.pos.GetTimestamp())) <=
paramtable.Get().MQCfg.MaxPositionTsGap.GetAsDuration(time.Minute) {
candidateTargets = append(candidateTargets, t)
}
}
vchannels := lo.Map(candidateTargets, func(t *target, _ int) string {
return t.vchannel
})
log.Info("start to build dispatchers", zap.Int("numTargets", len(vchannels)),
zap.Strings("vchannels", vchannels))
// dispatcher will pull back from the earliest position
// to the latest position in lack targets.
latestTarget := candidateTargets[len(candidateTargets)-1]
// TODO: add newDispatcher timeout param and init context
id := c.idAllocator.Inc()
d, err := NewDispatcher(context.Background(), c.factory, id, c.pchannel, earliestTarget.pos, earliestTarget.subPos, latestTarget.pos.GetTimestamp())
if err != nil {
panic(err)
}
for _, t := range candidateTargets {
d.AddTarget(t)
}
d.Handle(start)
buildDur := tr.RecordSpan()
// block util pullback to the latest target position
if len(candidateTargets) > 1 {
d.BlockUtilPullbackDone()
}
var (
pullbackBeginTs = earliestTarget.pos.GetTimestamp()
pullbackEndTs = latestTarget.pos.GetTimestamp()
pullbackBeginTime = tsoutil.PhysicalTime(pullbackBeginTs)
pullbackEndTime = tsoutil.PhysicalTime(pullbackEndTs)
)
log.Info("build dispatcher done",
zap.Int64("id", d.ID()),
zap.Int("numVchannels", len(vchannels)),
zap.Uint64("pullbackBeginTs", pullbackBeginTs),
zap.Uint64("pullbackEndTs", pullbackEndTs),
zap.Duration("lag", pullbackEndTime.Sub(pullbackBeginTime)),
zap.Time("pullbackBeginTime", pullbackBeginTime),
zap.Time("pullbackEndTime", pullbackEndTime),
zap.Duration("buildDur", buildDur),
zap.Duration("pullbackDur", tr.RecordSpan()),
zap.Strings("vchannels", vchannels),
)
c.mu.Lock()
defer c.mu.Unlock()
d.Handle(pause)
for _, candidate := range candidateTargets {
vchannel := candidate.vchannel
t, ok := c.registeredTargets.Get(vchannel)
// During the build process, the target may undergo repeated deregister and register,
// causing the channel object to change. Here, validate whether the channel is the
// same as before the build. If inconsistent, remove the target.
if !ok || t.ch != candidate.ch {
d.RemoveTarget(vchannel)
}
}
d.Handle(resume)
if c.mainDispatcher == nil {
c.mainDispatcher = d
log.Info("add main dispatcher", zap.Int64("id", d.ID()))
} else {
c.deputyDispatchers[d.ID()] = d
log.Info("add deputy dispatcher", zap.Int64("id", d.ID()))
}
}
func (c *dispatcherManager) tryMerge() {
c.mu.Lock()
defer c.mu.Unlock()
start := time.Now()
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
if c.mainDispatcher == nil || c.mainDispatcher.CurTs() == 0 {
return
}
candidates := make(map[string]struct{})
for vchannel, sd := range c.soloDispatchers {
candidates := make([]*Dispatcher, 0, len(c.deputyDispatchers))
for _, sd := range c.deputyDispatchers {
if sd.CurTs() == c.mainDispatcher.CurTs() {
candidates[vchannel] = struct{}{}
candidates = append(candidates, sd)
}
}
if len(candidates) == 0 {
return
}
log.Info("start merging...", zap.Any("vchannel", candidates))
dispatcherIDs := lo.Map(candidates, func(d *Dispatcher, _ int) int64 {
return d.ID()
})
log.Info("start merging...", zap.Int64s("dispatchers", dispatcherIDs))
mergeCandidates := make([]*Dispatcher, 0, len(candidates))
c.mainDispatcher.Handle(pause)
for vchannel := range candidates {
c.soloDispatchers[vchannel].Handle(pause)
for _, dispatcher := range candidates {
dispatcher.Handle(pause)
// after pause, check alignment again, if not, evict it and try to merge next time
if c.mainDispatcher.CurTs() != c.soloDispatchers[vchannel].CurTs() {
c.soloDispatchers[vchannel].Handle(resume)
delete(candidates, vchannel)
if c.mainDispatcher.CurTs() != dispatcher.CurTs() {
dispatcher.Handle(resume)
continue
}
mergeCandidates = append(mergeCandidates, dispatcher)
}
mergeTs := c.mainDispatcher.CurTs()
for vchannel := range candidates {
t, err := c.soloDispatchers[vchannel].GetTarget(vchannel)
if err == nil {
for _, dispatcher := range mergeCandidates {
targets := dispatcher.GetTargets()
for _, t := range targets {
c.mainDispatcher.AddTarget(t)
}
c.soloDispatchers[vchannel].Handle(terminate)
delete(c.soloDispatchers, vchannel)
c.deleteMetric(vchannel)
dispatcher.Handle(terminate)
delete(c.deputyDispatchers, dispatcher.ID())
}
c.mainDispatcher.Handle(resume)
log.Info("merge done", zap.Any("vchannel", candidates), zap.Uint64("mergeTs", mergeTs))
}
func (c *dispatcherManager) split(t *target) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", t.vchannel))
log.Info("start splitting...")
// remove stale soloDispatcher if it existed
if _, ok := c.soloDispatchers[t.vchannel]; ok {
c.soloDispatchers[t.vchannel].Handle(terminate)
delete(c.soloDispatchers, t.vchannel)
c.deleteMetric(t.vchannel)
}
var newSolo *Dispatcher
err := retry.Do(context.Background(), func() error {
var err error
newSolo, err = NewDispatcher(context.Background(), c.factory, false, c.pchannel, t.pos, c.constructSubName(t.vchannel, false), common.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets, true)
return err
}, retry.Attempts(10))
if err != nil {
log.Error("split failed", zap.Error(err))
panic(err)
}
newSolo.AddTarget(t)
c.soloDispatchers[t.vchannel] = newSolo
newSolo.Handle(start)
log.Info("split done")
log.Info("merge done", zap.Int64s("dispatchers", dispatcherIDs),
zap.Uint64("mergeTs", mergeTs),
zap.Duration("dur", time.Since(start)))
}
// deleteMetric remove specific prometheus metric,
@ -289,18 +381,21 @@ func (c *dispatcherManager) deleteMetric(channel string) {
func (c *dispatcherManager) uploadMetric() {
c.mu.RLock()
defer c.mu.RUnlock()
nodeIDStr := fmt.Sprintf("%d", c.nodeID)
fn := func(gauge *prometheus.GaugeVec) {
if c.mainDispatcher == nil {
return
}
// for main dispatcher, use pchannel as channel label
gauge.WithLabelValues(nodeIDStr, c.pchannel).Set(
float64(time.Since(tsoutil.PhysicalTime(c.mainDispatcher.CurTs())).Milliseconds()))
// for solo dispatchers, use vchannel as channel label
for vchannel, dispatcher := range c.soloDispatchers {
gauge.WithLabelValues(nodeIDStr, vchannel).Set(
float64(time.Since(tsoutil.PhysicalTime(dispatcher.CurTs())).Milliseconds()))
for _, t := range c.mainDispatcher.GetTargets() {
gauge.WithLabelValues(nodeIDStr, t.vchannel).Set(
float64(time.Since(tsoutil.PhysicalTime(c.mainDispatcher.CurTs())).Milliseconds()))
}
for _, dispatcher := range c.deputyDispatchers {
for _, t := range dispatcher.GetTargets() {
gauge.WithLabelValues(nodeIDStr, t.vchannel).Set(
float64(time.Since(tsoutil.PhysicalTime(dispatcher.CurTs())).Milliseconds()))
}
}
}
if c.role == typeutil.DataNodeRole {

View File

@ -20,422 +20,207 @@ import (
"context"
"fmt"
"math/rand"
"reflect"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestManager(t *testing.T) {
paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "65536")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key)
t.Run("test add and remove dispatcher", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())
factory := newMockFactory()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory)
assert.NotNil(t, c)
go c.Run()
defer c.Close()
assert.Equal(t, 0, c.NumConsumer())
assert.Equal(t, 0, c.NumTarget())
var offset int
for i := 0; i < 100; i++ {
r := rand.Intn(10) + 1
for i := 0; i < 30; i++ {
r := rand.Intn(5) + 1
for j := 0; j < r; j++ {
offset++
vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset)
vchannel := fmt.Sprintf("%s_vchannelv%d", pchannel, offset)
t.Logf("add vchannel, %s", vchannel)
_, err := c.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
_, err := c.Add(ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, offset, c.NumConsumer())
assert.Equal(t, offset, c.NumTarget())
}
assert.Eventually(t, func() bool {
t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget())
return c.NumTarget() == offset
}, 3*time.Second, 10*time.Millisecond)
for j := 0; j < rand.Intn(r); j++ {
vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset)
vchannel := fmt.Sprintf("%s_vchannelv%d", pchannel, offset)
t.Logf("remove vchannel, %s", vchannel)
c.Remove(vchannel)
offset--
assert.Equal(t, offset, c.NumConsumer())
assert.Equal(t, offset, c.NumTarget())
}
assert.Eventually(t, func() bool {
t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget())
return c.NumTarget() == offset
}, 3*time.Second, 10*time.Millisecond)
}
})
t.Run("test merge and split", func(t *testing.T) {
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
ctx := context.Background()
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "16")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())
factory := newMockFactory()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory)
assert.NotNil(t, c)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.NumConsumer())
assert.Equal(t, 3, c.NumTarget())
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
c.(*dispatcherManager).mu.RLock()
for _, d := range c.(*dispatcherManager).soloDispatchers {
d.curTs.Store(1000)
}
c.(*dispatcherManager).mu.RUnlock()
c.(*dispatcherManager).tryMerge()
assert.Equal(t, 1, c.NumConsumer())
go c.Run()
defer c.Close()
paramtable.Get().Save(paramtable.Get().MQCfg.MaxTolerantLag.Key, "0.5")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.MaxTolerantLag.Key)
o0, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
o1, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
o2, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.NumTarget())
info := &target{
vchannel: "mock_vchannel_2",
pos: nil,
ch: nil,
consumeFn := func(output <-chan *MsgPack, done <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case <-done:
return
case <-output:
}
}
}
c.(*dispatcherManager).split(info)
assert.Equal(t, 2, c.NumConsumer())
wg := &sync.WaitGroup{}
wg.Add(3)
d0 := make(chan struct{}, 1)
d1 := make(chan struct{}, 1)
d2 := make(chan struct{}, 1)
go consumeFn(o0, d0, wg)
go consumeFn(o1, d1, wg)
go consumeFn(o2, d2, wg)
assert.Eventually(t, func() bool {
return c.NumConsumer() == 1 // expected merge
}, 20*time.Second, 10*time.Millisecond)
// stop consume vchannel_2 to trigger split
d2 <- struct{}{}
assert.Eventually(t, func() bool {
t.Logf("c.NumConsumer=%d", c.NumConsumer())
return c.NumConsumer() == 2 // expected split
}, 20*time.Second, 10*time.Millisecond)
// stop all
d0 <- struct{}{}
d1 <- struct{}{}
wg.Wait()
})
t.Run("test run and close", func(t *testing.T) {
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
ctx := context.Background()
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())
factory := newMockFactory()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory)
assert.NotNil(t, c)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
go c.Run()
defer c.Close()
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.NumConsumer())
assert.Equal(t, 3, c.NumTarget())
assert.Eventually(t, func() bool {
return c.NumConsumer() >= 1
}, 3*time.Second, 10*time.Millisecond)
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
c.(*dispatcherManager).mu.RLock()
for _, d := range c.(*dispatcherManager).soloDispatchers {
for _, d := range c.(*dispatcherManager).deputyDispatchers {
d.curTs.Store(1000)
}
c.(*dispatcherManager).mu.RUnlock()
checkIntervalK := paramtable.Get().MQCfg.MergeCheckInterval.Key
paramtable.Get().Save(checkIntervalK, "0.01")
defer paramtable.Get().Reset(checkIntervalK)
go c.Run()
assert.Eventually(t, func() bool {
return c.NumConsumer() == 1 // expected merged
}, 3*time.Second, 10*time.Millisecond)
assert.Equal(t, 3, c.NumTarget())
assert.NotPanics(t, func() {
c.Close()
})
})
t.Run("test add timeout", func(t *testing.T) {
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Millisecond*2)
defer cancel()
time.Sleep(time.Millisecond * 2)
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
go c.Run()
assert.NotNil(t, c)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
assert.Equal(t, 0, c.NumConsumer())
assert.Equal(t, 0, c.NumTarget())
assert.NotPanics(t, func() {
c.Close()
})
})
t.Run("test_repeated_vchannel", func(t *testing.T) {
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())
factory := newMockFactory()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(t, ctx, producer)
c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory)
go c.Run()
defer c.Close()
assert.NotNil(t, c)
ctx := context.Background()
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
assert.NotPanics(t, func() {
c.Close()
})
assert.Eventually(t, func() bool {
return c.NumConsumer() >= 1
}, 3*time.Second, 10*time.Millisecond)
})
}
type vchannelHelper struct {
output <-chan *msgstream.MsgPack
pubInsMsgNum int
pubDelMsgNum int
pubDDLMsgNum int
pubPackNum int
subInsMsgNum int
subDelMsgNum int
subDDLMsgNum int
subPackNum int
}
type SimulationSuite struct {
suite.Suite
testVchannelNum int
manager DispatcherManager
pchannel string
vchannels map[string]*vchannelHelper
producer msgstream.MsgStream
factory msgstream.Factory
}
func (suite *SimulationSuite) SetupSuite() {
suite.factory = newMockFactory()
}
func (suite *SimulationSuite) SetupTest() {
suite.pchannel = fmt.Sprintf("by-dev-rootcoord-dispatcher-simulation-dml_%d", time.Now().UnixNano())
producer, err := newMockProducer(suite.factory, suite.pchannel)
assert.NoError(suite.T(), err)
suite.producer = producer
suite.manager = NewDispatcherManager(suite.pchannel, typeutil.DataNodeRole, 0, suite.factory)
go suite.manager.Run()
}
func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64) {
defer wg.Done()
const timeTickCount = 100
var uniqueMsgID int64
vchannelKeys := reflect.ValueOf(suite.vchannels).MapKeys()
for i := 1; i <= timeTickCount; i++ {
// produce random insert
insNum := rand.Intn(10)
for j := 0; j < insNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
uniqueMsgID++
suite.vchannels[vchannel].pubInsMsgNum++
}
// produce random delete
delNum := rand.Intn(2)
for j := 0; j < delNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
uniqueMsgID++
suite.vchannels[vchannel].pubDelMsgNum++
}
// produce random ddl
ddlNum := rand.Intn(2)
for j := 0; j < ddlNum; j++ {
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)},
})
assert.NoError(suite.T(), err)
for k := range suite.vchannels {
suite.vchannels[k].pubDDLMsgNum++
}
}
// produce time tick
ts := uint64(i * 100)
err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
for k := range suite.vchannels {
suite.vchannels[k].pubPackNum++
}
}
suite.T().Logf("[%s] produce %d msgPack for %s done", time.Now(), timeTickCount, suite.pchannel)
}
func (suite *SimulationSuite) consumeMsg(ctx context.Context, wg *sync.WaitGroup, vchannel string) {
defer wg.Done()
var lastTs typeutil.Timestamp
for {
select {
case <-ctx.Done():
return
case pack := <-suite.vchannels[vchannel].output:
assert.Greater(suite.T(), pack.EndTs, lastTs)
lastTs = pack.EndTs
helper := suite.vchannels[vchannel]
helper.subPackNum++
for _, msg := range pack.Msgs {
switch msg.Type() {
case commonpb.MsgType_Insert:
helper.subInsMsgNum++
case commonpb.MsgType_Delete:
helper.subDelMsgNum++
case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection,
commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition:
helper.subDDLMsgNum++
}
}
}
}
}
func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
tt := 1
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ts := uint64(tt * 1000)
err := suite.producer.Produce(ctx, &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
tt++
}
}
}
func (suite *SimulationSuite) TestDispatchToVchannels() {
ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond)
defer cancel()
const (
vchannelNum = 10
collectionID int64 = 1234
)
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_%dv%d", suite.pchannel, collectionID, i)
output, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}
wg.Add(1)
go suite.produceMsg(wg, collectionID)
wg.Wait()
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(ctx, wg, vchannel)
}
wg.Wait()
for vchannel, helper := range suite.vchannels {
msg := fmt.Sprintf("vchannel=%s", vchannel)
assert.Equal(suite.T(), helper.pubInsMsgNum, helper.subInsMsgNum, msg)
assert.Equal(suite.T(), helper.pubDelMsgNum, helper.subDelMsgNum, msg)
assert.Equal(suite.T(), helper.pubDDLMsgNum, helper.subDDLMsgNum, msg)
assert.Equal(suite.T(), helper.pubPackNum, helper.subPackNum, msg)
}
}
func (suite *SimulationSuite) TestMerge() {
ctx, cancel := context.WithCancel(context.Background())
go suite.produceTimeTickOnly(ctx)
const vchannelNum = 10
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
positions, err := getSeekPositions(suite.factory, suite.pchannel, 100)
assert.NoError(suite.T(), err)
assert.NotEqual(suite.T(), 0, len(positions))
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(context.Background(), NewStreamConfig(
vchannel, positions[rand.Intn(len(positions))],
common.SubscriptionPositionUnknown,
)) // seek from random position
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(ctx, wg, vchannel)
}
suite.Eventually(func() bool {
suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.NumConsumer())
return suite.manager.NumConsumer() == 1 // expected all merged, only mainDispatcher exist
}, 15*time.Second, 100*time.Millisecond)
assert.Equal(suite.T(), vchannelNum, suite.manager.NumTarget())
cancel()
wg.Wait()
}
func (suite *SimulationSuite) TestSplit() {
ctx, cancel := context.WithCancel(context.Background())
go suite.produceTimeTickOnly(ctx)
const (
vchannelNum = 10
splitNum = 3
)
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
maxTolerantLagK := paramtable.Get().MQCfg.MaxTolerantLag.Key
paramtable.Get().Save(maxTolerantLagK, "0.5")
defer paramtable.Get().Reset(maxTolerantLagK)
targetBufSizeK := paramtable.Get().MQCfg.TargetBufSize.Key
defer paramtable.Get().Reset(targetBufSizeK)
for i := 0; i < vchannelNum; i++ {
paramtable.Get().Save(targetBufSizeK, "65536")
if i >= vchannelNum-splitNum {
paramtable.Get().Save(targetBufSizeK, "10")
}
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
_, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
assert.NoError(suite.T(), err)
}
suite.Eventually(func() bool {
suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.NumConsumer(), splitNum+1)
return suite.manager.NumConsumer() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers
}, 10*time.Second, 100*time.Millisecond)
assert.Equal(suite.T(), vchannelNum, suite.manager.NumTarget())
cancel()
}
func (suite *SimulationSuite) TearDownTest() {
for vchannel := range suite.vchannels {
suite.manager.Remove(vchannel)
}
suite.manager.Close()
}
func (suite *SimulationSuite) TearDownSuite() {
}
func TestSimulation(t *testing.T) {
suite.Run(t, new(SimulationSuite))
}

View File

@ -21,14 +21,20 @@ import (
"fmt"
"math/rand"
"os"
"sync"
"testing"
"time"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -55,34 +61,11 @@ func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgS
if err != nil {
return nil, err
}
stream.AsProducer(context.TODO(), []string{pchannel})
stream.AsProducer(context.Background(), []string{pchannel})
stream.SetRepackFunc(defaultInsertRepackFunc)
return stream, nil
}
func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([]*msgstream.MsgPosition, error) {
stream, err := factory.NewTtMsgStream(context.Background())
if err != nil {
return nil, err
}
defer stream.Close()
stream.AsConsumer(context.TODO(), []string{pchannel}, fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest)
positions := make([]*msgstream.MsgPosition, 0)
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for {
select {
case <-timeoutCtx.Done(): // no message to consume
return positions, nil
case pack := <-stream.Chan():
positions = append(positions, pack.EndPositions[0])
if len(positions) >= maxNum {
return positions, nil
}
}
}
}
func genPKs(numRows int) []typeutil.IntPrimaryKey {
ids := make([]typeutil.IntPrimaryKey, numRows)
for i := 0; i < numRows; i++ {
@ -91,15 +74,15 @@ func genPKs(numRows int) []typeutil.IntPrimaryKey {
return ids
}
func genTimestamps(numRows int) []typeutil.Timestamp {
ts := make([]typeutil.Timestamp, numRows)
func genTimestamps(numRows int, ts typeutil.Timestamp) []typeutil.Timestamp {
tss := make([]typeutil.Timestamp, numRows)
for i := 0; i < numRows; i++ {
ts[i] = typeutil.Timestamp(i + 1)
tss[i] = ts
}
return ts
return tss
}
func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.InsertMsg {
func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID, ts typeutil.Timestamp) *msgstream.InsertMsg {
floatVec := make([]float32, numRows*dim)
for i := 0; i < numRows*dim; i++ {
floatVec[i] = rand.Float32()
@ -111,9 +94,9 @@ func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstr
return &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{HashValues: hashValues},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID},
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID, Timestamp: ts},
ShardName: vchannel,
Timestamps: genTimestamps(numRows),
Timestamps: genTimestamps(numRows, ts),
RowIDs: genPKs(numRows),
FieldsData: []*schemapb.FieldData{{
Field: &schemapb.FieldData_Vectors{
@ -129,11 +112,11 @@ func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstr
}
}
func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.DeleteMsg {
func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID, ts typeutil.Timestamp) *msgstream.DeleteMsg {
return &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{HashValues: make([]uint32, numRows)},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID},
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID, Timestamp: ts},
ShardName: vchannel,
PrimaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
@ -142,19 +125,19 @@ func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstr
},
},
},
Timestamps: genTimestamps(numRows),
Timestamps: genTimestamps(numRows, ts),
NumRows: int64(numRows),
},
}
}
func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg {
func genDDLMsg(msgType commonpb.MsgType, collectionID int64, ts typeutil.Timestamp) msgstream.TsMsg {
switch msgType {
case commonpb.MsgType_CreateCollection:
return &msgstream.CreateCollectionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
CreateCollectionRequest: &msgpb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection, Timestamp: ts},
CollectionID: collectionID,
},
}
@ -162,7 +145,7 @@ func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg {
return &msgstream.DropCollectionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
DropCollectionRequest: &msgpb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection, Timestamp: ts},
CollectionID: collectionID,
},
}
@ -170,7 +153,7 @@ func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg {
return &msgstream.CreatePartitionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
CreatePartitionRequest: &msgpb.CreatePartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition},
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition, Timestamp: ts},
CollectionID: collectionID,
},
}
@ -178,7 +161,7 @@ func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg {
return &msgstream.DropPartitionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
DropPartitionRequest: &msgpb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition, Timestamp: ts},
CollectionID: collectionID,
},
}
@ -226,3 +209,196 @@ func defaultInsertRepackFunc(
}
return pack, nil
}
type vchannelHelper struct {
output <-chan *msgstream.MsgPack
pubInsMsgNum atomic.Int32
pubDelMsgNum atomic.Int32
pubDDLMsgNum atomic.Int32
pubPackNum atomic.Int32
subInsMsgNum atomic.Int32
subDelMsgNum atomic.Int32
subDDLMsgNum atomic.Int32
subPackNum atomic.Int32
seekPos *Pos
skippedInsMsgNum int32
skippedDelMsgNum int32
skippedDDLMsgNum int32
skippedPackNum int32
}
func produceMsgs(t *testing.T, ctx context.Context, wg *sync.WaitGroup, producer msgstream.MsgStream, vchannels map[string]*vchannelHelper) {
defer wg.Done()
uniqueMsgID := int64(0)
vchannelNames := lo.Keys(vchannels)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
i := 1
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ts := uint64(i * 100)
// produce random insert
insNum := rand.Intn(10)
for j := 0; j < insNum; j++ {
vchannel := vchannelNames[rand.Intn(len(vchannels))]
err := producer.Produce(context.Background(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID, ts)},
})
assert.NoError(t, err)
uniqueMsgID++
vchannels[vchannel].pubInsMsgNum.Inc()
}
// produce random delete
delNum := rand.Intn(2)
for j := 0; j < delNum; j++ {
vchannel := vchannelNames[rand.Intn(len(vchannels))]
err := producer.Produce(context.Background(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(10)+1, vchannel, uniqueMsgID, ts)},
})
assert.NoError(t, err)
uniqueMsgID++
vchannels[vchannel].pubDelMsgNum.Inc()
}
// produce random ddl
ddlNum := rand.Intn(2)
for j := 0; j < ddlNum; j++ {
vchannel := vchannelNames[rand.Intn(len(vchannels))]
collectionID := funcutil.GetCollectionIDFromVChannel(vchannel)
err := producer.Produce(context.Background(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID, ts)},
})
assert.NoError(t, err)
uniqueMsgID++
vchannels[vchannel].pubDDLMsgNum.Inc()
}
// produce time tick
err := producer.Produce(context.Background(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(t, err)
for k := range vchannels {
vchannels[k].pubPackNum.Inc()
}
i++
}
}
}
func consumeMsgsFromTargets(t *testing.T, ctx context.Context, wg *sync.WaitGroup, vchannel string, helper *vchannelHelper) {
defer wg.Done()
var lastTs typeutil.Timestamp
for {
select {
case <-ctx.Done():
return
case pack := <-helper.output:
if pack == nil || pack.EndTs == 0 {
continue
}
assert.Greater(t, pack.EndTs, lastTs, fmt.Sprintf("vchannel=%s", vchannel))
lastTs = pack.EndTs
helper.subPackNum.Inc()
for _, msg := range pack.Msgs {
switch msg.Type() {
case commonpb.MsgType_Insert:
helper.subInsMsgNum.Inc()
case commonpb.MsgType_Delete:
helper.subDelMsgNum.Inc()
case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection,
commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition:
helper.subDDLMsgNum.Inc()
}
}
}
}
}
func produceTimeTick(t *testing.T, ctx context.Context, producer msgstream.MsgStream) {
tt := 1
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ts := uint64(tt * 1000)
err := producer.Produce(ctx, &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(t, err)
tt++
}
}
}
func getRandomSeekPositions(t *testing.T, ctx context.Context, factory msgstream.Factory, pchannel string, vchannels map[string]*vchannelHelper) {
stream, err := factory.NewTtMsgStream(context.Background())
assert.NoError(t, err)
defer stream.Close()
err = stream.AsConsumer(context.Background(), []string{pchannel}, fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest)
assert.NoError(t, err)
for {
select {
case <-ctx.Done():
return
case pack := <-stream.Chan():
for _, msg := range pack.Msgs {
switch msg.Type() {
case commonpb.MsgType_Insert:
vchannel := msg.(*msgstream.InsertMsg).GetShardName()
if vchannels[vchannel].seekPos == nil {
vchannels[vchannel].skippedInsMsgNum++
}
case commonpb.MsgType_Delete:
vchannel := msg.(*msgstream.DeleteMsg).GetShardName()
if vchannels[vchannel].seekPos == nil {
vchannels[vchannel].skippedDelMsgNum++
}
case commonpb.MsgType_DropCollection:
collectionID := msg.(*msgstream.DropCollectionMsg).GetCollectionID()
for vchannel := range vchannels {
if vchannels[vchannel].seekPos == nil &&
funcutil.GetCollectionIDFromVChannel(vchannel) == collectionID {
vchannels[vchannel].skippedDDLMsgNum++
}
}
}
}
for _, helper := range vchannels {
if helper.seekPos == nil {
helper.skippedPackNum++
}
}
if rand.Intn(5) == 0 { // assign random seek position
for _, helper := range vchannels {
if helper.seekPos == nil {
helper.seekPos = pack.EndPositions[0]
break
}
}
}
allAssigned := true
for _, helper := range vchannels {
if helper.seekPos == nil {
allAssigned = false
break
}
}
if allAssigned {
return // all seek positions have been assigned
}
}
}
}

View File

@ -32,6 +32,7 @@ import (
type target struct {
vchannel string
ch chan *MsgPack
subPos SubPos
pos *Pos
closeMu sync.Mutex
@ -44,12 +45,14 @@ type target struct {
cancelCh lifetime.SafeChan
}
func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateConfig) *target {
func newTarget(streamConfig *StreamConfig) *target {
replicateConfig := streamConfig.ReplicateConfig
maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second)
t := &target{
vchannel: vchannel,
vchannel: streamConfig.VChannel,
ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()),
pos: pos,
subPos: streamConfig.SubPos,
pos: streamConfig.Pos,
cancelCh: lifetime.NewSafeChan(),
maxLag: maxTolerantLag,
timer: time.NewTimer(maxTolerantLag),
@ -58,7 +61,7 @@ func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateCo
t.closed = false
if replicateConfig != nil {
log.Info("have replicate config",
zap.String("vchannel", vchannel),
zap.String("vchannel", streamConfig.VChannel),
zap.String("replicateID", replicateConfig.ReplicateID))
}
return t
@ -72,6 +75,7 @@ func (t *target) close() {
t.closed = true
t.timer.Stop()
close(t.ch)
log.Info("close target chan", zap.String("vchannel", t.vchannel))
})
}
@ -94,7 +98,7 @@ func (t *target) send(pack *MsgPack) error {
log.Info("target closed", zap.String("vchannel", t.vchannel))
return nil
case <-t.timer.C:
return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s", t.vchannel, t.maxLag)
return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s, beginTs=%d, endTs=%d", t.vchannel, t.maxLag, pack.BeginTs, pack.EndTs)
case t.ch <- pack:
return nil
}

View File

@ -14,7 +14,10 @@ import (
)
func TestSendTimeout(t *testing.T) {
target := newTarget("test1", &msgpb.MsgPosition{}, nil)
target := newTarget(&StreamConfig{
VChannel: "test1",
Pos: &msgpb.MsgPosition{},
})
time.Sleep(paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second))

View File

@ -637,7 +637,7 @@ func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subN
return errors.Wrapf(err, errMsg)
}
panic(fmt.Sprintf("%s, errors = %s", errMsg, err.Error()))
panic(fmt.Sprintf("%s, subName = %s, errors = %s", errMsg, subName, err.Error()))
}
}

View File

@ -140,11 +140,10 @@ var (
ErrMetricNotFound = newMilvusError("metric not found", 1200, false)
// Message queue related
ErrMqTopicNotFound = newMilvusError("topic not found", 1300, false)
ErrMqTopicNotEmpty = newMilvusError("topic not empty", 1301, false)
ErrMqInternal = newMilvusError("message queue internal error", 1302, false)
ErrDenyProduceMsg = newMilvusError("deny to write the message to mq", 1303, false)
ErrTooManyConsumers = newMilvusError("consumer number limit exceeded", 1304, false)
ErrMqTopicNotFound = newMilvusError("topic not found", 1300, false)
ErrMqTopicNotEmpty = newMilvusError("topic not empty", 1301, false)
ErrMqInternal = newMilvusError("message queue internal error", 1302, false)
ErrDenyProduceMsg = newMilvusError("deny to write the message to mq", 1303, false)
// Privilege related
// this operation is denied because the user not authorized, user need to login in first

View File

@ -147,7 +147,6 @@ func (s *ErrSuite) TestWrap() {
s.ErrorIs(WrapErrMqTopicNotFound("unknown", "failed to get topic"), ErrMqTopicNotFound)
s.ErrorIs(WrapErrMqTopicNotEmpty("unknown", "topic is not empty"), ErrMqTopicNotEmpty)
s.ErrorIs(WrapErrMqInternal(errors.New("unknown"), "failed to consume"), ErrMqInternal)
s.ErrorIs(WrapErrTooManyConsumers("unknown", "too many consumers"), ErrTooManyConsumers)
// field related
s.ErrorIs(WrapErrFieldNotFound("meta", "failed to get field"), ErrFieldNotFound)

View File

@ -1000,14 +1000,6 @@ func WrapErrMqInternal(err error, msg ...string) error {
return err
}
func WrapErrTooManyConsumers(vchannel string, msg ...string) error {
err := wrapFields(ErrTooManyConsumers, value("vchannel", vchannel))
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "->"))
}
return err
}
func WrapErrPrivilegeNotAuthenticated(fmt string, args ...any) error {
err := errors.Wrapf(ErrPrivilegeNotAuthenticated, fmt, args...)
return err

View File

@ -529,12 +529,10 @@ type MQConfig struct {
IgnoreBadPosition ParamItem `refreshable:"true"`
// msgdispatcher
MergeCheckInterval ParamItem `refreshable:"false"`
TargetBufSize ParamItem `refreshable:"false"`
MaxTolerantLag ParamItem `refreshable:"true"`
MaxDispatcherNumPerPchannel ParamItem `refreshable:"true"`
RetrySleep ParamItem `refreshable:"true"`
RetryTimeout ParamItem `refreshable:"true"`
MergeCheckInterval ParamItem `refreshable:"false"`
TargetBufSize ParamItem `refreshable:"false"`
MaxTolerantLag ParamItem `refreshable:"true"`
MaxPositionTsGap ParamItem `refreshable:"true"`
}
// Init initializes the MQConfig object with a BaseTable.
@ -558,33 +556,6 @@ Valid values: [default, pulsar, kafka, rocksmq, natsmq]`,
}
p.MaxTolerantLag.Init(base.mgr)
p.MaxDispatcherNumPerPchannel = ParamItem{
Key: "mq.dispatcher.maxDispatcherNumPerPchannel",
Version: "2.4.19",
DefaultValue: "5",
Doc: `The maximum number of dispatchers per physical channel, primarily to limit the number of consumers and prevent performance issues(e.g., during recovery when a large number of channels are watched).`,
Export: true,
}
p.MaxDispatcherNumPerPchannel.Init(base.mgr)
p.RetrySleep = ParamItem{
Key: "mq.dispatcher.retrySleep",
Version: "2.4.19",
DefaultValue: "3",
Doc: `register retry sleep time in seconds`,
Export: true,
}
p.RetrySleep.Init(base.mgr)
p.RetryTimeout = ParamItem{
Key: "mq.dispatcher.retryTimeout",
Version: "2.4.19",
DefaultValue: "60",
Doc: `register retry timeout in seconds`,
Export: true,
}
p.RetryTimeout.Init(base.mgr)
p.TargetBufSize = ParamItem{
Key: "mq.dispatcher.targetBufSize",
Version: "2.4.4",
@ -603,6 +574,14 @@ Valid values: [default, pulsar, kafka, rocksmq, natsmq]`,
}
p.MergeCheckInterval.Init(base.mgr)
p.MaxPositionTsGap = ParamItem{
Key: "mq.dispatcher.maxPositionGapInMinutes",
Version: "2.5",
DefaultValue: "60",
Doc: `The max position timestamp gap in minutes.`,
}
p.MaxPositionTsGap.Init(base.mgr)
p.EnablePursuitMode = ParamItem{
Key: "mq.enablePursuitMode",
Version: "2.3.0",

View File

@ -37,9 +37,7 @@ func TestServiceParam(t *testing.T) {
assert.Equal(t, 1*time.Second, Params.MergeCheckInterval.GetAsDuration(time.Second))
assert.Equal(t, 16, Params.TargetBufSize.GetAsInt())
assert.Equal(t, 3*time.Second, Params.MaxTolerantLag.GetAsDuration(time.Second))
assert.Equal(t, 5, Params.MaxDispatcherNumPerPchannel.GetAsInt())
assert.Equal(t, 3*time.Second, Params.RetrySleep.GetAsDuration(time.Second))
assert.Equal(t, 60*time.Second, Params.RetryTimeout.GetAsDuration(time.Second))
assert.Equal(t, 60*time.Minute, Params.MaxPositionTsGap.GetAsDuration(time.Minute))
})
t.Run("test etcdConfig", func(t *testing.T) {

View File

@ -474,7 +474,7 @@ func TestSearchGroupByUnsupportedDataType(t *testing.T) {
common.DefaultFloatFieldName, common.DefaultDoubleFieldName,
common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultInt8ArrayField, common.DefaultFloatArrayField,
} {
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName))
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, false, "unsupported data type")
}
}
@ -495,7 +495,7 @@ func TestSearchGroupByRangeSearch(t *testing.T) {
// range search
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).
WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8"))
WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8").WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by")
}

View File

@ -268,7 +268,7 @@ func TestHybridSearchMultiVectorsPagination(t *testing.T) {
// offset 0, -1 -> 0
for _, offset := range []int{0, -1} {
searchRes, err := mc.HybridSearch(ctx, client.NewHybridSearchOption(schema.CollectionName, common.DefaultLimit, annReqDef).WithOffset(offset))
searchRes, err := mc.HybridSearch(ctx, client.NewHybridSearchOption(schema.CollectionName, common.DefaultLimit, annReqDef).WithOffset(offset).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultLimit)
}

View File

@ -65,14 +65,14 @@ func TestQueryVarcharPkDefault(t *testing.T) {
// query
expr := fmt.Sprintf("%s in ['0', '1', '2', '3', '4']", common.DefaultVarcharFieldName)
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr))
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)})
// get ids -> same result with query
varcharValues := []string{"0", "1", "2", "3", "4"}
ids := column.NewColumnVarChar(common.DefaultVarcharFieldName, varcharValues)
getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids))
getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errGet, true)
common.CheckQueryResult(t, getRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)})
}
@ -1094,12 +1094,12 @@ func TestQueryWithTemplateParam(t *testing.T) {
}
// default
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values))
WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{column.NewColumnInt64(common.DefaultInt64FieldName, int64Values)})
// cover keys
res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5))
res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.Equal(t, 5, res.ResultCount)
@ -1107,14 +1107,14 @@ func TestQueryWithTemplateParam(t *testing.T) {
anyValues := []int64{0.0, 100.0, 10000.0}
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("json_contains_any (%s, {any_values})", common.DefaultFloatArrayField)).WithTemplateParam("any_values", anyValues).
WithOutputFields(common.QueryCountFieldName))
WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 101, count)
// dynamic
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName))
WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 1500, count)
@ -1123,7 +1123,8 @@ func TestQueryWithTemplateParam(t *testing.T) {
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s['bool'] == {v}", common.DefaultJSONFieldName)).
WithTemplateParam("v", false).
WithOutputFields(common.QueryCountFieldName))
WithOutputFields(common.QueryCountFieldName).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 1500/2, count)
@ -1132,7 +1133,8 @@ func TestQueryWithTemplateParam(t *testing.T) {
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s == {v}", common.DefaultBoolFieldName)).
WithTemplateParam("v", true).
WithOutputFields(common.QueryCountFieldName))
WithOutputFields(common.QueryCountFieldName).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb/2, count)
@ -1141,7 +1143,8 @@ func TestQueryWithTemplateParam(t *testing.T) {
res, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s >= {k1} && %s < {k2}", common.DefaultInt64FieldName, common.DefaultInt64FieldName)).
WithTemplateParam("v", 0).WithTemplateParam("k1", 1000).
WithTemplateParam("k2", 2000))
WithTemplateParam("k2", 2000).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.EqualValues(t, 1000, res.ResultCount)
}