Fix Wait Group Race in Msgstream (#17931)

Signed-off-by: xiaofan-luan <xiaofan.luan@zilliz.com>
This commit is contained in:
Xiaofan 2022-06-30 10:38:18 +08:00 committed by GitHub
parent 4612e0c7ea
commit a803f9e4b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,7 +50,7 @@ type mqMsgStream struct {
repackFunc RepackFunc
unmarshal UnmarshalDispatcher
receiveBuf chan *MsgPack
wait *sync.WaitGroup
closeRWMutex *sync.RWMutex
streamCancel func()
bufSize int64
producerLock *sync.Mutex
@ -87,7 +87,7 @@ func NewMqMsgStream(ctx context.Context,
streamCancel: streamCancel,
producerLock: &sync.Mutex{},
consumerLock: &sync.Mutex{},
wait: &sync.WaitGroup{},
closeRWMutex: &sync.RWMutex{},
closed: 0,
}
@ -184,8 +184,11 @@ func (ms *mqMsgStream) Start() {
func (ms *mqMsgStream) Close() {
ms.streamCancel()
ms.wait.Wait()
ms.closeRWMutex.Lock()
defer ms.closeRWMutex.Unlock()
if !atomic.CompareAndSwapInt32(&ms.closed, 0, 1) {
return
}
for _, producer := range ms.producers {
if producer != nil {
producer.Close()
@ -199,10 +202,8 @@ func (ms *mqMsgStream) Close() {
ms.client.Close()
if !atomic.CompareAndSwapInt32(&ms.closed, 0, 1) {
return
}
close(ms.receiveBuf)
}
func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 {
@ -472,7 +473,11 @@ func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqwrapper.Message) (TsMsg, er
}
func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) {
defer ms.wait.Done()
ms.closeRWMutex.RLock()
defer ms.closeRWMutex.RUnlock()
if atomic.LoadInt32(&ms.closed) != 0 {
return
}
for {
select {
@ -524,7 +529,6 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) {
func (ms *mqMsgStream) Chan() <-chan *MsgPack {
ms.onceChan.Do(func() {
for _, c := range ms.consumers {
ms.wait.Add(1)
go ms.receiveMsg(c)
}
})
@ -665,7 +669,11 @@ func (ms *MqTtMsgStream) Close() {
}
func (ms *MqTtMsgStream) bufMsgPackToChannel() {
defer ms.wait.Done()
ms.closeRWMutex.RLock()
defer ms.closeRWMutex.RUnlock()
if atomic.LoadInt32(&ms.closed) != 0 {
return
}
chanTtMsgSync := make(map[mqwrapper.Consumer]bool)
// block here until addConsumer
@ -926,7 +934,6 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*internalpb.MsgPosition) error {
func (ms *MqTtMsgStream) Chan() <-chan *MsgPack {
ms.onceChan.Do(func() {
if ms.consumers != nil {
ms.wait.Add(1)
go ms.bufMsgPackToChannel()
}
})