enhance: [2.4] add the includeCurrentMsg param for the Seek method (#33743)

/kind improvement

- issue: #33325
- pr: #33326

Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
SimFG 2024-06-11 15:01:55 +08:00 committed by GitHub
parent ee22750104
commit c331aa4ad3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 59 additions and 54 deletions

View File

@ -91,7 +91,7 @@ func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstrea
return nil, nil
}
func (mtm *mockTtMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error {
func (mtm *mockTtMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
return nil
}

View File

@ -240,7 +240,7 @@ func TestMqMsgStream_SeekNotSubscribed(t *testing.T) {
ChannelName: "b",
},
}
err = m.Seek(context.Background(), p)
err = m.Seek(context.Background(), p, false)
assert.Error(t, err)
}
@ -403,7 +403,7 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
consumerSubName = funcutil.RandomString(8)
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(ctx, receivedMsg.StartPositions)
outputStream.Seek(ctx, receivedMsg.StartPositions, false)
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 1+2)
assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1)
@ -506,7 +506,7 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
consumerSubName = funcutil.RandomString(8)
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(ctx, receivedMsg3.StartPositions)
outputStream.Seek(ctx, receivedMsg3.StartPositions, false)
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
@ -565,7 +565,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
},
}
err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.NoError(t, err)
for i := 10; i < 20; i++ {

View File

@ -298,7 +298,7 @@ func (ms *simpleMockMsgStream) GetProduceChannels() []string {
return nil
}
func (ms *simpleMockMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error {
func (ms *simpleMockMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
return nil
}

View File

@ -698,7 +698,7 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position
}
ts = time.Now()
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position})
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false)
if err != nil {
return nil, err
}

View File

@ -725,7 +725,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
}, 10)
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.MsgPack, 10)
close(ch)
@ -1173,7 +1173,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
defer cancel()
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.MsgPack, 10)
s.mq.EXPECT().Chan().Return(ch)

View File

@ -306,7 +306,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()
@ -358,7 +358,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()
@ -432,7 +432,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Close().Return()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(errors.New("mock error")).Once()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once()
status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)

View File

@ -293,8 +293,10 @@ func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.M
}
return nil, nil
}
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { return nil }
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
return nil
}
func (ms *FailMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, error) {
return nil, nil

View File

@ -103,7 +103,7 @@ func NewDispatcher(ctx context.Context,
return nil, err
}
err = stream.Seek(ctx, []*Pos{position})
err = stream.Seek(ctx, []*Pos{position}, false)
if err != nil {
stream.Close()
log.Error("seek failed", zap.Error(err))

View File

@ -766,7 +766,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer,
consumer, err := newer(ctx)
assert.NoError(t, err)
consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
err = consumer.Seek(context.Background(), seekPositions)
err = consumer.Seek(context.Background(), seekPositions, false)
assert.NoError(t, err)
return consumer
}

View File

@ -44,10 +44,10 @@ type MockMsgStream_AsConsumer_Call struct {
}
// AsConsumer is a helper method to define mock.On call
// - ctx context.Context
// - channels []string
// - subName string
// - position mqwrapper.SubscriptionInitialPosition
// - ctx context.Context
// - channels []string
// - subName string
// - position mqwrapper.SubscriptionInitialPosition
func (_e *MockMsgStream_Expecter) AsConsumer(ctx interface{}, channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call {
return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", ctx, channels, subName, position)}
}
@ -80,7 +80,7 @@ type MockMsgStream_AsProducer_Call struct {
}
// AsProducer is a helper method to define mock.On call
// - channels []string
// - channels []string
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call {
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)}
}
@ -134,7 +134,7 @@ type MockMsgStream_Broadcast_Call struct {
}
// Broadcast is a helper method to define mock.On call
// - _a0 *MsgPack
// - _a0 *MsgPack
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call {
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)}
}
@ -219,7 +219,7 @@ type MockMsgStream_CheckTopicValid_Call struct {
}
// CheckTopicValid is a helper method to define mock.On call
// - channel string
// - channel string
func (_e *MockMsgStream_Expecter) CheckTopicValid(channel interface{}) *MockMsgStream_CheckTopicValid_Call {
return &MockMsgStream_CheckTopicValid_Call{Call: _e.mock.On("CheckTopicValid", channel)}
}
@ -284,7 +284,7 @@ type MockMsgStream_EnableProduce_Call struct {
}
// EnableProduce is a helper method to define mock.On call
// - can bool
// - can bool
func (_e *MockMsgStream_Expecter) EnableProduce(can interface{}) *MockMsgStream_EnableProduce_Call {
return &MockMsgStream_EnableProduce_Call{Call: _e.mock.On("EnableProduce", can)}
}
@ -338,7 +338,7 @@ type MockMsgStream_GetLatestMsgID_Call struct {
}
// GetLatestMsgID is a helper method to define mock.On call
// - channel string
// - channel string
func (_e *MockMsgStream_Expecter) GetLatestMsgID(channel interface{}) *MockMsgStream_GetLatestMsgID_Call {
return &MockMsgStream_GetLatestMsgID_Call{Call: _e.mock.On("GetLatestMsgID", channel)}
}
@ -423,7 +423,7 @@ type MockMsgStream_Produce_Call struct {
}
// Produce is a helper method to define mock.On call
// - _a0 *MsgPack
// - _a0 *MsgPack
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call {
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)}
}
@ -445,13 +445,13 @@ func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *Mo
return _c
}
// Seek provides a mock function with given fields: ctx, offset
func (_m *MockMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error {
ret := _m.Called(ctx, offset)
// Seek provides a mock function with given fields: ctx, msgPositions, includeCurrentMsg
func (_m *MockMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition, includeCurrentMsg bool) error {
ret := _m.Called(ctx, msgPositions, includeCurrentMsg)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.MsgPosition) error); ok {
r0 = rf(ctx, offset)
if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.MsgPosition, bool) error); ok {
r0 = rf(ctx, msgPositions, includeCurrentMsg)
} else {
r0 = ret.Error(0)
}
@ -465,15 +465,16 @@ type MockMsgStream_Seek_Call struct {
}
// Seek is a helper method to define mock.On call
// - ctx context.Context
// - offset []*msgpb.MsgPosition
func (_e *MockMsgStream_Expecter) Seek(ctx interface{}, offset interface{}) *MockMsgStream_Seek_Call {
return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", ctx, offset)}
// - ctx context.Context
// - msgPositions []*msgpb.MsgPosition
// - includeCurrentMsg bool
func (_e *MockMsgStream_Expecter) Seek(ctx interface{}, msgPositions interface{}, includeCurrentMsg interface{}) *MockMsgStream_Seek_Call {
return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", ctx, msgPositions, includeCurrentMsg)}
}
func (_c *MockMsgStream_Seek_Call) Run(run func(ctx context.Context, offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call {
func (_c *MockMsgStream_Seek_Call) Run(run func(ctx context.Context, msgPositions []*msgpb.MsgPosition, includeCurrentMsg bool)) *MockMsgStream_Seek_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]*msgpb.MsgPosition))
run(args[0].(context.Context), args[1].([]*msgpb.MsgPosition), args[2].(bool))
})
return _c
}
@ -483,7 +484,7 @@ func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call {
return _c
}
func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func(context.Context, []*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call {
func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func(context.Context, []*msgpb.MsgPosition, bool) error) *MockMsgStream_Seek_Call {
_c.Call.Return(run)
return _c
}
@ -499,7 +500,7 @@ type MockMsgStream_SetRepackFunc_Call struct {
}
// SetRepackFunc is a helper method to define mock.On call
// - repackFunc RepackFunc
// - repackFunc RepackFunc
func (_e *MockMsgStream_Expecter) SetRepackFunc(repackFunc interface{}) *MockMsgStream_SetRepackFunc_Call {
return &MockMsgStream_SetRepackFunc_Call{Call: _e.mock.On("SetRepackFunc", repackFunc)}
}

View File

@ -145,7 +145,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
defer outputStream2.Close()
assert.NoError(t, err)
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false)
assert.NoError(t, err)
cnt := 0
@ -482,6 +482,6 @@ func getKafkaTtOutputStreamAndSeek(ctx context.Context, kafkaAddress string, pos
consumerName = append(consumerName, c.ChannelName)
}
outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(context.Background(), positions)
outputStream.Seek(context.Background(), positions, false)
return outputStream
}

View File

@ -473,7 +473,7 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack {
// Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive
// User has to ensure mq_msgstream is not closed before seek, and the seek position is already written.
func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error {
func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error {
for _, mp := range msgPositions {
consumer, ok := ms.consumers[mp.ChannelName]
if !ok {
@ -493,8 +493,8 @@ func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPositi
}
}
log.Info("MsgStream seek begin", zap.String("channel", mp.ChannelName), zap.Any("MessageID", mp.MsgID))
err = consumer.Seek(messageID, false)
log.Info("MsgStream seek begin", zap.String("channel", mp.ChannelName), zap.Any("MessageID", mp.MsgID), zap.Bool("includeCurrentMsg", includeCurrentMsg))
err = consumer.Seek(messageID, includeCurrentMsg)
if err != nil {
log.Warn("Failed to seek", zap.String("channel", mp.ChannelName), zap.Error(err))
return err
@ -840,7 +840,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu
}
// Seek to the specified position
func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error {
func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error {
var consumer mqwrapper.Consumer
var mp *MsgPosition
var err error

View File

@ -517,7 +517,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
defer outputStream2.Close()
assert.NoError(t, err)
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false)
assert.NoError(t, err)
cnt := 0
@ -946,7 +946,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false)
for i := 6; i < 10; i++ {
result := consumer(ctx, outputStream2)
@ -1001,7 +1001,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
},
}
err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.NoError(t, err)
for i := 10; i < 20; i++ {
@ -1070,15 +1070,15 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
}
paramtable.Get().Save(paramtable.Get().MQCfg.IgnoreBadPosition.Key, "false")
err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.Error(t, err)
err = outputStream3.Seek(ctx, p)
err = outputStream3.Seek(ctx, p, false)
assert.Error(t, err)
paramtable.Get().Save(paramtable.Get().MQCfg.IgnoreBadPosition.Key, "true")
err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.NoError(t, err)
err = outputStream3.Seek(ctx, p)
err = outputStream3.Seek(ctx, p, false)
assert.NoError(t, err)
}
@ -1466,7 +1466,7 @@ func getPulsarTtOutputStreamAndSeek(ctx context.Context, pulsarAddress string, p
consumerName = append(consumerName, c.ChannelName)
}
outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(context.Background(), positions)
outputStream.Seek(context.Background(), positions, false)
return outputStream
}

View File

@ -63,7 +63,9 @@ type MsgStream interface {
AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error
Chan() <-chan *MsgPack
Seek(ctx context.Context, offset []*MsgPosition) error
// Seek consume message from the specified position
// includeCurrentMsg indicates whether to consume the current message, and in the milvus system, it should be always false
Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error
GetLatestMsgID(channel string) (MessageID, error)
CheckTopicValid(channel string) error