diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index 3b1132cff1..ae5b530387 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -347,7 +347,7 @@ func (s *Server) startStatsChannel(ctx context.Context) { // try to restore last processed pos pos, err := s.loadStreamLastPos(streamTypeStats) if err == nil { - err = statsStream.Seek(pos) + err = statsStream.Seek([]*internalpb.MsgPosition{pos}) if err != nil { log.Error("Failed to seek to last pos for statsStream", zap.String("StatisChanName", Params.StatisticsChannelName), @@ -403,7 +403,7 @@ func (s *Server) startSegmentFlushChannel(ctx context.Context) { // try to restore last processed pos pos, err := s.loadStreamLastPos(streamTypeFlush) if err == nil { - err = flushStream.Seek(pos) + err = flushStream.Seek([]*internalpb.MsgPosition{pos}) if err != nil { log.Error("Failed to seek to last pos for segment flush Stream", zap.String("SegInfoChannelName", Params.SegmentInfoChannelName), diff --git a/internal/msgstream/mem_msgstream.go b/internal/msgstream/mem_msgstream.go index 2257852d8c..fdb3bbe4c7 100644 --- a/internal/msgstream/mem_msgstream.go +++ b/internal/msgstream/mem_msgstream.go @@ -203,6 +203,6 @@ func (mms *MemMsgStream) Chan() <-chan *MsgPack { return mms.receiveBuf } -func (mms *MemMsgStream) Seek(offset *MsgPosition) error { +func (mms *MemMsgStream) Seek(offset []*MsgPosition) error { return errors.New("MemMsgStream seek not implemented") } diff --git a/internal/msgstream/mq_msgstream.go b/internal/msgstream/mq_msgstream.go index 83cec0afec..c276f65bfb 100644 --- a/internal/msgstream/mq_msgstream.go +++ b/internal/msgstream/mq_msgstream.go @@ -14,6 +14,7 @@ package msgstream import ( "context" "errors" + "fmt" "path/filepath" "sync" "time" @@ -351,9 +352,12 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack { return ms.receiveBuf } -func (ms *mqMsgStream) Seek(mp *internalpb.MsgPosition) error { - if _, ok := ms.consumers[mp.ChannelName]; ok { - consumer := ms.consumers[mp.ChannelName] +func (ms *mqMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { + for _, mp := range msgPositions { + consumer, ok := ms.consumers[mp.ChannelName] + if !ok { + return fmt.Errorf("channel %s not subscribed", mp.ChannelName) + } messageID, err := ms.client.BytesToMsgID(mp.MsgID) if err != nil { return err @@ -362,10 +366,8 @@ func (ms *mqMsgStream) Seek(mp *internalpb.MsgPosition) error { if err != nil { return err } - return nil } - - return errors.New("msgStream seek fail") + return nil } type MqTtMsgStream struct { @@ -661,28 +663,20 @@ func checkTimeTickMsg(msg map[mqclient.Consumer]Timestamp, return 0, false } -func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error { - if len(mp.MsgID) == 0 { - return errors.New("when msgID's length equal to 0, please use AsConsumer interface") - } +// Seek to the specified position +func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { var consumer mqclient.Consumer + var mp *MsgPosition var err error - var hasWatched bool - seekChannel := mp.ChannelName - subName := mp.MsgGroup - ms.consumerLock.Lock() - defer ms.consumerLock.Unlock() - consumer, hasWatched = ms.consumers[seekChannel] - - if hasWatched { - return errors.New("the channel should has not been subscribed") - } - fn := func() error { + if _, ok := ms.consumers[mp.ChannelName]; ok { + return fmt.Errorf("the channel should not been subscribed") + } + receiveChannel := make(chan mqclient.ConsumerMessage, ms.bufSize) consumer, err = ms.client.Subscribe(mqclient.ConsumerOptions{ - Topic: seekChannel, - SubscriptionName: subName, + Topic: mp.ChannelName, + SubscriptionName: mp.MsgGroup, SubscriptionInitialPosition: mqclient.SubscriptionPositionEarliest, Type: mqclient.KeyShared, MessageChannel: receiveChannel, @@ -691,70 +685,74 @@ func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error { return err } if consumer == nil { - err = errors.New("consumer is nil") - log.Debug("subscribe error", zap.String("error = ", err.Error())) - return err + return fmt.Errorf("consumer is nil") } seekMsgID, err := ms.client.BytesToMsgID(mp.MsgID) if err != nil { - log.Debug("convert messageID error", zap.String("error = ", err.Error())) return err } err = consumer.Seek(seekMsgID) if err != nil { - log.Debug("seek error ", zap.String("error = ", err.Error())) return err } return nil } - err = Retry(20, time.Millisecond*200, fn) - if err != nil { - errMsg := "Failed to seek, error = " + err.Error() - panic(errMsg) - } - ms.addConsumer(consumer, seekChannel) - //TODO: May cause problem - //if len(consumer.Chan()) == 0 { - // return nil - //} + ms.consumerLock.Lock() + defer ms.consumerLock.Unlock() - for { - select { - case <-ms.ctx.Done(): - return nil - case msg, ok := <-consumer.Chan(): - if !ok { - return errors.New("consumer closed") - } - consumer.Ack(msg) + for idx := range msgPositions { + mp = msgPositions[idx] + if len(mp.MsgID) == 0 { + return fmt.Errorf("when msgID's length equal to 0, please use AsConsumer interface") + } - headerMsg := commonpb.MsgHeader{} - err := proto.Unmarshal(msg.Payload(), &headerMsg) - if err != nil { - log.Error("Failed to unmarshal message header", zap.Error(err)) - } - tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType) - if err != nil { - log.Error("Failed to unmarshal tsMsg", zap.Error(err)) - } - if tsMsg.Type() == commonpb.MsgType_TimeTick { - if tsMsg.BeginTs() >= mp.Timestamp { - return nil + if err = Retry(20, time.Millisecond*200, fn); err != nil { + return fmt.Errorf("Failed to seek, error %s", err.Error()) + } + ms.addConsumer(consumer, mp.ChannelName) + + //TODO: May cause problem + //if len(consumer.Chan()) == 0 { + // return nil + //} + + runLoop := true + for runLoop { + select { + case <-ms.ctx.Done(): + return nil + case msg, ok := <-consumer.Chan(): + if !ok { + return fmt.Errorf("consumer closed") + } + consumer.Ack(msg) + + headerMsg := commonpb.MsgHeader{} + err := proto.Unmarshal(msg.Payload(), &headerMsg) + if err != nil { + return fmt.Errorf("Failed to unmarshal message header, err %s", err.Error()) + } + tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType) + if err != nil { + return fmt.Errorf("Failed to unmarshal tsMsg, err %s", err.Error()) + } + if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp { + runLoop = false + break + } else if tsMsg.BeginTs() > mp.Timestamp { + tsMsg.SetPosition(&MsgPosition{ + ChannelName: filepath.Base(msg.Topic()), + MsgID: msg.ID().Serialize(), + }) + ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) } - continue - } - if tsMsg.BeginTs() > mp.Timestamp { - tsMsg.SetPosition(&MsgPosition{ - ChannelName: filepath.Base(msg.Topic()), - MsgID: msg.ID().Serialize(), - }) - ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) } } } + return nil } //TODO test InMemMsgStream diff --git a/internal/msgstream/mq_msgstream_test.go b/internal/msgstream/mq_msgstream_test.go index b1d52b3fb9..c720c3be3f 100644 --- a/internal/msgstream/mq_msgstream_test.go +++ b/internal/msgstream/mq_msgstream_test.go @@ -246,13 +246,8 @@ func getPulsarTtOutputStreamAndSeek(pulsarAddress string, positions []*MsgPositi factory := ProtoUDFactory{} pulsarClient, _ := mqclient.NewPulsarClient(pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqTtMsgStream(context.Background(), 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - //outputStream.AsConsumer(consumerChannels, consumerSubName) - for _, pos := range positions { - pos.MsgGroup = funcutil.RandomString(4) - outputStream.Seek(pos) - } + outputStream.Seek(positions) outputStream.Start() - //outputStream.Start() return outputStream } diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index e4715cc2fb..035ba68f46 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -45,7 +45,7 @@ type MsgStream interface { Produce(*MsgPack) error Broadcast(*MsgPack) error Consume() *MsgPack - Seek(offset *MsgPosition) error + Seek(offset []*MsgPosition) error } type Factory interface { diff --git a/internal/msgstream/msgstream_mock.go b/internal/msgstream/msgstream_mock.go index 5f0d7b1bbb..288fd55c18 100644 --- a/internal/msgstream/msgstream_mock.go +++ b/internal/msgstream/msgstream_mock.go @@ -90,7 +90,7 @@ func (ms *SimpleMsgStream) Consume() *MsgPack { return <-ms.msgChan } -func (ms *SimpleMsgStream) Seek(offset *MsgPosition) error { +func (ms *SimpleMsgStream) Seek(offset []*MsgPosition) error { return nil } diff --git a/internal/querynode/data_sync_service.go b/internal/querynode/data_sync_service.go index b49f181f37..30322ae2c3 100644 --- a/internal/querynode/data_sync_service.go +++ b/internal/querynode/data_sync_service.go @@ -131,7 +131,7 @@ func (dsService *dataSyncService) initNodes() { } func (dsService *dataSyncService) seekSegment(position *internalpb.MsgPosition) error { - err := dsService.dmStream.Seek(position) + err := dsService.dmStream.Seek([]*internalpb.MsgPosition{position}) if err != nil { return err } diff --git a/internal/querynode/task.go b/internal/querynode/task.go index f6fb7063ad..b1896811e0 100644 --- a/internal/querynode/task.go +++ b/internal/querynode/task.go @@ -157,7 +157,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error { ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName) for _, pos := range toSeekInfo { - err := ds.dmStream.Seek(pos) + err := ds.dmStream.Seek([]*internalpb.MsgPosition{pos}) if err != nil { errMsg := "msgStream seek error :" + err.Error() log.Error(errMsg)