diff --git a/internal/distributed/streaming/internal/consumer/consumer_impl.go b/internal/distributed/streaming/internal/consumer/consumer_impl.go index d8488d0387..918c1255bc 100644 --- a/internal/distributed/streaming/internal/consumer/consumer_impl.go +++ b/internal/distributed/streaming/internal/consumer/consumer_impl.go @@ -74,10 +74,11 @@ func (rc *resumableConsumerImpl) resumeLoop() { // consumer need to resume when error occur, so message handler shouldn't close if the internal consumer encounter failure. nopCloseMH := nopCloseHandler{ Handler: rc.mh, - HandleInterceptor: func(msg message.ImmutableMessage, handle func(message.ImmutableMessage)) { + HandleInterceptor: func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error) { g := rc.metrics.StartConsume(msg.EstimateSize()) - handle(msg) + ok, err := handle(ctx, msg) g.Finish() + return ok, err }, } diff --git a/internal/distributed/streaming/internal/consumer/consumer_test.go b/internal/distributed/streaming/internal/consumer/consumer_test.go index 799f404ec4..c6a85792b3 100644 --- a/internal/distributed/streaming/internal/consumer/consumer_test.go +++ b/internal/distributed/streaming/internal/consumer/consumer_test.go @@ -26,7 +26,7 @@ func TestResumableConsumer(t *testing.T) { rc := NewResumableConsumer(func(ctx context.Context, opts *handler.ConsumerOptions) (consumer.Consumer, error) { if i == 0 { i++ - opts.MessageHandler.Handle(message.NewImmutableMesasge( + ok, err := opts.MessageHandler.Handle(context.Background(), message.NewImmutableMesasge( walimplstest.NewTestMessageID(123), []byte("payload"), map[string]string{ @@ -36,6 +36,8 @@ func TestResumableConsumer(t *testing.T) { "_v": "1", "_lc": walimplstest.NewTestMessageID(123).Marshal(), })) + assert.True(t, ok) + assert.NoError(t, err) return c, nil } else if i == 1 { i++ @@ -76,7 +78,7 @@ func TestHandler(t *testing.T) { hNop := nopCloseHandler{ Handler: message.ChanMessageHandler(ch), } - hNop.Handle(nil) + hNop.Handle(context.Background(), nil) assert.Nil(t, <-ch) hNop.Close() select { diff --git a/internal/distributed/streaming/internal/consumer/handler.go b/internal/distributed/streaming/internal/consumer/handler.go index c52e2d33f1..d106b9da4d 100644 --- a/internal/distributed/streaming/internal/consumer/handler.go +++ b/internal/distributed/streaming/internal/consumer/handler.go @@ -1,20 +1,25 @@ package consumer -import "github.com/milvus-io/milvus/pkg/streaming/util/message" +import ( + "context" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" +) + +type handleFunc func(ctx context.Context, msg message.ImmutableMessage) (bool, error) // nopCloseHandler is a handler that do nothing when close. type nopCloseHandler struct { message.Handler - HandleInterceptor func(msg message.ImmutableMessage, handle func(message.ImmutableMessage)) + HandleInterceptor func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error) } // Handle is the callback for handling message. -func (nch nopCloseHandler) Handle(msg message.ImmutableMessage) { +func (nch nopCloseHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) { if nch.HandleInterceptor != nil { - nch.HandleInterceptor(msg, nch.Handler.Handle) - return + return nch.HandleInterceptor(ctx, msg, nch.Handler.Handle) } - nch.Handler.Handle(msg) + return nch.Handler.Handle(ctx, msg) } // Close is called after all messages are handled or handling is interrupted. diff --git a/internal/distributed/streaming/internal/consumer/message_handler.go b/internal/distributed/streaming/internal/consumer/message_handler.go index 811a537a4e..538052ee17 100644 --- a/internal/distributed/streaming/internal/consumer/message_handler.go +++ b/internal/distributed/streaming/internal/consumer/message_handler.go @@ -1,24 +1,28 @@ package consumer import ( + "context" + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) -// timeTickOrderMessageHandler is a message handler that will do metrics and record the last sent message id. +// timeTickOrderMessageHandler is a message handler that will record the last sent message id. type timeTickOrderMessageHandler struct { inner message.Handler lastConfirmedMessageID message.MessageID lastTimeTick uint64 } -func (mh *timeTickOrderMessageHandler) Handle(msg message.ImmutableMessage) { +func (mh *timeTickOrderMessageHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) { lastConfirmedMessageID := msg.LastConfirmedMessageID() timetick := msg.TimeTick() - mh.inner.Handle(msg) - - mh.lastConfirmedMessageID = lastConfirmedMessageID - mh.lastTimeTick = timetick + ok, err := mh.inner.Handle(ctx, msg) + if ok { + mh.lastConfirmedMessageID = lastConfirmedMessageID + mh.lastTimeTick = timetick + } + return ok, err } func (mh *timeTickOrderMessageHandler) Close() { diff --git a/internal/streamingnode/client/handler/consumer/consumer_impl.go b/internal/streamingnode/client/handler/consumer/consumer_impl.go index 2ba136c659..dd8cd92ac8 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_impl.go +++ b/internal/streamingnode/client/handler/consumer/consumer_impl.go @@ -40,7 +40,7 @@ func CreateConsumer( opts *ConsumerOptions, handlerClient streamingpb.StreamingNodeHandlerServiceClient, ) (Consumer, error) { - ctx, err := createConsumeRequest(ctx, opts) + ctxWithReq, err := createConsumeRequest(ctx, opts) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func CreateConsumer( // TODO: configurable or auto adjust grpc.MaxCallRecvMsgSize // The messages are always managed by milvus cluster, so the size of message shouldn't be controlled here // to avoid infinitely blocks. - streamClient, err := handlerClient.Consume(ctx, grpc.MaxCallRecvMsgSize(math.MaxInt32)) + streamClient, err := handlerClient.Consume(ctxWithReq, grpc.MaxCallRecvMsgSize(math.MaxInt32)) if err != nil { return nil, err } @@ -64,6 +64,7 @@ func CreateConsumer( return nil, status.NewInvalidRequestSeq("first message arrive must be create response") } cli := &consumerImpl{ + ctx: ctx, walName: createResp.GetWalName(), assignment: *opts.Assignment, grpcStreamClient: streamClient, @@ -93,6 +94,7 @@ func createConsumeRequest(ctx context.Context, opts *ConsumerOptions) (context.C } type consumerImpl struct { + ctx context.Context // TODO: the cancel method of consumer should be managed by consumerImpl, fix it in future. walName string assignment types.PChannelInfoAssigned grpcStreamClient streamingpb.StreamingNodeHandlerService_ConsumeClient @@ -177,12 +179,17 @@ func (c *consumerImpl) recvLoop() (err error) { resp.Consume.GetMessage().GetProperties(), ) if newImmutableMsg.TxnContext() != nil { - c.handleTxnMessage(newImmutableMsg) + if err := c.handleTxnMessage(newImmutableMsg); err != nil { + return err + } } else { if c.txnBuilder != nil { panic("unreachable code: txn builder should be nil if we receive a non-txn message") } - c.msgHandler.Handle(newImmutableMsg) + if _, err := c.msgHandler.Handle(c.ctx, newImmutableMsg); err != nil { + c.logger.Warn("message handle canceled", zap.Error(err)) + return errors.Wrapf(err, "At Handler") + } } case *streamingpb.ConsumeResponse_Close: // Should receive io.EOF after that. @@ -193,7 +200,7 @@ func (c *consumerImpl) recvLoop() (err error) { } } -func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) { +func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) error { switch msg.MessageType() { case message.MessageTypeBeginTxn: if c.txnBuilder != nil { @@ -202,7 +209,7 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) { beginMsg, err := message.AsImmutableBeginTxnMessageV2(msg) if err != nil { c.logger.Warn("failed to convert message to begin txn message", zap.Any("messageID", beginMsg.MessageID()), zap.Error(err)) - return + return nil } c.txnBuilder = message.NewImmutableTxnMessageBuilder(beginMsg) case message.MessageTypeCommitTxn: @@ -213,19 +220,23 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) { if err != nil { c.logger.Warn("failed to convert message to commit txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err)) c.txnBuilder = nil - return + return nil } msg, err := c.txnBuilder.Build(commitMsg) c.txnBuilder = nil if err != nil { c.logger.Warn("failed to build txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err)) - return + return nil + } + if _, err := c.msgHandler.Handle(c.ctx, msg); err != nil { + c.logger.Warn("message handle canceled at txn", zap.Error(err)) + return errors.Wrap(err, "At Handler Of Txn") } - c.msgHandler.Handle(msg) default: if c.txnBuilder == nil { panic("unreachable code: txn builder should not be nil if we receive a non-begin txn message") } c.txnBuilder.Add(msg) } + return nil } diff --git a/internal/streamingnode/client/handler/consumer/consumer_test.go b/internal/streamingnode/client/handler/consumer/consumer_test.go index 656e92e234..6a634223bb 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_test.go +++ b/internal/streamingnode/client/handler/consumer/consumer_test.go @@ -21,6 +21,101 @@ import ( ) func TestConsumer(t *testing.T) { + resultCh := make(message.ChanMessageHandler, 1) + c := newMockedConsumerImpl(t, context.Background(), resultCh) + + mmsg, _ := message.NewInsertMessageBuilderV1(). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + WithVChannel("test-1"). + BuildMutable() + c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg) + + msg := <-resultCh + assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1))) + + txnCtx := message.TxnContext{ + TxnID: 1, + Keepalive: time.Second, + } + mmsg, _ = message.NewBeginTxnMessageBuilderV2(). + WithVChannel("test-1"). + WithHeader(&message.BeginTxnMessageHeader{}). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx)) + + mmsg, _ = message.NewInsertMessageBuilderV1(). + WithVChannel("test-1"). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + BuildMutable() + c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx)) + + mmsg, _ = message.NewCommitTxnMessageBuilderV2(). + WithVChannel("test-1"). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + BuildMutable() + c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx)) + + msg = <-resultCh + assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4))) + assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID) + assert.Equal(t, message.MessageTypeTxn, msg.MessageType()) + + c.consumer.Close() + <-c.consumer.Done() + assert.NoError(t, c.consumer.Error()) +} + +func TestConsumerWithCancellation(t *testing.T) { + resultCh := make(message.ChanMessageHandler, 1) + ctx, cancel := context.WithCancel(context.Background()) + c := newMockedConsumerImpl(t, ctx, resultCh) + + mmsg, _ := message.NewInsertMessageBuilderV1(). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + WithVChannel("test-1"). + BuildMutable() + c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg) + // The recv goroutinue will be blocked until the context is canceled. + mmsg, _ = message.NewInsertMessageBuilderV1(). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + WithVChannel("test-1"). + BuildMutable() + c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg) + + // The background recv loop should be started. + time.Sleep(20 * time.Millisecond) + + go func() { + c.consumer.Close() + }() + + select { + case <-c.consumer.Done(): + panic("should not reach here") + case <-time.After(10 * time.Millisecond): + } + + cancel() + select { + case <-c.consumer.Done(): + case <-time.After(20 * time.Millisecond): + panic("should not reach here") + } + assert.ErrorIs(t, c.consumer.Error(), context.Canceled) +} + +type mockedConsumer struct { + consumer Consumer + recvCh chan *streamingpb.ConsumeResponse +} + +func newMockedConsumerImpl(t *testing.T, ctx context.Context, h message.Handler) *mockedConsumer { c := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t) cc := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeClient(t) recvCh := make(chan *streamingpb.ConsumeResponse, 10) @@ -43,8 +138,6 @@ func TestConsumer(t *testing.T) { return nil }) - ctx := context.Background() - resultCh := make(message.ChanMessageHandler, 1) opts := &ConsumerOptions{ Assignment: &types.PChannelInfoAssigned{ Channel: types.PChannelInfo{Name: "test", Term: 1}, @@ -55,7 +148,7 @@ func TestConsumer(t *testing.T) { options.DeliverFilterVChannel("test-1"), options.DeliverFilterTimeTickGT(100), }, - MessageHandler: resultCh, + MessageHandler: h, } recvCh <- &streamingpb.ConsumeResponse{ @@ -65,53 +158,15 @@ func TestConsumer(t *testing.T) { }, }, } - - mmsg, _ := message.NewInsertMessageBuilderV1(). - WithHeader(&message.InsertMessageHeader{}). - WithBody(&msgpb.InsertRequest{}). - WithVChannel("test-1"). - BuildMutable() - recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg) - consumer, err := CreateConsumer(ctx, opts, c) - assert.NoError(t, err) - assert.NotNil(t, consumer) - msg := <-resultCh - assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1))) - - txnCtx := message.TxnContext{ - TxnID: 1, - Keepalive: time.Second, + if err != nil { + panic(err) } - mmsg, _ = message.NewBeginTxnMessageBuilderV2(). - WithVChannel("test-1"). - WithHeader(&message.BeginTxnMessageHeader{}). - WithBody(&message.BeginTxnMessageBody{}). - BuildMutable() - recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx)) - mmsg, _ = message.NewInsertMessageBuilderV1(). - WithVChannel("test-1"). - WithHeader(&message.InsertMessageHeader{}). - WithBody(&msgpb.InsertRequest{}). - BuildMutable() - recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx)) - - mmsg, _ = message.NewCommitTxnMessageBuilderV2(). - WithVChannel("test-1"). - WithHeader(&message.CommitTxnMessageHeader{}). - WithBody(&message.CommitTxnMessageBody{}). - BuildMutable() - recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx)) - - msg = <-resultCh - assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4))) - assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID) - assert.Equal(t, message.MessageTypeTxn, msg.MessageType()) - - consumer.Close() - <-consumer.Done() - assert.NoError(t, consumer.Error()) + return &mockedConsumer{ + consumer: consumer, + recvCh: recvCh, + } } func newConsumeResponse(id message.MessageID, msg message.MutableMessage) *streamingpb.ConsumeResponse { diff --git a/pkg/streaming/util/message/adaptor/handler.go b/pkg/streaming/util/message/adaptor/handler.go index d7dc1c97d0..80fd72be07 100644 --- a/pkg/streaming/util/message/adaptor/handler.go +++ b/pkg/streaming/util/message/adaptor/handler.go @@ -1,6 +1,8 @@ package adaptor import ( + "context" + "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" @@ -27,12 +29,17 @@ func (m *MsgPackAdaptorHandler) Chan() <-chan *msgstream.MsgPack { } // Handle is the callback for handling message. -func (m *MsgPackAdaptorHandler) Handle(msg message.ImmutableMessage) { +func (m *MsgPackAdaptorHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) { m.base.GenerateMsgPack(msg) for m.base.PendingMsgPack.Len() > 0 { - m.base.Channel <- m.base.PendingMsgPack.Next() - m.base.PendingMsgPack.UnsafeAdvance() + select { + case <-ctx.Done(): + return true, ctx.Err() + case m.base.Channel <- m.base.PendingMsgPack.Next(): + m.base.PendingMsgPack.UnsafeAdvance() + } } + return true, nil } // Close is the callback for closing message. diff --git a/pkg/streaming/util/message/adaptor/handler_test.go b/pkg/streaming/util/message/adaptor/handler_test.go index 84194d274f..1c5909a079 100644 --- a/pkg/streaming/util/message/adaptor/handler_test.go +++ b/pkg/streaming/util/message/adaptor/handler_test.go @@ -1,6 +1,7 @@ package adaptor import ( + "context" "testing" "time" @@ -26,7 +27,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) { } close(ch) }() - h.Handle(insertImmutableMessage) + ok, err := h.Handle(context.Background(), insertImmutableMessage) + assert.True(t, ok) + assert.NoError(t, err) msgPack := <-ch assert.Equal(t, uint64(10), msgPack.BeginTs) @@ -60,7 +63,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) { WithLastConfirmedUseMessageID(). IntoImmutableMessage(id) - h.Handle(deleteImmutableMsg) + ok, err = h.Handle(context.Background(), deleteImmutableMsg) + assert.True(t, ok) + assert.NoError(t, err) msgPack = <-ch assert.Equal(t, uint64(11), msgPack.BeginTs) assert.Equal(t, uint64(11), msgPack.EndTs) @@ -114,7 +119,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) { Build(commitImmutableMsg) assert.NoError(t, err) - h.Handle(txn) + ok, err = h.Handle(context.Background(), txn) + assert.True(t, ok) + assert.NoError(t, err) msgPack = <-ch assert.Equal(t, uint64(12), msgPack.BeginTs) @@ -133,7 +140,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) { WithLastConfirmedUseMessageID(). IntoImmutableMessage(rmq.NewRmqID(4)) - h.Handle(flushMsg) + ok, err = h.Handle(context.Background(), flushMsg) + assert.True(t, ok) + assert.NoError(t, err) msgPack = <-ch @@ -143,3 +152,18 @@ func TestMsgPackAdaptorHandler(t *testing.T) { h.Close() <-ch } + +func TestMsgPackAdaptorHandlerTimeout(t *testing.T) { + id := rmq.NewRmqID(1) + + insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id) + insertImmutableMessage := insertMsg.IntoImmutableMessage(id) + + h := NewMsgPackAdaptorHandler() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ok, err := h.Handle(ctx, insertImmutableMessage) + assert.True(t, ok) + assert.ErrorIs(t, err, ctx.Err()) +} diff --git a/pkg/streaming/util/message/message_handler.go b/pkg/streaming/util/message/message_handler.go index 93fbab0368..c6b6355c6a 100644 --- a/pkg/streaming/util/message/message_handler.go +++ b/pkg/streaming/util/message/message_handler.go @@ -1,9 +1,15 @@ package message +import "context" + // Handler is used to handle message read from log. type Handler interface { // Handle is the callback for handling message. - Handle(msg ImmutableMessage) + // Return true if the message is consumed, false if the message is not consumed. + // Should return error if and only if ctx is done. + // !!! It's a bad implementation for compatibility for msgstream, + // should be removed in the future. + Handle(ctx context.Context, msg ImmutableMessage) (bool, error) // Close is called after all messages are handled or handling is interrupted. Close() @@ -15,8 +21,13 @@ var _ Handler = ChanMessageHandler(nil) type ChanMessageHandler chan ImmutableMessage // Handle is the callback for handling message. -func (cmh ChanMessageHandler) Handle(msg ImmutableMessage) { - cmh <- msg +func (cmh ChanMessageHandler) Handle(ctx context.Context, msg ImmutableMessage) (bool, error) { + select { + case <-ctx.Done(): + return false, ctx.Err() + case cmh <- msg: + return true, nil + } } // Close is called after all messages are handled or handling is interrupted. diff --git a/pkg/streaming/util/message/message_handler_test.go b/pkg/streaming/util/message/message_handler_test.go index 25757a9597..12b0281022 100644 --- a/pkg/streaming/util/message/message_handler_test.go +++ b/pkg/streaming/util/message/message_handler_test.go @@ -1,17 +1,27 @@ package message import ( + "context" "testing" "github.com/stretchr/testify/assert" ) func TestMessageHandler(t *testing.T) { - ch := make(chan ImmutableMessage, 100) + ch := make(chan ImmutableMessage, 1) h := ChanMessageHandler(ch) - h.Handle(nil) + ok, err := h.Handle(context.Background(), nil) + assert.NoError(t, err) + assert.True(t, ok) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ok, err = h.Handle(ctx, nil) + assert.ErrorIs(t, err, ctx.Err()) + assert.False(t, ok) + assert.Nil(t, <-ch) h.Close() - _, ok := <-ch + _, ok = <-ch assert.False(t, ok) }