enhance: refine pular related mq interfaces (#38007)

issue: #35917 
Refines the pulsar-related mq APIs to allow the ctx to be passed down

Signed-off-by: tinswzy <zhenyuan.wei@zilliz.com>
This commit is contained in:
tinswzy 2024-12-04 20:50:39 +08:00 committed by GitHub
parent 73aa95f596
commit 5768dbbb5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 380 additions and 367 deletions

View File

@ -406,7 +406,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
return return
} }
segment := t.meta.GetHealthySegment(t.meta.ctx, signal.segmentID) segment := t.meta.GetHealthySegment(context.TODO(), signal.segmentID)
if segment == nil { if segment == nil {
log.Warn("segment in compaction signal not found in meta", zap.Int64("segmentID", signal.segmentID)) log.Warn("segment in compaction signal not found in meta", zap.Int64("segmentID", signal.segmentID))
return return

View File

@ -68,7 +68,7 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
return make(chan *msgstream.MsgPack, 100) return make(chan *msgstream.MsgPack, 100)
} }
func (mtm *mockTtMsgStream) AsProducer(channels []string) {} func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}
func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil return nil
@ -80,11 +80,11 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string {
return make([]string, 0) return make([]string, 0)
} }
func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error { func (mtm *mockTtMsgStream) Produce(context.Context, *msgstream.MsgPack) error {
return nil return nil
} }
func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { func (mtm *mockTtMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return nil, nil return nil, nil
} }

View File

@ -39,7 +39,7 @@ import (
type channelsMgr interface { type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error) getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error) getVChannels(collectionID UniqueID) ([]vChan, error)
getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID) removeDMLStream(collectionID UniqueID)
removeAllDMLStream() removeAllDMLStream()
} }
@ -172,7 +172,7 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool
return ok && streamInfos.stream != nil return ok && streamInfos.stream != nil
} }
func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) { func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream var stream msgstream.MsgStream
var err error var err error
@ -181,7 +181,7 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
return nil, err return nil, err
} }
stream.AsProducer(pchans) stream.AsProducer(ctx, pchans)
if repack != nil { if repack != nil {
stream.SetRepackFunc(repack) stream.SetRepackFunc(repack)
} }
@ -202,7 +202,7 @@ func decPChanMetrics(pchans []pChan) {
// createMsgStream create message stream for specified collection. Idempotent. // createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return it and no error will be returned. // If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) { func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock() mgr.mu.RLock()
infos, ok := mgr.infos[collectionID] infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil { if ok && infos.stream != nil {
@ -219,7 +219,7 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr
return nil, err return nil, err
} }
stream, err := createStream(mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc) stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
if err != nil { if err != nil {
// What if stream created by other goroutines? // What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID)) log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
@ -253,12 +253,12 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea
// getOrCreateStream get message stream of specified collection. // getOrCreateStream get message stream of specified collection.
// If stream doesn't exist, call createMsgStream to create for it. // If stream doesn't exist, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) { func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil { if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil return stream, nil
} }
return mgr.createMsgStream(collectionID) return mgr.createMsgStream(ctx, collectionID)
} }
// removeStream remove the corresponding stream of the specified collection. Idempotent. // removeStream remove the corresponding stream of the specified collection. Idempotent.
@ -315,8 +315,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID) return mgr.dmlChannelsMgr.getVChannels(collectionID)
} }
func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) { func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(collectionID) return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
} }
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) { func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {

View File

@ -214,7 +214,7 @@ func Test_createStream(t *testing.T) {
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) { factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock") return nil, errors.New("mock")
} }
_, err := createStream(factory, nil, nil) _, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -223,7 +223,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock") return nil, errors.New("mock")
} }
_, err := createStream(factory, nil, nil) _, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -232,7 +232,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil return newMockMsgStream(), nil
} }
_, err := createStream(factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { _, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil return nil, nil
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -247,7 +247,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
100: {stream: newMockMsgStream()}, 100: {stream: newMockMsgStream()},
}, },
} }
stream, err := m.createMsgStream(100) stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, stream) assert.NotNil(t, stream)
}) })
@ -275,7 +275,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
stream, err := m.createMsgStream(100) stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, stream) assert.NotNil(t, stream)
}() }()
@ -295,7 +295,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
return channelInfos{}, errors.New("mock") return channelInfos{}, errors.New("mock")
}, },
} }
_, err := m.createMsgStream(100) _, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -311,7 +311,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory, msgStreamFactory: factory,
repackFunc: nil, repackFunc: nil,
} }
_, err := m.createMsgStream(100) _, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -328,10 +328,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory, msgStreamFactory: factory,
repackFunc: nil, repackFunc: nil,
} }
stream, err := m.createMsgStream(100) stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, stream) assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(100) stream, err = m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, stream) assert.NotNil(t, stream)
}) })
@ -365,7 +365,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
100: {stream: newMockMsgStream()}, 100: {stream: newMockMsgStream()},
}, },
} }
stream, err := m.getOrCreateStream(100) stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, stream) assert.NotNil(t, stream)
}) })
@ -377,7 +377,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
return channelInfos{}, errors.New("mock") return channelInfos{}, errors.New("mock")
}, },
} }
_, err := m.getOrCreateStream(100) _, err := m.getOrCreateStream(context.TODO(), 100)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -394,7 +394,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
msgStreamFactory: factory, msgStreamFactory: factory,
repackFunc: nil, repackFunc: nil,
} }
stream, err := m.getOrCreateStream(100) stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, stream) assert.NotNil(t, stream)
}) })

View File

@ -6323,7 +6323,7 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
Status: merr.Status(err), Status: merr.Status(err),
}, nil }, nil
} }
messageIDsMap, err := msgStream.Broadcast(msgPack) messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
if err != nil { if err != nil {
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err)) log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil

View File

@ -440,7 +440,7 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
@ -483,7 +483,7 @@ func TestProxy_FlushAll(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
@ -955,7 +955,7 @@ func TestProxyCreateDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("create database fail", func(t *testing.T) { t.Run("create database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t) rc := mocks.NewMockRootCoordClient(t)
@ -1015,7 +1015,7 @@ func TestProxyDropDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("drop database fail", func(t *testing.T) { t.Run("drop database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t) rc := mocks.NewMockRootCoordClient(t)
@ -1496,13 +1496,13 @@ func TestProxy_ReplicateMessage(t *testing.T) {
factory := newMockMsgStreamFactory() factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t) msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return() msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return() msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return() msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return()
msgStreamObj.EXPECT().Close().Return() msgStreamObj.EXPECT().Close().Return()
mockMsgID1 := mqcommon.NewMockMessageID(t) mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t) mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")) mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{ broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {mockMsgID1, mockMsgID2}, "unit_test_replicate_message": {mockMsgID1, mockMsgID2},
}, nil) }, nil)
@ -1581,7 +1581,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
{ {
broadcastMock.Unset() broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(nil, errors.New("mock error: broadcast")) broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode()) assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
@ -1590,7 +1590,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
} }
{ {
broadcastMock.Unset() broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{ broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {}, "unit_test_replicate_message": {},
}, nil) }, nil)
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)

View File

@ -3,6 +3,8 @@
package proxy package proxy
import ( import (
context "context"
msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream" msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
) )
@ -78,9 +80,9 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri
return _c return _c
} }
// getOrCreateDmlStream provides a mock function with given fields: collectionID // getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID
func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.MsgStream, error) { func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) {
ret := _m.Called(collectionID) ret := _m.Called(ctx, collectionID)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for getOrCreateDmlStream") panic("no return value specified for getOrCreateDmlStream")
@ -88,19 +90,19 @@ func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.M
var r0 msgstream.MsgStream var r0 msgstream.MsgStream
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(int64) (msgstream.MsgStream, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok {
return rf(collectionID) return rf(ctx, collectionID)
} }
if rf, ok := ret.Get(0).(func(int64) msgstream.MsgStream); ok { if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok {
r0 = rf(collectionID) r0 = rf(ctx, collectionID)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(msgstream.MsgStream) r0 = ret.Get(0).(msgstream.MsgStream)
} }
} }
if rf, ok := ret.Get(1).(func(int64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(collectionID) r1 = rf(ctx, collectionID)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -114,14 +116,15 @@ type MockChannelsMgr_getOrCreateDmlStream_Call struct {
} }
// getOrCreateDmlStream is a helper method to define mock.On call // getOrCreateDmlStream is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64 // - collectionID int64
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call { func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", collectionID)} return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)}
} }
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call { func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64)) run(args[0].(context.Context), args[1].(int64))
}) })
return _c return _c
} }
@ -131,7 +134,7 @@ func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStr
return _c return _c
} }
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call { func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -16,7 +16,7 @@ type mockMsgStream struct {
enableProduce func(bool) enableProduce func(bool)
} }
func (m *mockMsgStream) AsProducer(producers []string) { func (m *mockMsgStream) AsProducer(ctx context.Context, producers []string) {
if m.asProducer != nil { if m.asProducer != nil {
m.asProducer(producers) m.asProducer(producers)
} }

View File

@ -255,7 +255,7 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
return ms.msgChan return ms.msgChan
} }
func (ms *simpleMockMsgStream) AsProducer(channels []string) { func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
} }
func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
@ -283,7 +283,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
ms.increaseMsgCount(-delta) ms.increaseMsgCount(-delta)
} }
func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error { func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
defer ms.increaseMsgCount(1) defer ms.increaseMsgCount(1)
ms.msgChan <- pack ms.msgChan <- pack
@ -291,7 +291,7 @@ func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
return nil return nil
} }
func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { func (ms *simpleMockMsgStream) Broadcast(ctx context.Context, pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil return map[string][]msgstream.MessageID{}, nil
} }

View File

@ -278,7 +278,7 @@ func (node *Proxy) Init() error {
return err return err
} }
node.replicateMsgStream.EnableProduce(true) node.replicateMsgStream.EnableProduce(true)
node.replicateMsgStream.AsProducer([]string{replicateMsgChannel}) node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
if err != nil { if err != nil {

View File

@ -34,15 +34,15 @@ func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, r
return manager return manager
} }
func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.NewResourceFunc { func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc {
return func() (resource.Resource, error) { return func() (resource.Resource, error) {
msgStream, err := m.factory.NewMsgStream(m.ctx) msgStream, err := m.factory.NewMsgStream(ctx)
if err != nil { if err != nil {
log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err)) log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err return nil, err
} }
msgStream.SetRepackFunc(replicatePackFunc) msgStream.SetRepackFunc(replicatePackFunc)
msgStream.AsProducer([]string{channel}) msgStream.AsProducer(ctx, []string{channel})
msgStream.EnableProduce(true) msgStream.EnableProduce(true)
res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() { res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() {
@ -55,7 +55,7 @@ func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.N
func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) { func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) {
ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel)) ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel))
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(channel)) res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel))
if err != nil { if err != nil {
ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err)) ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err return nil, err

View File

@ -142,7 +142,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
} }
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
stream, err := dt.chMgr.getOrCreateDmlStream(dt.collectionID) stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID)
if err != nil { if err != nil {
return err return err
} }
@ -178,7 +178,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
zap.Int64("taskID", dt.ID()), zap.Int64("taskID", dt.ID()),
zap.Duration("prepare duration", dt.tr.RecordSpan())) zap.Duration("prepare duration", dt.tr.RecordSpan()))
err = stream.Produce(msgPack) err = stream.Produce(ctx, msgPack)
if err != nil { if err != nil {
return err return err
} }

View File

@ -161,7 +161,7 @@ func TestDeleteTask_Execute(t *testing.T) {
}, },
} }
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(nil, errors.New("mock error")) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background())) assert.Error(t, dt.Execute(context.Background()))
}) })
@ -190,7 +190,7 @@ func TestDeleteTask_Execute(t *testing.T) {
primaryKeys: pk, primaryKeys: pk,
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
assert.Error(t, dt.Execute(context.Background())) assert.Error(t, dt.Execute(context.Background()))
}) })
@ -226,8 +226,8 @@ func TestDeleteTask_Execute(t *testing.T) {
primaryKeys: pk, primaryKeys: pk,
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error")) stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background())) assert.Error(t, dt.Execute(context.Background()))
}) })
} }
@ -535,9 +535,9 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything).Return(fmt.Errorf("mock error")) stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error"))
assert.Error(t, dr.Run(context.Background())) assert.Error(t, dr.Run(context.Background()))
assert.Equal(t, int64(0), dr.result.DeleteCnt) assert.Equal(t, int64(0), dr.result.DeleteCnt)
@ -644,9 +644,9 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything).Return(nil) stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -768,7 +768,7 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -792,7 +792,7 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil) server.FinishSend(nil)
return client return client
}, nil) }, nil)
stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error")) stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
assert.Error(t, dr.Run(ctx)) assert.Error(t, dr.Run(ctx))
assert.Equal(t, int64(0), dr.result.DeleteCnt) assert.Equal(t, int64(0), dr.result.DeleteCnt)
@ -830,7 +830,7 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -854,7 +854,7 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil) server.FinishSend(nil)
return client return client
}, nil) }, nil)
stream.EXPECT().Produce(mock.Anything).Return(nil) stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx)) assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt) assert.Equal(t, int64(3), dr.result.DeleteCnt)
@ -911,7 +911,7 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
} }
stream := msgstream.NewMockMsgStream(t) stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -936,7 +936,7 @@ func TestDeleteRunner_Run(t *testing.T) {
return client return client
}, nil) }, nil)
stream.EXPECT().Produce(mock.Anything).Return(nil) stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx)) assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt) assert.Equal(t, int64(3), dr.result.DeleteCnt)
}) })

View File

@ -243,7 +243,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
it.insertMsg.CollectionID = collID it.insertMsg.CollectionID = collID
getCacheDur := tr.RecordSpan() getCacheDur := tr.RecordSpan()
stream, err := it.chMgr.getOrCreateDmlStream(collID) stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil { if err != nil {
return err return err
} }
@ -280,7 +280,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
log.Debug("assign segmentID for insert data success", log.Debug("assign segmentID for insert data success",
zap.Duration("assign segmentID duration", assignSegmentIDDur)) zap.Duration("assign segmentID duration", assignSegmentIDDur))
err = stream.Produce(msgPack) err = stream.Produce(ctx, msgPack)
if err != nil { if err != nil {
log.Warn("fail to produce insert msg", zap.Error(err)) log.Warn("fail to produce insert msg", zap.Error(err))
it.result.Status = merr.Status(err) it.result.Status = merr.Status(err)

View File

@ -1755,7 +1755,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream() defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID) _, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err) assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -2004,7 +2004,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream() defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID) _, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err) assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -3460,7 +3460,7 @@ func TestPartitionKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream() defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID) _, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err) assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -393,7 +393,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP
zap.Int64("collectionID", collID)) zap.Int64("collectionID", collID))
getCacheDur := tr.RecordSpan() getCacheDur := tr.RecordSpan()
_, err = it.chMgr.getOrCreateDmlStream(collID) _, err = it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil { if err != nil {
return err return err
} }
@ -526,7 +526,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) {
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName)) log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID) stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID)
if err != nil { if err != nil {
return err return err
} }
@ -547,7 +547,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) {
} }
tr.RecordSpan() tr.RecordSpan()
err = stream.Produce(msgPack) err = stream.Produce(ctx, msgPack)
if err != nil { if err != nil {
it.result.Status = merr.Status(err) it.result.Status = merr.Status(err)
return err return err

View File

@ -1985,7 +1985,7 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.
EndTs: ts, EndTs: ts,
Msgs: []msgstream.TsMsg{tsMsg}, Msgs: []msgstream.TsMsg{tsMsg},
} }
msgErr := replicateMsgStream.Produce(msgPack) msgErr := replicateMsgStream.Produce(ctx, msgPack)
// ignore the error if the msg stream failed to produce the msg, // ignore the error if the msg stream failed to produce the msg,
// because it can be manually fixed in this error // because it can be manually fixed in this error
if msgErr != nil { if msgErr != nil {

View File

@ -2430,7 +2430,7 @@ func TestSendReplicateMessagePack(t *testing.T) {
}) })
t.Run("produce fail", func(t *testing.T) { t.Run("produce fail", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything).Return(errors.New("produce error")).Once() mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once()
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{ SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{ Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true, IsReplicate: true,
@ -2444,7 +2444,7 @@ func TestSendReplicateMessagePack(t *testing.T) {
}) })
t.Run("normal case", func(t *testing.T) { t.Run("normal case", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything).Return(nil) mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{}) SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{}) SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{})

View File

@ -188,7 +188,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref
d.checkPreCreatedTopic(ctx, factory, name) d.checkPreCreatedTopic(ctx, factory, name)
} }
ms.AsProducer([]string{name}) ms.AsProducer(ctx, []string{name})
dms := &dmlMsgStream{ dms := &dmlMsgStream{
ms: ms, ms: ms,
refcnt: 0, refcnt: 0,
@ -291,7 +291,7 @@ func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) err
dms.mutex.RLock() dms.mutex.RLock()
if dms.refcnt > 0 { if dms.refcnt > 0 {
if _, err := dms.ms.Broadcast(pack); err != nil { if _, err := dms.ms.Broadcast(d.ctx, pack); err != nil {
log.Error("Broadcast failed", zap.Error(err), zap.String("chanName", chanName)) log.Error("Broadcast failed", zap.Error(err), zap.String("chanName", chanName))
dms.mutex.RUnlock() dms.mutex.RUnlock()
return err return err
@ -312,7 +312,7 @@ func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack)
dms.mutex.RLock() dms.mutex.RLock()
if dms.refcnt > 0 { if dms.refcnt > 0 {
ids, err := dms.ms.Broadcast(pack) ids, err := dms.ms.Broadcast(d.ctx, pack)
if err != nil { if err != nil {
log.Error("BroadcastMark failed", zap.Error(err), zap.String("chanName", chanName)) log.Error("BroadcastMark failed", zap.Error(err), zap.String("chanName", chanName))
dms.mutex.RUnlock() dms.mutex.RUnlock()

View File

@ -277,17 +277,17 @@ type FailMsgStream struct {
errBroadcast bool errBroadcast bool
} }
func (ms *FailMsgStream) Close() {} func (ms *FailMsgStream) Close() {}
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil } func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) AsProducer(channels []string) {} func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {}
func (ms *FailMsgStream) AsReader(channels []string, subName string) {} func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil return nil
} }
func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {} func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
func (ms *FailMsgStream) GetProduceChannels() []string { return nil } func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
func (ms *FailMsgStream) Produce(*msgstream.MsgPack) error { return nil } func (ms *FailMsgStream) Produce(context.Context, *msgstream.MsgPack) error { return nil }
func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { func (ms *FailMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
if ms.errBroadcast { if ms.errBroadcast {
return nil, errors.New("broadcast error") return nil, errors.New("broadcast error")
} }

View File

@ -42,8 +42,8 @@ func TestInputNode(t *testing.T) {
msgPack := generateMsgPack() msgPack := generateMsgPack()
produceStream, _ := factory.NewMsgStream(context.TODO()) produceStream, _ := factory.NewMsgStream(context.TODO())
produceStream.AsProducer(channels) produceStream.AsProducer(context.TODO(), channels)
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
nodeName := "input_node" nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "") inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
@ -84,7 +84,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
msgStream.AsConsumer(context.Background(), channels, "sub", common.SubscriptionPositionEarliest) msgStream.AsConsumer(context.Background(), channels, "sub", common.SubscriptionPositionEarliest)
produceStream, _ := factory.NewMsgStream(context.TODO()) produceStream, _ := factory.NewMsgStream(context.TODO())
produceStream.AsProducer(channels) produceStream.AsProducer(context.TODO(), channels)
closeCh := make(chan struct{}) closeCh := make(chan struct{})
outputCh := make(chan bool) outputCh := make(chan bool)
@ -110,7 +110,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
defer close(closeCh) defer close(closeCh)
msgPack := generateMsgPack() msgPack := generateMsgPack()
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
log.Info("produce empty ttmsg") log.Info("produce empty ttmsg")
<-outputCh <-outputCh
assert.Equal(t, 1, outputCount) assert.Equal(t, 1, outputCount)
@ -118,7 +118,7 @@ func Test_InputNodeSkipMode(t *testing.T) {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
assert.Equal(t, false, inputNode.skipMode) assert.Equal(t, false, inputNode.skipMode)
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
log.Info("after 3 seconds with no active msg receive, input node will turn on skip mode") log.Info("after 3 seconds with no active msg receive, input node will turn on skip mode")
<-outputCh <-outputCh
assert.Equal(t, 2, outputCount) assert.Equal(t, 2, outputCount)
@ -126,13 +126,13 @@ func Test_InputNodeSkipMode(t *testing.T) {
log.Info("some ttmsg will be skipped in skip mode") log.Info("some ttmsg will be skipped in skip mode")
// this msg will be skipped // this msg will be skipped
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
<-outputCh <-outputCh
assert.Equal(t, 2, outputCount) assert.Equal(t, 2, outputCount)
assert.Equal(t, true, inputNode.skipMode) assert.Equal(t, true, inputNode.skipMode)
// this msg will be consumed // this msg will be consumed
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
<-outputCh <-outputCh
assert.Equal(t, 3, outputCount) assert.Equal(t, 3, outputCount)
assert.Equal(t, true, inputNode.skipMode) assert.Equal(t, true, inputNode.skipMode)

View File

@ -80,13 +80,13 @@ func TestNodeManager_Start(t *testing.T) {
msgStream.AsConsumer(context.TODO(), channels, "sub", common.SubscriptionPositionEarliest) msgStream.AsConsumer(context.TODO(), channels, "sub", common.SubscriptionPositionEarliest)
produceStream, _ := factory.NewMsgStream(context.TODO()) produceStream, _ := factory.NewMsgStream(context.TODO())
produceStream.AsProducer(channels) produceStream.AsProducer(context.TODO(), channels)
msgPack := generateMsgPack() msgPack := generateMsgPack()
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
time.Sleep(time.Millisecond * 2) time.Sleep(time.Millisecond * 2)
msgPack = generateMsgPack() msgPack = generateMsgPack()
produceStream.Produce(&msgPack) produceStream.Produce(context.TODO(), &msgPack)
nodeName := "input_node" nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "") inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")

View File

@ -226,7 +226,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
insNum := rand.Intn(10) insNum := rand.Intn(10)
for j := 0; j < insNum; j++ { for j := 0; j < insNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string) vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{ err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)}, Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
}) })
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -237,7 +237,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
delNum := rand.Intn(2) delNum := rand.Intn(2)
for j := 0; j < delNum; j++ { for j := 0; j < delNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string) vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{ err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)}, Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
}) })
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -247,7 +247,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
// produce random ddl // produce random ddl
ddlNum := rand.Intn(2) ddlNum := rand.Intn(2)
for j := 0; j < ddlNum; j++ { for j := 0; j < ddlNum; j++ {
err := suite.producer.Produce(&msgstream.MsgPack{ err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)}, Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection, collectionID)},
}) })
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -257,7 +257,7 @@ func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup, collectionID int64)
} }
// produce time tick // produce time tick
ts := uint64(i * 100) ts := uint64(i * 100)
err := suite.producer.Produce(&msgstream.MsgPack{ err := suite.producer.Produce(context.TODO(), &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
}) })
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -305,7 +305,7 @@ func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
return return
case <-ticker.C: case <-ticker.C:
ts := uint64(tt * 1000) ts := uint64(tt * 1000)
err := suite.producer.Produce(&msgstream.MsgPack{ err := suite.producer.Produce(ctx, &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)}, Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
}) })
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)

View File

@ -55,7 +55,7 @@ func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgS
if err != nil { if err != nil {
return nil, err return nil, err
} }
stream.AsProducer([]string{pchannel}) stream.AsProducer(context.TODO(), []string{pchannel})
stream.SetRepackFunc(defaultInsertRepackFunc) stream.SetRepackFunc(defaultInsertRepackFunc)
return stream, nil return stream, nil
} }

View File

@ -173,11 +173,11 @@ func testTimeTickerAndInsert(t *testing.T, f []Factory) {
defer consumer.Close() defer consumer.Close()
var err error var err error
_, err = producer.Broadcast(&msgPack0) _, err = producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = producer.Produce(&msgPack1) err = producer.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack2) _, err = producer.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs)) receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs))
@ -210,17 +210,17 @@ func testTimeTickerNoSeek(t *testing.T, f []Factory) {
defer producer.Close() defer producer.Close()
var err error var err error
_, err = producer.Broadcast(&msgPack0) _, err = producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = producer.Produce(&msgPack1) err = producer.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack2) _, err = producer.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
err = producer.Produce(&msgPack3) err = producer.Produce(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack4) _, err = producer.Broadcast(ctx, &msgPack4)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack5) _, err = producer.Broadcast(ctx, &msgPack5)
assert.NoError(t, err) assert.NoError(t, err)
o1 := consume(ctx, consumer) o1 := consume(ctx, consumer)
@ -259,7 +259,7 @@ func testSeekToLast(t *testing.T, f []Factory) {
} }
// produce test data // produce test data
err := producer.Produce(msgPack) err := producer.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
// pick a seekPosition // pick a seekPosition
@ -346,21 +346,21 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
defer producer.Close() defer producer.Close()
// Send message // Send message
_, err := producer.Broadcast(&msgPack0) _, err := producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = producer.Produce(&msgPack1) err = producer.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack2) _, err = producer.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
err = producer.Produce(&msgPack3) err = producer.Produce(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack4) _, err = producer.Broadcast(ctx, &msgPack4)
assert.NoError(t, err) assert.NoError(t, err)
err = producer.Produce(&msgPack5) err = producer.Produce(ctx, &msgPack5)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack6) _, err = producer.Broadcast(ctx, &msgPack6)
assert.NoError(t, err) assert.NoError(t, err)
_, err = producer.Broadcast(&msgPack7) _, err = producer.Broadcast(ctx, &msgPack7)
assert.NoError(t, err) assert.NoError(t, err)
// Test received message // Test received message
@ -434,13 +434,13 @@ func testTimeTickUnmarshalHeader(t *testing.T, f []Factory) {
defer producer.Close() defer producer.Close()
defer consumer.Close() defer consumer.Close()
_, err := producer.Broadcast(&msgPack0) _, err := producer.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = producer.Produce(&msgPack1) err = producer.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = producer.Broadcast(&msgPack2) _, err = producer.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs)) receiveAndValidateMsg(ctx, consumer, len(msgPack1.Msgs))
@ -571,7 +571,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := producer.Produce(msgPack) err := producer.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -605,7 +605,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := producer.Produce(msgPack) err := producer.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -622,7 +622,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err = producer.Produce(msgPack) err = producer.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
result := consume(ctx, consumer2) result := consume(ctx, consumer2)
assert.Equal(t, result.Msgs[0].ID(), int64(1)) assert.Equal(t, result.Msgs[0].ID(), int64(1))
@ -642,7 +642,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := producer.Produce(msgPack) err := producer.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
consumer2 := createLatestConsumer(ctx, t, f[1].NewMsgStream, channels) consumer2 := createLatestConsumer(ctx, t, f[1].NewMsgStream, channels)
defer consumer2.Close() defer consumer2.Close()
@ -653,7 +653,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err = producer.Produce(msgPack) err = producer.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
for i := 10; i < 20; i++ { for i := 10; i < 20; i++ {
@ -673,7 +673,7 @@ func testBroadcastMark(t *testing.T, f []Factory) {
msgPack0 := MsgPack{} msgPack0 := MsgPack{}
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
ids, err := producer.Broadcast(&msgPack0) ids, err := producer.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, ids) assert.NotNil(t, ids)
assert.Equal(t, len(channels), len(ids)) assert.Equal(t, len(channels), len(ids))
@ -687,7 +687,7 @@ func testBroadcastMark(t *testing.T, f []Factory) {
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
ids, err = producer.Broadcast(&msgPack1) ids, err = producer.Broadcast(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, ids) assert.NotNil(t, ids)
assert.Equal(t, len(channels), len(ids)) assert.Equal(t, len(channels), len(ids))
@ -698,12 +698,12 @@ func testBroadcastMark(t *testing.T, f []Factory) {
} }
// edge cases // edge cases
_, err = producer.Broadcast(nil) _, err = producer.Broadcast(ctx, nil)
assert.Error(t, err) assert.Error(t, err)
msgPack2 := MsgPack{} msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{}) msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
_, err = producer.Broadcast(&msgPack2) _, err = producer.Broadcast(ctx, &msgPack2)
assert.Error(t, err) assert.Error(t, err)
} }
@ -712,7 +712,7 @@ func applyBroadCastAndConsume(t *testing.T, msgPack *MsgPack, newer []streamNewe
defer producer.Close() defer producer.Close()
defer consumer.Close() defer consumer.Close()
_, err := producer.Broadcast(msgPack) _, err := producer.Broadcast(context.TODO(), msgPack)
assert.NoError(t, err) assert.NoError(t, err)
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)*channelNum) receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)*channelNum)
} }
@ -728,7 +728,7 @@ func applyProduceAndConsumeWithRepack(
defer producer.Close() defer producer.Close()
defer consumer.Close() defer consumer.Close()
err := producer.Produce(msgPack) err := producer.Produce(context.TODO(), msgPack)
assert.NoError(t, err) assert.NoError(t, err)
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)) receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
} }
@ -743,7 +743,7 @@ func applyProduceAndConsume(
defer producer.Close() defer producer.Close()
defer consumer.Close() defer consumer.Close()
err := producer.Produce(msgPack) err := producer.Produce(context.TODO(), msgPack)
assert.NoError(t, err) assert.NoError(t, err)
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)) receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
} }
@ -774,7 +774,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer,
func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream { func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream {
producer, err := newer(ctx) producer, err := newer(ctx)
assert.NoError(t, err) assert.NoError(t, err)
producer.AsProducer(channels) producer.AsProducer(ctx, channels)
return producer return producer
} }
@ -798,7 +798,7 @@ func createStream(ctx context.Context, t *testing.T, newer []streamNewer, channe
assert.NotEmpty(t, channels) assert.NotEmpty(t, channels)
producer, err := newer[0](ctx) producer, err := newer[0](ctx)
assert.NoError(t, err) assert.NoError(t, err)
producer.AsProducer(channels) producer.AsProducer(ctx, channels)
consumer, err := newer[1](ctx) consumer, err := newer[1](ctx)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -74,9 +74,9 @@ func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func(context.Context,
return _c return _c
} }
// AsProducer provides a mock function with given fields: channels // AsProducer provides a mock function with given fields: ctx, channels
func (_m *MockMsgStream) AsProducer(channels []string) { func (_m *MockMsgStream) AsProducer(ctx context.Context, channels []string) {
_m.Called(channels) _m.Called(ctx, channels)
} }
// MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer' // MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer'
@ -85,14 +85,15 @@ type MockMsgStream_AsProducer_Call struct {
} }
// AsProducer is a helper method to define mock.On call // AsProducer is a helper method to define mock.On call
// - ctx context.Context
// - channels []string // - channels []string
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call { func (_e *MockMsgStream_Expecter) AsProducer(ctx interface{}, channels interface{}) *MockMsgStream_AsProducer_Call {
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)} return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", ctx, channels)}
} }
func (_c *MockMsgStream_AsProducer_Call) Run(run func(channels []string)) *MockMsgStream_AsProducer_Call { func (_c *MockMsgStream_AsProducer_Call) Run(run func(ctx context.Context, channels []string)) *MockMsgStream_AsProducer_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string)) run(args[0].(context.Context), args[1].([]string))
}) })
return _c return _c
} }
@ -102,14 +103,14 @@ func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call
return _c return _c
} }
func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func([]string)) *MockMsgStream_AsProducer_Call { func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func(context.Context, []string)) *MockMsgStream_AsProducer_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
// Broadcast provides a mock function with given fields: _a0 // Broadcast provides a mock function with given fields: _a0, _a1
func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID, error) { func (_m *MockMsgStream) Broadcast(_a0 context.Context, _a1 *MsgPack) (map[string][]common.MessageID, error) {
ret := _m.Called(_a0) ret := _m.Called(_a0, _a1)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Broadcast") panic("no return value specified for Broadcast")
@ -117,19 +118,19 @@ func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]common.MessageID,
var r0 map[string][]common.MessageID var r0 map[string][]common.MessageID
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(*MsgPack) (map[string][]common.MessageID, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) (map[string][]common.MessageID, error)); ok {
return rf(_a0) return rf(_a0, _a1)
} }
if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]common.MessageID); ok { if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) map[string][]common.MessageID); ok {
r0 = rf(_a0) r0 = rf(_a0, _a1)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]common.MessageID) r0 = ret.Get(0).(map[string][]common.MessageID)
} }
} }
if rf, ok := ret.Get(1).(func(*MsgPack) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *MsgPack) error); ok {
r1 = rf(_a0) r1 = rf(_a0, _a1)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -143,14 +144,15 @@ type MockMsgStream_Broadcast_Call struct {
} }
// Broadcast is a helper method to define mock.On call // Broadcast is a helper method to define mock.On call
// - _a0 *MsgPack // - _a0 context.Context
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call { // - _a1 *MsgPack
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)} func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}, _a1 interface{}) *MockMsgStream_Broadcast_Call {
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0, _a1)}
} }
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Broadcast_Call { func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Broadcast_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*MsgPack)) run(args[0].(context.Context), args[1].(*MsgPack))
}) })
return _c return _c
} }
@ -160,7 +162,7 @@ func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]common.MessageID
return _c return _c
} }
func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(*MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call { func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, *MsgPack) (map[string][]common.MessageID, error)) *MockMsgStream_Broadcast_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@ -428,17 +430,17 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin
return _c return _c
} }
// Produce provides a mock function with given fields: _a0 // Produce provides a mock function with given fields: _a0, _a1
func (_m *MockMsgStream) Produce(_a0 *MsgPack) error { func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error {
ret := _m.Called(_a0) ret := _m.Called(_a0, _a1)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Produce") panic("no return value specified for Produce")
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(*MsgPack) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *MsgPack) error); ok {
r0 = rf(_a0) r0 = rf(_a0, _a1)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -452,14 +454,15 @@ type MockMsgStream_Produce_Call struct {
} }
// Produce is a helper method to define mock.On call // Produce is a helper method to define mock.On call
// - _a0 *MsgPack // - _a0 context.Context
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call { // - _a1 *MsgPack
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)} func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}, _a1 interface{}) *MockMsgStream_Produce_Call {
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0, _a1)}
} }
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Produce_Call { func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 context.Context, _a1 *MsgPack)) *MockMsgStream_Produce_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*MsgPack)) run(args[0].(context.Context), args[1].(*MsgPack))
}) })
return _c return _c
} }
@ -469,7 +472,7 @@ func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_C
return _c return _c
} }
func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *MockMsgStream_Produce_Call { func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(context.Context, *MsgPack) error) *MockMsgStream_Produce_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -123,7 +123,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
} }
// produce test data // produce test data
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
// pick a seekPosition // pick a seekPosition
@ -219,21 +219,21 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) {
inputStream := getKafkaInputStream(ctx, kafkaAddress, producerChannels) inputStream := getKafkaInputStream(ctx, kafkaAddress, producerChannels)
outputStream := getKafkaTtOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName) outputStream := getKafkaTtOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack3) err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4) _, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack5) err = inputStream.Produce(ctx, &msgPack5)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack6) _, err = inputStream.Broadcast(ctx, &msgPack6)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack7) _, err = inputStream.Broadcast(ctx, &msgPack7)
assert.NoError(t, err) assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream) receivedMsg := consumer(ctx, outputStream)
@ -450,7 +450,7 @@ func getKafkaInputStream(ctx context.Context, kafkaAddress string, producerChann
} }
kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfigMap(config, nil, nil) kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfigMap(config, nil, nil)
inputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts { for _, opt := range opts {
inputStream.SetRepackFunc(opt) inputStream.SetRepackFunc(opt)
} }

View File

@ -121,7 +121,7 @@ func NewMqMsgStream(ctx context.Context,
} }
// AsProducer create producer to send message to channels // AsProducer create producer to send message to channels
func (ms *mqMsgStream) AsProducer(channels []string) { func (ms *mqMsgStream) AsProducer(ctx context.Context, channels []string) {
for _, channel := range channels { for _, channel := range channels {
if len(channel) == 0 { if len(channel) == 0 {
log.Error("MsgStream asProducer's channel is an empty string") log.Error("MsgStream asProducer's channel is an empty string")
@ -129,7 +129,7 @@ func (ms *mqMsgStream) AsProducer(channels []string) {
} }
fn := func() error { fn := func() error {
pp, err := ms.client.CreateProducer(common.ProducerOptions{Topic: channel, EnableCompression: true}) pp, err := ms.client.CreateProducer(ctx, common.ProducerOptions{Topic: channel, EnableCompression: true})
if err != nil { if err != nil {
return err return err
} }
@ -176,7 +176,7 @@ func (ms *mqMsgStream) AsConsumer(ctx context.Context, channels []string, subNam
continue continue
} }
fn := func() error { fn := func() error {
pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{ pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: channel, Topic: channel,
SubscriptionName: subName, SubscriptionName: subName,
SubscriptionInitialPosition: position, SubscriptionInitialPosition: position,
@ -273,7 +273,7 @@ func (ms *mqMsgStream) isEnabledProduce() bool {
return ms.enableProduce.Load().(bool) return ms.enableProduce.Load().(bool)
} }
func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
if !ms.isEnabledProduce() { if !ms.isEnabledProduce() {
log.Warn("can't produce the msg in the backup instance", zap.Stack("stack")) log.Warn("can't produce the msg in the backup instance", zap.Stack("stack"))
return merr.ErrDenyProduceMsg return merr.ErrDenyProduceMsg
@ -346,7 +346,7 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
// BroadcastMark broadcast msg pack to all producers and returns corresponding msg id // BroadcastMark broadcast msg pack to all producers and returns corresponding msg id
// the returned message id serves as marking // the returned message id serves as marking
func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) (map[string][]MessageID, error) { func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[string][]MessageID, error) {
ids := make(map[string][]MessageID) ids := make(map[string][]MessageID)
if msgPack == nil || len(msgPack.Msgs) <= 0 { if msgPack == nil || len(msgPack.Msgs) <= 0 {
return ids, errors.New("empty msgs") return ids, errors.New("empty msgs")
@ -581,7 +581,7 @@ func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subN
continue continue
} }
fn := func() error { fn := func() error {
pc, err := ms.client.Subscribe(mqwrapper.ConsumerOptions{ pc, err := ms.client.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: channel, Topic: channel,
SubscriptionName: subName, SubscriptionName: subName,
SubscriptionInitialPosition: position, SubscriptionInitialPosition: position,

View File

@ -130,12 +130,12 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) {
{ {
inputStream.EnableProduce(false) inputStream.EnableProduce(false)
err := inputStream.Produce(&msgPack) err := inputStream.Produce(ctx, &msgPack)
require.Error(t, err) require.Error(t, err)
} }
inputStream.EnableProduce(true) inputStream.EnableProduce(true)
err := inputStream.Produce(&msgPack) err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs)) receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -156,7 +156,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack) err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs)) receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -178,7 +178,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack) err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs)) receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -203,12 +203,12 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) {
{ {
inputStream.EnableProduce(false) inputStream.EnableProduce(false)
_, err := inputStream.Broadcast(&msgPack) _, err := inputStream.Broadcast(ctx, &msgPack)
require.Error(t, err) require.Error(t, err)
} }
inputStream.EnableProduce(true) inputStream.EnableProduce(true)
_, err := inputStream.Broadcast(&msgPack) _, err := inputStream.Broadcast(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(consumerChannels)*len(msgPack.Msgs)) receiveMsg(ctx, outputStream, len(consumerChannels)*len(msgPack.Msgs))
@ -230,7 +230,7 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels, repackFunc) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels, repackFunc)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack) err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs)) receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -277,14 +277,14 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
ctx := context.Background() ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
var output MsgStream = outputStream var output MsgStream = outputStream
err := (*inputStream).Produce(&msgPack) err := (*inputStream).Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, output, len(msgPack.Msgs)*2) receiveMsg(ctx, output, len(msgPack.Msgs)*2)
@ -328,14 +328,14 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
ctx := context.Background() ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
var output MsgStream = outputStream var output MsgStream = outputStream
err := (*inputStream).Produce(&msgPack) err := (*inputStream).Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, output, len(msgPack.Msgs)*1) receiveMsg(ctx, output, len(msgPack.Msgs)*1)
@ -360,14 +360,14 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
ctx := context.Background() ctx := context.Background()
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
var output MsgStream = outputStream var output MsgStream = outputStream
err := (*inputStream).Produce(&msgPack) err := (*inputStream).Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, output, len(msgPack.Msgs)) receiveMsg(ctx, output, len(msgPack.Msgs))
@ -395,13 +395,13 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
@ -440,17 +440,17 @@ func TestStream_PulsarTtMsgStream_NoSeek(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack3) err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4) _, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack5) _, err = inputStream.Broadcast(ctx, &msgPack5)
assert.NoError(t, err) assert.NoError(t, err)
o1 := consumer(ctx, outputStream) o1 := consumer(ctx, outputStream)
@ -495,7 +495,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
} }
// produce test data // produce test data
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
// pick a seekPosition // pick a seekPosition
@ -617,21 +617,21 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack3) err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4) _, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack5) err = inputStream.Produce(ctx, &msgPack5)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack6) _, err = inputStream.Broadcast(ctx, &msgPack6)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack7) _, err = inputStream.Broadcast(ctx, &msgPack7)
assert.NoError(t, err) assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream) receivedMsg := consumer(ctx, outputStream)
@ -711,13 +711,13 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
@ -748,16 +748,16 @@ func TestStream_PulsarTtMsgStream_DropCollection(t *testing.T) {
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
_, err = inputStream.Broadcast(&msgPack3) _, err = inputStream.Broadcast(ctx, &msgPack3)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, 2) receiveMsg(ctx, outputStream, 2)
@ -803,12 +803,12 @@ func sendMsgPacks(ms MsgStream, msgPacks []*MsgPack) error {
printMsgPack(msgPacks[i]) printMsgPack(msgPacks[i])
if i%2 == 0 { if i%2 == 0 {
// insert msg use Produce // insert msg use Produce
if err := ms.Produce(msgPacks[i]); err != nil { if err := ms.Produce(context.TODO(), msgPacks[i]); err != nil {
return err return err
} }
} else { } else {
// tt msg use Broadcast // tt msg use Broadcast
if _, err := ms.Broadcast(msgPacks[i]); err != nil { if _, err := ms.Broadcast(context.TODO(), msgPacks[i]); err != nil {
return err return err
} }
} }
@ -971,7 +971,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -1015,7 +1015,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -1049,7 +1049,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err = inputStream.Produce(msgPack) err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
result := consumer(ctx, outputStream2) result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1)) assert.Equal(t, result.Msgs[0].ID(), int64(1))
@ -1074,7 +1074,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -1086,7 +1086,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
// produce timetick for mqtt msgstream seek // produce timetick for mqtt msgstream seek
msgPack = &MsgPack{} msgPack = &MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTimeTickMsg(1000)) msgPack.Msgs = append(msgPack.Msgs, getTimeTickMsg(1000))
err = inputStream.Produce(msgPack) err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
factory := ProtoUDFactory{} factory := ProtoUDFactory{}
@ -1139,7 +1139,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
factory := ProtoUDFactory{} factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
@ -1152,7 +1152,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err = inputStream.Produce(msgPack) err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
for i := 10; i < 20; i++ { for i := 10; i < 20; i++ {
@ -1169,6 +1169,7 @@ func TestStream_BroadcastMark(t *testing.T) {
c1 := funcutil.RandomString(8) c1 := funcutil.RandomString(8)
c2 := funcutil.RandomString(8) c2 := funcutil.RandomString(8)
producerChannels := []string{c1, c2} producerChannels := []string{c1, c2}
ctx := context.Background()
factory := ProtoUDFactory{} factory := ProtoUDFactory{}
pulsarClient, err := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient, err := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
@ -1177,12 +1178,12 @@ func TestStream_BroadcastMark(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// add producer channels // add producer channels
outputStream.AsProducer(producerChannels) outputStream.AsProducer(ctx, producerChannels)
msgPack0 := MsgPack{} msgPack0 := MsgPack{}
msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0))
ids, err := outputStream.Broadcast(&msgPack0) ids, err := outputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, ids) assert.NotNil(t, ids)
assert.Equal(t, len(producerChannels), len(ids)) assert.Equal(t, len(producerChannels), len(ids))
@ -1196,7 +1197,7 @@ func TestStream_BroadcastMark(t *testing.T) {
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
ids, err = outputStream.Broadcast(&msgPack1) ids, err = outputStream.Broadcast(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, ids) assert.NotNil(t, ids)
assert.Equal(t, len(producerChannels), len(ids)) assert.Equal(t, len(producerChannels), len(ids))
@ -1207,19 +1208,19 @@ func TestStream_BroadcastMark(t *testing.T) {
} }
// edge cases // edge cases
_, err = outputStream.Broadcast(nil) _, err = outputStream.Broadcast(ctx, nil)
assert.Error(t, err) assert.Error(t, err)
msgPack2 := MsgPack{} msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{}) msgPack2.Msgs = append(msgPack2.Msgs, &MarshalFailTsMsg{})
_, err = outputStream.Broadcast(&msgPack2) _, err = outputStream.Broadcast(ctx, &msgPack2)
assert.Error(t, err) assert.Error(t, err)
// mock send fail // mock send fail
for k, p := range outputStream.producers { for k, p := range outputStream.producers {
outputStream.producers[k] = &mockSendFailProducer{Producer: p} outputStream.producers[k] = &mockSendFailProducer{Producer: p}
} }
_, err = outputStream.Broadcast(&msgPack1) _, err = outputStream.Broadcast(ctx, &msgPack1)
assert.Error(t, err) assert.Error(t, err)
outputStream.Close() outputStream.Close()
@ -1497,7 +1498,7 @@ func getPulsarInputStream(ctx context.Context, pulsarAddress string, producerCha
factory := ProtoUDFactory{} factory := ProtoUDFactory{}
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts { for _, opt := range opts {
inputStream.SetRepackFunc(opt) inputStream.SetRepackFunc(opt)
} }

View File

@ -52,7 +52,7 @@ func TestMqMsgStream_AsProducer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// empty channel name // empty channel name
m.AsProducer([]string{""}) m.AsProducer(context.TODO(), []string{""})
} }
// TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage // TODO(wxyu): add a mock implement of mqwrapper.Client, then inject errors to improve coverage
@ -121,7 +121,7 @@ func TestMqMsgStream_GetProduceChannels(t *testing.T) {
assert.Equal(t, 0, len(chs)) assert.Equal(t, 0, len(chs))
// not empty after AsProducer // not empty after AsProducer
m.AsProducer([]string{"a"}) m.AsProducer(context.TODO(), []string{"a"})
chs = m.GetProduceChannels() chs = m.GetProduceChannels()
assert.Equal(t, 1, len(chs)) assert.Equal(t, 1, len(chs))
} }
@ -160,7 +160,7 @@ func TestMqMsgStream_Produce(t *testing.T) {
msgPack := &MsgPack{ msgPack := &MsgPack{
Msgs: []TsMsg{insertMsg}, Msgs: []TsMsg{insertMsg},
} }
err = m.Produce(msgPack) err = m.Produce(context.TODO(), msgPack)
assert.Error(t, err) assert.Error(t, err)
} }
@ -173,7 +173,7 @@ func TestMqMsgStream_Broadcast(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Broadcast nil pointer // Broadcast nil pointer
_, err = m.Broadcast(nil) _, err = m.Broadcast(context.TODO(), nil)
assert.Error(t, err) assert.Error(t, err)
} }
@ -241,7 +241,7 @@ func initRmqStream(ctx context.Context,
rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx) rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx)
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts { for _, opt := range opts {
inputStream.SetRepackFunc(opt) inputStream.SetRepackFunc(opt)
} }
@ -265,7 +265,7 @@ func initRmqTtStream(ctx context.Context,
rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx) rmqClient, _ := rmq.NewClientWithDefaultOptions(ctx)
inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(ctx, producerChannels)
for _, opt := range opts { for _, opt := range opts {
inputStream.SetRepackFunc(opt) inputStream.SetRepackFunc(opt)
} }
@ -290,7 +290,7 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName) inputStream, outputStream := initRmqStream(ctx, producerChannels, consumerChannels, consumerGroupName)
err := inputStream.Produce(&msgPack) err := inputStream.Produce(ctx, &msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs)) receiveMsg(ctx, outputStream, len(msgPack.Msgs))
@ -316,13 +316,13 @@ func TestStream_RmqTtMsgStream_Insert(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack1.Msgs)) receiveMsg(ctx, outputStream, len(msgPack1.Msgs))
@ -355,13 +355,13 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack2) err = inputStream.Produce(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack3) _, err = inputStream.Broadcast(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream) receivedMsg := consumer(ctx, outputStream)
@ -425,21 +425,21 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName) inputStream, outputStream := initRmqTtStream(ctx, producerChannels, consumerChannels, consumerSubName)
_, err := inputStream.Broadcast(&msgPack0) _, err := inputStream.Broadcast(ctx, &msgPack0)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack1) err = inputStream.Produce(ctx, &msgPack1)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack2) _, err = inputStream.Broadcast(ctx, &msgPack2)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack3) err = inputStream.Produce(ctx, &msgPack3)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack4) _, err = inputStream.Broadcast(ctx, &msgPack4)
assert.NoError(t, err) assert.NoError(t, err)
err = inputStream.Produce(&msgPack5) err = inputStream.Produce(ctx, &msgPack5)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack6) _, err = inputStream.Broadcast(ctx, &msgPack6)
assert.NoError(t, err) assert.NoError(t, err)
_, err = inputStream.Broadcast(&msgPack7) _, err = inputStream.Broadcast(ctx, &msgPack7)
assert.NoError(t, err) assert.NoError(t, err)
receivedMsg := consumer(ctx, outputStream) receivedMsg := consumer(ctx, outputStream)
@ -512,7 +512,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err := inputStream.Produce(msgPack) err := inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -546,7 +546,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i)) insertMsg := getTsMsg(commonpb.MsgType_Insert, int64(i))
msgPack.Msgs = append(msgPack.Msgs, insertMsg) msgPack.Msgs = append(msgPack.Msgs, insertMsg)
} }
err = inputStream.Produce(msgPack) err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err) assert.NoError(t, err)
result := consumer(ctx, outputStream2) result := consumer(ctx, outputStream2)
@ -560,27 +560,28 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
producerChannels := []string{"insert1"} producerChannels := []string{"insert1"}
consumerChannels := []string{"insert1"} consumerChannels := []string{"insert1"}
consumerSubName := "subInsert" consumerSubName := "subInsert"
ctx := context.Background()
factory := ProtoUDFactory{} factory := ProtoUDFactory{}
rmqClient, _ := rmq.NewClientWithDefaultOptions(context.Background()) rmqClient, _ := rmq.NewClientWithDefaultOptions(context.Background())
otherInputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) otherInputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
otherInputStream.AsProducer([]string{"root_timetick"}) otherInputStream.AsProducer(context.TODO(), []string{"root_timetick"})
otherInputStream.Produce(getTimeTickMsgPack(999)) otherInputStream.Produce(ctx, getTimeTickMsgPack(999))
inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) inputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
inputStream.AsProducer(producerChannels) inputStream.AsProducer(context.TODO(), producerChannels)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
inputStream.Produce(getTimeTickMsgPack(int64(i))) inputStream.Produce(ctx, getTimeTickMsgPack(int64(i)))
} }
rmqClient2, _ := rmq.NewClientWithDefaultOptions(context.Background()) rmqClient2, _ := rmq.NewClientWithDefaultOptions(context.Background())
outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) outputStream, _ := NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqcommon.SubscriptionPositionLatest) outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqcommon.SubscriptionPositionLatest)
inputStream.Produce(getTimeTickMsgPack(1000)) inputStream.Produce(ctx, getTimeTickMsgPack(1000))
pack := <-outputStream.Chan() pack := <-outputStream.Chan()
assert.NotNil(t, pack) assert.NotNil(t, pack)
assert.Equal(t, 1, len(pack.Msgs)) assert.Equal(t, 1, len(pack.Msgs))

View File

@ -17,16 +17,18 @@
package mqwrapper package mqwrapper
import ( import (
"context"
"github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/common"
) )
// Client is the interface that provides operations of message queues // Client is the interface that provides operations of message queues
type Client interface { type Client interface {
// CreateProducer creates a producer instance // CreateProducer creates a producer instance
CreateProducer(options common.ProducerOptions) (Producer, error) CreateProducer(ctx context.Context, options common.ProducerOptions) (Producer, error)
// Subscribe creates a consumer instance and subscribe a topic // Subscribe creates a consumer instance and subscribe a topic
Subscribe(options ConsumerOptions) (Consumer, error) Subscribe(ctx context.Context, options ConsumerOptions) (Consumer, error)
// Get the earliest MessageID // Get the earliest MessageID
EarliestMessageID() common.MessageID EarliestMessageID() common.MessageID

View File

@ -205,7 +205,7 @@ func (kc *kafkaClient) newConsumerConfig(group string, offset common.Subscriptio
return newConf return newConf
} }
func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) { func (kc *kafkaClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer") start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -224,7 +224,7 @@ func (kc *kafkaClient) CreateProducer(options common.ProducerOptions) (mqwrapper
return producer, nil return producer, nil
} }
func (kc *kafkaClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { func (kc *kafkaClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer") start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -64,7 +64,7 @@ func BytesToInt(b []byte) int {
// Consume1 will consume random messages and record the last MessageID it received // Consume1 will consume random messages and record the last MessageID it received
func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, c chan mqcommon.MessageID, total *int) { func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, c chan mqcommon.MessageID, total *int) {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -103,7 +103,7 @@ func Consume1(ctx context.Context, t *testing.T, kc *kafkaClient, topic string,
// Consume2 will consume messages from specified MessageID // Consume2 will consume messages from specified MessageID
func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, msgID mqcommon.MessageID, total *int) { func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, msgID mqcommon.MessageID, total *int) {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -139,7 +139,7 @@ func Consume2(ctx context.Context, t *testing.T, kc *kafkaClient, topic string,
} }
func Consume3(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, total *int) { func Consume3(ctx context.Context, t *testing.T, kc *kafkaClient, topic string, subName string, total *int) {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := kc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -418,7 +418,7 @@ func createConsumer(t *testing.T,
groupID string, groupID string,
initPosition mqcommon.SubscriptionInitialPosition, initPosition mqcommon.SubscriptionInitialPosition,
) mqwrapper.Consumer { ) mqwrapper.Consumer {
consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := kc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: groupID, SubscriptionName: groupID,
BufSize: 1024, BufSize: 1024,
@ -429,7 +429,7 @@ func createConsumer(t *testing.T,
} }
func createProducer(t *testing.T, kc *kafkaClient, topic string) mqwrapper.Producer { func createProducer(t *testing.T, kc *kafkaClient, topic string) mqwrapper.Producer {
producer, err := kc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) producer, err := kc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
return producer return producer

View File

@ -23,7 +23,7 @@ func TestKafkaProducer_SendSuccess(t *testing.T) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
topic := fmt.Sprintf("test-topic-%d", rand.Int()) topic := fmt.Sprintf("test-topic-%d", rand.Int())
producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic}) producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
@ -76,7 +76,7 @@ func TestKafkaProducer_SendFailAfterClose(t *testing.T) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
topic := fmt.Sprintf("test-topic-%d", rand.Int()) topic := fmt.Sprintf("test-topic-%d", rand.Int())
producer, err := kc.CreateProducer(common.ProducerOptions{Topic: topic}) producer, err := kc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)

View File

@ -80,7 +80,7 @@ func NewClient(url string, options ...nats.Option) (*nmqClient, error) {
} }
// CreateProducer creates a producer for natsmq client // CreateProducer creates a producer for natsmq client
func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) { func (nc *nmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer") start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -112,7 +112,7 @@ func (nc *nmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P
return &rp, nil return &rp, nil
} }
func (nc *nmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { func (nc *nmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer") start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -86,7 +86,7 @@ func TestNmqClient_CreateProducer(t *testing.T) {
topic := "TestNmqClient_CreateProducer" topic := "TestNmqClient_CreateProducer"
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts) producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
defer producer.Close() defer producer.Close()
@ -102,7 +102,7 @@ func TestNmqClient_CreateProducer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
invalidOpts := common.ProducerOptions{Topic: ""} invalidOpts := common.ProducerOptions{Topic: ""}
producer, e := client.CreateProducer(invalidOpts) producer, e := client.CreateProducer(context.TODO(), invalidOpts)
assert.Nil(t, producer) assert.Nil(t, producer)
assert.Error(t, e) assert.Error(t, e)
} }
@ -114,7 +114,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) {
topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int()) topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int())
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts) producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
defer producer.Close() defer producer.Close()
@ -135,7 +135,7 @@ func TestNmqClient_GetLatestMsg(t *testing.T) {
BufSize: 1024, BufSize: 1024,
} }
consumer, err := client.Subscribe(consumerOpts) consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.NoError(t, err) assert.NoError(t, err)
expectLastMsg, err := consumer.GetLatestMsgID() expectLastMsg, err := consumer.GetLatestMsgID()
@ -166,13 +166,13 @@ func TestNmqClient_IllegalSubscribe(t *testing.T) {
assert.NotNil(t, client) assert.NotNil(t, client)
defer client.Close() defer client.Close()
sub, err := client.Subscribe(mqwrapper.ConsumerOptions{ sub, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: "", Topic: "",
}) })
assert.Nil(t, sub) assert.Nil(t, sub)
assert.Error(t, err) assert.Error(t, err)
sub, err = client.Subscribe(mqwrapper.ConsumerOptions{ sub, err = client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: "123", Topic: "123",
SubscriptionName: "", SubscriptionName: "",
}) })
@ -188,7 +188,7 @@ func TestNmqClient_Subscribe(t *testing.T) {
topic := "TestNmqClient_Subscribe" topic := "TestNmqClient_Subscribe"
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts) producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
defer producer.Close() defer producer.Close()
@ -201,12 +201,12 @@ func TestNmqClient_Subscribe(t *testing.T) {
BufSize: 1024, BufSize: 1024,
} }
consumer, err := client.Subscribe(consumerOpts) consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, consumer) assert.Nil(t, consumer)
consumerOpts.Topic = topic consumerOpts.Topic = topic
consumer, err = client.Subscribe(consumerOpts) consumer, err = client.Subscribe(context.TODO(), consumerOpts)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, consumer) assert.NotNil(t, consumer)
defer consumer.Close() defer consumer.Close()

View File

@ -36,10 +36,10 @@ func TestNatsConsumer_Subscription(t *testing.T) {
topic := t.Name() topic := t.Name()
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
_, err = client.CreateProducer(proOpts) _, err = client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -69,7 +69,7 @@ func Test_BadLatestMessageID(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer client.Close() defer client.Close()
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -88,10 +88,10 @@ func TestComsumeMessage(t *testing.T) {
defer client.Close() defer client.Close()
topic := t.Name() topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{ c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -149,7 +149,7 @@ func TestNatsConsumer_Close(t *testing.T) {
defer client.Close() defer client.Close()
topic := t.Name() topic := t.Name()
c, err := client.Subscribe(mqwrapper.ConsumerOptions{ c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -177,7 +177,7 @@ func TestNatsClientErrorOnUnsubscribeTwice(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer client.Close() defer client.Close()
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -199,7 +199,7 @@ func TestCheckTopicValid(t *testing.T) {
defer client.Close() defer client.Close()
topic := t.Name() topic := t.Name()
consumer, err := client.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -220,7 +220,7 @@ func TestCheckTopicValid(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// not empty topic can pass // not empty topic can pass
pub, err := client.CreateProducer(common.ProducerOptions{ pub, err := client.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topic, Topic: topic,
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -240,7 +240,7 @@ func TestCheckTopicValid(t *testing.T) {
func newTestConsumer(t *testing.T, topic string, position common.SubscriptionInitialPosition) (mqwrapper.Consumer, error) { func newTestConsumer(t *testing.T, topic string, position common.SubscriptionInitialPosition) (mqwrapper.Consumer, error) {
client, err := createNmqClient() client, err := createNmqClient()
assert.NoError(t, err) assert.NoError(t, err)
return client.Subscribe(mqwrapper.ConsumerOptions{ return client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: position, SubscriptionInitialPosition: position,
@ -251,7 +251,7 @@ func newTestConsumer(t *testing.T, topic string, position common.SubscriptionIni
func newProducer(t *testing.T, topic string) (*nmqClient, mqwrapper.Producer) { func newProducer(t *testing.T, topic string) (*nmqClient, mqwrapper.Producer) {
client, err := createNmqClient() client, err := createNmqClient()
assert.NoError(t, err) assert.NoError(t, err)
producer, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) producer, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
return client, producer return client, producer
} }
@ -272,10 +272,10 @@ func TestNmqConsumer_GetLatestMsgID(t *testing.T) {
defer client.Close() defer client.Close()
topic := t.Name() topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{ c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -301,13 +301,13 @@ func TestNmqConsumer_ConsumeFromLatest(t *testing.T) {
defer client.Close() defer client.Close()
topic := t.Name() topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
msgs := []string{"111", "222", "333"} msgs := []string{"111", "222", "333"}
process(t, msgs, p) process(t, msgs, p)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{ c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionLatest, SubscriptionInitialPosition: common.SubscriptionPositionLatest,
@ -331,13 +331,13 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) {
defer client.Close() defer client.Close()
topic := t.Name() topic := t.Name()
p, err := client.CreateProducer(common.ProducerOptions{Topic: topic}) p, err := client.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
msgs := []string{"111", "222"} msgs := []string{"111", "222"}
process(t, msgs, p) process(t, msgs, p)
c, err := client.Subscribe(mqwrapper.ConsumerOptions{ c, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -354,7 +354,7 @@ func TestNmqConsumer_ConsumeFromEarliest(t *testing.T) {
msg = <-c.Chan() msg = <-c.Chan()
assert.Equal(t, "222", string(msg.Payload())) assert.Equal(t, "222", string(msg.Payload()))
c2, err := client.Subscribe(mqwrapper.ConsumerOptions{ c2, err := client.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,

View File

@ -3,7 +3,7 @@
// distributed with this work for additional information // distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file // regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the // to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance // "License"); you may not use this file exceapt in compliance
// with the License. You may obtain a copy of the License at // with the License. You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0

View File

@ -33,7 +33,7 @@ func TestNatsMQProducer(t *testing.T) {
pOpts := common.ProducerOptions{Topic: topic} pOpts := common.ProducerOptions{Topic: topic}
// Check Topic() // Check Topic()
p, err := c.CreateProducer(pOpts) p, err := c.CreateProducer(context.TODO(), pOpts)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, p.(*nmqProducer).Topic(), topic) assert.Equal(t, p.(*nmqProducer).Topic(), topic)

View File

@ -17,6 +17,7 @@
package pulsar package pulsar
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -66,7 +67,7 @@ func NewClient(tenant string, namespace string, opts pulsar.ClientOptions) (*pul
} }
// CreateProducer create a pulsar producer from options // CreateProducer create a pulsar producer from options
func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrapper.Producer, error) { func (pc *pulsarClient) CreateProducer(ctx context.Context, options mqcommon.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer") start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -102,7 +103,7 @@ func (pc *pulsarClient) CreateProducer(options mqcommon.ProducerOptions) (mqwrap
} }
// Subscribe creates a pulsar consumer instance and subscribe a topic // Subscribe creates a pulsar consumer instance and subscribe a topic
func (pc *pulsarClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { func (pc *pulsarClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer") start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -78,7 +78,7 @@ func BytesToInt(b []byte) int {
} }
func Produce(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, arr []int) { func Produce(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, arr []int) {
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
@ -110,7 +110,7 @@ func VerifyMessage(t *testing.T, msg mqcommon.Message) {
// Consume1 will consume random messages and record the last MessageID it received // Consume1 will consume random messages and record the last MessageID it received
func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, c chan mqcommon.MessageID, total *int) { func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, c chan mqcommon.MessageID, total *int) {
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -147,7 +147,7 @@ func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
// Consume2 will consume messages from specified MessageID // Consume2 will consume messages from specified MessageID
func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, msgID mqcommon.MessageID, total *int) { func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, msgID mqcommon.MessageID, total *int) {
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -181,7 +181,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string,
} }
func Consume3(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, total *int) { func Consume3(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, subName string, total *int) {
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := pc.Subscribe(ctx, mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -420,7 +420,7 @@ func TestPulsarClient_SeekPosition(t *testing.T) {
topic := fmt.Sprintf("test-topic-%d", rand.Int()) topic := fmt.Sprintf("test-topic-%d", rand.Int())
subName := fmt.Sprintf("test-subname-%d", rand.Int()) subName := fmt.Sprintf("test-subname-%d", rand.Int())
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
@ -498,7 +498,7 @@ func TestPulsarClient_SeekLatest(t *testing.T) {
topic := fmt.Sprintf("test-topic-%d", rand.Int()) topic := fmt.Sprintf("test-topic-%d", rand.Int())
subName := fmt.Sprintf("test-subname-%d", rand.Int()) subName := fmt.Sprintf("test-subname-%d", rand.Int())
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) producer, err := pc.CreateProducer(ctx, mqcommon.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
@ -671,7 +671,7 @@ func TestPulsarClient_SubscribeExclusiveFail(t *testing.T) {
client: &mockPulsarClient{}, client: &mockPulsarClient{},
} }
_, err := pc.Subscribe(mqwrapper.ConsumerOptions{Topic: "test_topic_name"}) _, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{Topic: "test_topic_name"})
assert.Error(t, err) assert.Error(t, err)
assert.True(t, retry.IsRecoverable(err)) assert.True(t, retry.IsRecoverable(err))
}) })
@ -686,7 +686,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) {
pulsarAddress := getPulsarAddress() pulsarAddress := getPulsarAddress()
pc, err := NewClient(tenant, namespace, pulsar.ClientOptions{URL: pulsarAddress}) pc, err := NewClient(tenant, namespace, pulsar.ClientOptions{URL: pulsarAddress})
assert.NoError(t, err) assert.NoError(t, err)
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: topic}) producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: topic})
defer producer.Close() defer producer.Close()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
@ -695,7 +695,7 @@ func TestPulsarClient_WithTenantAndNamespace(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, fullTopicName, producer.(*pulsarProducer).Topic()) assert.Equal(t, fullTopicName, producer.(*pulsarProducer).Topic())
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -713,7 +713,7 @@ func TestPulsarCtl(t *testing.T) {
pulsarAddress := getPulsarAddress() pulsarAddress := getPulsarAddress()
pc, err := NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) pc, err := NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
assert.NoError(t, err) assert.NoError(t, err)
consumer, err := pc.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -723,7 +723,7 @@ func TestPulsarCtl(t *testing.T) {
assert.NotNil(t, consumer) assert.NotNil(t, consumer)
defer consumer.Close() defer consumer.Close()
_, err = pc.Subscribe(mqwrapper.ConsumerOptions{ _, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -732,7 +732,7 @@ func TestPulsarCtl(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
_, err = pc.Subscribe(mqwrapper.ConsumerOptions{ _, err = pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,
@ -762,7 +762,7 @@ func TestPulsarCtl(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
consumer2, err := pc.Subscribe(mqwrapper.ConsumerOptions{ consumer2, err := pc.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: subName, SubscriptionName: subName,
BufSize: 1024, BufSize: 1024,

View File

@ -80,9 +80,9 @@ func TestComsumeCompressedMessage(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer consumer.Close() defer consumer.Close()
producer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics"}) producer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics"})
assert.NoError(t, err) assert.NoError(t, err)
compressProducer, err := pc.CreateProducer(mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true}) compressProducer, err := pc.CreateProducer(context.TODO(), mqcommon.ProducerOptions{Topic: "TestTopics", EnableCompression: true})
assert.NoError(t, err) assert.NoError(t, err)
msg := []byte("test message") msg := []byte("test message")

View File

@ -34,7 +34,7 @@ func TestPulsarProducer(t *testing.T) {
assert.NotNil(t, pc) assert.NotNil(t, pc)
topic := "TEST" topic := "TEST"
producer, err := pc.CreateProducer(common.ProducerOptions{Topic: topic}) producer, err := pc.CreateProducer(context.TODO(), common.ProducerOptions{Topic: topic})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)

View File

@ -58,7 +58,7 @@ func NewClient(opts client.Options) (*rmqClient, error) {
} }
// CreateProducer creates a producer for rocksmq client // CreateProducer creates a producer for rocksmq client
func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.Producer, error) { func (rc *rmqClient) CreateProducer(ctx context.Context, options common.ProducerOptions) (mqwrapper.Producer, error) {
start := timerecord.NewTimeRecorder("create producer") start := timerecord.NewTimeRecorder("create producer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateProducerLabel, metrics.TotalLabel).Inc()
@ -77,7 +77,7 @@ func (rc *rmqClient) CreateProducer(options common.ProducerOptions) (mqwrapper.P
} }
// Subscribe subscribes a consumer in rmq client // Subscribe subscribes a consumer in rmq client
func (rc *rmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) { func (rc *rmqClient) Subscribe(ctx context.Context, options mqwrapper.ConsumerOptions) (mqwrapper.Consumer, error) {
start := timerecord.NewTimeRecorder("create consumer") start := timerecord.NewTimeRecorder("create consumer")
metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc() metrics.MsgStreamOpCounter.WithLabelValues(metrics.CreateConsumerLabel, metrics.TotalLabel).Inc()

View File

@ -65,7 +65,7 @@ func TestRmqClient_CreateProducer(t *testing.T) {
topic := "TestRmqClient_CreateProducer" topic := "TestRmqClient_CreateProducer"
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts) producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
@ -83,7 +83,7 @@ func TestRmqClient_CreateProducer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
invalidOpts := common.ProducerOptions{Topic: ""} invalidOpts := common.ProducerOptions{Topic: ""}
producer, e := client.CreateProducer(invalidOpts) producer, e := client.CreateProducer(context.TODO(), invalidOpts)
assert.Nil(t, producer) assert.Nil(t, producer)
assert.Error(t, e) assert.Error(t, e)
} }
@ -95,7 +95,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) {
topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int()) topic := fmt.Sprintf("t2GetLatestMsg-%d", rand.Int())
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts) producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
defer producer.Close() defer producer.Close()
@ -116,7 +116,7 @@ func TestRmqClient_GetLatestMsg(t *testing.T) {
BufSize: 1024, BufSize: 1024,
} }
consumer, err := client.Subscribe(consumerOpts) consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.NoError(t, err) assert.NoError(t, err)
expectLastMsg, err := consumer.GetLatestMsgID() expectLastMsg, err := consumer.GetLatestMsgID()
@ -149,7 +149,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
topic := "TestRmqClient_Subscribe" topic := "TestRmqClient_Subscribe"
proOpts := common.ProducerOptions{Topic: topic} proOpts := common.ProducerOptions{Topic: topic}
producer, err := client.CreateProducer(proOpts) producer, err := client.CreateProducer(context.TODO(), proOpts)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, producer) assert.NotNil(t, producer)
defer producer.Close() defer producer.Close()
@ -161,7 +161,7 @@ func TestRmqClient_Subscribe(t *testing.T) {
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
BufSize: 0, BufSize: 0,
} }
consumer, err := client.Subscribe(consumerOpts) consumer, err := client.Subscribe(context.TODO(), consumerOpts)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, consumer) assert.Nil(t, consumer)
@ -172,12 +172,12 @@ func TestRmqClient_Subscribe(t *testing.T) {
BufSize: 1024, BufSize: 1024,
} }
consumer, err = client.Subscribe(consumerOpts) consumer, err = client.Subscribe(context.TODO(), consumerOpts)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, consumer) assert.Nil(t, consumer)
consumerOpts.Topic = topic consumerOpts.Topic = topic
consumer, err = client.Subscribe(consumerOpts) consumer, err = client.Subscribe(context.TODO(), consumerOpts)
defer consumer.Close() defer consumer.Close()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, consumer) assert.NotNil(t, consumer)

View File

@ -55,11 +55,11 @@ type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, erro
type MsgStream interface { type MsgStream interface {
Close() Close()
AsProducer(channels []string) AsProducer(ctx context.Context, channels []string)
Produce(*MsgPack) error Produce(context.Context, *MsgPack) error
SetRepackFunc(repackFunc RepackFunc) SetRepackFunc(repackFunc RepackFunc)
GetProduceChannels() []string GetProduceChannels() []string
Broadcast(*MsgPack) (map[string][]MessageID, error) Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error)
AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error
Chan() <-chan *MsgPack Chan() <-chan *MsgPack

View File

@ -36,7 +36,7 @@ func TestPulsarMsgUtil(t *testing.T) {
defer msgStream.Close() defer msgStream.Close()
// create a topic // create a topic
msgStream.AsProducer([]string{"test"}) msgStream.AsProducer(ctx, []string{"test"})
UnsubscribeChannels(ctx, pmsFactory, "sub", []string{"test"}) UnsubscribeChannels(ctx, pmsFactory, "sub", []string{"test"})
} }

View File

@ -46,7 +46,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [
go func() { go func() {
defer wg.Done() defer wg.Done()
p, err := mqClient.CreateProducer(common.ProducerOptions{ p, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topic, Topic: topic,
}) })
assert.NoError(b, err) assert.NoError(b, err)
@ -55,7 +55,7 @@ func benchmarkProduceAndConsume(b *testing.B, mqClient mqwrapper.Client, cases [
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
c, _ := mqClient.Subscribe(mqwrapper.ConsumerOptions{ c, _ := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topic, Topic: topic,
SubscriptionName: topic, SubscriptionName: topic,
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,

View File

@ -40,13 +40,13 @@ func testStreamOperation(t *testing.T, mqClient mqwrapper.Client) {
func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2) topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{ producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0], Topic: topics[0],
}) })
defer producer.Close() defer producer.Close()
assert.NoError(t, err) assert.NoError(t, err)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0], Topic: topics[0],
SubscriptionName: funcutil.RandomString(8), SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionEarliest, SubscriptionInitialPosition: common.SubscriptionPositionEarliest,
@ -61,7 +61,7 @@ func testConcurrentStream(t *testing.T, mqClient mqwrapper.Client) {
func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2) topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{ producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0], Topic: topics[0],
}) })
defer producer.Close() defer producer.Close()
@ -69,7 +69,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien
ids := sendMessages(context.Background(), t, producer, generateRandMessage(1024, 1000)) ids := sendMessages(context.Background(), t, producer, generateRandMessage(1024, 1000))
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0], Topic: topics[0],
SubscriptionName: funcutil.RandomString(8), SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionLatest, SubscriptionInitialPosition: common.SubscriptionPositionLatest,
@ -90,7 +90,7 @@ func testConcurrentStreamAndSubscribeLast(t *testing.T, mqClient mqwrapper.Clien
func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2) topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{ producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0], Topic: topics[0],
}) })
defer producer.Close() defer producer.Close()
@ -99,7 +99,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien
cases := generateRandMessage(1024, 1000) cases := generateRandMessage(1024, 1000)
ids := sendMessages(context.Background(), t, producer, cases) ids := sendMessages(context.Background(), t, producer, cases)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0], Topic: topics[0],
SubscriptionName: funcutil.RandomString(8), SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionUnknown, SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
@ -124,7 +124,7 @@ func testConcurrentStreamAndSeekInclusive(t *testing.T, mqClient mqwrapper.Clien
func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2) topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{ producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0], Topic: topics[0],
}) })
defer producer.Close() defer producer.Close()
@ -133,7 +133,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli
cases := generateRandMessage(1024, 1000) cases := generateRandMessage(1024, 1000)
ids := sendMessages(context.Background(), t, producer, cases) ids := sendMessages(context.Background(), t, producer, cases)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0], Topic: topics[0],
SubscriptionName: funcutil.RandomString(8), SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionUnknown, SubscriptionInitialPosition: common.SubscriptionPositionUnknown,
@ -158,7 +158,7 @@ func testConcurrentStreamAndSeekNoInclusive(t *testing.T, mqClient mqwrapper.Cli
func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client) { func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client) {
topics := getChannel(2) topics := getChannel(2)
producer, err := mqClient.CreateProducer(common.ProducerOptions{ producer, err := mqClient.CreateProducer(context.TODO(), common.ProducerOptions{
Topic: topics[0], Topic: topics[0],
}) })
defer producer.Close() defer producer.Close()
@ -167,7 +167,7 @@ func testConcurrentStreamAndSeekToLast(t *testing.T, mqClient mqwrapper.Client)
cases := generateRandMessage(1024, 1000) cases := generateRandMessage(1024, 1000)
sendMessages(context.Background(), t, producer, cases) sendMessages(context.Background(), t, producer, cases)
consumer, err := mqClient.Subscribe(mqwrapper.ConsumerOptions{ consumer, err := mqClient.Subscribe(context.TODO(), mqwrapper.ConsumerOptions{
Topic: topics[0], Topic: topics[0],
SubscriptionName: funcutil.RandomString(8), SubscriptionName: funcutil.RandomString(8),
SubscriptionInitialPosition: common.SubscriptionPositionUnknown, SubscriptionInitialPosition: common.SubscriptionPositionUnknown,

View File

@ -1,5 +1,7 @@
package msgstream package msgstream
import "context"
type WastedMockMsgStream struct { type WastedMockMsgStream struct {
MsgStream MsgStream
AsProducerFunc func(channels []string) AsProducerFunc func(channels []string)
@ -12,11 +14,11 @@ func NewWastedMockMsgStream() *WastedMockMsgStream {
return &WastedMockMsgStream{} return &WastedMockMsgStream{}
} }
func (m WastedMockMsgStream) AsProducer(channels []string) { func (m WastedMockMsgStream) AsProducer(ctx context.Context, channels []string) {
m.AsProducerFunc(channels) m.AsProducerFunc(channels)
} }
func (m WastedMockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) { func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[string][]MessageID, error) {
return m.BroadcastMarkFunc(pack) return m.BroadcastMarkFunc(pack)
} }