diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 1e9253d223..fdd7078edc 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -176,9 +176,6 @@ mq: mergeCheckInterval: 1 # the interval time(in seconds) for dispatcher to check whether to merge targetBufSize: 16 # the lenth of channel buffer for targe maxTolerantLag: 3 # Default value: "3", the timeout(in seconds) that target sends msgPack - maxDispatcherNumPerPchannel: 5 # The maximum number of dispatchers per physical channel, primarily to limit the number of consumers and prevent performance issues(e.g., during recovery when a large number of channels are watched). - retrySleep: 3 # register retry sleep time in seconds - retryTimeout: 60 # register retry timeout in seconds # Related configuration of pulsar, used to manage Milvus logs of recent mutation operations, output streaming log, and provide log publish-subscribe services. pulsar: diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index 9d205715ae..206e3592a7 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + "sort" "sync" "time" @@ -545,7 +546,15 @@ func (m *ChannelManagerImpl) advanceToNotifies(ctx context.Context, toNotifies [ zap.Int("total operation count", len(nodeAssign.Channels)), zap.Strings("channel names", chNames), ) - for _, ch := range nodeAssign.Channels { + + // Sort watch tasks by seek position to minimize lag between + // positions during batch subscription in the dispatcher. + channels := lo.Values(nodeAssign.Channels) + sort.Slice(channels, func(i, j int) bool { + return channels[i].GetWatchInfo().GetVchan().GetSeekPosition().GetTimestamp() < + channels[j].GetWatchInfo().GetVchan().GetSeekPosition().GetTimestamp() + }) + for _, ch := range channels { innerCh := ch tmpWatchInfo := typeutil.Clone(innerCh.GetWatchInfo()) tmpWatchInfo.Vchan = m.h.GetDataVChanPositions(innerCh, allPartitionID) diff --git a/internal/flushcommon/pipeline/data_sync_service_test.go b/internal/flushcommon/pipeline/data_sync_service_test.go index 8fcc499fcc..967faec4df 100644 --- a/internal/flushcommon/pipeline/data_sync_service_test.go +++ b/internal/flushcommon/pipeline/data_sync_service_test.go @@ -336,7 +336,7 @@ func (s *DataSyncServiceSuite) SetupTest() { s.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(s.ms, nil) s.ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) s.ms.EXPECT().Chan().Return(s.msChan) - s.ms.EXPECT().Close().Return() + s.ms.EXPECT().Close().Return().Maybe() s.pipelineParams = &util2.PipelineParams{ Ctx: context.TODO(), diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go index f0c8d55a86..cea202ce75 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go @@ -21,7 +21,6 @@ import ( "fmt" "time" - "github.com/cockroachdb/errors" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -34,9 +33,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/mq/common" "github.com/milvus-io/milvus/pkg/v2/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" - "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/retry" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -87,22 +84,15 @@ func createNewInputFromDispatcher(initCtx context.Context, replicateConfig := msgstream.GetReplicateConfig(replicateID, schema.GetDbName(), schema.GetName()) if seekPos != nil && len(seekPos.MsgID) != 0 { - err := retry.Handle(initCtx, func() (bool, error) { - input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{ - VChannel: vchannel, - Pos: seekPos, - SubPos: common.SubscriptionPositionUnknown, - ReplicateConfig: replicateConfig, - }) - if err != nil { - log.Warn("datanode consume failed", zap.Error(err)) - return errors.Is(err, merr.ErrTooManyConsumers), err - } - return false, nil - }, retry.Sleep(paramtable.Get().MQCfg.RetrySleep.GetAsDuration(time.Second)), // 5 seconds - retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes + input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{ + VChannel: vchannel, + Pos: seekPos, + SubPos: common.SubscriptionPositionUnknown, + ReplicateConfig: replicateConfig, + }) if err != nil { log.Warn("datanode consume failed after retried", zap.Error(err)) + dispatcherClient.Deregister(vchannel) return nil, err } @@ -114,22 +104,15 @@ func createNewInputFromDispatcher(initCtx context.Context, return input, err } - err = retry.Handle(initCtx, func() (bool, error) { - input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{ - VChannel: vchannel, - Pos: nil, - SubPos: common.SubscriptionPositionEarliest, - ReplicateConfig: replicateConfig, - }) - if err != nil { - log.Warn("datanode consume failed", zap.Error(err)) - return errors.Is(err, merr.ErrTooManyConsumers), err - } - return false, nil - }, retry.Sleep(paramtable.Get().MQCfg.RetrySleep.GetAsDuration(time.Second)), // 5 seconds - retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes + input, err = dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{ + VChannel: vchannel, + Pos: nil, + SubPos: common.SubscriptionPositionEarliest, + ReplicateConfig: replicateConfig, + }) if err != nil { log.Warn("datanode consume failed after retried", zap.Error(err)) + dispatcherClient.Deregister(vchannel) return nil, err } diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 864b3e0e63..4191829005 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -311,11 +311,11 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { } // mocks - suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) - suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Chan().Return(suite.msgChan) - suite.msgStream.EXPECT().Close() + suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe() + suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe() + suite.msgStream.EXPECT().Close().Maybe() // watchDmChannels status, err := suite.node.WatchDmChannels(ctx, req) @@ -363,11 +363,11 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { } // mocks - suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) - suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Chan().Return(suite.msgChan) - suite.msgStream.EXPECT().Close() + suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe() + suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe() + suite.msgStream.EXPECT().Close().Maybe() // watchDmChannels status, err := suite.node.WatchDmChannels(ctx, req) @@ -498,16 +498,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { suite.ErrorIs(merr.Error(status), merr.ErrChannelReduplicate) suite.node.unsubscribingChannels.Remove(suite.vchannel) - // init msgstream failed - suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) - suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Close().Return() - suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once() - - status, err = suite.node.WatchDmChannels(ctx, req) - suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) - // load growing failed badSegmentReq := typeutil.Clone(req) for _, info := range badSegmentReq.SegmentInfos { diff --git a/internal/util/pipeline/stream_pipeline.go b/internal/util/pipeline/stream_pipeline.go index 053b71fc52..1cce952a65 100644 --- a/internal/util/pipeline/stream_pipeline.go +++ b/internal/util/pipeline/stream_pipeline.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" "go.uber.org/atomic" "go.uber.org/zap" @@ -36,9 +35,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/v2/streaming/util/options" - "github.com/milvus-io/milvus/pkg/v2/util/merr" - "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/retry" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" ) @@ -127,22 +123,15 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M } start := time.Now() - err = retry.Handle(ctx, func() (bool, error) { - p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{ - VChannel: p.vChannel, - Pos: position, - SubPos: common.SubscriptionPositionUnknown, - ReplicateConfig: p.replicateConfig, - }) - if err != nil { - log.Warn("dispatcher register failed", zap.String("channel", position.ChannelName), zap.Error(err)) - return errors.Is(err, merr.ErrTooManyConsumers), err - } - return false, nil - }, retry.Sleep(paramtable.Get().MQCfg.RetrySleep.GetAsDuration(time.Second)), // 5 seconds - retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes + p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{ + VChannel: p.vChannel, + Pos: position, + SubPos: common.SubscriptionPositionUnknown, + ReplicateConfig: p.replicateConfig, + }) if err != nil { log.Error("dispatcher register failed after retried", zap.String("channel", position.ChannelName), zap.Error(err)) + p.dispatcher.Deregister(p.vChannel) return WrapErrRegDispather(err) } diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 931e358b5e..47977c0236 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -18,7 +18,6 @@ package msgdispatcher import ( "context" - "fmt" "time" "go.uber.org/zap" @@ -29,8 +28,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/lock" - "github.com/milvus-io/milvus/pkg/v2/util/merr" - "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -82,13 +79,15 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client { } func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) { - vchannel := streamConfig.VChannel - log := log.With(zap.String("role", c.role), - zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) - pchannel := funcutil.ToPhysicalChannel(vchannel) start := time.Now() + vchannel := streamConfig.VChannel + pchannel := funcutil.ToPhysicalChannel(vchannel) + + log := log.Ctx(ctx).With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) + c.managerMut.Lock(pchannel) defer c.managerMut.Unlock(pchannel) + var manager DispatcherManager manager, ok := c.managers.Get(pchannel) if !ok { @@ -96,18 +95,10 @@ func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-ch c.managers.Insert(pchannel, manager) go manager.Run() } - // Check if the consumer number limit has been reached. - limit := paramtable.Get().MQCfg.MaxDispatcherNumPerPchannel.GetAsInt() - if manager.NumConsumer() >= limit { - return nil, merr.WrapErrTooManyConsumers(vchannel, fmt.Sprintf("limit=%d", limit)) - } + // Begin to register ch, err := manager.Add(ctx, streamConfig) if err != nil { - if manager.NumTarget() == 0 { - manager.Close() - c.managers.Remove(pchannel) - } log.Error("register failed", zap.Error(err)) return nil, err } @@ -116,13 +107,15 @@ func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-ch } func (c *client) Deregister(vchannel string) { - pchannel := funcutil.ToPhysicalChannel(vchannel) start := time.Now() + pchannel := funcutil.ToPhysicalChannel(vchannel) + c.managerMut.Lock(pchannel) defer c.managerMut.Unlock(pchannel) + if manager, ok := c.managers.Get(pchannel); ok { manager.Remove(vchannel) - if manager.NumTarget() == 0 { + if manager.NumTarget() == 0 && manager.NumConsumer() == 0 { manager.Close() c.managers.Remove(pchannel) } @@ -132,12 +125,12 @@ func (c *client) Deregister(vchannel string) { } func (c *client) Close() { - log := log.With(zap.String("role", c.role), - zap.Int64("nodeID", c.nodeID)) + log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID)) c.managers.Range(func(pchannel string, manager DispatcherManager) bool { c.managerMut.Lock(pchannel) defer c.managerMut.Unlock(pchannel) + log.Info("close manager", zap.String("channel", pchannel)) c.managers.Remove(pchannel) manager.Close() diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 8eafe90c59..2ec9bff3b0 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -25,62 +25,437 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "go.uber.org/atomic" "github.com/milvus-io/milvus/pkg/v2/mq/common" + "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func TestClient(t *testing.T) { - client := NewClient(newMockFactory(), typeutil.ProxyRole, 1) + factory := newMockFactory() + client := NewClient(factory, typeutil.ProxyRole, 1) assert.NotNil(t, client) - _, err := client.Register(context.Background(), NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) - assert.NoError(t, err) - _, err = client.Register(context.Background(), NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) - assert.NoError(t, err) - assert.NotPanics(t, func() { - client.Deregister("mock_vchannel_0") - client.Close() - }) + defer client.Close() - t.Run("with timeout ctx", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond) - defer cancel() - <-time.After(2 * time.Millisecond) + pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63()) - client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1) - defer client.Close() - assert.NotNil(t, client) - _, err := client.Register(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) - assert.Error(t, err) - }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + producer, err := newMockProducer(factory, pchannel) + assert.NoError(t, err) + go produceTimeTick(t, ctx, producer) + + _, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v1", pchannel), nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + + _, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v2", pchannel), nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + + client.Deregister(fmt.Sprintf("%s_v1", pchannel)) + client.Deregister(fmt.Sprintf("%s_v2", pchannel)) } func TestClient_Concurrency(t *testing.T) { - client1 := NewClient(newMockFactory(), typeutil.ProxyRole, 1) + factory := newMockFactory() + client1 := NewClient(factory, typeutil.ProxyRole, 1) assert.NotNil(t, client1) + defer client1.Close() + + paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "65536") + defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key) + + const ( + vchannelNumPerPchannel = 10 + pchannelNum = 16 + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pchannels := make([]string, pchannelNum) + for i := 0; i < pchannelNum; i++ { + pchannel := fmt.Sprintf("by-dev-rootcoord-dml-%d_%d", rand.Int63(), i) + pchannels[i] = pchannel + producer, err := newMockProducer(factory, pchannel) + assert.NoError(t, err) + go produceTimeTick(t, ctx, producer) + t.Logf("start to produce time tick to pchannel %s\n", pchannel) + } + wg := &sync.WaitGroup{} - const total = 100 deregisterCount := atomic.NewInt32(0) - for i := 0; i < total; i++ { - vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int()) - wg.Add(1) - go func() { - _, err := client1.Register(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) - assert.NoError(t, err) - for j := 0; j < rand.Intn(2); j++ { - client1.Deregister(vchannel) - deregisterCount.Inc() - } - wg.Done() - }() + for i := 0; i < vchannelNumPerPchannel; i++ { + for j := 0; j < pchannelNum; j++ { + vchannel := fmt.Sprintf("%s_%dv%d", pchannels[i], rand.Int(), i) + wg.Add(1) + go func() { + defer wg.Done() + _, err := client1.Register(ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + for j := 0; j < rand.Intn(2); j++ { + client1.Deregister(vchannel) + deregisterCount.Inc() + } + }() + } } wg.Wait() - expected := int(total - deregisterCount.Load()) c := client1.(*client) - n := c.managers.Len() - assert.Equal(t, expected, n) + expected := int(vchannelNumPerPchannel*pchannelNum - deregisterCount.Load()) + + // Verify registered targets number. + actual := 0 + c.managers.Range(func(pchannel string, manager DispatcherManager) bool { + actual += manager.NumTarget() + return true + }) + assert.Equal(t, expected, actual) + + // Verify active targets number. + assert.Eventually(t, func() bool { + actual = 0 + c.managers.Range(func(pchannel string, manager DispatcherManager) bool { + m := manager.(*dispatcherManager) + m.mu.RLock() + defer m.mu.RUnlock() + if m.mainDispatcher != nil { + actual += m.mainDispatcher.targets.Len() + } + for _, d := range m.deputyDispatchers { + actual += d.targets.Len() + } + return true + }) + t.Logf("expect = %d, actual = %d\n", expected, actual) + return expected == actual + }, 15*time.Second, 100*time.Millisecond) +} + +type SimulationSuite struct { + suite.Suite + + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + + client Client + factory msgstream.Factory + + pchannel2Producer map[string]msgstream.MsgStream + pchannel2Vchannels map[string]map[string]*vchannelHelper +} + +func (suite *SimulationSuite) SetupSuite() { + suite.factory = newMockFactory() +} + +func (suite *SimulationSuite) SetupTest() { + const ( + pchannelNum = 16 + vchannelNumPerPchannel = 10 + ) + + suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Minute*3) + suite.wg = &sync.WaitGroup{} + suite.client = NewClient(suite.factory, "test-client", 1) + + // Init pchannel and producers. + suite.pchannel2Producer = make(map[string]msgstream.MsgStream) + suite.pchannel2Vchannels = make(map[string]map[string]*vchannelHelper) + for i := 0; i < pchannelNum; i++ { + pchannel := fmt.Sprintf("by-dev-rootcoord-dispatcher-dml-%d_%d", time.Now().UnixNano(), i) + producer, err := newMockProducer(suite.factory, pchannel) + suite.NoError(err) + suite.pchannel2Producer[pchannel] = producer + suite.pchannel2Vchannels[pchannel] = make(map[string]*vchannelHelper) + } + + // Init vchannels. + for pchannel := range suite.pchannel2Producer { + for i := 0; i < vchannelNumPerPchannel; i++ { + collectionID := time.Now().UnixNano() + vchannel := fmt.Sprintf("%s_%dv0", pchannel, collectionID) + suite.pchannel2Vchannels[pchannel][vchannel] = &vchannelHelper{} + } + } +} + +func (suite *SimulationSuite) TestDispatchToVchannels() { + // Register vchannels. + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + output, err := suite.client.Register(suite.ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest)) + suite.NoError(err) + helper.output = output + } + } + + // Produce and dispatch messages to vchannel targets. + produceCtx, produceCancel := context.WithTimeout(suite.ctx, time.Second*3) + defer produceCancel() + for pchannel, vchannels := range suite.pchannel2Vchannels { + suite.wg.Add(1) + go produceMsgs(suite.T(), produceCtx, suite.wg, suite.pchannel2Producer[pchannel], vchannels) + } + // Mock pipelines consume messages. + consumeCtx, consumeCancel := context.WithTimeout(suite.ctx, 10*time.Second) + defer consumeCancel() + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + suite.wg.Add(1) + go consumeMsgsFromTargets(suite.T(), consumeCtx, suite.wg, vchannel, helper) + } + } + suite.wg.Wait() + + // Verify pub-sub messages number. + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + suite.Equal(helper.pubInsMsgNum.Load(), helper.subInsMsgNum.Load(), vchannel) + suite.Equal(helper.pubDelMsgNum.Load(), helper.subDelMsgNum.Load(), vchannel) + suite.Equal(helper.pubDDLMsgNum.Load(), helper.subDDLMsgNum.Load(), vchannel) + suite.Equal(helper.pubPackNum.Load(), helper.subPackNum.Load(), vchannel) + } + } +} + +func (suite *SimulationSuite) TestMerge() { + // Produce msgs. + produceCtx, produceCancel := context.WithCancel(suite.ctx) + for pchannel, producer := range suite.pchannel2Producer { + suite.wg.Add(1) + go produceMsgs(suite.T(), produceCtx, suite.wg, producer, suite.pchannel2Vchannels[pchannel]) + } + + // Get random msg positions to seek for each vchannel. + for pchannel, vchannels := range suite.pchannel2Vchannels { + getRandomSeekPositions(suite.T(), suite.ctx, suite.factory, pchannel, vchannels) + } + + // Register vchannels. + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + pos := helper.seekPos + assert.NotNil(suite.T(), pos) + suite.T().Logf("seekTs = %d, vchannel = %s, msgID=%v\n", pos.GetTimestamp(), vchannel, pos.GetMsgID()) + output, err := suite.client.Register(suite.ctx, NewStreamConfig( + vchannel, pos, + common.SubscriptionPositionUnknown, + )) + suite.NoError(err) + helper.output = output + } + } + + // Mock pipelines consume messages. + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + suite.wg.Add(1) + go consumeMsgsFromTargets(suite.T(), suite.ctx, suite.wg, vchannel, helper) + } + } + + // Verify dispatchers merged. + suite.Eventually(func() bool { + for pchannel := range suite.pchannel2Producer { + manager, ok := suite.client.(*client).managers.Get(pchannel) + suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel) + suite.True(ok) + if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist + return false + } + } + return true + }, 15*time.Second, 100*time.Millisecond) + + // Stop produce and verify pub-sub messages number. + produceCancel() + suite.Eventually(func() bool { + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + logFn := func(pubNum, skipNum, subNum int32, name string) { + suite.T().Logf("pub%sNum[%d]-skipped%sNum[%d] = %d, sub%sNum = %d, vchannel = %s\n", + name, pubNum, name, skipNum, pubNum-skipNum, name, subNum, vchannel) + } + if helper.pubInsMsgNum.Load()-helper.skippedInsMsgNum != helper.subInsMsgNum.Load() { + logFn(helper.pubInsMsgNum.Load(), helper.skippedInsMsgNum, helper.subInsMsgNum.Load(), "InsMsg") + return false + } + if helper.pubDelMsgNum.Load()-helper.skippedDelMsgNum != helper.subDelMsgNum.Load() { + logFn(helper.pubDelMsgNum.Load(), helper.skippedDelMsgNum, helper.subDelMsgNum.Load(), "DelMsg") + return false + } + if helper.pubDDLMsgNum.Load()-helper.skippedDDLMsgNum != helper.subDDLMsgNum.Load() { + logFn(helper.pubDDLMsgNum.Load(), helper.skippedDDLMsgNum, helper.subDDLMsgNum.Load(), "DDLMsg") + return false + } + if helper.pubPackNum.Load()-helper.skippedPackNum != helper.subPackNum.Load() { + logFn(helper.pubPackNum.Load(), helper.skippedPackNum, helper.subPackNum.Load(), "Pack") + return false + } + } + } + return true + }, 15*time.Second, 100*time.Millisecond) +} + +func (suite *SimulationSuite) TestSplit() { + // Modify the parameters to make triggering split easier. + paramtable.Get().Save(paramtable.Get().MQCfg.MaxTolerantLag.Key, "0.5") + defer paramtable.Get().Reset(paramtable.Get().MQCfg.MaxTolerantLag.Key) + paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "512") + defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key) + + // Produce msgs. + produceCtx, produceCancel := context.WithCancel(suite.ctx) + for pchannel, producer := range suite.pchannel2Producer { + suite.wg.Add(1) + go produceMsgs(suite.T(), produceCtx, suite.wg, producer, suite.pchannel2Vchannels[pchannel]) + } + + // Register vchannels. + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + output, err := suite.client.Register(suite.ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest)) + suite.NoError(err) + helper.output = output + } + } + + // Verify dispatchers merged. + suite.Eventually(func() bool { + for pchannel := range suite.pchannel2Producer { + manager, ok := suite.client.(*client).managers.Get(pchannel) + suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel) + suite.True(ok) + if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist + return false + } + } + return true + }, 15*time.Second, 100*time.Millisecond) + + getTargetChan := func(pchannel, vchannel string) chan *MsgPack { + manager, ok := suite.client.(*client).managers.Get(pchannel) + suite.True(ok) + t, ok := manager.(*dispatcherManager).registeredTargets.Get(vchannel) + suite.True(ok) + return t.ch + } + + // Inject additional messages into targets to trigger lag and split. + injectCtx, injectCancel := context.WithCancel(context.Background()) + const splitNumPerPchannel = 3 + for pchannel, vchannels := range suite.pchannel2Vchannels { + cnt := 0 + for vchannel := range vchannels { + suite.wg.Add(1) + targetCh := getTargetChan(pchannel, vchannel) + go func() { + defer suite.wg.Done() + for { + select { + case targetCh <- &MsgPack{}: + case <-injectCtx.Done(): + return + } + } + }() + cnt++ + if cnt == splitNumPerPchannel { + break + } + } + } + + // Verify split. + suite.Eventually(func() bool { + for pchannel := range suite.pchannel2Producer { + manager, ok := suite.client.(*client).managers.Get(pchannel) + suite.True(ok) + suite.T().Logf("verifing split, dispatcherNum = %d, splitNum+1 = %d, pchannel = %s\n", + manager.NumConsumer(), splitNumPerPchannel+1, pchannel) + if manager.NumConsumer() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers + return false + } + } + return true + }, 20*time.Second, 100*time.Millisecond) + + injectCancel() + + // Mock pipelines consume messages to trigger merged again. + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + suite.wg.Add(1) + go consumeMsgsFromTargets(suite.T(), suite.ctx, suite.wg, vchannel, helper) + } + } + + // Verify dispatchers merged. + suite.Eventually(func() bool { + for pchannel := range suite.pchannel2Producer { + manager, ok := suite.client.(*client).managers.Get(pchannel) + suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel) + suite.True(ok) + if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist + return false + } + } + return true + }, 15*time.Second, 100*time.Millisecond) + + // Stop produce and verify pub-sub messages number. + produceCancel() + suite.Eventually(func() bool { + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel, helper := range vchannels { + if helper.pubInsMsgNum.Load() != helper.subInsMsgNum.Load() { + suite.T().Logf("pubInsMsgNum = %d, subInsMsgNum = %d, vchannel = %s\n", + helper.pubInsMsgNum.Load(), helper.subInsMsgNum.Load(), vchannel) + return false + } + if helper.pubDelMsgNum.Load() != helper.subDelMsgNum.Load() { + suite.T().Logf("pubDelMsgNum = %d, subDelMsgNum = %d, vchannel = %s\n", + helper.pubDelMsgNum.Load(), helper.subDelMsgNum.Load(), vchannel) + return false + } + if helper.pubDDLMsgNum.Load() != helper.subDDLMsgNum.Load() { + suite.T().Logf("pubDDLMsgNum = %d, subDDLMsgNum = %d, vchannel = %s\n", + helper.pubDDLMsgNum.Load(), helper.subDDLMsgNum.Load(), vchannel) + return false + } + if helper.pubPackNum.Load() != helper.subPackNum.Load() { + suite.T().Logf("pubPackNum = %d, subPackNum = %d, vchannel = %s\n", + helper.pubPackNum.Load(), helper.subPackNum.Load(), vchannel) + return false + } + } + } + return true + }, 15*time.Second, 100*time.Millisecond) +} + +func (suite *SimulationSuite) TearDownTest() { + for _, vchannels := range suite.pchannel2Vchannels { + for vchannel := range vchannels { + suite.client.Deregister(vchannel) + } + } + suite.client.Close() + suite.cancel() + suite.wg.Wait() +} + +func (suite *SimulationSuite) TearDownSuite() { +} + +func TestSimulation(t *testing.T) { + suite.Run(t, new(SimulationSuite)) } func TestClientMainDispatcherLeak(t *testing.T) { diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index 929f285f63..aeb23d6880 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -62,20 +63,20 @@ type Dispatcher struct { ctx context.Context cancel context.CancelFunc + id int64 + + pullbackEndTs typeutil.Timestamp + pullbackDone bool + pullbackDoneNotifier *syncutil.AsyncTaskNotifier[struct{}] + done chan struct{} wg sync.WaitGroup once sync.Once - isMain bool // indicates if it's a main dispatcher pchannel string curTs atomic.Uint64 - lagNotifyChan chan struct{} - lagTargets *typeutil.ConcurrentMap[string, *target] // vchannel -> *target - - // vchannel -> *target, lock free since we guarantee that - // it's modified only after dispatcher paused or terminated - targets map[string]*target + targets *typeutil.ConcurrentMap[string, *target] stream msgstream.MsgStream } @@ -83,18 +84,17 @@ type Dispatcher struct { func NewDispatcher( ctx context.Context, factory msgstream.Factory, - isMain bool, + id int64, pchannel string, position *Pos, - subName string, subPos SubPos, - lagNotifyChan chan struct{}, - lagTargets *typeutil.ConcurrentMap[string, *target], - includeCurrentMsg bool, + pullbackEndTs typeutil.Timestamp, ) (*Dispatcher, error) { - log := log.With(zap.String("pchannel", pchannel), - zap.String("subName", subName), zap.Bool("isMain", isMain)) - log.Info("creating dispatcher...") + subName := fmt.Sprintf("%s-%d-%d", pchannel, id, time.Now().UnixNano()) + + log := log.Ctx(ctx).With(zap.String("pchannel", pchannel), + zap.Int64("id", id), zap.String("subName", subName)) + log.Info("creating dispatcher...", zap.Uint64("pullbackEndTs", pullbackEndTs)) var stream msgstream.MsgStream var err error @@ -116,8 +116,8 @@ func NewDispatcher( log.Error("asConsumer failed", zap.Error(err)) return nil, err } - - err = stream.Seek(ctx, []*Pos{position}, includeCurrentMsg) + log.Info("as consumer done", zap.Any("position", position)) + err = stream.Seek(ctx, []*Pos{position}, false) if err != nil { log.Error("seek failed", zap.Error(err)) return nil, err @@ -135,59 +135,75 @@ func NewDispatcher( } d := &Dispatcher{ - done: make(chan struct{}, 1), - isMain: isMain, - pchannel: pchannel, - lagNotifyChan: lagNotifyChan, - lagTargets: lagTargets, - targets: make(map[string]*target), - stream: stream, + id: id, + pullbackEndTs: pullbackEndTs, + pullbackDoneNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + done: make(chan struct{}, 1), + pchannel: pchannel, + targets: typeutil.NewConcurrentMap[string, *target](), + stream: stream, } metrics.NumConsumers.WithLabelValues(paramtable.GetRole(), fmt.Sprint(paramtable.GetNodeID())).Inc() return d, nil } +func (d *Dispatcher) ID() int64 { + return d.id +} + func (d *Dispatcher) CurTs() typeutil.Timestamp { return d.curTs.Load() } func (d *Dispatcher) AddTarget(t *target) { - log := log.With(zap.String("vchannel", t.vchannel), zap.Bool("isMain", d.isMain)) - if _, ok := d.targets[t.vchannel]; ok { + log := log.With(zap.String("vchannel", t.vchannel), zap.Int64("id", d.ID()), zap.Uint64("ts", t.pos.GetTimestamp())) + if _, ok := d.targets.GetOrInsert(t.vchannel, t); ok { log.Warn("target exists") return } - d.targets[t.vchannel] = t log.Info("add new target") } func (d *Dispatcher) GetTarget(vchannel string) (*target, error) { - if t, ok := d.targets[vchannel]; ok { + if t, ok := d.targets.Get(vchannel); ok { return t, nil } - return nil, fmt.Errorf("cannot find target, vchannel=%s, isMain=%t", vchannel, d.isMain) + return nil, fmt.Errorf("cannot find target, vchannel=%s", vchannel) } -func (d *Dispatcher) CloseTarget(vchannel string) { - log := log.With(zap.String("vchannel", vchannel), zap.Bool("isMain", d.isMain)) - if t, ok := d.targets[vchannel]; ok { - t.close() - delete(d.targets, vchannel) - log.Info("closed target") +func (d *Dispatcher) GetTargets() []*target { + return d.targets.Values() +} + +func (d *Dispatcher) HasTarget(vchannel string) bool { + return d.targets.Contain(vchannel) +} + +func (d *Dispatcher) RemoveTarget(vchannel string) { + log := log.With(zap.String("vchannel", vchannel), zap.Int64("id", d.ID())) + if _, ok := d.targets.GetAndRemove(vchannel); ok { + log.Info("target removed") } else { log.Warn("target not exist") } } func (d *Dispatcher) TargetNum() int { - return len(d.targets) + return d.targets.Len() +} + +func (d *Dispatcher) BlockUtilPullbackDone() { + select { + case <-d.ctx.Done(): + case <-d.pullbackDoneNotifier.FinishChan(): + } } func (d *Dispatcher) Handle(signal signal) { - log := log.With(zap.String("pchannel", d.pchannel), - zap.String("signal", signal.String()), zap.Bool("isMain", d.isMain)) - log.Info("get signal") + log := log.With(zap.String("pchannel", d.pchannel), zap.Int64("id", d.ID()), + zap.String("signal", signal.String())) + log.Debug("get signal") switch signal { case start: d.ctx, d.cancel = context.WithCancel(context.Background()) @@ -214,7 +230,7 @@ func (d *Dispatcher) Handle(signal signal) { } func (d *Dispatcher) work() { - log := log.With(zap.String("pchannel", d.pchannel), zap.Bool("isMain", d.isMain)) + log := log.With(zap.String("pchannel", d.pchannel), zap.Int64("id", d.ID())) log.Info("begin to work") defer d.wg.Done() for { @@ -232,12 +248,36 @@ func (d *Dispatcher) work() { targetPacks := d.groupingMsgs(pack) for vchannel, p := range targetPacks { var err error - t := d.targets[vchannel] - if d.isMain { - // for main dispatcher, split target if err occurs + t, _ := d.targets.Get(vchannel) + // The dispatcher seeks from the oldest target, + // so for each target, msg before the target position must be filtered out. + if p.EndTs <= t.pos.GetTimestamp() { + log.Info("skip msg", + zap.String("vchannel", vchannel), + zap.Int("msgCount", len(p.Msgs)), + zap.Uint64("packBeginTs", p.BeginTs), + zap.Uint64("packEndTs", p.EndTs), + zap.Uint64("posTs", t.pos.GetTimestamp()), + ) + for _, msg := range p.Msgs { + log.Debug("skip msg info", + zap.String("vchannel", vchannel), + zap.String("msgType", msg.Type().String()), + zap.Int64("msgID", msg.ID()), + zap.Uint64("msgBeginTs", msg.BeginTs()), + zap.Uint64("msgEndTs", msg.EndTs()), + zap.Uint64("packBeginTs", p.BeginTs), + zap.Uint64("packEndTs", p.EndTs), + zap.Uint64("posTs", t.pos.GetTimestamp()), + ) + } + continue + } + if d.targets.Len() > 1 { + // for dispatcher with multiple targets, split target if err occurs err = t.send(p) } else { - // for solo dispatcher, only 1 target exists, we should + // for dispatcher with only one target, // keep retrying if err occurs, unless it paused or terminated. for { err = t.send(p) @@ -250,12 +290,19 @@ func (d *Dispatcher) work() { t.pos = typeutil.Clone(pack.StartPositions[0]) // replace the pChannel with vChannel t.pos.ChannelName = t.vchannel - d.lagTargets.Insert(t.vchannel, t) - d.nonBlockingNotify() - delete(d.targets, vchannel) - log.Warn("lag target notified", zap.Error(err)) + d.targets.GetAndRemove(vchannel) + log.Warn("lag target", zap.Error(err)) } } + + if !d.pullbackDone && pack.EndPositions[0].GetTimestamp() >= d.pullbackEndTs { + d.pullbackDoneNotifier.Finish(struct{}{}) + log.Info("dispatcher pullback done", + zap.Uint64("pullbackEndTs", d.pullbackEndTs), + zap.Time("pullbackTime", tsoutil.PhysicalTime(d.pullbackEndTs)), + ) + d.pullbackDone = true + } } } } @@ -265,7 +312,7 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { // but we still need to dispatch time ticks to the targets. targetPacks := make(map[string]*MsgPack) replicateConfigs := make(map[string]*msgstream.ReplicateConfig) - for vchannel, t := range d.targets { + d.targets.Range(func(vchannel string, t *target) bool { targetPacks[vchannel] = &MsgPack{ BeginTs: pack.BeginTs, EndTs: pack.EndTs, @@ -276,7 +323,8 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { if t.replicateConfig != nil { replicateConfigs[vchannel] = t.replicateConfig } - } + return true + }) // group messages by vchannel for _, msg := range pack.Msgs { var vchannel, collectionID string @@ -348,7 +396,7 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { d.resetMsgPackTS(targetPacks[vchannel], beginTs, endTs) } for vchannel := range replicateEndChannels { - if t, ok := d.targets[vchannel]; ok { + if t, ok := d.targets.Get(vchannel); ok { t.replicateConfig = nil log.Info("replicate end, set replicate config nil", zap.String("vchannel", vchannel)) } @@ -374,10 +422,3 @@ func (d *Dispatcher) resetMsgPackTS(pack *MsgPack, newBeginTs, newEndTs typeutil pack.StartPositions = startPositions pack.EndPositions = endPositions } - -func (d *Dispatcher) nonBlockingNotify() { - select { - case d.lagNotifyChan <- struct{}{}: - default: - } -} diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index 9874ae97df..89f011a2ac 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -17,8 +17,6 @@ package msgdispatcher import ( - "fmt" - "math/rand" "sync" "testing" "time" @@ -37,7 +35,8 @@ import ( func TestDispatcher(t *testing.T) { ctx := context.Background() t.Run("test base", func(t *testing.T) { - d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) + d, err := NewDispatcher(ctx, newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0", + nil, common.SubscriptionPositionEarliest, 0) assert.NoError(t, err) assert.NotPanics(t, func() { d.Handle(start) @@ -65,19 +64,24 @@ func TestDispatcher(t *testing.T) { return ms, nil }, } - d, err := NewDispatcher(ctx, factory, true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) + d, err := NewDispatcher(ctx, factory, time.Now().UnixNano(), "mock_pchannel_0", + nil, common.SubscriptionPositionEarliest, 0) assert.Error(t, err) assert.Nil(t, d) }) t.Run("test target", func(t *testing.T) { - d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) + d, err := NewDispatcher(ctx, newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0", + nil, common.SubscriptionPositionEarliest, 0) assert.NoError(t, err) output := make(chan *msgstream.MsgPack, 1024) getTarget := func(vchannel string, pos *Pos, ch chan *msgstream.MsgPack) *target { - target := newTarget(vchannel, pos, nil) + target := newTarget(&StreamConfig{ + VChannel: vchannel, + Pos: pos, + }) target.ch = ch return target } @@ -91,14 +95,7 @@ func TestDispatcher(t *testing.T) { assert.NoError(t, err) assert.Equal(t, cap(output), cap(target.ch)) - d.CloseTarget("mock_vchannel_0") - - select { - case <-time.After(1 * time.Second): - assert.Fail(t, "timeout, didn't receive close message") - case _, ok := <-target.ch: - assert.False(t, ok) - } + d.RemoveTarget("mock_vchannel_0") num = d.TargetNum() assert.Equal(t, 1, num) @@ -107,7 +104,7 @@ func TestDispatcher(t *testing.T) { t.Run("test concurrent send and close", func(t *testing.T) { for i := 0; i < 100; i++ { output := make(chan *msgstream.MsgPack, 1024) - target := newTarget("mock_vchannel_0", nil, nil) + target := newTarget(&StreamConfig{VChannel: "mock_vchannel_0"}) target.ch = output assert.Equal(t, cap(output), cap(target.ch)) wg := &sync.WaitGroup{} @@ -130,7 +127,8 @@ func TestDispatcher(t *testing.T) { } func BenchmarkDispatcher_handle(b *testing.B) { - d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) + d, err := NewDispatcher(context.Background(), newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0", + nil, common.SubscriptionPositionEarliest, 0) assert.NoError(b, err) for i := 0; i < b.N; i++ { @@ -144,10 +142,14 @@ func BenchmarkDispatcher_handle(b *testing.B) { } func TestGroupMessage(t *testing.T) { - d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false) + d, err := NewDispatcher(context.Background(), newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0", + nil, common.SubscriptionPositionEarliest, 0) assert.NoError(t, err) - d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil)) - d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo"))) + d.AddTarget(newTarget(&StreamConfig{VChannel: "mock_pchannel_0_1v0"})) + d.AddTarget(newTarget(&StreamConfig{ + VChannel: "mock_pchannel_0_2v0", + ReplicateConfig: msgstream.GetReplicateConfig("local-test", "foo", "coo"), + })) { // no replicate msg packs := d.groupingMsgs(&MsgPack{ @@ -286,7 +288,8 @@ func TestGroupMessage(t *testing.T) { { // replicate end - replicateTarget := d.targets["mock_pchannel_0_2v0"] + replicateTarget, ok := d.targets.Get("mock_pchannel_0_2v0") + assert.True(t, ok) assert.NotNil(t, replicateTarget.replicateConfig) packs := d.groupingMsgs(&MsgPack{ BeginTs: 1, diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index 13fbe06e65..9f8d3378b6 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -19,18 +19,20 @@ package msgdispatcher import ( "context" "fmt" + "sort" "sync" "time" "github.com/prometheus/client_golang/prometheus" + "github.com/samber/lo" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" - "github.com/milvus-io/milvus/pkg/v2/mq/common" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" - "github.com/milvus-io/milvus/pkg/v2/util/retry" + "github.com/milvus-io/milvus/pkg/v2/util/timerecord" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -51,121 +53,67 @@ type dispatcherManager struct { nodeID int64 pchannel string - lagNotifyChan chan struct{} - lagTargets *typeutil.ConcurrentMap[string, *target] // vchannel -> *target + registeredTargets *typeutil.ConcurrentMap[string, *target] - mu sync.RWMutex // guards mainDispatcher and soloDispatchers - mainDispatcher *Dispatcher - soloDispatchers map[string]*Dispatcher // vchannel -> *Dispatcher + mu sync.RWMutex + mainDispatcher *Dispatcher + deputyDispatchers map[int64]*Dispatcher // ID -> *Dispatcher - factory msgstream.Factory - closeChan chan struct{} - closeOnce sync.Once + idAllocator atomic.Int64 + factory msgstream.Factory + closeChan chan struct{} + closeOnce sync.Once } func NewDispatcherManager(pchannel string, role string, nodeID int64, factory msgstream.Factory) DispatcherManager { log.Info("create new dispatcherManager", zap.String("role", role), zap.Int64("nodeID", nodeID), zap.String("pchannel", pchannel)) c := &dispatcherManager{ - role: role, - nodeID: nodeID, - pchannel: pchannel, - lagNotifyChan: make(chan struct{}, 1), - lagTargets: typeutil.NewConcurrentMap[string, *target](), - soloDispatchers: make(map[string]*Dispatcher), - factory: factory, - closeChan: make(chan struct{}), + role: role, + nodeID: nodeID, + pchannel: pchannel, + registeredTargets: typeutil.NewConcurrentMap[string, *target](), + deputyDispatchers: make(map[int64]*Dispatcher), + factory: factory, + closeChan: make(chan struct{}), } return c } -func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) string { - return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain) -} - func (c *dispatcherManager) Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) { - vchannel := streamConfig.VChannel - log := log.With(zap.String("role", c.role), - zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) - - c.mu.Lock() - defer c.mu.Unlock() - if _, ok := c.soloDispatchers[vchannel]; ok { - // current dispatcher didn't allow multiple subscriptions on same vchannel at same time - log.Warn("unreachable: solo vchannel dispatcher already exists") - return nil, fmt.Errorf("solo vchannel dispatcher already exists") + t := newTarget(streamConfig) + if _, ok := c.registeredTargets.GetOrInsert(t.vchannel, t); ok { + return nil, fmt.Errorf("vchannel %s already exists in the dispatcher", t.vchannel) } - if c.mainDispatcher != nil { - if _, err := c.mainDispatcher.GetTarget(vchannel); err == nil { - // current dispatcher didn't allow multiple subscriptions on same vchannel at same time - log.Warn("unreachable: vchannel has been registered in main dispatcher, ") - return nil, fmt.Errorf("vchannel has been registered in main dispatcher") - } - } - - isMain := c.mainDispatcher == nil - d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, streamConfig.Pos, c.constructSubName(vchannel, isMain), streamConfig.SubPos, c.lagNotifyChan, c.lagTargets, false) - if err != nil { - return nil, err - } - t := newTarget(vchannel, streamConfig.Pos, streamConfig.ReplicateConfig) - d.AddTarget(t) - if isMain { - c.mainDispatcher = d - log.Info("add main dispatcher") - } else { - c.soloDispatchers[vchannel] = d - log.Info("add solo dispatcher") - } - d.Handle(start) + log.Ctx(ctx).Info("target register done", zap.String("vchannel", t.vchannel)) return t.ch, nil } func (c *dispatcherManager) Remove(vchannel string) { - log := log.With(zap.String("role", c.role), - zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) - c.mu.Lock() - defer c.mu.Unlock() - if _, ok := c.soloDispatchers[vchannel]; ok { - c.soloDispatchers[vchannel].Handle(terminate) - c.soloDispatchers[vchannel].CloseTarget(vchannel) - delete(c.soloDispatchers, vchannel) - c.deleteMetric(vchannel) - log.Info("remove soloDispatcher done") + t, ok := c.registeredTargets.GetAndRemove(vchannel) + if !ok { + log.Info("the target was not registered before", zap.String("role", c.role), + zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) + return } - - if c.mainDispatcher != nil { - c.mainDispatcher.Handle(pause) - c.mainDispatcher.CloseTarget(vchannel) - if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 { - c.mainDispatcher.Handle(terminate) - c.mainDispatcher = nil - log.Info("remove mainDispatcher done") - } else { - c.mainDispatcher.Handle(resume) - } - } - c.lagTargets.GetAndRemove(vchannel) + c.removeTargetFromDispatcher(t) + t.close() } func (c *dispatcherManager) NumTarget() int { - c.mu.RLock() - defer c.mu.RUnlock() - var res int - if c.mainDispatcher != nil { - res += c.mainDispatcher.TargetNum() - } - return res + len(c.soloDispatchers) + c.lagTargets.Len() + return c.registeredTargets.Len() } func (c *dispatcherManager) NumConsumer() int { c.mu.RLock() defer c.mu.RUnlock() - var res int + + numConsumer := 0 if c.mainDispatcher != nil { - res++ + numConsumer++ } - return res + len(c.soloDispatchers) + numConsumer += len(c.deputyDispatchers) + return numConsumer } func (c *dispatcherManager) Close() { @@ -175,8 +123,7 @@ func (c *dispatcherManager) Close() { } func (c *dispatcherManager) Run() { - log := log.With(zap.String("role", c.role), - zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) + log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) log.Info("dispatcherManager is running...") ticker1 := time.NewTicker(10 * time.Second) ticker2 := time.NewTicker(paramtable.Get().MQCfg.MergeCheckInterval.GetAsDuration(time.Second)) @@ -190,87 +137,232 @@ func (c *dispatcherManager) Run() { case <-ticker1.C: c.uploadMetric() case <-ticker2.C: + c.tryRemoveUnregisteredTargets() + c.tryBuildDispatcher() c.tryMerge() - case <-c.lagNotifyChan: - c.mu.Lock() - c.lagTargets.Range(func(vchannel string, t *target) bool { - c.split(t) - c.lagTargets.GetAndRemove(vchannel) - return true - }) - c.mu.Unlock() } } } -func (c *dispatcherManager) tryMerge() { - log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID)) +func (c *dispatcherManager) removeTargetFromDispatcher(t *target) { + log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) + c.mu.Lock() + defer c.mu.Unlock() + for _, dispatcher := range c.deputyDispatchers { + if dispatcher.HasTarget(t.vchannel) { + dispatcher.Handle(pause) + dispatcher.RemoveTarget(t.vchannel) + if dispatcher.TargetNum() == 0 { + dispatcher.Handle(terminate) + delete(c.deputyDispatchers, dispatcher.ID()) + log.Info("remove deputy dispatcher done", zap.Int64("id", dispatcher.ID())) + } else { + dispatcher.Handle(resume) + } + t.close() + } + } + if c.mainDispatcher != nil { + if c.mainDispatcher.HasTarget(t.vchannel) { + c.mainDispatcher.Handle(pause) + c.mainDispatcher.RemoveTarget(t.vchannel) + if c.mainDispatcher.TargetNum() == 0 && len(c.deputyDispatchers) == 0 { + c.mainDispatcher.Handle(terminate) + c.mainDispatcher = nil + } else { + c.mainDispatcher.Handle(resume) + } + t.close() + } + } +} + +func (c *dispatcherManager) tryRemoveUnregisteredTargets() { + unregisteredTargets := make([]*target, 0) + c.mu.RLock() + for _, dispatcher := range c.deputyDispatchers { + for _, t := range dispatcher.GetTargets() { + if !c.registeredTargets.Contain(t.vchannel) { + unregisteredTargets = append(unregisteredTargets, t) + } + } + } + if c.mainDispatcher != nil { + for _, t := range c.mainDispatcher.GetTargets() { + if !c.registeredTargets.Contain(t.vchannel) { + unregisteredTargets = append(unregisteredTargets, t) + } + } + } + c.mu.RUnlock() + + for _, t := range unregisteredTargets { + c.removeTargetFromDispatcher(t) + } +} + +func (c *dispatcherManager) tryBuildDispatcher() { + tr := timerecord.NewTimeRecorder("") + log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) + + allTargets := c.registeredTargets.Values() + // get lack targets to perform subscription + lackTargets := make([]*target, 0, len(allTargets)) + + c.mu.RLock() +OUTER: + for _, t := range allTargets { + if c.mainDispatcher != nil && c.mainDispatcher.HasTarget(t.vchannel) { + continue + } + for _, dispatcher := range c.deputyDispatchers { + if dispatcher.HasTarget(t.vchannel) { + continue OUTER + } + } + lackTargets = append(lackTargets, t) + } + c.mu.RUnlock() + + if len(lackTargets) == 0 { + return + } + + sort.Slice(lackTargets, func(i, j int) bool { + return lackTargets[i].pos.GetTimestamp() < lackTargets[j].pos.GetTimestamp() + }) + + // To prevent the position gap between targets from becoming too large and causing excessive pull-back time, + // limit the position difference between targets to no more than 60 minutes. + earliestTarget := lackTargets[0] + candidateTargets := make([]*target, 0, len(lackTargets)) + for _, t := range lackTargets { + if tsoutil.PhysicalTime(t.pos.GetTimestamp()).Sub( + tsoutil.PhysicalTime(earliestTarget.pos.GetTimestamp())) <= + paramtable.Get().MQCfg.MaxPositionTsGap.GetAsDuration(time.Minute) { + candidateTargets = append(candidateTargets, t) + } + } + + vchannels := lo.Map(candidateTargets, func(t *target, _ int) string { + return t.vchannel + }) + log.Info("start to build dispatchers", zap.Int("numTargets", len(vchannels)), + zap.Strings("vchannels", vchannels)) + + // dispatcher will pull back from the earliest position + // to the latest position in lack targets. + latestTarget := candidateTargets[len(candidateTargets)-1] + + // TODO: add newDispatcher timeout param and init context + id := c.idAllocator.Inc() + d, err := NewDispatcher(context.Background(), c.factory, id, c.pchannel, earliestTarget.pos, earliestTarget.subPos, latestTarget.pos.GetTimestamp()) + if err != nil { + panic(err) + } + for _, t := range candidateTargets { + d.AddTarget(t) + } + d.Handle(start) + buildDur := tr.RecordSpan() + + // block util pullback to the latest target position + if len(candidateTargets) > 1 { + d.BlockUtilPullbackDone() + } + + var ( + pullbackBeginTs = earliestTarget.pos.GetTimestamp() + pullbackEndTs = latestTarget.pos.GetTimestamp() + pullbackBeginTime = tsoutil.PhysicalTime(pullbackBeginTs) + pullbackEndTime = tsoutil.PhysicalTime(pullbackEndTs) + ) + log.Info("build dispatcher done", + zap.Int64("id", d.ID()), + zap.Int("numVchannels", len(vchannels)), + zap.Uint64("pullbackBeginTs", pullbackBeginTs), + zap.Uint64("pullbackEndTs", pullbackEndTs), + zap.Duration("lag", pullbackEndTime.Sub(pullbackBeginTime)), + zap.Time("pullbackBeginTime", pullbackBeginTime), + zap.Time("pullbackEndTime", pullbackEndTime), + zap.Duration("buildDur", buildDur), + zap.Duration("pullbackDur", tr.RecordSpan()), + zap.Strings("vchannels", vchannels), + ) c.mu.Lock() defer c.mu.Unlock() + + d.Handle(pause) + for _, candidate := range candidateTargets { + vchannel := candidate.vchannel + t, ok := c.registeredTargets.Get(vchannel) + // During the build process, the target may undergo repeated deregister and register, + // causing the channel object to change. Here, validate whether the channel is the + // same as before the build. If inconsistent, remove the target. + if !ok || t.ch != candidate.ch { + d.RemoveTarget(vchannel) + } + } + d.Handle(resume) + if c.mainDispatcher == nil { + c.mainDispatcher = d + log.Info("add main dispatcher", zap.Int64("id", d.ID())) + } else { + c.deputyDispatchers[d.ID()] = d + log.Info("add deputy dispatcher", zap.Int64("id", d.ID())) + } +} + +func (c *dispatcherManager) tryMerge() { + c.mu.Lock() + defer c.mu.Unlock() + + start := time.Now() + log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel)) + if c.mainDispatcher == nil || c.mainDispatcher.CurTs() == 0 { return } - candidates := make(map[string]struct{}) - for vchannel, sd := range c.soloDispatchers { + candidates := make([]*Dispatcher, 0, len(c.deputyDispatchers)) + for _, sd := range c.deputyDispatchers { if sd.CurTs() == c.mainDispatcher.CurTs() { - candidates[vchannel] = struct{}{} + candidates = append(candidates, sd) } } if len(candidates) == 0 { return } - log.Info("start merging...", zap.Any("vchannel", candidates)) + dispatcherIDs := lo.Map(candidates, func(d *Dispatcher, _ int) int64 { + return d.ID() + }) + + log.Info("start merging...", zap.Int64s("dispatchers", dispatcherIDs)) + mergeCandidates := make([]*Dispatcher, 0, len(candidates)) c.mainDispatcher.Handle(pause) - for vchannel := range candidates { - c.soloDispatchers[vchannel].Handle(pause) + for _, dispatcher := range candidates { + dispatcher.Handle(pause) // after pause, check alignment again, if not, evict it and try to merge next time - if c.mainDispatcher.CurTs() != c.soloDispatchers[vchannel].CurTs() { - c.soloDispatchers[vchannel].Handle(resume) - delete(candidates, vchannel) + if c.mainDispatcher.CurTs() != dispatcher.CurTs() { + dispatcher.Handle(resume) + continue } + mergeCandidates = append(mergeCandidates, dispatcher) } mergeTs := c.mainDispatcher.CurTs() - for vchannel := range candidates { - t, err := c.soloDispatchers[vchannel].GetTarget(vchannel) - if err == nil { + for _, dispatcher := range mergeCandidates { + targets := dispatcher.GetTargets() + for _, t := range targets { c.mainDispatcher.AddTarget(t) } - c.soloDispatchers[vchannel].Handle(terminate) - delete(c.soloDispatchers, vchannel) - c.deleteMetric(vchannel) + dispatcher.Handle(terminate) + delete(c.deputyDispatchers, dispatcher.ID()) } c.mainDispatcher.Handle(resume) - log.Info("merge done", zap.Any("vchannel", candidates), zap.Uint64("mergeTs", mergeTs)) -} - -func (c *dispatcherManager) split(t *target) { - log := log.With(zap.String("role", c.role), - zap.Int64("nodeID", c.nodeID), zap.String("vchannel", t.vchannel)) - log.Info("start splitting...") - - // remove stale soloDispatcher if it existed - if _, ok := c.soloDispatchers[t.vchannel]; ok { - c.soloDispatchers[t.vchannel].Handle(terminate) - delete(c.soloDispatchers, t.vchannel) - c.deleteMetric(t.vchannel) - } - - var newSolo *Dispatcher - err := retry.Do(context.Background(), func() error { - var err error - newSolo, err = NewDispatcher(context.Background(), c.factory, false, c.pchannel, t.pos, c.constructSubName(t.vchannel, false), common.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets, true) - return err - }, retry.Attempts(10)) - if err != nil { - log.Error("split failed", zap.Error(err)) - panic(err) - } - newSolo.AddTarget(t) - c.soloDispatchers[t.vchannel] = newSolo - newSolo.Handle(start) - log.Info("split done") + log.Info("merge done", zap.Int64s("dispatchers", dispatcherIDs), + zap.Uint64("mergeTs", mergeTs), + zap.Duration("dur", time.Since(start))) } // deleteMetric remove specific prometheus metric, @@ -289,18 +381,21 @@ func (c *dispatcherManager) deleteMetric(channel string) { func (c *dispatcherManager) uploadMetric() { c.mu.RLock() defer c.mu.RUnlock() + nodeIDStr := fmt.Sprintf("%d", c.nodeID) fn := func(gauge *prometheus.GaugeVec) { if c.mainDispatcher == nil { return } - // for main dispatcher, use pchannel as channel label - gauge.WithLabelValues(nodeIDStr, c.pchannel).Set( - float64(time.Since(tsoutil.PhysicalTime(c.mainDispatcher.CurTs())).Milliseconds())) - // for solo dispatchers, use vchannel as channel label - for vchannel, dispatcher := range c.soloDispatchers { - gauge.WithLabelValues(nodeIDStr, vchannel).Set( - float64(time.Since(tsoutil.PhysicalTime(dispatcher.CurTs())).Milliseconds())) + for _, t := range c.mainDispatcher.GetTargets() { + gauge.WithLabelValues(nodeIDStr, t.vchannel).Set( + float64(time.Since(tsoutil.PhysicalTime(c.mainDispatcher.CurTs())).Milliseconds())) + } + for _, dispatcher := range c.deputyDispatchers { + for _, t := range dispatcher.GetTargets() { + gauge.WithLabelValues(nodeIDStr, t.vchannel).Set( + float64(time.Since(tsoutil.PhysicalTime(dispatcher.CurTs())).Milliseconds())) + } } } if c.role == typeutil.DataNodeRole { diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index 83c7c0b856..c5e660b90c 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -20,422 +20,207 @@ import ( "context" "fmt" "math/rand" - "reflect" "sync" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/v2/mq/common" - "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func TestManager(t *testing.T) { + paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "65536") + defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key) + t.Run("test add and remove dispatcher", func(t *testing.T) { - c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63()) + + factory := newMockFactory() + producer, err := newMockProducer(factory, pchannel) + assert.NoError(t, err) + go produceTimeTick(t, ctx, producer) + + c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory) assert.NotNil(t, c) + go c.Run() + defer c.Close() assert.Equal(t, 0, c.NumConsumer()) assert.Equal(t, 0, c.NumTarget()) var offset int - for i := 0; i < 100; i++ { - r := rand.Intn(10) + 1 + for i := 0; i < 30; i++ { + r := rand.Intn(5) + 1 for j := 0; j < r; j++ { offset++ - vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset) + vchannel := fmt.Sprintf("%s_vchannelv%d", pchannel, offset) t.Logf("add vchannel, %s", vchannel) - _, err := c.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) + _, err := c.Add(ctx, NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, offset, c.NumConsumer()) - assert.Equal(t, offset, c.NumTarget()) } + assert.Eventually(t, func() bool { + t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget()) + return c.NumTarget() == offset + }, 3*time.Second, 10*time.Millisecond) for j := 0; j < rand.Intn(r); j++ { - vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset) + vchannel := fmt.Sprintf("%s_vchannelv%d", pchannel, offset) t.Logf("remove vchannel, %s", vchannel) c.Remove(vchannel) offset-- - assert.Equal(t, offset, c.NumConsumer()) - assert.Equal(t, offset, c.NumTarget()) } + assert.Eventually(t, func() bool { + t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget()) + return c.NumTarget() == offset + }, 3*time.Second, 10*time.Millisecond) } }) t.Run("test merge and split", func(t *testing.T) { - prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) - ctx := context.Background() - c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) + paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "16") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63()) + + factory := newMockFactory() + producer, err := newMockProducer(factory, pchannel) + assert.NoError(t, err) + go produceTimeTick(t, ctx, producer) + + c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory) assert.NotNil(t, c) - _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) - assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) - assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) - assert.NoError(t, err) - assert.Equal(t, 3, c.NumConsumer()) - assert.Equal(t, 3, c.NumTarget()) - c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) - c.(*dispatcherManager).mu.RLock() - for _, d := range c.(*dispatcherManager).soloDispatchers { - d.curTs.Store(1000) - } - c.(*dispatcherManager).mu.RUnlock() - c.(*dispatcherManager).tryMerge() - assert.Equal(t, 1, c.NumConsumer()) + go c.Run() + defer c.Close() + + paramtable.Get().Save(paramtable.Get().MQCfg.MaxTolerantLag.Key, "0.5") + defer paramtable.Get().Reset(paramtable.Get().MQCfg.MaxTolerantLag.Key) + + o0, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + o1, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) + o2, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown)) + assert.NoError(t, err) assert.Equal(t, 3, c.NumTarget()) - info := &target{ - vchannel: "mock_vchannel_2", - pos: nil, - ch: nil, + consumeFn := func(output <-chan *MsgPack, done <-chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + for { + select { + case <-done: + return + case <-output: + } + } } - c.(*dispatcherManager).split(info) - assert.Equal(t, 2, c.NumConsumer()) + wg := &sync.WaitGroup{} + wg.Add(3) + d0 := make(chan struct{}, 1) + d1 := make(chan struct{}, 1) + d2 := make(chan struct{}, 1) + go consumeFn(o0, d0, wg) + go consumeFn(o1, d1, wg) + go consumeFn(o2, d2, wg) + + assert.Eventually(t, func() bool { + return c.NumConsumer() == 1 // expected merge + }, 20*time.Second, 10*time.Millisecond) + + // stop consume vchannel_2 to trigger split + d2 <- struct{}{} + assert.Eventually(t, func() bool { + t.Logf("c.NumConsumer=%d", c.NumConsumer()) + return c.NumConsumer() == 2 // expected split + }, 20*time.Second, 10*time.Millisecond) + + // stop all + d0 <- struct{}{} + d1 <- struct{}{} + wg.Wait() }) t.Run("test run and close", func(t *testing.T) { - prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) - ctx := context.Background() - c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63()) + + factory := newMockFactory() + producer, err := newMockProducer(factory, pchannel) + assert.NoError(t, err) + go produceTimeTick(t, ctx, producer) + + c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory) assert.NotNil(t, c) - _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) + + go c.Run() + defer c.Close() + + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - assert.Equal(t, 3, c.NumConsumer()) assert.Equal(t, 3, c.NumTarget()) + assert.Eventually(t, func() bool { + return c.NumConsumer() >= 1 + }, 3*time.Second, 10*time.Millisecond) c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) - c.(*dispatcherManager).mu.RLock() - for _, d := range c.(*dispatcherManager).soloDispatchers { + for _, d := range c.(*dispatcherManager).deputyDispatchers { d.curTs.Store(1000) } - c.(*dispatcherManager).mu.RUnlock() checkIntervalK := paramtable.Get().MQCfg.MergeCheckInterval.Key paramtable.Get().Save(checkIntervalK, "0.01") defer paramtable.Get().Reset(checkIntervalK) - go c.Run() + assert.Eventually(t, func() bool { return c.NumConsumer() == 1 // expected merged }, 3*time.Second, 10*time.Millisecond) assert.Equal(t, 3, c.NumTarget()) - - assert.NotPanics(t, func() { - c.Close() - }) - }) - - t.Run("test add timeout", func(t *testing.T) { - prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*2) - defer cancel() - time.Sleep(time.Millisecond * 2) - c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) - go c.Run() - assert.NotNil(t, c) - _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) - assert.Error(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) - assert.Error(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) - assert.Error(t, err) - assert.Equal(t, 0, c.NumConsumer()) - assert.Equal(t, 0, c.NumTarget()) - - assert.NotPanics(t, func() { - c.Close() - }) }) t.Run("test_repeated_vchannel", func(t *testing.T) { - prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) - c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63()) + + factory := newMockFactory() + producer, err := newMockProducer(factory, pchannel) + assert.NoError(t, err) + go produceTimeTick(t, ctx, producer) + + c := NewDispatcherManager(pchannel, typeutil.ProxyRole, 1, factory) + go c.Run() + defer c.Close() + assert.NotNil(t, c) - ctx := context.Background() - _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-0", pchannel), nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-1", pchannel), nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) + _, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - assert.NotPanics(t, func() { - c.Close() - }) + assert.Eventually(t, func() bool { + return c.NumConsumer() >= 1 + }, 3*time.Second, 10*time.Millisecond) }) } - -type vchannelHelper struct { - output <-chan *msgstream.MsgPack - - pubInsMsgNum int - pubDelMsgNum int - pubDDLMsgNum int - pubPackNum int - - subInsMsgNum int - subDelMsgNum int - subDDLMsgNum int - subPackNum int -} - -type SimulationSuite struct { - suite.Suite - - testVchannelNum int - - manager DispatcherManager - pchannel string - vchannels map[string]*vchannelHelper - - producer msgstream.MsgStream - factory msgstream.Factory -} - -func (suite *SimulationSuite) SetupSuite() { - suite.factory = newMockFactory() -} - -func (suite *SimulationSuite) SetupTest() { - suite.pchannel = fmt.Sprintf("by-dev-rootcoord-dispatcher-simulation-dml_%d", time.Now().UnixNano()) - producer, err := newMockProducer(suite.factory, suite.pchannel) - assert.NoError(suite.T(), err) - suite.producer = producer - - suite.manager = NewDispatcherManager(suite.pchannel, typeutil.DataNodeRole, 0, suite.factory) - go suite.manager.Run() -} - -func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64) { - defer wg.Done() - - const timeTickCount = 100 - var uniqueMsgID int64 - vchannelKeys := reflect.ValueOf(suite.vchannels).MapKeys() - - for i := 1; i <= timeTickCount; i++ { - // produce random insert - insNum := rand.Intn(10) - for j := 0; j < insNum; j++ { - vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string) - err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ - Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)}, - }) - assert.NoError(suite.T(), err) - uniqueMsgID++ - suite.vchannels[vchannel].pubInsMsgNum++ - } - // produce random delete - delNum := rand.Intn(2) - for j := 0; j < delNum; j++ { - vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string) - err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ - Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)}, - }) - assert.NoError(suite.T(), err) - uniqueMsgID++ - suite.vchannels[vchannel].pubDelMsgNum++ - } - // produce random ddl - ddlNum := rand.Intn(2) - for j := 0; j < ddlNum; j++ { - err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ - Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)}, - }) - assert.NoError(suite.T(), err) - for k := range suite.vchannels { - suite.vchannels[k].pubDDLMsgNum++ - } - } - // produce time tick - ts := uint64(i * 100) - err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ - Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, - }) - assert.NoError(suite.T(), err) - for k := range suite.vchannels { - suite.vchannels[k].pubPackNum++ - } - } - suite.T().Logf("[%s] produce %d msgPack for %s done", time.Now(), timeTickCount, suite.pchannel) -} - -func (suite *SimulationSuite) consumeMsg(ctx context.Context, wg *sync.WaitGroup, vchannel string) { - defer wg.Done() - var lastTs typeutil.Timestamp - for { - select { - case <-ctx.Done(): - return - case pack := <-suite.vchannels[vchannel].output: - assert.Greater(suite.T(), pack.EndTs, lastTs) - lastTs = pack.EndTs - helper := suite.vchannels[vchannel] - helper.subPackNum++ - for _, msg := range pack.Msgs { - switch msg.Type() { - case commonpb.MsgType_Insert: - helper.subInsMsgNum++ - case commonpb.MsgType_Delete: - helper.subDelMsgNum++ - case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection, - commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition: - helper.subDDLMsgNum++ - } - } - } - } -} - -func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) { - tt := 1 - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - ts := uint64(tt * 1000) - err := suite.producer.Produce(ctx, &msgstream.MsgPack{ - Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, - }) - assert.NoError(suite.T(), err) - tt++ - } - } -} - -func (suite *SimulationSuite) TestDispatchToVchannels() { - ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond) - defer cancel() - - const ( - vchannelNum = 10 - collectionID int64 = 1234 - ) - suite.vchannels = make(map[string]*vchannelHelper, vchannelNum) - for i := 0; i < vchannelNum; i++ { - vchannel := fmt.Sprintf("%s_%dv%d", suite.pchannel, collectionID, i) - output, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest)) - assert.NoError(suite.T(), err) - suite.vchannels[vchannel] = &vchannelHelper{output: output} - } - - wg := &sync.WaitGroup{} - wg.Add(1) - go suite.produceMsg(wg, collectionID) - wg.Wait() - for vchannel := range suite.vchannels { - wg.Add(1) - go suite.consumeMsg(ctx, wg, vchannel) - } - wg.Wait() - for vchannel, helper := range suite.vchannels { - msg := fmt.Sprintf("vchannel=%s", vchannel) - assert.Equal(suite.T(), helper.pubInsMsgNum, helper.subInsMsgNum, msg) - assert.Equal(suite.T(), helper.pubDelMsgNum, helper.subDelMsgNum, msg) - assert.Equal(suite.T(), helper.pubDDLMsgNum, helper.subDDLMsgNum, msg) - assert.Equal(suite.T(), helper.pubPackNum, helper.subPackNum, msg) - } -} - -func (suite *SimulationSuite) TestMerge() { - ctx, cancel := context.WithCancel(context.Background()) - go suite.produceTimeTickOnly(ctx) - - const vchannelNum = 10 - suite.vchannels = make(map[string]*vchannelHelper, vchannelNum) - positions, err := getSeekPositions(suite.factory, suite.pchannel, 100) - assert.NoError(suite.T(), err) - assert.NotEqual(suite.T(), 0, len(positions)) - - for i := 0; i < vchannelNum; i++ { - vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - output, err := suite.manager.Add(context.Background(), NewStreamConfig( - vchannel, positions[rand.Intn(len(positions))], - common.SubscriptionPositionUnknown, - )) // seek from random position - assert.NoError(suite.T(), err) - suite.vchannels[vchannel] = &vchannelHelper{output: output} - } - wg := &sync.WaitGroup{} - for vchannel := range suite.vchannels { - wg.Add(1) - go suite.consumeMsg(ctx, wg, vchannel) - } - - suite.Eventually(func() bool { - suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.NumConsumer()) - return suite.manager.NumConsumer() == 1 // expected all merged, only mainDispatcher exist - }, 15*time.Second, 100*time.Millisecond) - assert.Equal(suite.T(), vchannelNum, suite.manager.NumTarget()) - - cancel() - wg.Wait() -} - -func (suite *SimulationSuite) TestSplit() { - ctx, cancel := context.WithCancel(context.Background()) - go suite.produceTimeTickOnly(ctx) - - const ( - vchannelNum = 10 - splitNum = 3 - ) - suite.vchannels = make(map[string]*vchannelHelper, vchannelNum) - maxTolerantLagK := paramtable.Get().MQCfg.MaxTolerantLag.Key - paramtable.Get().Save(maxTolerantLagK, "0.5") - defer paramtable.Get().Reset(maxTolerantLagK) - - targetBufSizeK := paramtable.Get().MQCfg.TargetBufSize.Key - defer paramtable.Get().Reset(targetBufSizeK) - - for i := 0; i < vchannelNum; i++ { - paramtable.Get().Save(targetBufSizeK, "65536") - if i >= vchannelNum-splitNum { - paramtable.Get().Save(targetBufSizeK, "10") - } - vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - _, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest)) - assert.NoError(suite.T(), err) - } - - suite.Eventually(func() bool { - suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.NumConsumer(), splitNum+1) - return suite.manager.NumConsumer() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers - }, 10*time.Second, 100*time.Millisecond) - assert.Equal(suite.T(), vchannelNum, suite.manager.NumTarget()) - - cancel() -} - -func (suite *SimulationSuite) TearDownTest() { - for vchannel := range suite.vchannels { - suite.manager.Remove(vchannel) - } - suite.manager.Close() -} - -func (suite *SimulationSuite) TearDownSuite() { -} - -func TestSimulation(t *testing.T) { - suite.Run(t, new(SimulationSuite)) -} diff --git a/pkg/mq/msgdispatcher/mock_test.go b/pkg/mq/msgdispatcher/mock_test.go index 3b6892a2ce..e9bcb3cca0 100644 --- a/pkg/mq/msgdispatcher/mock_test.go +++ b/pkg/mq/msgdispatcher/mock_test.go @@ -21,14 +21,20 @@ import ( "fmt" "math/rand" "os" + "sync" "testing" "time" + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/mq/common" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -55,34 +61,11 @@ func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgS if err != nil { return nil, err } - stream.AsProducer(context.TODO(), []string{pchannel}) + stream.AsProducer(context.Background(), []string{pchannel}) stream.SetRepackFunc(defaultInsertRepackFunc) return stream, nil } -func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([]*msgstream.MsgPosition, error) { - stream, err := factory.NewTtMsgStream(context.Background()) - if err != nil { - return nil, err - } - defer stream.Close() - stream.AsConsumer(context.TODO(), []string{pchannel}, fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest) - positions := make([]*msgstream.MsgPosition, 0) - timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - for { - select { - case <-timeoutCtx.Done(): // no message to consume - return positions, nil - case pack := <-stream.Chan(): - positions = append(positions, pack.EndPositions[0]) - if len(positions) >= maxNum { - return positions, nil - } - } - } -} - func genPKs(numRows int) []typeutil.IntPrimaryKey { ids := make([]typeutil.IntPrimaryKey, numRows) for i := 0; i < numRows; i++ { @@ -91,15 +74,15 @@ func genPKs(numRows int) []typeutil.IntPrimaryKey { return ids } -func genTimestamps(numRows int) []typeutil.Timestamp { - ts := make([]typeutil.Timestamp, numRows) +func genTimestamps(numRows int, ts typeutil.Timestamp) []typeutil.Timestamp { + tss := make([]typeutil.Timestamp, numRows) for i := 0; i < numRows; i++ { - ts[i] = typeutil.Timestamp(i + 1) + tss[i] = ts } - return ts + return tss } -func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.InsertMsg { +func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID, ts typeutil.Timestamp) *msgstream.InsertMsg { floatVec := make([]float32, numRows*dim) for i := 0; i < numRows*dim; i++ { floatVec[i] = rand.Float32() @@ -111,9 +94,9 @@ func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstr return &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{HashValues: hashValues}, InsertRequest: &msgpb.InsertRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID}, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID, Timestamp: ts}, ShardName: vchannel, - Timestamps: genTimestamps(numRows), + Timestamps: genTimestamps(numRows, ts), RowIDs: genPKs(numRows), FieldsData: []*schemapb.FieldData{{ Field: &schemapb.FieldData_Vectors{ @@ -129,11 +112,11 @@ func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstr } } -func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.DeleteMsg { +func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID, ts typeutil.Timestamp) *msgstream.DeleteMsg { return &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{HashValues: make([]uint32, numRows)}, DeleteRequest: &msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID}, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID, Timestamp: ts}, ShardName: vchannel, PrimaryKeys: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ @@ -142,19 +125,19 @@ func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstr }, }, }, - Timestamps: genTimestamps(numRows), + Timestamps: genTimestamps(numRows, ts), NumRows: int64(numRows), }, } } -func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg { +func genDDLMsg(msgType commonpb.MsgType, collectionID int64, ts typeutil.Timestamp) msgstream.TsMsg { switch msgType { case commonpb.MsgType_CreateCollection: return &msgstream.CreateCollectionMsg{ BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, CreateCollectionRequest: &msgpb.CreateCollectionRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection, Timestamp: ts}, CollectionID: collectionID, }, } @@ -162,7 +145,7 @@ func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg { return &msgstream.DropCollectionMsg{ BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, DropCollectionRequest: &msgpb.DropCollectionRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection}, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection, Timestamp: ts}, CollectionID: collectionID, }, } @@ -170,7 +153,7 @@ func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg { return &msgstream.CreatePartitionMsg{ BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, CreatePartitionRequest: &msgpb.CreatePartitionRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition}, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition, Timestamp: ts}, CollectionID: collectionID, }, } @@ -178,7 +161,7 @@ func genDDLMsg(msgType commonpb.MsgType, collectionID int64) msgstream.TsMsg { return &msgstream.DropPartitionMsg{ BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, DropPartitionRequest: &msgpb.DropPartitionRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition}, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition, Timestamp: ts}, CollectionID: collectionID, }, } @@ -226,3 +209,196 @@ func defaultInsertRepackFunc( } return pack, nil } + +type vchannelHelper struct { + output <-chan *msgstream.MsgPack + + pubInsMsgNum atomic.Int32 + pubDelMsgNum atomic.Int32 + pubDDLMsgNum atomic.Int32 + pubPackNum atomic.Int32 + + subInsMsgNum atomic.Int32 + subDelMsgNum atomic.Int32 + subDDLMsgNum atomic.Int32 + subPackNum atomic.Int32 + + seekPos *Pos + skippedInsMsgNum int32 + skippedDelMsgNum int32 + skippedDDLMsgNum int32 + skippedPackNum int32 +} + +func produceMsgs(t *testing.T, ctx context.Context, wg *sync.WaitGroup, producer msgstream.MsgStream, vchannels map[string]*vchannelHelper) { + defer wg.Done() + + uniqueMsgID := int64(0) + vchannelNames := lo.Keys(vchannels) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + i := 1 + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + ts := uint64(i * 100) + // produce random insert + insNum := rand.Intn(10) + for j := 0; j < insNum; j++ { + vchannel := vchannelNames[rand.Intn(len(vchannels))] + err := producer.Produce(context.Background(), &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID, ts)}, + }) + assert.NoError(t, err) + uniqueMsgID++ + vchannels[vchannel].pubInsMsgNum.Inc() + } + // produce random delete + delNum := rand.Intn(2) + for j := 0; j < delNum; j++ { + vchannel := vchannelNames[rand.Intn(len(vchannels))] + err := producer.Produce(context.Background(), &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(10)+1, vchannel, uniqueMsgID, ts)}, + }) + assert.NoError(t, err) + uniqueMsgID++ + vchannels[vchannel].pubDelMsgNum.Inc() + } + // produce random ddl + ddlNum := rand.Intn(2) + for j := 0; j < ddlNum; j++ { + vchannel := vchannelNames[rand.Intn(len(vchannels))] + collectionID := funcutil.GetCollectionIDFromVChannel(vchannel) + err := producer.Produce(context.Background(), &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID, ts)}, + }) + assert.NoError(t, err) + uniqueMsgID++ + vchannels[vchannel].pubDDLMsgNum.Inc() + } + // produce time tick + err := producer.Produce(context.Background(), &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, + }) + assert.NoError(t, err) + for k := range vchannels { + vchannels[k].pubPackNum.Inc() + } + i++ + } + } +} + +func consumeMsgsFromTargets(t *testing.T, ctx context.Context, wg *sync.WaitGroup, vchannel string, helper *vchannelHelper) { + defer wg.Done() + + var lastTs typeutil.Timestamp + for { + select { + case <-ctx.Done(): + return + case pack := <-helper.output: + if pack == nil || pack.EndTs == 0 { + continue + } + assert.Greater(t, pack.EndTs, lastTs, fmt.Sprintf("vchannel=%s", vchannel)) + lastTs = pack.EndTs + helper.subPackNum.Inc() + for _, msg := range pack.Msgs { + switch msg.Type() { + case commonpb.MsgType_Insert: + helper.subInsMsgNum.Inc() + case commonpb.MsgType_Delete: + helper.subDelMsgNum.Inc() + case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection, + commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition: + helper.subDDLMsgNum.Inc() + } + } + } + } +} + +func produceTimeTick(t *testing.T, ctx context.Context, producer msgstream.MsgStream) { + tt := 1 + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + ts := uint64(tt * 1000) + err := producer.Produce(ctx, &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, + }) + assert.NoError(t, err) + tt++ + } + } +} + +func getRandomSeekPositions(t *testing.T, ctx context.Context, factory msgstream.Factory, pchannel string, vchannels map[string]*vchannelHelper) { + stream, err := factory.NewTtMsgStream(context.Background()) + assert.NoError(t, err) + defer stream.Close() + + err = stream.AsConsumer(context.Background(), []string{pchannel}, fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest) + assert.NoError(t, err) + + for { + select { + case <-ctx.Done(): + return + case pack := <-stream.Chan(): + for _, msg := range pack.Msgs { + switch msg.Type() { + case commonpb.MsgType_Insert: + vchannel := msg.(*msgstream.InsertMsg).GetShardName() + if vchannels[vchannel].seekPos == nil { + vchannels[vchannel].skippedInsMsgNum++ + } + case commonpb.MsgType_Delete: + vchannel := msg.(*msgstream.DeleteMsg).GetShardName() + if vchannels[vchannel].seekPos == nil { + vchannels[vchannel].skippedDelMsgNum++ + } + case commonpb.MsgType_DropCollection: + collectionID := msg.(*msgstream.DropCollectionMsg).GetCollectionID() + for vchannel := range vchannels { + if vchannels[vchannel].seekPos == nil && + funcutil.GetCollectionIDFromVChannel(vchannel) == collectionID { + vchannels[vchannel].skippedDDLMsgNum++ + } + } + } + } + for _, helper := range vchannels { + if helper.seekPos == nil { + helper.skippedPackNum++ + } + } + if rand.Intn(5) == 0 { // assign random seek position + for _, helper := range vchannels { + if helper.seekPos == nil { + helper.seekPos = pack.EndPositions[0] + break + } + } + } + allAssigned := true + for _, helper := range vchannels { + if helper.seekPos == nil { + allAssigned = false + break + } + } + if allAssigned { + return // all seek positions have been assigned + } + } + } +} diff --git a/pkg/mq/msgdispatcher/target.go b/pkg/mq/msgdispatcher/target.go index d5f84f5a7a..9455df8f02 100644 --- a/pkg/mq/msgdispatcher/target.go +++ b/pkg/mq/msgdispatcher/target.go @@ -32,6 +32,7 @@ import ( type target struct { vchannel string ch chan *MsgPack + subPos SubPos pos *Pos closeMu sync.Mutex @@ -44,12 +45,14 @@ type target struct { cancelCh lifetime.SafeChan } -func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateConfig) *target { +func newTarget(streamConfig *StreamConfig) *target { + replicateConfig := streamConfig.ReplicateConfig maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second) t := &target{ - vchannel: vchannel, + vchannel: streamConfig.VChannel, ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()), - pos: pos, + subPos: streamConfig.SubPos, + pos: streamConfig.Pos, cancelCh: lifetime.NewSafeChan(), maxLag: maxTolerantLag, timer: time.NewTimer(maxTolerantLag), @@ -58,7 +61,7 @@ func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateCo t.closed = false if replicateConfig != nil { log.Info("have replicate config", - zap.String("vchannel", vchannel), + zap.String("vchannel", streamConfig.VChannel), zap.String("replicateID", replicateConfig.ReplicateID)) } return t @@ -72,6 +75,7 @@ func (t *target) close() { t.closed = true t.timer.Stop() close(t.ch) + log.Info("close target chan", zap.String("vchannel", t.vchannel)) }) } @@ -94,7 +98,7 @@ func (t *target) send(pack *MsgPack) error { log.Info("target closed", zap.String("vchannel", t.vchannel)) return nil case <-t.timer.C: - return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s", t.vchannel, t.maxLag) + return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s, beginTs=%d, endTs=%d", t.vchannel, t.maxLag, pack.BeginTs, pack.EndTs) case t.ch <- pack: return nil } diff --git a/pkg/mq/msgdispatcher/target_test.go b/pkg/mq/msgdispatcher/target_test.go index 444f1e5dfa..4a40116897 100644 --- a/pkg/mq/msgdispatcher/target_test.go +++ b/pkg/mq/msgdispatcher/target_test.go @@ -14,7 +14,10 @@ import ( ) func TestSendTimeout(t *testing.T) { - target := newTarget("test1", &msgpb.MsgPosition{}, nil) + target := newTarget(&StreamConfig{ + VChannel: "test1", + Pos: &msgpb.MsgPosition{}, + }) time.Sleep(paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second)) diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 439df5905e..ddf906ab59 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -637,7 +637,7 @@ func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subN return errors.Wrapf(err, errMsg) } - panic(fmt.Sprintf("%s, errors = %s", errMsg, err.Error())) + panic(fmt.Sprintf("%s, subName = %s, errors = %s", errMsg, subName, err.Error())) } } diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index ed98c92ffb..26e9854a64 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -140,11 +140,10 @@ var ( ErrMetricNotFound = newMilvusError("metric not found", 1200, false) // Message queue related - ErrMqTopicNotFound = newMilvusError("topic not found", 1300, false) - ErrMqTopicNotEmpty = newMilvusError("topic not empty", 1301, false) - ErrMqInternal = newMilvusError("message queue internal error", 1302, false) - ErrDenyProduceMsg = newMilvusError("deny to write the message to mq", 1303, false) - ErrTooManyConsumers = newMilvusError("consumer number limit exceeded", 1304, false) + ErrMqTopicNotFound = newMilvusError("topic not found", 1300, false) + ErrMqTopicNotEmpty = newMilvusError("topic not empty", 1301, false) + ErrMqInternal = newMilvusError("message queue internal error", 1302, false) + ErrDenyProduceMsg = newMilvusError("deny to write the message to mq", 1303, false) // Privilege related // this operation is denied because the user not authorized, user need to login in first diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index 8eb2fc2174..643f0f040b 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -147,7 +147,6 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrMqTopicNotFound("unknown", "failed to get topic"), ErrMqTopicNotFound) s.ErrorIs(WrapErrMqTopicNotEmpty("unknown", "topic is not empty"), ErrMqTopicNotEmpty) s.ErrorIs(WrapErrMqInternal(errors.New("unknown"), "failed to consume"), ErrMqInternal) - s.ErrorIs(WrapErrTooManyConsumers("unknown", "too many consumers"), ErrTooManyConsumers) // field related s.ErrorIs(WrapErrFieldNotFound("meta", "failed to get field"), ErrFieldNotFound) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index eb4692050b..cefcc9be38 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -1000,14 +1000,6 @@ func WrapErrMqInternal(err error, msg ...string) error { return err } -func WrapErrTooManyConsumers(vchannel string, msg ...string) error { - err := wrapFields(ErrTooManyConsumers, value("vchannel", vchannel)) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "->")) - } - return err -} - func WrapErrPrivilegeNotAuthenticated(fmt string, args ...any) error { err := errors.Wrapf(ErrPrivilegeNotAuthenticated, fmt, args...) return err diff --git a/pkg/util/paramtable/service_param.go b/pkg/util/paramtable/service_param.go index 6de262e33d..29eaf2fbdf 100644 --- a/pkg/util/paramtable/service_param.go +++ b/pkg/util/paramtable/service_param.go @@ -529,12 +529,10 @@ type MQConfig struct { IgnoreBadPosition ParamItem `refreshable:"true"` // msgdispatcher - MergeCheckInterval ParamItem `refreshable:"false"` - TargetBufSize ParamItem `refreshable:"false"` - MaxTolerantLag ParamItem `refreshable:"true"` - MaxDispatcherNumPerPchannel ParamItem `refreshable:"true"` - RetrySleep ParamItem `refreshable:"true"` - RetryTimeout ParamItem `refreshable:"true"` + MergeCheckInterval ParamItem `refreshable:"false"` + TargetBufSize ParamItem `refreshable:"false"` + MaxTolerantLag ParamItem `refreshable:"true"` + MaxPositionTsGap ParamItem `refreshable:"true"` } // Init initializes the MQConfig object with a BaseTable. @@ -558,33 +556,6 @@ Valid values: [default, pulsar, kafka, rocksmq, natsmq]`, } p.MaxTolerantLag.Init(base.mgr) - p.MaxDispatcherNumPerPchannel = ParamItem{ - Key: "mq.dispatcher.maxDispatcherNumPerPchannel", - Version: "2.4.19", - DefaultValue: "5", - Doc: `The maximum number of dispatchers per physical channel, primarily to limit the number of consumers and prevent performance issues(e.g., during recovery when a large number of channels are watched).`, - Export: true, - } - p.MaxDispatcherNumPerPchannel.Init(base.mgr) - - p.RetrySleep = ParamItem{ - Key: "mq.dispatcher.retrySleep", - Version: "2.4.19", - DefaultValue: "3", - Doc: `register retry sleep time in seconds`, - Export: true, - } - p.RetrySleep.Init(base.mgr) - - p.RetryTimeout = ParamItem{ - Key: "mq.dispatcher.retryTimeout", - Version: "2.4.19", - DefaultValue: "60", - Doc: `register retry timeout in seconds`, - Export: true, - } - p.RetryTimeout.Init(base.mgr) - p.TargetBufSize = ParamItem{ Key: "mq.dispatcher.targetBufSize", Version: "2.4.4", @@ -603,6 +574,14 @@ Valid values: [default, pulsar, kafka, rocksmq, natsmq]`, } p.MergeCheckInterval.Init(base.mgr) + p.MaxPositionTsGap = ParamItem{ + Key: "mq.dispatcher.maxPositionGapInMinutes", + Version: "2.5", + DefaultValue: "60", + Doc: `The max position timestamp gap in minutes.`, + } + p.MaxPositionTsGap.Init(base.mgr) + p.EnablePursuitMode = ParamItem{ Key: "mq.enablePursuitMode", Version: "2.3.0", diff --git a/pkg/util/paramtable/service_param_test.go b/pkg/util/paramtable/service_param_test.go index 7c68382018..c371c6dcac 100644 --- a/pkg/util/paramtable/service_param_test.go +++ b/pkg/util/paramtable/service_param_test.go @@ -37,9 +37,7 @@ func TestServiceParam(t *testing.T) { assert.Equal(t, 1*time.Second, Params.MergeCheckInterval.GetAsDuration(time.Second)) assert.Equal(t, 16, Params.TargetBufSize.GetAsInt()) assert.Equal(t, 3*time.Second, Params.MaxTolerantLag.GetAsDuration(time.Second)) - assert.Equal(t, 5, Params.MaxDispatcherNumPerPchannel.GetAsInt()) - assert.Equal(t, 3*time.Second, Params.RetrySleep.GetAsDuration(time.Second)) - assert.Equal(t, 60*time.Second, Params.RetryTimeout.GetAsDuration(time.Second)) + assert.Equal(t, 60*time.Minute, Params.MaxPositionTsGap.GetAsDuration(time.Minute)) }) t.Run("test etcdConfig", func(t *testing.T) { diff --git a/tests/go_client/testcases/groupby_search_test.go b/tests/go_client/testcases/groupby_search_test.go index 4bdea0bd77..886d8d0512 100644 --- a/tests/go_client/testcases/groupby_search_test.go +++ b/tests/go_client/testcases/groupby_search_test.go @@ -474,7 +474,7 @@ func TestSearchGroupByUnsupportedDataType(t *testing.T) { common.DefaultFloatFieldName, common.DefaultDoubleFieldName, common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultInt8ArrayField, common.DefaultFloatArrayField, } { - _, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName)) + _, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, false, "unsupported data type") } } @@ -495,7 +495,7 @@ func TestSearchGroupByRangeSearch(t *testing.T) { // range search _, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName). - WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8")) + WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8").WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by") } diff --git a/tests/go_client/testcases/hybrid_search_test.go b/tests/go_client/testcases/hybrid_search_test.go index 5438c8cd14..0cb9679562 100644 --- a/tests/go_client/testcases/hybrid_search_test.go +++ b/tests/go_client/testcases/hybrid_search_test.go @@ -268,7 +268,7 @@ func TestHybridSearchMultiVectorsPagination(t *testing.T) { // offset 0, -1 -> 0 for _, offset := range []int{0, -1} { - searchRes, err := mc.HybridSearch(ctx, client.NewHybridSearchOption(schema.CollectionName, common.DefaultLimit, annReqDef).WithOffset(offset)) + searchRes, err := mc.HybridSearch(ctx, client.NewHybridSearchOption(schema.CollectionName, common.DefaultLimit, annReqDef).WithOffset(offset).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultLimit) } diff --git a/tests/go_client/testcases/query_test.go b/tests/go_client/testcases/query_test.go index 8f14c4ecc9..384345c015 100644 --- a/tests/go_client/testcases/query_test.go +++ b/tests/go_client/testcases/query_test.go @@ -65,14 +65,14 @@ func TestQueryVarcharPkDefault(t *testing.T) { // query expr := fmt.Sprintf("%s in ['0', '1', '2', '3', '4']", common.DefaultVarcharFieldName) - queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr)) + queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) common.CheckQueryResult(t, queryRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)}) // get ids -> same result with query varcharValues := []string{"0", "1", "2", "3", "4"} ids := column.NewColumnVarChar(common.DefaultVarcharFieldName, varcharValues) - getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids)) + getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, errGet, true) common.CheckQueryResult(t, getRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)}) } @@ -1094,12 +1094,12 @@ func TestQueryWithTemplateParam(t *testing.T) { } // default queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName). - WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values)) + WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) common.CheckQueryResult(t, queryRes.Fields, []column.Column{column.NewColumnInt64(common.DefaultInt64FieldName, int64Values)}) // cover keys - res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5)) + res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) require.Equal(t, 5, res.ResultCount) @@ -1107,14 +1107,14 @@ func TestQueryWithTemplateParam(t *testing.T) { anyValues := []int64{0.0, 100.0, 10000.0} countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName). WithFilter(fmt.Sprintf("json_contains_any (%s, {any_values})", common.DefaultFloatArrayField)).WithTemplateParam("any_values", anyValues). - WithOutputFields(common.QueryCountFieldName)) + WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) count, _ := countRes.Fields[0].GetAsInt64(0) require.EqualValues(t, 101, count) // dynamic countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName). - WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName)) + WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) count, _ = countRes.Fields[0].GetAsInt64(0) require.EqualValues(t, 1500, count) @@ -1123,7 +1123,8 @@ func TestQueryWithTemplateParam(t *testing.T) { countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName). WithFilter(fmt.Sprintf("%s['bool'] == {v}", common.DefaultJSONFieldName)). WithTemplateParam("v", false). - WithOutputFields(common.QueryCountFieldName)) + WithOutputFields(common.QueryCountFieldName). + WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) count, _ = countRes.Fields[0].GetAsInt64(0) require.EqualValues(t, 1500/2, count) @@ -1132,7 +1133,8 @@ func TestQueryWithTemplateParam(t *testing.T) { countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName). WithFilter(fmt.Sprintf("%s == {v}", common.DefaultBoolFieldName)). WithTemplateParam("v", true). - WithOutputFields(common.QueryCountFieldName)) + WithOutputFields(common.QueryCountFieldName). + WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) count, _ = countRes.Fields[0].GetAsInt64(0) require.EqualValues(t, common.DefaultNb/2, count) @@ -1141,7 +1143,8 @@ func TestQueryWithTemplateParam(t *testing.T) { res, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName). WithFilter(fmt.Sprintf("%s >= {k1} && %s < {k2}", common.DefaultInt64FieldName, common.DefaultInt64FieldName)). WithTemplateParam("v", 0).WithTemplateParam("k1", 1000). - WithTemplateParam("k2", 2000)) + WithTemplateParam("k2", 2000). + WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) require.EqualValues(t, 1000, res.ResultCount) }