diff --git a/internal/datacoord/compaction_trigger.go b/internal/datacoord/compaction_trigger.go index bb2aa10d94..103afc9570 100644 --- a/internal/datacoord/compaction_trigger.go +++ b/internal/datacoord/compaction_trigger.go @@ -406,7 +406,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) { return } - segment := t.meta.GetHealthySegment(t.meta.ctx, signal.segmentID) + segment := t.meta.GetHealthySegment(context.TODO(), signal.segmentID) if segment == nil { log.Warn("segment in compaction signal not found in meta", zap.Int64("segmentID", signal.segmentID)) return diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go index 2b213bdd3d..4a6fcda1a4 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go @@ -68,7 +68,7 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack { return make(chan *msgstream.MsgPack, 100) } -func (mtm *mockTtMsgStream) AsProducer(channels []string) {} +func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {} func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { return nil @@ -80,11 +80,11 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string { return make([]string, 0) } -func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error { +func (mtm *mockTtMsgStream) Produce(context.Context, *msgstream.MsgPack) error { return nil } -func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { +func (mtm *mockTtMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { return nil, nil } diff --git a/internal/proxy/channels_mgr.go b/internal/proxy/channels_mgr.go index 641a23b726..375649e5b6 100644 --- a/internal/proxy/channels_mgr.go +++ b/internal/proxy/channels_mgr.go @@ -39,7 +39,7 @@ import ( type channelsMgr interface { getChannels(collectionID UniqueID) ([]pChan, error) getVChannels(collectionID UniqueID) ([]vChan, error) - getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) + getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) removeDMLStream(collectionID UniqueID) removeAllDMLStream() } @@ -172,7 +172,7 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool return ok && streamInfos.stream != nil } -func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) { +func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) { var stream msgstream.MsgStream var err error @@ -181,7 +181,7 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy return nil, err } - stream.AsProducer(pchans) + stream.AsProducer(ctx, pchans) if repack != nil { stream.SetRepackFunc(repack) } @@ -202,7 +202,7 @@ func decPChanMetrics(pchans []pChan) { // createMsgStream create message stream for specified collection. Idempotent. // If stream already exists, directly return it and no error will be returned. -func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) { +func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) { mgr.mu.RLock() infos, ok := mgr.infos[collectionID] if ok && infos.stream != nil { @@ -219,7 +219,7 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr return nil, err } - stream, err := createStream(mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc) + stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc) if err != nil { // What if stream created by other goroutines? log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID)) @@ -253,12 +253,12 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea // getOrCreateStream get message stream of specified collection. // If stream doesn't exist, call createMsgStream to create for it. -func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) { +func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) { if stream, err := mgr.lockGetStream(collectionID); err == nil { return stream, nil } - return mgr.createMsgStream(collectionID) + return mgr.createMsgStream(ctx, collectionID) } // removeStream remove the corresponding stream of the specified collection. Idempotent. @@ -315,8 +315,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error) return mgr.dmlChannelsMgr.getVChannels(collectionID) } -func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) { - return mgr.dmlChannelsMgr.getOrCreateStream(collectionID) +func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) { + return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID) } func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) { diff --git a/internal/proxy/channels_mgr_test.go b/internal/proxy/channels_mgr_test.go index 555fd18a95..db5e89f323 100644 --- a/internal/proxy/channels_mgr_test.go +++ b/internal/proxy/channels_mgr_test.go @@ -214,7 +214,7 @@ func Test_createStream(t *testing.T) { factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) { return nil, errors.New("mock") } - _, err := createStream(factory, nil, nil) + _, err := createStream(context.TODO(), factory, nil, nil) assert.Error(t, err) }) @@ -223,7 +223,7 @@ func Test_createStream(t *testing.T) { factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { return nil, errors.New("mock") } - _, err := createStream(factory, nil, nil) + _, err := createStream(context.TODO(), factory, nil, nil) assert.Error(t, err) }) @@ -232,7 +232,7 @@ func Test_createStream(t *testing.T) { factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { return newMockMsgStream(), nil } - _, err := createStream(factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { + _, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { return nil, nil }) assert.NoError(t, err) @@ -247,7 +247,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { 100: {stream: newMockMsgStream()}, }, } - stream, err := m.createMsgStream(100) + stream, err := m.createMsgStream(context.TODO(), 100) assert.NoError(t, err) assert.NotNil(t, stream) }) @@ -275,7 +275,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - stream, err := m.createMsgStream(100) + stream, err := m.createMsgStream(context.TODO(), 100) assert.NoError(t, err) assert.NotNil(t, stream) }() @@ -295,7 +295,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { return channelInfos{}, errors.New("mock") }, } - _, err := m.createMsgStream(100) + _, err := m.createMsgStream(context.TODO(), 100) assert.Error(t, err) }) @@ -311,7 +311,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { msgStreamFactory: factory, repackFunc: nil, } - _, err := m.createMsgStream(100) + _, err := m.createMsgStream(context.TODO(), 100) assert.Error(t, err) }) @@ -328,10 +328,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { msgStreamFactory: factory, repackFunc: nil, } - stream, err := m.createMsgStream(100) + stream, err := m.createMsgStream(context.TODO(), 100) assert.NoError(t, err) assert.NotNil(t, stream) - stream, err = m.getOrCreateStream(100) + stream, err = m.getOrCreateStream(context.TODO(), 100) assert.NoError(t, err) assert.NotNil(t, stream) }) @@ -365,7 +365,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) { 100: {stream: newMockMsgStream()}, }, } - stream, err := m.getOrCreateStream(100) + stream, err := m.getOrCreateStream(context.TODO(), 100) assert.NoError(t, err) assert.NotNil(t, stream) }) @@ -377,7 +377,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) { return channelInfos{}, errors.New("mock") }, } - _, err := m.getOrCreateStream(100) + _, err := m.getOrCreateStream(context.TODO(), 100) assert.Error(t, err) }) @@ -394,7 +394,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) { msgStreamFactory: factory, repackFunc: nil, } - stream, err := m.getOrCreateStream(100) + stream, err := m.getOrCreateStream(context.TODO(), 100) assert.NoError(t, err) assert.NotNil(t, stream) }) diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index f3d797e165..afe83c873e 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -6323,7 +6323,7 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate Status: merr.Status(err), }, nil } - messageIDsMap, err := msgStream.Broadcast(msgPack) + messageIDsMap, err := msgStream.Broadcast(ctx, msgPack) if err != nil { log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 6aa84e48e8..8e995d9537 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -440,7 +440,7 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) { rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) assert.NoError(t, err) - node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) + node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel}) Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) @@ -483,7 +483,7 @@ func TestProxy_FlushAll(t *testing.T) { rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) assert.NoError(t, err) - node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) + node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel}) Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) @@ -955,7 +955,7 @@ func TestProxyCreateDatabase(t *testing.T) { rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) assert.NoError(t, err) - node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) + node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel}) t.Run("create database fail", func(t *testing.T) { rc := mocks.NewMockRootCoordClient(t) @@ -1015,7 +1015,7 @@ func TestProxyDropDatabase(t *testing.T) { rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) assert.NoError(t, err) - node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) + node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel}) t.Run("drop database fail", func(t *testing.T) { rc := mocks.NewMockRootCoordClient(t) @@ -1496,13 +1496,13 @@ func TestProxy_ReplicateMessage(t *testing.T) { factory := newMockMsgStreamFactory() msgStreamObj := msgstream.NewMockMsgStream(t) msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return() - msgStreamObj.EXPECT().AsProducer(mock.Anything).Return() + msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return() msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return() msgStreamObj.EXPECT().Close().Return() mockMsgID1 := mqcommon.NewMockMessageID(t) mockMsgID2 := mqcommon.NewMockMessageID(t) mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")) - broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{ + broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{ "unit_test_replicate_message": {mockMsgID1, mockMsgID2}, }, nil) @@ -1581,7 +1581,7 @@ func TestProxy_ReplicateMessage(t *testing.T) { { broadcastMock.Unset() - broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(nil, errors.New("mock error: broadcast")) + broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast")) resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) assert.NoError(t, err) assert.NotEqualValues(t, 0, resp.GetStatus().GetCode()) @@ -1590,7 +1590,7 @@ func TestProxy_ReplicateMessage(t *testing.T) { } { broadcastMock.Unset() - broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{ + broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{ "unit_test_replicate_message": {}, }, nil) resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) diff --git a/internal/proxy/mock_channels_manager.go b/internal/proxy/mock_channels_manager.go index 7342014623..441272ac9c 100644 --- a/internal/proxy/mock_channels_manager.go +++ b/internal/proxy/mock_channels_manager.go @@ -3,6 +3,8 @@ package proxy import ( + context "context" + msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream" mock "github.com/stretchr/testify/mock" ) @@ -78,9 +80,9 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri return _c } -// getOrCreateDmlStream provides a mock function with given fields: collectionID -func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.MsgStream, error) { - ret := _m.Called(collectionID) +// getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID +func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) { + ret := _m.Called(ctx, collectionID) if len(ret) == 0 { panic("no return value specified for getOrCreateDmlStream") @@ -88,19 +90,19 @@ func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.M var r0 msgstream.MsgStream var r1 error - if rf, ok := ret.Get(0).(func(int64) (msgstream.MsgStream, error)); ok { - return rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok { + return rf(ctx, collectionID) } - if rf, ok := ret.Get(0).(func(int64) msgstream.MsgStream); ok { - r0 = rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok { + r0 = rf(ctx, collectionID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(msgstream.MsgStream) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(collectionID) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, collectionID) } else { r1 = ret.Error(1) } @@ -114,14 +116,15 @@ type MockChannelsMgr_getOrCreateDmlStream_Call struct { } // getOrCreateDmlStream is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call { - return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", collectionID)} +func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call { + return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)} } -func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call { +func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -131,7 +134,7 @@ func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStr return _c } -func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call { +func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/mock_msgstream_test.go b/internal/proxy/mock_msgstream_test.go index 613dd97b94..9b99888d88 100644 --- a/internal/proxy/mock_msgstream_test.go +++ b/internal/proxy/mock_msgstream_test.go @@ -16,7 +16,7 @@ type mockMsgStream struct { enableProduce func(bool) } -func (m *mockMsgStream) AsProducer(producers []string) { +func (m *mockMsgStream) AsProducer(ctx context.Context, producers []string) { if m.asProducer != nil { m.asProducer(producers) } diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 8132e5bb3a..08b1a27232 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -255,7 +255,7 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack { return ms.msgChan } -func (ms *simpleMockMsgStream) AsProducer(channels []string) { +func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) { } func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { @@ -283,7 +283,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) { ms.increaseMsgCount(-delta) } -func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error { +func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error { defer ms.increaseMsgCount(1) ms.msgChan <- pack @@ -291,7 +291,7 @@ func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error { return nil } -func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { +func (ms *simpleMockMsgStream) Broadcast(ctx context.Context, pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { return map[string][]msgstream.MessageID{}, nil } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8486526dd6..1e0915f68f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -278,7 +278,7 @@ func (node *Proxy) Init() error { return err } node.replicateMsgStream.EnableProduce(true) - node.replicateMsgStream.AsProducer([]string{replicateMsgChannel}) + node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel}) node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory) if err != nil { diff --git a/internal/proxy/replicate_stream_manager.go b/internal/proxy/replicate_stream_manager.go index 5bf01d1f6e..651e09b51f 100644 --- a/internal/proxy/replicate_stream_manager.go +++ b/internal/proxy/replicate_stream_manager.go @@ -34,15 +34,15 @@ func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, r return manager } -func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.NewResourceFunc { +func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc { return func() (resource.Resource, error) { - msgStream, err := m.factory.NewMsgStream(m.ctx) + msgStream, err := m.factory.NewMsgStream(ctx) if err != nil { log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err)) return nil, err } msgStream.SetRepackFunc(replicatePackFunc) - msgStream.AsProducer([]string{channel}) + msgStream.AsProducer(ctx, []string{channel}) msgStream.EnableProduce(true) res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() { @@ -55,7 +55,7 @@ func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.N func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) { ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel)) - res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(channel)) + res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel)) if err != nil { ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err)) return nil, err diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 1d9df1c59f..8ffdf84860 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -142,7 +142,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { } dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) - stream, err := dt.chMgr.getOrCreateDmlStream(dt.collectionID) + stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID) if err != nil { return err } @@ -178,7 +178,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { zap.Int64("taskID", dt.ID()), zap.Duration("prepare duration", dt.tr.RecordSpan())) - err = stream.Produce(msgPack) + err = stream.Produce(ctx, msgPack) if err != nil { return err } diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index 99cd2e03e6..881fa4ee04 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -161,7 +161,7 @@ func TestDeleteTask_Execute(t *testing.T) { }, } - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(nil, errors.New("mock error")) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) assert.Error(t, dt.Execute(context.Background())) }) @@ -190,7 +190,7 @@ func TestDeleteTask_Execute(t *testing.T) { primaryKeys: pk, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) assert.Error(t, dt.Execute(context.Background())) }) @@ -226,8 +226,8 @@ func TestDeleteTask_Execute(t *testing.T) { primaryKeys: pk, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) - stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error")) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) + stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error")) assert.Error(t, dt.Execute(context.Background())) }) } @@ -535,9 +535,9 @@ func TestDeleteRunner_Run(t *testing.T) { }, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) - stream.EXPECT().Produce(mock.Anything).Return(fmt.Errorf("mock error")) + stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")) assert.Error(t, dr.Run(context.Background())) assert.Equal(t, int64(0), dr.result.DeleteCnt) @@ -644,9 +644,9 @@ func TestDeleteRunner_Run(t *testing.T) { }, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) - stream.EXPECT().Produce(mock.Anything).Return(nil) + stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { return workload.exec(ctx, 1, qn, "") @@ -768,7 +768,7 @@ func TestDeleteRunner_Run(t *testing.T) { }, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { return workload.exec(ctx, 1, qn, "") @@ -792,7 +792,7 @@ func TestDeleteRunner_Run(t *testing.T) { server.FinishSend(nil) return client }, nil) - stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error")) + stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error")) assert.Error(t, dr.Run(ctx)) assert.Equal(t, int64(0), dr.result.DeleteCnt) @@ -830,7 +830,7 @@ func TestDeleteRunner_Run(t *testing.T) { }, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { return workload.exec(ctx, 1, qn, "") @@ -854,7 +854,7 @@ func TestDeleteRunner_Run(t *testing.T) { server.FinishSend(nil) return client }, nil) - stream.EXPECT().Produce(mock.Anything).Return(nil) + stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil) assert.NoError(t, dr.Run(ctx)) assert.Equal(t, int64(3), dr.result.DeleteCnt) @@ -911,7 +911,7 @@ func TestDeleteRunner_Run(t *testing.T) { }, } stream := msgstream.NewMockMsgStream(t) - mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { return workload.exec(ctx, 1, qn, "") @@ -936,7 +936,7 @@ func TestDeleteRunner_Run(t *testing.T) { return client }, nil) - stream.EXPECT().Produce(mock.Anything).Return(nil) + stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil) assert.NoError(t, dr.Run(ctx)) assert.Equal(t, int64(3), dr.result.DeleteCnt) }) diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 620b469d2f..444d317c79 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -243,7 +243,7 @@ func (it *insertTask) Execute(ctx context.Context) error { it.insertMsg.CollectionID = collID getCacheDur := tr.RecordSpan() - stream, err := it.chMgr.getOrCreateDmlStream(collID) + stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID) if err != nil { return err } @@ -280,7 +280,7 @@ func (it *insertTask) Execute(ctx context.Context) error { log.Debug("assign segmentID for insert data success", zap.Duration("assign segmentID duration", assignSegmentIDDur)) - err = stream.Produce(msgPack) + err = stream.Produce(ctx, msgPack) if err != nil { log.Warn("fail to produce insert msg", zap.Error(err)) it.result.Status = merr.Status(err) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 1f74590c16..b1371ffe66 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1755,7 +1755,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) defer chMgr.removeAllDMLStream() - _, err = chMgr.getOrCreateDmlStream(collectionID) + _, err = chMgr.getOrCreateDmlStream(ctx, collectionID) assert.NoError(t, err) pchans, err := chMgr.getChannels(collectionID) assert.NoError(t, err) @@ -2004,7 +2004,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) defer chMgr.removeAllDMLStream() - _, err = chMgr.getOrCreateDmlStream(collectionID) + _, err = chMgr.getOrCreateDmlStream(ctx, collectionID) assert.NoError(t, err) pchans, err := chMgr.getChannels(collectionID) assert.NoError(t, err) @@ -3460,7 +3460,7 @@ func TestPartitionKey(t *testing.T) { chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) defer chMgr.removeAllDMLStream() - _, err = chMgr.getOrCreateDmlStream(collectionID) + _, err = chMgr.getOrCreateDmlStream(ctx, collectionID) assert.NoError(t, err) pchans, err := chMgr.getChannels(collectionID) assert.NoError(t, err) diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 3ca4853fa9..5755231e13 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -393,7 +393,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP zap.Int64("collectionID", collID)) getCacheDur := tr.RecordSpan() - _, err = it.chMgr.getOrCreateDmlStream(collID) + _, err = it.chMgr.getOrCreateDmlStream(ctx, collID) if err != nil { return err } @@ -526,7 +526,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) { log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName)) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID())) - stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID) + stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID) if err != nil { return err } @@ -547,7 +547,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) { } tr.RecordSpan() - err = stream.Produce(msgPack) + err = stream.Produce(ctx, msgPack) if err != nil { it.result.Status = merr.Status(err) return err diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 5a14b97f5d..15f4f36490 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1985,7 +1985,7 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream. EndTs: ts, Msgs: []msgstream.TsMsg{tsMsg}, } - msgErr := replicateMsgStream.Produce(msgPack) + msgErr := replicateMsgStream.Produce(ctx, msgPack) // ignore the error if the msg stream failed to produce the msg, // because it can be manually fixed in this error if msgErr != nil { diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 86b1cd8188..5065f65bd1 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -2430,7 +2430,7 @@ func TestSendReplicateMessagePack(t *testing.T) { }) t.Run("produce fail", func(t *testing.T) { - mockStream.EXPECT().Produce(mock.Anything).Return(errors.New("produce error")).Once() + mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once() SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{ Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{ IsReplicate: true, @@ -2444,7 +2444,7 @@ func TestSendReplicateMessagePack(t *testing.T) { }) t.Run("normal case", func(t *testing.T) { - mockStream.EXPECT().Produce(mock.Anything).Return(nil) + mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil) SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{}) SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{}) diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index 8e9719614f..8c40f99807 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -188,7 +188,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref d.checkPreCreatedTopic(ctx, factory, name) } - ms.AsProducer([]string{name}) + ms.AsProducer(ctx, []string{name}) dms := &dmlMsgStream{ ms: ms, refcnt: 0, @@ -291,7 +291,7 @@ func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) err dms.mutex.RLock() if dms.refcnt > 0 { - if _, err := dms.ms.Broadcast(pack); err != nil { + if _, err := dms.ms.Broadcast(d.ctx, pack); err != nil { log.Error("Broadcast failed", zap.Error(err), zap.String("chanName", chanName)) dms.mutex.RUnlock() return err @@ -312,7 +312,7 @@ func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack) dms.mutex.RLock() if dms.refcnt > 0 { - ids, err := dms.ms.Broadcast(pack) + ids, err := dms.ms.Broadcast(d.ctx, pack) if err != nil { log.Error("BroadcastMark failed", zap.Error(err), zap.String("chanName", chanName)) dms.mutex.RUnlock() diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index c7ce78b6ce..7d7f328483 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -277,17 +277,17 @@ type FailMsgStream struct { errBroadcast bool } -func (ms *FailMsgStream) Close() {} -func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil } -func (ms *FailMsgStream) AsProducer(channels []string) {} -func (ms *FailMsgStream) AsReader(channels []string, subName string) {} +func (ms *FailMsgStream) Close() {} +func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil } +func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {} +func (ms *FailMsgStream) AsReader(channels []string, subName string) {} func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { return nil } -func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {} -func (ms *FailMsgStream) GetProduceChannels() []string { return nil } -func (ms *FailMsgStream) Produce(*msgstream.MsgPack) error { return nil } -func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { +func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {} +func (ms *FailMsgStream) GetProduceChannels() []string { return nil } +func (ms *FailMsgStream) Produce(context.Context, *msgstream.MsgPack) error { return nil } +func (ms *FailMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { if ms.errBroadcast { return nil, errors.New("broadcast error") } diff --git a/internal/util/flowgraph/input_node_test.go b/internal/util/flowgraph/input_node_test.go index 03c8d38d90..bd7087b44c 100644 --- a/internal/util/flowgraph/input_node_test.go +++ b/internal/util/flowgraph/input_node_test.go @@ -42,8 +42,8 @@ func TestInputNode(t *testing.T) { msgPack := generateMsgPack() produceStream, _ := factory.NewMsgStream(context.TODO()) - produceStream.AsProducer(channels) - produceStream.Produce(&msgPack) + produceStream.AsProducer(context.TODO(), channels) + produceStream.Produce(context.TODO(), &msgPack) nodeName := "input_node" inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "") @@ -84,7 +84,7 @@ func Test_InputNodeSkipMode(t *testing.T) { msgStream.AsConsumer(context.Background(), channels, "sub", common.SubscriptionPositionEarliest) produceStream, _ := factory.NewMsgStream(context.TODO()) - produceStream.AsProducer(channels) + produceStream.AsProducer(context.TODO(), channels) closeCh := make(chan struct{}) outputCh := make(chan bool) @@ -110,7 +110,7 @@ func Test_InputNodeSkipMode(t *testing.T) { defer close(closeCh) msgPack := generateMsgPack() - produceStream.Produce(&msgPack) + produceStream.Produce(context.TODO(), &msgPack) log.Info("produce empty ttmsg") <-outputCh assert.Equal(t, 1, outputCount) @@ -118,7 +118,7 @@ func Test_InputNodeSkipMode(t *testing.T) { time.Sleep(3 * time.Second) assert.Equal(t, false, inputNode.skipMode) - produceStream.Produce(&msgPack) + produceStream.Produce(context.TODO(), &msgPack) log.Info("after 3 seconds with no active msg receive, input node will turn on skip mode") <-outputCh assert.Equal(t, 2, outputCount) @@ -126,13 +126,13 @@ func Test_InputNodeSkipMode(t *testing.T) { log.Info("some ttmsg will be skipped in skip mode") // this msg will be skipped - produceStream.Produce(&msgPack) + produceStream.Produce(context.TODO(), &msgPack) <-outputCh assert.Equal(t, 2, outputCount) assert.Equal(t, true, inputNode.skipMode) // this msg will be consumed - produceStream.Produce(&msgPack) + produceStream.Produce(context.TODO(), &msgPack) <-outputCh assert.Equal(t, 3, outputCount) assert.Equal(t, true, inputNode.skipMode) diff --git a/internal/util/flowgraph/node_test.go b/internal/util/flowgraph/node_test.go index 4b5020be18..08752ca435 100644 --- a/internal/util/flowgraph/node_test.go +++ b/internal/util/flowgraph/node_test.go @@ -80,13 +80,13 @@ func TestNodeManager_Start(t *testing.T) { msgStream.AsConsumer(context.TODO(), channels, "sub", common.SubscriptionPositionEarliest) produceStream, _ := factory.NewMsgStream(context.TODO()) - produceStream.AsProducer(channels) + produceStream.AsProducer(context.TODO(), channels) msgPack := generateMsgPack() - produceStream.Produce(&msgPack) + produceStream.Produce(context.TODO(), &msgPack) time.Sleep(time.Millisecond * 2) msgPack = generateMsgPack() - produceStream.Produce(&msgPack) + produceStream.Produce(context.TODO(), &msgPack) nodeName := "input_node" inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "") diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index 4f42392cb5..feb5579965 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -226,7 +226,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64) insNum := rand.Intn(10) for j := 0; j < insNum; j++ { vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string) - err := suite.producer.Produce(&msgstream.MsgPack{ + err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)}, }) assert.NoError(suite.T(), err) @@ -237,7 +237,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64) delNum := rand.Intn(2) for j := 0; j < delNum; j++ { vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string) - err := suite.producer.Produce(&msgstream.MsgPack{ + err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)}, }) assert.NoError(suite.T(), err) @@ -247,7 +247,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64) // produce random ddl ddlNum := rand.Intn(2) for j := 0; j < ddlNum; j++ { - err := suite.producer.Produce(&msgstream.MsgPack{ + err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)}, }) assert.NoError(suite.T(), err) @@ -257,7 +257,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64) } // produce time tick ts := uint64(i * 100) - err := suite.producer.Produce(&msgstream.MsgPack{ + err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{ Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, }) assert.NoError(suite.T(), err) @@ -305,7 +305,7 @@ func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) { return case <-ticker.C: ts := uint64(tt * 1000) - err := suite.producer.Produce(&msgstream.MsgPack{ + err := suite.producer.Produce(ctx, &msgstream.MsgPack{ Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, }) assert.NoError(suite.T(), err) diff --git a/pkg/mq/msgdispatcher/mock_test.go b/pkg/mq/msgdispatcher/mock_test.go index dbd4a0f08d..a2d93bfaf8 100644 --- a/pkg/mq/msgdispatcher/mock_test.go +++ b/pkg/mq/msgdispatcher/mock_test.go @@ -55,7 +55,7 @@ func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgS if err != nil { return nil, err } - stream.AsProducer([]string{pchannel}) + stream.AsProducer(context.TODO(), []string{pchannel}) stream.SetRepackFunc(defaultInsertRepackFunc) return stream, nil } diff --git a/pkg/mq/msgstream/factory_stream_test.go b/pkg/mq/msgstream/factory_stream_test.go index 38803e9723..be0e6fb503 100644 --- a/pkg/mq/msgstream/factory_stream_test.go +++ b/pkg/mq/msgstream/factory_stream_test.go @@ -173,11 +173,11 @@ func testTimeTickerAndInsert(t *testing.T, f []Factory) { defer consumer.Close() var err error - _, err = producer.Broadcast(&msgPack0) + _, err = producer.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = producer.Produce(&msgPack1) + err = producer.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack2) + _, err = producer.Broadcast(ctx, &msgPack2) assert.NoError(t, err) receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs)) @@ -210,17 +210,17 @@ func testTimeTickerNoSeek(t *testing.T, f []Factory) { defer producer.Close() var err error - _, err = producer.Broadcast(&msgPack0) + _, err = producer.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = producer.Produce(&msgPack1) + err = producer.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack2) + _, err = producer.Broadcast(ctx, &msgPack2) assert.NoError(t, err) - err = producer.Produce(&msgPack3) + err = producer.Produce(ctx, &msgPack3) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack4) + _, err = producer.Broadcast(ctx, &msgPack4) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack5) + _, err = producer.Broadcast(ctx, &msgPack5) assert.NoError(t, err) o1 := consume(ctx, consumer) @@ -259,7 +259,7 @@ func testSeekToLast(t *testing.T, f []Factory) { } // produce test data - err := producer.Produce(msgPack) + err := producer.Produce(ctx, msgPack) assert.NoError(t, err) // pick a seekPosition @@ -346,21 +346,21 @@ func testTimeTickerSeek(t *testing.T, f []Factory) { defer producer.Close() // Send message - _, err := producer.Broadcast(&msgPack0) + _, err := producer.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = producer.Produce(&msgPack1) + err = producer.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack2) + _, err = producer.Broadcast(ctx, &msgPack2) assert.NoError(t, err) - err = producer.Produce(&msgPack3) + err = producer.Produce(ctx, &msgPack3) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack4) + _, err = producer.Broadcast(ctx, &msgPack4) assert.NoError(t, err) - err = producer.Produce(&msgPack5) + err = producer.Produce(ctx, &msgPack5) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack6) + _, err = producer.Broadcast(ctx, &msgPack6) assert.NoError(t, err) - _, err = producer.Broadcast(&msgPack7) + _, err = producer.Broadcast(ctx, &msgPack7) assert.NoError(t, err) // Test received message @@ -434,13 +434,13 @@ func testTimeTickUnmarshalHeader(t *testing.T, f []Factory) { defer producer.Close() defer consumer.Close() - _, err := producer.Broadcast(&msgPack0) + _, err := producer.Broadcast(ctx, &msgPack0) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) - err = producer.Produce(&msgPack1) + err = producer.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - _, err = producer.Broadcast(&msgPack2) + _, err = producer.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs)) @@ -571,7 +571,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := producer.Produce(msgPack) + err := producer.Produce(ctx, msgPack) assert.NoError(t, err) var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { @@ -605,7 +605,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := producer.Produce(msgPack) + err := producer.Produce(ctx, msgPack) assert.NoError(t, err) var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { @@ -622,7 +622,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err = producer.Produce(msgPack) + err = producer.Produce(ctx, msgPack) assert.NoError(t, err) result := consume(ctx, consumer2) assert.Equal(t, result.Msgs[0].ID(), int64(1)) @@ -642,7 +642,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := producer.Produce(msgPack) + err := producer.Produce(ctx, msgPack) assert.NoError(t, err) consumer2 := createLatestConsumer(ctx, t, f[1].NewMsgStream, channels) defer consumer2.Close() @@ -653,7 +653,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) { insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err = producer.Produce(msgPack) + err = producer.Produce(ctx, msgPack) assert.NoError(t, err) for i := 10; i < 20; i++ { @@ -673,7 +673,7 @@ func testBroadcastMark(t *testing.T, f []Factory) { msgPack0 := MsgPack{} msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) - ids, err := producer.Broadcast(&msgPack0) + ids, err := producer.Broadcast(ctx, &msgPack0) assert.NoError(t, err) assert.NotNil(t, ids) assert.Equal(t, len(channels), len(ids)) @@ -687,7 +687,7 @@ func testBroadcastMark(t *testing.T, f []Factory) { msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) - ids, err = producer.Broadcast(&msgPack1) + ids, err = producer.Broadcast(ctx, &msgPack1) assert.NoError(t, err) assert.NotNil(t, ids) assert.Equal(t, len(channels), len(ids)) @@ -698,12 +698,12 @@ func testBroadcastMark(t *testing.T, f []Factory) { } // edge cases - _, err = producer.Broadcast(nil) + _, err = producer.Broadcast(ctx, nil) assert.Error(t, err) msgPack2 := MsgPack{} msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{}) - _, err = producer.Broadcast(&msgPack2) + _, err = producer.Broadcast(ctx, &msgPack2) assert.Error(t, err) } @@ -712,7 +712,7 @@ func applyBroadCastAndConsume(t *testing.T, msgPack *MsgPack, newer []streamNewe defer producer.Close() defer consumer.Close() - _, err := producer.Broadcast(msgPack) + _, err := producer.Broadcast(context.TODO(), msgPack) assert.NoError(t, err) receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)*channelNum) } @@ -728,7 +728,7 @@ func applyProduceAndConsumeWithRepack( defer producer.Close() defer consumer.Close() - err := producer.Produce(msgPack) + err := producer.Produce(context.TODO(), msgPack) assert.NoError(t, err) receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)) } @@ -743,7 +743,7 @@ func applyProduceAndConsume( defer producer.Close() defer consumer.Close() - err := producer.Produce(msgPack) + err := producer.Produce(context.TODO(), msgPack) assert.NoError(t, err) receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)) } @@ -774,7 +774,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer, func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream { producer, err := newer(ctx) assert.NoError(t, err) - producer.AsProducer(channels) + producer.AsProducer(ctx, channels) return producer } @@ -798,7 +798,7 @@ func createStream(ctx context.Context, t *testing.T, newer []streamNewer, channe assert.NotEmpty(t, channels) producer, err := newer[0](ctx) assert.NoError(t, err) - producer.AsProducer(channels) + producer.AsProducer(ctx, channels) consumer, err := newer[1](ctx) assert.NoError(t, err) diff --git a/pkg/mq/msgstream/mock_msgstream.go b/pkg/mq/msgstream/mock_msgstream.go index 50b403d939..47a1b9cb6d 100644 --- a/pkg/mq/msgstream/mock_msgstream.go +++ b/pkg/mq/msgstream/mock_msgstream.go @@ -74,9 +74,9 @@ func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func(context.Context, return _c } -// AsProducer provides a mock function with given fields: channels -func (_m *MockMsgStream) AsProducer(channels []string) { - _m.Called(channels) +// AsProducer provides a mock function with given fields: ctx, channels +func (_m *MockMsgStream) AsProducer(ctx context.Context, channels []string) { + _m.Called(ctx, channels) } // MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer' @@ -85,14 +85,15 @@ type MockMsgStream_AsProducer_Call struct { } // AsProducer is a helper method to define mock.On call +// - ctx context.Context // - channels []string -func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call { - return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)} +func (_e *MockMsgStream_Expecter) AsProducer(ctx interface{}, channels interface{}) *MockMsgStream_AsProducer_Call { + return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", ctx, channels)} } -func (_c *MockMsgStream_AsProducer_Call) Run(run func(channels []string)) *MockMsgStream_AsProducer_Call { +func (_c *MockMsgStream_AsProducer_Call) Run(run func(ctx context.Context, channels []string)) *MockMsgStream_AsProducer_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]string)) + run(args[0].(context.Context), args[1].([]string)) }) return _c } @@ -102,14 +103,14 @@ func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call return _c } -func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func([]string)) *MockMsgStream_AsProducer_Call { +func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func(context.Context, []string)) *MockMsgStream_AsProducer_Call { _c.Call.Return(run) return _c } -// Broadcast provides a mock function with given fields: _a0 -func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID, error) { - ret := _m.Called(_a0) +// Broadcast provides a mock function with given fields: _a0, _a1 +func (_m *MockMsgStream) Broadcast(_a0 context.Context, _a1 *MsgPack) (map[string][]common.MessageID, error) { + ret := _m.Called(_a0, _a1) if len(ret) == 0 { panic("no return value specified for Broadcast") @@ -117,19 +118,19 @@ func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID, var r0 map[string][]common.MessageID var r1 error - if rf, ok := ret.Get(0).(func(*MsgPack) (map[string][]common.MessageID, error)); ok { - return rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) (map[string][]common.MessageID, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]common.MessageID); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) map[string][]common.MessageID); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[string][]common.MessageID) } } - if rf, ok := ret.Get(1).(func(*MsgPack) error); ok { - r1 = rf(_a0) + if rf, ok := ret.Get(1).(func(context.Context, *MsgPack) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -143,14 +144,15 @@ type MockMsgStream_Broadcast_Call struct { } // Broadcast is a helper method to define mock.On call -// - _a0 *MsgPack -func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call { - return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)} +// - _a0 context.Context +// - _a1 *MsgPack +func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}, _a1 interface{}) *MockMsgStream_Broadcast_Call { + return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0, _a1)} } -func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Broadcast_Call { +func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Broadcast_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*MsgPack)) + run(args[0].(context.Context), args[1].(*MsgPack)) }) return _c } @@ -160,7 +162,7 @@ func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]common.MessageID return _c } -func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(*MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call { +func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, *MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call { _c.Call.Return(run) return _c } @@ -428,17 +430,17 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin return _c } -// Produce provides a mock function with given fields: _a0 -func (_m *MockMsgStream) Produce(_a0 *MsgPack) error { - ret := _m.Called(_a0) +// Produce provides a mock function with given fields: _a0, _a1 +func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error { + ret := _m.Called(_a0, _a1) if len(ret) == 0 { panic("no return value specified for Produce") } var r0 error - if rf, ok := ret.Get(0).(func(*MsgPack) error); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) error); ok { + r0 = rf(_a0, _a1) } else { r0 = ret.Error(0) } @@ -452,14 +454,15 @@ type MockMsgStream_Produce_Call struct { } // Produce is a helper method to define mock.On call -// - _a0 *MsgPack -func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call { - return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)} +// - _a0 context.Context +// - _a1 *MsgPack +func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}, _a1 interface{}) *MockMsgStream_Produce_Call { + return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0, _a1)} } -func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Produce_Call { +func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Produce_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*MsgPack)) + run(args[0].(context.Context), args[1].(*MsgPack)) }) return _c } @@ -469,7 +472,7 @@ func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_C return _c } -func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *MockMsgStream_Produce_Call { +func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(context.Context, *MsgPack) error) *MockMsgStream_Produce_Call { _c.Call.Return(run) return _c } diff --git a/pkg/mq/msgstream/mq_kafka_msgstream_test.go b/pkg/mq/msgstream/mq_kafka_msgstream_test.go index 03ab985f79..18317e4501 100644 --- a/pkg/mq/msgstream/mq_kafka_msgstream_test.go +++ b/pkg/mq/msgstream/mq_kafka_msgstream_test.go @@ -123,7 +123,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) { } // produce test data - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) // pick a seekPosition @@ -219,21 +219,21 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) { inputStream := getKafkaInputStream(ctx, kafkaAddress, producerChannels) outputStream := getKafkaTtOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) assert.NoError(t, err) - err = inputStream.Produce(&msgPack3) + err = inputStream.Produce(ctx, &msgPack3) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack4) + _, err = inputStream.Broadcast(ctx, &msgPack4) assert.NoError(t, err) - err = inputStream.Produce(&msgPack5) + err = inputStream.Produce(ctx, &msgPack5) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack6) + _, err = inputStream.Broadcast(ctx, &msgPack6) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack7) + _, err = inputStream.Broadcast(ctx, &msgPack7) assert.NoError(t, err) receivedMsg := consumer(ctx, outputStream) @@ -450,7 +450,7 @@ func getKafkaInputStream(ctx context.Context, kafkaAddress string, producerChann } kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfigMap(config, nil, nil) inputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) for _, opt := range opts { inputStream.SetRepackFunc(opt) } diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 09d7121985..5127a841c9 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -121,7 +121,7 @@ func NewMqMsgStream(ctx context.Context, } // AsProducer create producer to send message to channels -func (ms *mqMsgStream) AsProducer(channels []string) { +func (ms *mqMsgStream) AsProducer(ctx context.Context, channels []string) { for _, channel := range channels { if len(channel) == 0 { log.Error("MsgStream asProducer's channel is an empty string") @@ -129,7 +129,7 @@ func (ms *mqMsgStream) AsProducer(channels []string) { } fn := func() error { - pp, err := ms.client.CreateProducer(common.ProducerOptions{Topic: channel, EnableCompression: true}) + pp, err := ms.client.CreateProducer(ctx, common.ProducerOptions{Topic: channel, EnableCompression: true}) if err != nil { return err } @@ -176,7 +176,7 @@ func (ms *mqMsgStream) AsConsumer(ctx context.Context, channels []string, subNam continue } fn := func() error { - pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{ + pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: channel, SubscriptionName: subName, SubscriptionInitialPosition: position, @@ -273,7 +273,7 @@ func (ms *mqMsgStream) isEnabledProduce() bool { return ms.enableProduce.Load().(bool) } -func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { +func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error { if !ms.isEnabledProduce() { log.Warn("can't produce the msg in the backup instance", zap.Stack("stack")) return merr.ErrDenyProduceMsg @@ -346,7 +346,7 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { // BroadcastMark broadcast msg pack to all producers and returns corresponding msg id // the returned message id serves as marking -func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) (map[string][]MessageID, error) { +func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[string][]MessageID, error) { ids := make(map[string][]MessageID) if msgPack == nil || len(msgPack.Msgs) <= 0 { return ids, errors.New("empty msgs") @@ -581,7 +581,7 @@ func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subN continue } fn := func() error { - pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{ + pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: channel, SubscriptionName: subName, SubscriptionInitialPosition: position, diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index 9870a14422..b46ca44533 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -130,12 +130,12 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { { inputStream.EnableProduce(false) - err := inputStream.Produce(&msgPack) + err := inputStream.Produce(ctx, &msgPack) require.Error(t, err) } inputStream.EnableProduce(true) - err := inputStream.Produce(&msgPack) + err := inputStream.Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack.Msgs)) @@ -156,7 +156,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - err := inputStream.Produce(&msgPack) + err := inputStream.Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack.Msgs)) @@ -178,7 +178,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - err := inputStream.Produce(&msgPack) + err := inputStream.Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack.Msgs)) @@ -203,12 +203,12 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { { inputStream.EnableProduce(false) - _, err := inputStream.Broadcast(&msgPack) + _, err := inputStream.Broadcast(ctx, &msgPack) require.Error(t, err) } inputStream.EnableProduce(true) - _, err := inputStream.Broadcast(&msgPack) + _, err := inputStream.Broadcast(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveMsg(ctx, outputStream, len(consumerChannels)*len(msgPack.Msgs)) @@ -230,7 +230,7 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { ctx := context.Background() inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels, repackFunc) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - err := inputStream.Produce(&msgPack) + err := inputStream.Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack.Msgs)) @@ -277,14 +277,14 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { ctx := context.Background() pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) var output MsgStream = outputStream - err := (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, output, len(msgPack.Msgs)*2) @@ -328,14 +328,14 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { ctx := context.Background() pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) var output MsgStream = outputStream - err := (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, output, len(msgPack.Msgs)*1) @@ -360,14 +360,14 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { ctx := context.Background() pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) var output MsgStream = outputStream - err := (*inputStream).Produce(&msgPack) + err := (*inputStream).Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, output, len(msgPack.Msgs)) @@ -395,13 +395,13 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) @@ -440,17 +440,17 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) assert.NoError(t, err) - err = inputStream.Produce(&msgPack3) + err = inputStream.Produce(ctx, &msgPack3) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack4) + _, err = inputStream.Broadcast(ctx, &msgPack4) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack5) + _, err = inputStream.Broadcast(ctx, &msgPack5) assert.NoError(t, err) o1 := consumer(ctx, outputStream) @@ -495,7 +495,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) { } // produce test data - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) // pick a seekPosition @@ -617,21 +617,21 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) assert.NoError(t, err) - err = inputStream.Produce(&msgPack3) + err = inputStream.Produce(ctx, &msgPack3) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack4) + _, err = inputStream.Broadcast(ctx, &msgPack4) assert.NoError(t, err) - err = inputStream.Produce(&msgPack5) + err = inputStream.Produce(ctx, &msgPack5) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack6) + _, err = inputStream.Broadcast(ctx, &msgPack6) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack7) + _, err = inputStream.Broadcast(ctx, &msgPack7) assert.NoError(t, err) receivedMsg := consumer(ctx, outputStream) @@ -711,13 +711,13 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) @@ -748,16 +748,16 @@ func TestStream_PulsarTtMsgStream_DropCollection(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) - _, err = inputStream.Broadcast(&msgPack3) + _, err = inputStream.Broadcast(ctx, &msgPack3) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveMsg(ctx, outputStream, 2) @@ -803,12 +803,12 @@ func sendMsgPacks(ms MsgStream, msgPacks []*MsgPack) error { printMsgPack(msgPacks[i]) if i%2 == 0 { // insert msg use Produce - if err := ms.Produce(msgPacks[i]); err != nil { + if err := ms.Produce(context.TODO(), msgPacks[i]); err != nil { return err } } else { // tt msg use Broadcast - if _, err := ms.Broadcast(msgPacks[i]); err != nil { + if _, err := ms.Broadcast(context.TODO(), msgPacks[i]); err != nil { return err } } @@ -971,7 +971,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { @@ -1015,7 +1015,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { @@ -1049,7 +1049,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err = inputStream.Produce(msgPack) + err = inputStream.Produce(ctx, msgPack) assert.NoError(t, err) result := consumer(ctx, outputStream2) assert.Equal(t, result.Msgs[0].ID(), int64(1)) @@ -1074,7 +1074,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { @@ -1086,7 +1086,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) { // produce timetick for mqtt msgstream seek msgPack = &MsgPack{} msgPack.Msgs = append(msgPack.Msgs, getTimeTickMsg(1000)) - err = inputStream.Produce(msgPack) + err = inputStream.Produce(ctx, msgPack) assert.NoError(t, err) factory := ProtoUDFactory{} @@ -1139,7 +1139,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) @@ -1152,7 +1152,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) { insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err = inputStream.Produce(msgPack) + err = inputStream.Produce(ctx, msgPack) assert.NoError(t, err) for i := 10; i < 20; i++ { @@ -1169,6 +1169,7 @@ func TestStream_BroadcastMark(t *testing.T) { c1 := funcutil.RandomString(8) c2 := funcutil.RandomString(8) producerChannels := []string{c1, c2} + ctx := context.Background() factory := ProtoUDFactory{} pulsarClient, err := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) @@ -1177,12 +1178,12 @@ func TestStream_BroadcastMark(t *testing.T) { assert.NoError(t, err) // add producer channels - outputStream.AsProducer(producerChannels) + outputStream.AsProducer(ctx, producerChannels) msgPack0 := MsgPack{} msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) - ids, err := outputStream.Broadcast(&msgPack0) + ids, err := outputStream.Broadcast(ctx, &msgPack0) assert.NoError(t, err) assert.NotNil(t, ids) assert.Equal(t, len(producerChannels), len(ids)) @@ -1196,7 +1197,7 @@ func TestStream_BroadcastMark(t *testing.T) { msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) - ids, err = outputStream.Broadcast(&msgPack1) + ids, err = outputStream.Broadcast(ctx, &msgPack1) assert.NoError(t, err) assert.NotNil(t, ids) assert.Equal(t, len(producerChannels), len(ids)) @@ -1207,19 +1208,19 @@ func TestStream_BroadcastMark(t *testing.T) { } // edge cases - _, err = outputStream.Broadcast(nil) + _, err = outputStream.Broadcast(ctx, nil) assert.Error(t, err) msgPack2 := MsgPack{} msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{}) - _, err = outputStream.Broadcast(&msgPack2) + _, err = outputStream.Broadcast(ctx, &msgPack2) assert.Error(t, err) // mock send fail for k, p := range outputStream.producers { outputStream.producers[k] = &mockSendFailProducer{Producer: p} } - _, err = outputStream.Broadcast(&msgPack1) + _, err = outputStream.Broadcast(ctx, &msgPack1) assert.Error(t, err) outputStream.Close() @@ -1497,7 +1498,7 @@ func getPulsarInputStream(ctx context.Context, pulsarAddress string, producerCha factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) for _, opt := range opts { inputStream.SetRepackFunc(opt) } diff --git a/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go b/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go index 73f3f069b4..c982d401b4 100644 --- a/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go @@ -52,7 +52,7 @@ func TestMqMsgStream_AsProducer(t *testing.T) { assert.NoError(t, err) // empty channel name - m.AsProducer([]string{""}) + m.AsProducer(context.TODO(), []string{""}) } // TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage @@ -121,7 +121,7 @@ func TestMqMsgStream_GetProduceChannels(t *testing.T) { assert.Equal(t, 0, len(chs)) // not empty after AsProducer - m.AsProducer([]string{"a"}) + m.AsProducer(context.TODO(), []string{"a"}) chs = m.GetProduceChannels() assert.Equal(t, 1, len(chs)) } @@ -160,7 +160,7 @@ func TestMqMsgStream_Produce(t *testing.T) { msgPack := &MsgPack{ Msgs: []TsMsg{insertMsg}, } - err = m.Produce(msgPack) + err = m.Produce(context.TODO(), msgPack) assert.Error(t, err) } @@ -173,7 +173,7 @@ func TestMqMsgStream_Broadcast(t *testing.T) { assert.NoError(t, err) // Broadcast nil pointer - _, err = m.Broadcast(nil) + _, err = m.Broadcast(context.TODO(), nil) assert.Error(t, err) } @@ -241,7 +241,7 @@ func initRmqStream(ctx context.Context, rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx) inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) for _, opt := range opts { inputStream.SetRepackFunc(opt) } @@ -265,7 +265,7 @@ func initRmqTtStream(ctx context.Context, rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx) inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(ctx, producerChannels) for _, opt := range opts { inputStream.SetRepackFunc(opt) } @@ -290,7 +290,7 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) { ctx := context.Background() inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName) - err := inputStream.Produce(&msgPack) + err := inputStream.Produce(ctx, &msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack.Msgs)) @@ -316,13 +316,13 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) { ctx := context.Background() inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) @@ -355,13 +355,13 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { ctx := context.Background() inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) assert.NoError(t, err) - err = inputStream.Produce(&msgPack2) + err = inputStream.Produce(ctx, &msgPack2) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack3) + _, err = inputStream.Broadcast(ctx, &msgPack3) assert.NoError(t, err) receivedMsg := consumer(ctx, outputStream) @@ -425,21 +425,21 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) { ctx := context.Background() inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) - _, err := inputStream.Broadcast(&msgPack0) + _, err := inputStream.Broadcast(ctx, &msgPack0) assert.NoError(t, err) - err = inputStream.Produce(&msgPack1) + err = inputStream.Produce(ctx, &msgPack1) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack2) + _, err = inputStream.Broadcast(ctx, &msgPack2) assert.NoError(t, err) - err = inputStream.Produce(&msgPack3) + err = inputStream.Produce(ctx, &msgPack3) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack4) + _, err = inputStream.Broadcast(ctx, &msgPack4) assert.NoError(t, err) - err = inputStream.Produce(&msgPack5) + err = inputStream.Produce(ctx, &msgPack5) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack6) + _, err = inputStream.Broadcast(ctx, &msgPack6) assert.NoError(t, err) - _, err = inputStream.Broadcast(&msgPack7) + _, err = inputStream.Broadcast(ctx, &msgPack7) assert.NoError(t, err) receivedMsg := consumer(ctx, outputStream) @@ -512,7 +512,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err := inputStream.Produce(msgPack) + err := inputStream.Produce(ctx, msgPack) assert.NoError(t, err) var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { @@ -546,7 +546,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - err = inputStream.Produce(msgPack) + err = inputStream.Produce(ctx, msgPack) assert.NoError(t, err) result := consumer(ctx, outputStream2) @@ -560,27 +560,28 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { producerChannels := []string{"insert1"} consumerChannels := []string{"insert1"} consumerSubName := "subInsert" + ctx := context.Background() factory := ProtoUDFactory{} rmqClient, _ := rmq.NewClientWithDefaultOptions(context.Background()) otherInputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) - otherInputStream.AsProducer([]string{"root_timetick"}) - otherInputStream.Produce(getTimeTickMsgPack(999)) + otherInputStream.AsProducer(context.TODO(), []string{"root_timetick"}) + otherInputStream.Produce(ctx, getTimeTickMsgPack(999)) inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) - inputStream.AsProducer(producerChannels) + inputStream.AsProducer(context.TODO(), producerChannels) for i := 0; i < 100; i++ { - inputStream.Produce(getTimeTickMsgPack(int64(i))) + inputStream.Produce(ctx, getTimeTickMsgPack(int64(i))) } rmqClient2, _ := rmq.NewClientWithDefaultOptions(context.Background()) outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqcommon.SubscriptionPositionLatest) - inputStream.Produce(getTimeTickMsgPack(1000)) + inputStream.Produce(ctx, getTimeTickMsgPack(1000)) pack := <-outputStream.Chan() assert.NotNil(t, pack) assert.Equal(t, 1, len(pack.Msgs)) diff --git a/pkg/mq/msgstream/mqwrapper/client.go b/pkg/mq/msgstream/mqwrapper/client.go index 3ec394a4db..e0ac2128b5 100644 --- a/pkg/mq/msgstream/mqwrapper/client.go +++ b/pkg/mq/msgstream/mqwrapper/client.go @@ -17,16 +17,18 @@ package mqwrapper import ( + "context" + "github.com/milvus-io/milvus/pkg/mq/common" ) // Client is the interface that provides operations of message queues type Client interface { // CreateProducer creates a producer instance - CreateProducer(options common.ProducerOptions) (Producer, error) + CreateProducer(ctx context.Context, options common.ProducerOptions) (Producer, error) // Subscribe creates a consumer instance and subscribe a topic - Subscribe(options ConsumerOptions) (Consumer, error) + Subscribe(ctx context.Context, options ConsumerOptions) (Consumer, error) // Get the earliest MessageID EarliestMessageID() common.MessageID diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go index 9950d6e164..91b9753ce0 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go @@ -205,7 +205,7 @@ func (kc *kafkaClient) newConsumerConfig(group string, offset common.Subscriptio return newConf } -func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) { +func (kc *kafkaClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) { start := timerecord.NewTimeRecorder("create producer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() @@ -224,7 +224,7 @@ func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper return producer, nil } -func (kc *kafkaClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { +func (kc *kafkaClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { start := timerecord.NewTimeRecorder("create consumer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go index 63559ef71a..565fc67cad 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go @@ -64,7 +64,7 @@ func BytesToInt(b []byte) int { // Consume1 will consume random messages and record the last MessageID it received func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, c chan mqcommon.MessageID, total *int) { - consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -103,7 +103,7 @@ func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, // Consume2 will consume messages from specified MessageID func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, msgID mqcommon.MessageID, total *int) { - consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -139,7 +139,7 @@ func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, } func Consume3(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, total *int) { - consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -418,7 +418,7 @@ func createConsumer(t *testing.T, groupID string, initPosition mqcommon.SubscriptionInitialPosition, ) mqwrapper.Consumer { - consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := kc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: groupID, BufSize: 1024, @@ -429,7 +429,7 @@ func createConsumer(t *testing.T, } func createProducer(t *testing.T, kc *kafkaClient, topic string) mqwrapper.Producer { - producer, err := kc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) + producer, err := kc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic}) assert.NoError(t, err) assert.NotNil(t, producer) return producer diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go index c2f2b771f5..ab8dbad91a 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go @@ -23,7 +23,7 @@ func TestKafkaProducer_SendSuccess(t *testing.T) { rand.Seed(time.Now().UnixNano()) topic := fmt.Sprintf("test-topic-%d", rand.Int()) - producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic}) + producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) assert.NotNil(t, producer) @@ -76,7 +76,7 @@ func TestKafkaProducer_SendFailAfterClose(t *testing.T) { rand.Seed(time.Now().UnixNano()) topic := fmt.Sprintf("test-topic-%d", rand.Int()) - producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic}) + producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.Nil(t, err) assert.NotNil(t, producer) diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go index 1a6fb8493c..6306025940 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go @@ -80,7 +80,7 @@ func NewClient(url string, options ...nats.Option) (*nmqClient, error) { } // CreateProducer creates a producer for natsmq client -func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) { +func (nc *nmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) { start := timerecord.NewTimeRecorder("create producer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() @@ -112,7 +112,7 @@ func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P return &rp, nil } -func (nc *nmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { +func (nc *nmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { start := timerecord.NewTimeRecorder("create consumer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go index f2e35b2350..38fcd781e1 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go @@ -86,7 +86,7 @@ func TestNmqClient_CreateProducer(t *testing.T) { topic := "TestNmqClient_CreateProducer" proOpts := common.ProducerOptions{Topic: topic} - producer, err := client.CreateProducer(proOpts) + producer, err := client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) assert.NotNil(t, producer) defer producer.Close() @@ -102,7 +102,7 @@ func TestNmqClient_CreateProducer(t *testing.T) { assert.NoError(t, err) invalidOpts := common.ProducerOptions{Topic: ""} - producer, e := client.CreateProducer(invalidOpts) + producer, e := client.CreateProducer(context.TODO(), invalidOpts) assert.Nil(t, producer) assert.Error(t, e) } @@ -114,7 +114,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) { topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int()) proOpts := common.ProducerOptions{Topic: topic} - producer, err := client.CreateProducer(proOpts) + producer, err := client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) defer producer.Close() @@ -135,7 +135,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) { BufSize: 1024, } - consumer, err := client.Subscribe(consumerOpts) + consumer, err := client.Subscribe(context.TODO(), consumerOpts) assert.NoError(t, err) expectLastMsg, err := consumer.GetLatestMsgID() @@ -166,13 +166,13 @@ func TestNmqClient_IllegalSubscribe(t *testing.T) { assert.NotNil(t, client) defer client.Close() - sub, err := client.Subscribe(mqwrapper.ConsumerOptions{ + sub, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: "", }) assert.Nil(t, sub) assert.Error(t, err) - sub, err = client.Subscribe(mqwrapper.ConsumerOptions{ + sub, err = client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: "123", SubscriptionName: "", }) @@ -188,7 +188,7 @@ func TestNmqClient_Subscribe(t *testing.T) { topic := "TestNmqClient_Subscribe" proOpts := common.ProducerOptions{Topic: topic} - producer, err := client.CreateProducer(proOpts) + producer, err := client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) assert.NotNil(t, producer) defer producer.Close() @@ -201,12 +201,12 @@ func TestNmqClient_Subscribe(t *testing.T) { BufSize: 1024, } - consumer, err := client.Subscribe(consumerOpts) + consumer, err := client.Subscribe(context.TODO(), consumerOpts) assert.Error(t, err) assert.Nil(t, consumer) consumerOpts.Topic = topic - consumer, err = client.Subscribe(consumerOpts) + consumer, err = client.Subscribe(context.TODO(), consumerOpts) assert.NoError(t, err) assert.NotNil(t, consumer) defer consumer.Close() diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go index bc3652ff71..70c2607da7 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go @@ -36,10 +36,10 @@ func TestNatsConsumer_Subscription(t *testing.T) { topic := t.Name() proOpts := common.ProducerOptions{Topic: topic} - _, err = client.CreateProducer(proOpts) + _, err = client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) - consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -69,7 +69,7 @@ func Test_BadLatestMessageID(t *testing.T) { assert.NoError(t, err) defer client.Close() - consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -88,10 +88,10 @@ func TestComsumeMessage(t *testing.T) { defer client.Close() topic := t.Name() - p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) + p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) - c, err := client.Subscribe(mqwrapper.ConsumerOptions{ + c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -149,7 +149,7 @@ func TestNatsConsumer_Close(t *testing.T) { defer client.Close() topic := t.Name() - c, err := client.Subscribe(mqwrapper.ConsumerOptions{ + c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -177,7 +177,7 @@ func TestNatsClientErrorOnUnsubscribeTwice(t *testing.T) { assert.NoError(t, err) defer client.Close() - consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -199,7 +199,7 @@ func TestCheckTopicValid(t *testing.T) { defer client.Close() topic := t.Name() - consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -220,7 +220,7 @@ func TestCheckTopicValid(t *testing.T) { assert.Error(t, err) // not empty topic can pass - pub, err := client.CreateProducer(common.ProducerOptions{ + pub, err := client.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topic, }) assert.NoError(t, err) @@ -240,7 +240,7 @@ func TestCheckTopicValid(t *testing.T) { func newTestConsumer(t *testing.T, topic string, position common.SubscriptionInitialPosition) (mqwrapper.Consumer, error) { client, err := createNmqClient() assert.NoError(t, err) - return client.Subscribe(mqwrapper.ConsumerOptions{ + return client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: position, @@ -251,7 +251,7 @@ func newTestConsumer(t *testing.T, topic string, position common.SubscriptionIni func newProducer(t *testing.T, topic string) (*nmqClient, mqwrapper.Producer) { client, err := createNmqClient() assert.NoError(t, err) - producer, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) + producer, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) return client, producer } @@ -272,10 +272,10 @@ func TestNmqConsumer_GetLatestMsgID(t *testing.T) { defer client.Close() topic := t.Name() - p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) + p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) - c, err := client.Subscribe(mqwrapper.ConsumerOptions{ + c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -301,13 +301,13 @@ func TestNmqConsumer_ConsumeFromLatest(t *testing.T) { defer client.Close() topic := t.Name() - p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) + p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) msgs := []string{"111", "222", "333"} process(t, msgs, p) - c, err := client.Subscribe(mqwrapper.ConsumerOptions{ + c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionLatest, @@ -331,13 +331,13 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) { defer client.Close() topic := t.Name() - p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) + p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) msgs := []string{"111", "222"} process(t, msgs, p) - c, err := client.Subscribe(mqwrapper.ConsumerOptions{ + c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -354,7 +354,7 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) { msg = <-c.Chan() assert.Equal(t, "222", string(msg.Payload())) - c2, err := client.Subscribe(mqwrapper.ConsumerOptions{ + c2, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go index 69a9953824..3bc3907166 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go @@ -3,7 +3,7 @@ // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance +// "License"); you may not use this file exceapt in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer_test.go index 119e1ef44e..dc974808f0 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer_test.go @@ -33,7 +33,7 @@ func TestNatsMQProducer(t *testing.T) { pOpts := common.ProducerOptions{Topic: topic} // Check Topic() - p, err := c.CreateProducer(pOpts) + p, err := c.CreateProducer(context.TODO(), pOpts) assert.NoError(t, err) assert.Equal(t, p.(*nmqProducer).Topic(), topic) diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go index f5918870b8..c33f0df938 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go @@ -17,6 +17,7 @@ package pulsar import ( + "context" "fmt" "sync" "time" @@ -66,7 +67,7 @@ func NewClient(tenant string, namespace string, opts pulsar.ClientOptions) (*pul } // CreateProducer create a pulsar producer from options -func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrapper.Producer, error) { +func (pc *pulsarClient) CreateProducer(ctx context.Context, options mqcommon.ProducerOptions) (mqwrapper.Producer, error) { start := timerecord.NewTimeRecorder("create producer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() @@ -102,7 +103,7 @@ func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrap } // Subscribe creates a pulsar consumer instance and subscribe a topic -func (pc *pulsarClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { +func (pc *pulsarClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { start := timerecord.NewTimeRecorder("create consumer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go index aa8ce1e450..bded0c3227 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go @@ -78,7 +78,7 @@ func BytesToInt(b []byte) int { } func Produce(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, arr []int) { - producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) + producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic}) assert.NoError(t, err) assert.NotNil(t, producer) @@ -110,7 +110,7 @@ func VerifyMessage(t *testing.T, msg mqcommon.Message) { // Consume1 will consume random messages and record the last MessageID it received func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, c chan mqcommon.MessageID, total *int) { - consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -147,7 +147,7 @@ func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, // Consume2 will consume messages from specified MessageID func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, msgID mqcommon.MessageID, total *int) { - consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -181,7 +181,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, } func Consume3(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, total *int) { - consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -420,7 +420,7 @@ func TestPulsarClient_SeekPosition(t *testing.T) { topic := fmt.Sprintf("test-topic-%d", rand.Int()) subName := fmt.Sprintf("test-subname-%d", rand.Int()) - producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) + producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic}) assert.NoError(t, err) assert.NotNil(t, producer) @@ -498,7 +498,7 @@ func TestPulsarClient_SeekLatest(t *testing.T) { topic := fmt.Sprintf("test-topic-%d", rand.Int()) subName := fmt.Sprintf("test-subname-%d", rand.Int()) - producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) + producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic}) assert.NoError(t, err) assert.NotNil(t, producer) @@ -671,7 +671,7 @@ func TestPulsarClient_SubscribeExclusiveFail(t *testing.T) { client: &mockPulsarClient{}, } - _, err := pc.Subscribe(mqwrapper.ConsumerOptions{Topic: "test_topic_name"}) + _, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{Topic: "test_topic_name"}) assert.Error(t, err) assert.True(t, retry.IsRecoverable(err)) }) @@ -686,7 +686,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) { pulsarAddress := getPulsarAddress() pc, err := NewClient(tenant, namespace, pulsar.ClientOptions{URL: pulsarAddress}) assert.NoError(t, err) - producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) + producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic}) defer producer.Close() assert.NoError(t, err) assert.NotNil(t, producer) @@ -695,7 +695,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) { assert.NoError(t, err) assert.Equal(t, fullTopicName, producer.(*pulsarProducer).Topic()) - consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -713,7 +713,7 @@ func TestPulsarCtl(t *testing.T) { pulsarAddress := getPulsarAddress() pc, err := NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) assert.NoError(t, err) - consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -723,7 +723,7 @@ func TestPulsarCtl(t *testing.T) { assert.NotNil(t, consumer) defer consumer.Close() - _, err = pc.Subscribe(mqwrapper.ConsumerOptions{ + _, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -732,7 +732,7 @@ func TestPulsarCtl(t *testing.T) { assert.Error(t, err) - _, err = pc.Subscribe(mqwrapper.ConsumerOptions{ + _, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, @@ -762,7 +762,7 @@ func TestPulsarCtl(t *testing.T) { assert.NoError(t, err) } - consumer2, err := pc.Subscribe(mqwrapper.ConsumerOptions{ + consumer2, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: subName, BufSize: 1024, diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go index 6f541a137f..b829c2428a 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go @@ -80,9 +80,9 @@ func TestComsumeCompressedMessage(t *testing.T) { assert.NoError(t, err) defer consumer.Close() - producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics"}) + producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics"}) assert.NoError(t, err) - compressProducer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true}) + compressProducer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true}) assert.NoError(t, err) msg := []byte("test message") diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_producer_test.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_producer_test.go index ebace99df1..5c4fdf5f2f 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_producer_test.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_producer_test.go @@ -34,7 +34,7 @@ func TestPulsarProducer(t *testing.T) { assert.NotNil(t, pc) topic := "TEST" - producer, err := pc.CreateProducer(common.ProducerOptions{Topic: topic}) + producer, err := pc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic}) assert.NoError(t, err) assert.NotNil(t, producer) diff --git a/pkg/mq/msgstream/mqwrapper/rmq/rmq_client.go b/pkg/mq/msgstream/mqwrapper/rmq/rmq_client.go index ce7f1d7cc9..b40192fe85 100644 --- a/pkg/mq/msgstream/mqwrapper/rmq/rmq_client.go +++ b/pkg/mq/msgstream/mqwrapper/rmq/rmq_client.go @@ -58,7 +58,7 @@ func NewClient(opts client.Options) (*rmqClient, error) { } // CreateProducer creates a producer for rocksmq client -func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) { +func (rc *rmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) { start := timerecord.NewTimeRecorder("create producer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() @@ -77,7 +77,7 @@ func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P } // Subscribe subscribes a consumer in rmq client -func (rc *rmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { +func (rc *rmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { start := timerecord.NewTimeRecorder("create consumer") metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() diff --git a/pkg/mq/msgstream/mqwrapper/rmq/rmq_client_test.go b/pkg/mq/msgstream/mqwrapper/rmq/rmq_client_test.go index 87c007e25b..4c42275b6d 100644 --- a/pkg/mq/msgstream/mqwrapper/rmq/rmq_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/rmq/rmq_client_test.go @@ -65,7 +65,7 @@ func TestRmqClient_CreateProducer(t *testing.T) { topic := "TestRmqClient_CreateProducer" proOpts := common.ProducerOptions{Topic: topic} - producer, err := client.CreateProducer(proOpts) + producer, err := client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) assert.NotNil(t, producer) @@ -83,7 +83,7 @@ func TestRmqClient_CreateProducer(t *testing.T) { assert.NoError(t, err) invalidOpts := common.ProducerOptions{Topic: ""} - producer, e := client.CreateProducer(invalidOpts) + producer, e := client.CreateProducer(context.TODO(), invalidOpts) assert.Nil(t, producer) assert.Error(t, e) } @@ -95,7 +95,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) { topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int()) proOpts := common.ProducerOptions{Topic: topic} - producer, err := client.CreateProducer(proOpts) + producer, err := client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) defer producer.Close() @@ -116,7 +116,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) { BufSize: 1024, } - consumer, err := client.Subscribe(consumerOpts) + consumer, err := client.Subscribe(context.TODO(), consumerOpts) assert.NoError(t, err) expectLastMsg, err := consumer.GetLatestMsgID() @@ -149,7 +149,7 @@ func TestRmqClient_Subscribe(t *testing.T) { topic := "TestRmqClient_Subscribe" proOpts := common.ProducerOptions{Topic: topic} - producer, err := client.CreateProducer(proOpts) + producer, err := client.CreateProducer(context.TODO(), proOpts) assert.NoError(t, err) assert.NotNil(t, producer) defer producer.Close() @@ -161,7 +161,7 @@ func TestRmqClient_Subscribe(t *testing.T) { SubscriptionInitialPosition: common.SubscriptionPositionEarliest, BufSize: 0, } - consumer, err := client.Subscribe(consumerOpts) + consumer, err := client.Subscribe(context.TODO(), consumerOpts) assert.Error(t, err) assert.Nil(t, consumer) @@ -172,12 +172,12 @@ func TestRmqClient_Subscribe(t *testing.T) { BufSize: 1024, } - consumer, err = client.Subscribe(consumerOpts) + consumer, err = client.Subscribe(context.TODO(), consumerOpts) assert.Error(t, err) assert.Nil(t, consumer) consumerOpts.Topic = topic - consumer, err = client.Subscribe(consumerOpts) + consumer, err = client.Subscribe(context.TODO(), consumerOpts) defer consumer.Close() assert.NoError(t, err) assert.NotNil(t, consumer) diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 4d6e3b0a9c..6dcc8271e2 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -55,11 +55,11 @@ type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, erro type MsgStream interface { Close() - AsProducer(channels []string) - Produce(*MsgPack) error + AsProducer(ctx context.Context, channels []string) + Produce(context.Context, *MsgPack) error SetRepackFunc(repackFunc RepackFunc) GetProduceChannels() []string - Broadcast(*MsgPack) (map[string][]MessageID, error) + Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error Chan() <-chan *MsgPack diff --git a/pkg/mq/msgstream/msgstream_util_test.go b/pkg/mq/msgstream/msgstream_util_test.go index 66f4bbce45..f8d1754eac 100644 --- a/pkg/mq/msgstream/msgstream_util_test.go +++ b/pkg/mq/msgstream/msgstream_util_test.go @@ -36,7 +36,7 @@ func TestPulsarMsgUtil(t *testing.T) { defer msgStream.Close() // create a topic - msgStream.AsProducer([]string{"test"}) + msgStream.AsProducer(ctx, []string{"test"}) UnsubscribeChannels(ctx, pmsFactory, "sub", []string{"test"}) } diff --git a/pkg/mq/msgstream/stream_bench_test.go b/pkg/mq/msgstream/stream_bench_test.go index ca69642244..ca7676ae5f 100644 --- a/pkg/mq/msgstream/stream_bench_test.go +++ b/pkg/mq/msgstream/stream_bench_test.go @@ -46,7 +46,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [ go func() { defer wg.Done() - p, err := mqClient.CreateProducer(common.ProducerOptions{ + p, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topic, }) assert.NoError(b, err) @@ -55,7 +55,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [ }() go func() { defer wg.Done() - c, _ := mqClient.Subscribe(mqwrapper.ConsumerOptions{ + c, _ := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: topic, SubscriptionInitialPosition: common.SubscriptionPositionEarliest, diff --git a/pkg/mq/msgstream/stream_test.go b/pkg/mq/msgstream/stream_test.go index bcdc400498..d752d34529 100644 --- a/pkg/mq/msgstream/stream_test.go +++ b/pkg/mq/msgstream/stream_test.go @@ -40,13 +40,13 @@ func testStreamOperation(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) { topics := getChannel(2) - producer, err := mqClient.CreateProducer(common.ProducerOptions{ + producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topics[0], }) defer producer.Close() assert.NoError(t, err) - consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topics[0], SubscriptionName: funcutil.RandomString(8), SubscriptionInitialPosition: common.SubscriptionPositionEarliest, @@ -61,7 +61,7 @@ func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Client) { topics := getChannel(2) - producer, err := mqClient.CreateProducer(common.ProducerOptions{ + producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topics[0], }) defer producer.Close() @@ -69,7 +69,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien ids := sendMessages(context.Background(), t, producer, generateRandMessage(1024, 1000)) - consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topics[0], SubscriptionName: funcutil.RandomString(8), SubscriptionInitialPosition: common.SubscriptionPositionLatest, @@ -90,7 +90,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Client) { topics := getChannel(2) - producer, err := mqClient.CreateProducer(common.ProducerOptions{ + producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topics[0], }) defer producer.Close() @@ -99,7 +99,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien cases := generateRandMessage(1024, 1000) ids := sendMessages(context.Background(), t, producer, cases) - consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topics[0], SubscriptionName: funcutil.RandomString(8), SubscriptionInitialPosition: common.SubscriptionPositionUnknown, @@ -124,7 +124,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Client) { topics := getChannel(2) - producer, err := mqClient.CreateProducer(common.ProducerOptions{ + producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topics[0], }) defer producer.Close() @@ -133,7 +133,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli cases := generateRandMessage(1024, 1000) ids := sendMessages(context.Background(), t, producer, cases) - consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topics[0], SubscriptionName: funcutil.RandomString(8), SubscriptionInitialPosition: common.SubscriptionPositionUnknown, @@ -158,7 +158,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client) { topics := getChannel(2) - producer, err := mqClient.CreateProducer(common.ProducerOptions{ + producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{ Topic: topics[0], }) defer producer.Close() @@ -167,7 +167,7 @@ func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client) cases := generateRandMessage(1024, 1000) sendMessages(context.Background(), t, producer, cases) - consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ + consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{ Topic: topics[0], SubscriptionName: funcutil.RandomString(8), SubscriptionInitialPosition: common.SubscriptionPositionUnknown, diff --git a/pkg/mq/msgstream/wasted_mock_msgstream.go b/pkg/mq/msgstream/wasted_mock_msgstream.go index 73b4cc0c08..2efc0ff0e5 100644 --- a/pkg/mq/msgstream/wasted_mock_msgstream.go +++ b/pkg/mq/msgstream/wasted_mock_msgstream.go @@ -1,5 +1,7 @@ package msgstream +import "context" + type WastedMockMsgStream struct { MsgStream AsProducerFunc func(channels []string) @@ -12,11 +14,11 @@ func NewWastedMockMsgStream() *WastedMockMsgStream { return &WastedMockMsgStream{} } -func (m WastedMockMsgStream) AsProducer(channels []string) { +func (m WastedMockMsgStream) AsProducer(ctx context.Context, channels []string) { m.AsProducerFunc(channels) } -func (m WastedMockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) { +func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[string][]MessageID, error) { return m.BroadcastMarkFunc(pack) }