diff --git a/internal/mq/msgstream/mq_msgstream.go b/internal/mq/msgstream/mq_msgstream.go index 709e0bc968..9e25338e14 100644 --- a/internal/mq/msgstream/mq_msgstream.go +++ b/internal/mq/msgstream/mq_msgstream.go @@ -50,7 +50,7 @@ type mqMsgStream struct { repackFunc RepackFunc unmarshal UnmarshalDispatcher receiveBuf chan *MsgPack - wait *sync.WaitGroup + closeRWMutex *sync.RWMutex streamCancel func() bufSize int64 producerLock *sync.Mutex @@ -87,7 +87,7 @@ func NewMqMsgStream(ctx context.Context, streamCancel: streamCancel, producerLock: &sync.Mutex{}, consumerLock: &sync.Mutex{}, - wait: &sync.WaitGroup{}, + closeRWMutex: &sync.RWMutex{}, closed: 0, } @@ -184,8 +184,11 @@ func (ms *mqMsgStream) Start() { func (ms *mqMsgStream) Close() { ms.streamCancel() - ms.wait.Wait() - + ms.closeRWMutex.Lock() + defer ms.closeRWMutex.Unlock() + if !atomic.CompareAndSwapInt32(&ms.closed, 0, 1) { + return + } for _, producer := range ms.producers { if producer != nil { producer.Close() @@ -199,10 +202,8 @@ func (ms *mqMsgStream) Close() { ms.client.Close() - if !atomic.CompareAndSwapInt32(&ms.closed, 0, 1) { - return - } close(ms.receiveBuf) + } func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 { @@ -472,7 +473,11 @@ func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqwrapper.Message) (TsMsg, er } func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { - defer ms.wait.Done() + ms.closeRWMutex.RLock() + defer ms.closeRWMutex.RUnlock() + if atomic.LoadInt32(&ms.closed) != 0 { + return + } for { select { @@ -524,7 +529,6 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { func (ms *mqMsgStream) Chan() <-chan *MsgPack { ms.onceChan.Do(func() { for _, c := range ms.consumers { - ms.wait.Add(1) go ms.receiveMsg(c) } }) @@ -665,7 +669,11 @@ func (ms *MqTtMsgStream) Close() { } func (ms *MqTtMsgStream) bufMsgPackToChannel() { - defer ms.wait.Done() + ms.closeRWMutex.RLock() + defer ms.closeRWMutex.RUnlock() + if atomic.LoadInt32(&ms.closed) != 0 { + return + } chanTtMsgSync := make(map[mqwrapper.Consumer]bool) // block here until addConsumer @@ -926,7 +934,6 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error { func (ms *MqTtMsgStream) Chan() <-chan *MsgPack { ms.onceChan.Do(func() { if ms.consumers != nil { - ms.wait.Add(1) go ms.bufMsgPackToChannel() } })