From bdd39c0623a153a478e3579a0f10dea204cbbaf5 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Mon, 29 Nov 2021 14:31:18 +0800 Subject: [PATCH] Fix bug: check message payload before unmarshaling (#12315) Signed-off-by: zhenshan.cao --- internal/msgstream/mq_msgstream.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/internal/msgstream/mq_msgstream.go b/internal/msgstream/mq_msgstream.go index 2c33d76cf4..5bcb2cc1d6 100644 --- a/internal/msgstream/mq_msgstream.go +++ b/internal/msgstream/mq_msgstream.go @@ -25,7 +25,6 @@ import ( "time" "github.com/golang/protobuf/proto" - "github.com/opentracing/opentracing-go" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" @@ -34,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/util/mqclient" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/trace" + "github.com/opentracing/opentracing-go" ) var _ MsgStream = (*mqMsgStream)(nil) @@ -485,13 +485,19 @@ func (ms *mqMsgStream) Consume() *MsgPack { func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqclient.Message) (TsMsg, error) { header := commonpb.MsgHeader{} + if msg.Payload() == nil { + return nil, fmt.Errorf("failed to unmarshal message header, payload is empty") + } err := proto.Unmarshal(msg.Payload(), &header) if err != nil { - return nil, fmt.Errorf("Failed to unmarshal message header, err %s", err.Error()) + return nil, fmt.Errorf("failed to unmarshal message header, err %s", err.Error()) + } + if header.Base == nil { + return nil, fmt.Errorf("failed to unmarshal message, header is uncomplete") } tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), header.Base.MsgType) if err != nil { - return nil, fmt.Errorf("Failed to unmarshal tsMsg, err %s", err.Error()) + return nil, fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error()) } // set msg info to tsMsg @@ -515,7 +521,10 @@ func (ms *mqMsgStream) receiveMsg(consumer mqclient.Consumer) { return } consumer.Ack(msg) - + if msg.Payload() == nil { + log.Warn("MqMsgStream get msg whose payload is nil") + continue + } tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) if err != nil { log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) @@ -579,6 +588,10 @@ func (ms *mqMsgStream) Next(ctx context.Context, channelName string) (TsMsg, err if err != nil { return nil, err } + if msg.Payload() == nil { + log.Warn("mqMsgStream reader Next get msg whose payload is nil") + return nil, nil + } tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) if err != nil { log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) @@ -868,6 +881,10 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqclient.Consumer) { } consumer.Ack(msg) + if msg.Payload() == nil { + log.Warn("MqTtMsgStream get msg whose payload is nil") + continue + } tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) if err != nil { log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err))