Add timeout for reader next (#12308)

Signed-off-by: godchen <qingxiang.chen@zilliz.com>
This commit is contained in:
godchen 2021-11-26 22:45:24 +08:00 committed by GitHub
parent e9be4a81ba
commit f31ed089b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 26 deletions

View File

@ -104,6 +104,9 @@ func (mtm *mockTtMsgStream) SeekReaders(msgPositions []*internalpb.MsgPosition)
func (mtm *mockTtMsgStream) Next(ctx context.Context, channelName string) (msgstream.TsMsg, error) { func (mtm *mockTtMsgStream) Next(ctx context.Context, channelName string) (msgstream.TsMsg, error) {
return nil, nil return nil, nil
} }
func (mtm *mockTtMsgStream) HasNext(channelName string) bool {
return true
}
func TestNewDmInputNode(t *testing.T) { func TestNewDmInputNode(t *testing.T) {
ctx := context.Background() ctx := context.Background()

View File

@ -575,27 +575,29 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err
if !ok { if !ok {
return nil, fmt.Errorf("reader for channel %s is not exist", channelName) return nil, fmt.Errorf("reader for channel %s is not exist", channelName)
} }
if reader.HasNext() { msg, err := reader.Next(ctx)
msg, err := reader.Next(ctx) if err != nil {
if err != nil { return nil, err
return nil, err
}
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
if err != nil {
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
return nil, errors.New("Failed to getTsMsgFromConsumerMsg")
}
pos := tsMsg.Position()
tsMsg.SetPosition(&MsgPosition{
ChannelName: pos.ChannelName,
MsgID: pos.MsgID,
Timestamp: tsMsg.BeginTs(),
})
return tsMsg, nil
} }
log.Debug("All data has been read, there is no more data", zap.String("channel", channelName)) tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
return nil, nil if err != nil {
log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
return nil, errors.New("Failed to getTsMsgFromConsumerMsg")
}
pos := tsMsg.Position()
tsMsg.SetPosition(&MsgPosition{
ChannelName: pos.ChannelName,
MsgID: pos.MsgID,
Timestamp: tsMsg.BeginTs(),
})
return tsMsg, nil
}
func (ms *mqMsgStream) HasNext(channelName string) bool {
reader, ok := ms.readers[channelName]
if !ok {
return false
}
return reader.HasNext()
} }
// Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive // Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive

View File

@ -1296,6 +1296,8 @@ func TestStream_MqMsgStream_Reader(t *testing.T) {
defer readStream.Close() defer readStream.Close()
var seekPosition *internalpb.MsgPosition var seekPosition *internalpb.MsgPosition
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
hasNext := readStream.HasNext(c)
assert.True(t, hasNext)
result, err := readStream.Next(ctx, c) result, err := readStream.Next(ctx, c)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, result.ID(), int64(i)) assert.Equal(t, result.ID(), int64(i))
@ -1303,8 +1305,12 @@ func TestStream_MqMsgStream_Reader(t *testing.T) {
seekPosition = result.Position() seekPosition = result.Position()
} }
} }
result, err := readStream.Next(ctx, c) hasNext := readStream.HasNext(c)
assert.Nil(t, err) assert.False(t, hasNext)
timeoutCtx1, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
result, err := readStream.Next(timeoutCtx1, c)
assert.NotNil(t, err)
assert.Nil(t, result) assert.Nil(t, result)
readStream2 := getPulsarReader(pulsarAddress, readerChannels) readStream2 := getPulsarReader(pulsarAddress, readerChannels)
@ -1312,12 +1318,18 @@ func TestStream_MqMsgStream_Reader(t *testing.T) {
readStream2.SeekReaders([]*internalpb.MsgPosition{seekPosition}) readStream2.SeekReaders([]*internalpb.MsgPosition{seekPosition})
for i := p; i < 10; i++ { for i := p; i < 10; i++ {
hasNext := readStream2.HasNext(c)
assert.True(t, hasNext)
result, err := readStream2.Next(ctx, c) result, err := readStream2.Next(ctx, c)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, result.ID(), int64(i)) assert.Equal(t, result.ID(), int64(i))
} }
result2, err := readStream2.Next(ctx, c) hasNext = readStream2.HasNext(c)
assert.Nil(t, err) assert.False(t, hasNext)
timeoutCtx2, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
result2, err := readStream2.Next(timeoutCtx2, c)
assert.NotNil(t, err)
assert.Nil(t, result2) assert.Nil(t, result2)
} }

View File

@ -69,6 +69,7 @@ type MsgStream interface {
BroadcastMark(*MsgPack) (map[string][]MessageID, error) BroadcastMark(*MsgPack) (map[string][]MessageID, error)
Consume() *MsgPack Consume() *MsgPack
Next(ctx context.Context, channelName string) (TsMsg, error) Next(ctx context.Context, channelName string) (TsMsg, error)
HasNext(channelName string) bool
Seek(offset []*MsgPosition) error Seek(offset []*MsgPosition) error
SeekReaders(msgPositions []*internalpb.MsgPosition) error SeekReaders(msgPositions []*internalpb.MsgPosition) error
} }

View File

@ -288,6 +288,10 @@ func (ms *simpleMockMsgStream) Next(ctx context.Context, channelName string) (ms
return nil, nil return nil, nil
} }
func (ms *simpleMockMsgStream) HasNext(channelName string) bool {
return true
}
func (ms *simpleMockMsgStream) AsConsumerWithPosition(channels []string, subName string, position mqclient.SubscriptionInitialPosition) { func (ms *simpleMockMsgStream) AsConsumerWithPosition(channels []string, subName string, position mqclient.SubscriptionInitialPosition) {
} }

View File

@ -18,6 +18,7 @@ import (
"path" "path"
"strconv" "strconv"
"sync" "sync"
"time"
"go.uber.org/zap" "go.uber.org/zap"
@ -37,6 +38,8 @@ import (
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
) )
const timeoutForEachRead = 10 * time.Second
// segmentLoader is only responsible for loading the field data from binlog // segmentLoader is only responsible for loading the field data from binlog
type segmentLoader struct { type segmentLoader struct {
historicalReplica ReplicaInterface historicalReplica ReplicaInterface
@ -458,24 +461,30 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
deleteOffset: make(map[UniqueID]int64), deleteOffset: make(map[UniqueID]int64),
} }
log.Debug("start read msg from stream reader") log.Debug("start read msg from stream reader")
for { for stream.HasNext(pChannelName) {
ctx, cancel := context.WithTimeout(ctx, timeoutForEachRead)
tsMsg, err := stream.Next(ctx, pChannelName) tsMsg, err := stream.Next(ctx, pChannelName)
if err != nil { if err != nil {
cancel()
return err return err
} }
if tsMsg == nil { if tsMsg == nil {
break cancel()
continue
} }
if tsMsg.Type() == commonpb.MsgType_Delete { if tsMsg.Type() == commonpb.MsgType_Delete {
dmsg := tsMsg.(*msgstream.DeleteMsg) dmsg := tsMsg.(*msgstream.DeleteMsg)
if dmsg.CollectionID != collectionID { if dmsg.CollectionID != collectionID {
cancel()
continue continue
} }
log.Debug("delete pk", zap.Any("pk", dmsg.PrimaryKeys)) log.Debug("delete pk", zap.Any("pk", dmsg.PrimaryKeys))
processDeleteMessages(loader.historicalReplica, dmsg, delData) processDeleteMessages(loader.historicalReplica, dmsg, delData)
} }
cancel()
} }
log.Debug("All data has been read, there is no more data", zap.String("channel", pChannelName))
for segmentID, pks := range delData.deleteIDs { for segmentID, pks := range delData.deleteIDs {
segment, err := loader.historicalReplica.getSegmentByID(segmentID) segment, err := loader.historicalReplica.getSegmentByID(segmentID)
if err != nil { if err != nil {