mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: streaming consumer may get stucked when handler is un-consumed (#36818)
issue: #36378 Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
8905b042f1
commit
f0f5147aef
@ -74,10 +74,11 @@ func (rc *resumableConsumerImpl) resumeLoop() {
|
||||
// consumer need to resume when error occur, so message handler shouldn't close if the internal consumer encounter failure.
|
||||
nopCloseMH := nopCloseHandler{
|
||||
Handler: rc.mh,
|
||||
HandleInterceptor: func(msg message.ImmutableMessage, handle func(message.ImmutableMessage)) {
|
||||
HandleInterceptor: func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error) {
|
||||
g := rc.metrics.StartConsume(msg.EstimateSize())
|
||||
handle(msg)
|
||||
ok, err := handle(ctx, msg)
|
||||
g.Finish()
|
||||
return ok, err
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ func TestResumableConsumer(t *testing.T) {
|
||||
rc := NewResumableConsumer(func(ctx context.Context, opts *handler.ConsumerOptions) (consumer.Consumer, error) {
|
||||
if i == 0 {
|
||||
i++
|
||||
opts.MessageHandler.Handle(message.NewImmutableMesasge(
|
||||
ok, err := opts.MessageHandler.Handle(context.Background(), message.NewImmutableMesasge(
|
||||
walimplstest.NewTestMessageID(123),
|
||||
[]byte("payload"),
|
||||
map[string]string{
|
||||
@ -36,6 +36,8 @@ func TestResumableConsumer(t *testing.T) {
|
||||
"_v": "1",
|
||||
"_lc": walimplstest.NewTestMessageID(123).Marshal(),
|
||||
}))
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
return c, nil
|
||||
} else if i == 1 {
|
||||
i++
|
||||
@ -76,7 +78,7 @@ func TestHandler(t *testing.T) {
|
||||
hNop := nopCloseHandler{
|
||||
Handler: message.ChanMessageHandler(ch),
|
||||
}
|
||||
hNop.Handle(nil)
|
||||
hNop.Handle(context.Background(), nil)
|
||||
assert.Nil(t, <-ch)
|
||||
hNop.Close()
|
||||
select {
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
package consumer
|
||||
|
||||
import "github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
)
|
||||
|
||||
type handleFunc func(ctx context.Context, msg message.ImmutableMessage) (bool, error)
|
||||
|
||||
// nopCloseHandler is a handler that do nothing when close.
|
||||
type nopCloseHandler struct {
|
||||
message.Handler
|
||||
HandleInterceptor func(msg message.ImmutableMessage, handle func(message.ImmutableMessage))
|
||||
HandleInterceptor func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error)
|
||||
}
|
||||
|
||||
// Handle is the callback for handling message.
|
||||
func (nch nopCloseHandler) Handle(msg message.ImmutableMessage) {
|
||||
func (nch nopCloseHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) {
|
||||
if nch.HandleInterceptor != nil {
|
||||
nch.HandleInterceptor(msg, nch.Handler.Handle)
|
||||
return
|
||||
return nch.HandleInterceptor(ctx, msg, nch.Handler.Handle)
|
||||
}
|
||||
nch.Handler.Handle(msg)
|
||||
return nch.Handler.Handle(ctx, msg)
|
||||
}
|
||||
|
||||
// Close is called after all messages are handled or handling is interrupted.
|
||||
|
||||
@ -1,24 +1,28 @@
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
)
|
||||
|
||||
// timeTickOrderMessageHandler is a message handler that will do metrics and record the last sent message id.
|
||||
// timeTickOrderMessageHandler is a message handler that will record the last sent message id.
|
||||
type timeTickOrderMessageHandler struct {
|
||||
inner message.Handler
|
||||
lastConfirmedMessageID message.MessageID
|
||||
lastTimeTick uint64
|
||||
}
|
||||
|
||||
func (mh *timeTickOrderMessageHandler) Handle(msg message.ImmutableMessage) {
|
||||
func (mh *timeTickOrderMessageHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) {
|
||||
lastConfirmedMessageID := msg.LastConfirmedMessageID()
|
||||
timetick := msg.TimeTick()
|
||||
|
||||
mh.inner.Handle(msg)
|
||||
|
||||
mh.lastConfirmedMessageID = lastConfirmedMessageID
|
||||
mh.lastTimeTick = timetick
|
||||
ok, err := mh.inner.Handle(ctx, msg)
|
||||
if ok {
|
||||
mh.lastConfirmedMessageID = lastConfirmedMessageID
|
||||
mh.lastTimeTick = timetick
|
||||
}
|
||||
return ok, err
|
||||
}
|
||||
|
||||
func (mh *timeTickOrderMessageHandler) Close() {
|
||||
|
||||
@ -40,7 +40,7 @@ func CreateConsumer(
|
||||
opts *ConsumerOptions,
|
||||
handlerClient streamingpb.StreamingNodeHandlerServiceClient,
|
||||
) (Consumer, error) {
|
||||
ctx, err := createConsumeRequest(ctx, opts)
|
||||
ctxWithReq, err := createConsumeRequest(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -48,7 +48,7 @@ func CreateConsumer(
|
||||
// TODO: configurable or auto adjust grpc.MaxCallRecvMsgSize
|
||||
// The messages are always managed by milvus cluster, so the size of message shouldn't be controlled here
|
||||
// to avoid infinitely blocks.
|
||||
streamClient, err := handlerClient.Consume(ctx, grpc.MaxCallRecvMsgSize(math.MaxInt32))
|
||||
streamClient, err := handlerClient.Consume(ctxWithReq, grpc.MaxCallRecvMsgSize(math.MaxInt32))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -64,6 +64,7 @@ func CreateConsumer(
|
||||
return nil, status.NewInvalidRequestSeq("first message arrive must be create response")
|
||||
}
|
||||
cli := &consumerImpl{
|
||||
ctx: ctx,
|
||||
walName: createResp.GetWalName(),
|
||||
assignment: *opts.Assignment,
|
||||
grpcStreamClient: streamClient,
|
||||
@ -93,6 +94,7 @@ func createConsumeRequest(ctx context.Context, opts *ConsumerOptions) (context.C
|
||||
}
|
||||
|
||||
type consumerImpl struct {
|
||||
ctx context.Context // TODO: the cancel method of consumer should be managed by consumerImpl, fix it in future.
|
||||
walName string
|
||||
assignment types.PChannelInfoAssigned
|
||||
grpcStreamClient streamingpb.StreamingNodeHandlerService_ConsumeClient
|
||||
@ -177,12 +179,17 @@ func (c *consumerImpl) recvLoop() (err error) {
|
||||
resp.Consume.GetMessage().GetProperties(),
|
||||
)
|
||||
if newImmutableMsg.TxnContext() != nil {
|
||||
c.handleTxnMessage(newImmutableMsg)
|
||||
if err := c.handleTxnMessage(newImmutableMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if c.txnBuilder != nil {
|
||||
panic("unreachable code: txn builder should be nil if we receive a non-txn message")
|
||||
}
|
||||
c.msgHandler.Handle(newImmutableMsg)
|
||||
if _, err := c.msgHandler.Handle(c.ctx, newImmutableMsg); err != nil {
|
||||
c.logger.Warn("message handle canceled", zap.Error(err))
|
||||
return errors.Wrapf(err, "At Handler")
|
||||
}
|
||||
}
|
||||
case *streamingpb.ConsumeResponse_Close:
|
||||
// Should receive io.EOF after that.
|
||||
@ -193,7 +200,7 @@ func (c *consumerImpl) recvLoop() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) {
|
||||
func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) error {
|
||||
switch msg.MessageType() {
|
||||
case message.MessageTypeBeginTxn:
|
||||
if c.txnBuilder != nil {
|
||||
@ -202,7 +209,7 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) {
|
||||
beginMsg, err := message.AsImmutableBeginTxnMessageV2(msg)
|
||||
if err != nil {
|
||||
c.logger.Warn("failed to convert message to begin txn message", zap.Any("messageID", beginMsg.MessageID()), zap.Error(err))
|
||||
return
|
||||
return nil
|
||||
}
|
||||
c.txnBuilder = message.NewImmutableTxnMessageBuilder(beginMsg)
|
||||
case message.MessageTypeCommitTxn:
|
||||
@ -213,19 +220,23 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) {
|
||||
if err != nil {
|
||||
c.logger.Warn("failed to convert message to commit txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err))
|
||||
c.txnBuilder = nil
|
||||
return
|
||||
return nil
|
||||
}
|
||||
msg, err := c.txnBuilder.Build(commitMsg)
|
||||
c.txnBuilder = nil
|
||||
if err != nil {
|
||||
c.logger.Warn("failed to build txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err))
|
||||
return
|
||||
return nil
|
||||
}
|
||||
if _, err := c.msgHandler.Handle(c.ctx, msg); err != nil {
|
||||
c.logger.Warn("message handle canceled at txn", zap.Error(err))
|
||||
return errors.Wrap(err, "At Handler Of Txn")
|
||||
}
|
||||
c.msgHandler.Handle(msg)
|
||||
default:
|
||||
if c.txnBuilder == nil {
|
||||
panic("unreachable code: txn builder should not be nil if we receive a non-begin txn message")
|
||||
}
|
||||
c.txnBuilder.Add(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -21,6 +21,101 @@ import (
|
||||
)
|
||||
|
||||
func TestConsumer(t *testing.T) {
|
||||
resultCh := make(message.ChanMessageHandler, 1)
|
||||
c := newMockedConsumerImpl(t, context.Background(), resultCh)
|
||||
|
||||
mmsg, _ := message.NewInsertMessageBuilderV1().
|
||||
WithHeader(&message.InsertMessageHeader{}).
|
||||
WithBody(&msgpb.InsertRequest{}).
|
||||
WithVChannel("test-1").
|
||||
BuildMutable()
|
||||
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)
|
||||
|
||||
msg := <-resultCh
|
||||
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1)))
|
||||
|
||||
txnCtx := message.TxnContext{
|
||||
TxnID: 1,
|
||||
Keepalive: time.Second,
|
||||
}
|
||||
mmsg, _ = message.NewBeginTxnMessageBuilderV2().
|
||||
WithVChannel("test-1").
|
||||
WithHeader(&message.BeginTxnMessageHeader{}).
|
||||
WithBody(&message.BeginTxnMessageBody{}).
|
||||
BuildMutable()
|
||||
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx))
|
||||
|
||||
mmsg, _ = message.NewInsertMessageBuilderV1().
|
||||
WithVChannel("test-1").
|
||||
WithHeader(&message.InsertMessageHeader{}).
|
||||
WithBody(&msgpb.InsertRequest{}).
|
||||
BuildMutable()
|
||||
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx))
|
||||
|
||||
mmsg, _ = message.NewCommitTxnMessageBuilderV2().
|
||||
WithVChannel("test-1").
|
||||
WithHeader(&message.CommitTxnMessageHeader{}).
|
||||
WithBody(&message.CommitTxnMessageBody{}).
|
||||
BuildMutable()
|
||||
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx))
|
||||
|
||||
msg = <-resultCh
|
||||
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4)))
|
||||
assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID)
|
||||
assert.Equal(t, message.MessageTypeTxn, msg.MessageType())
|
||||
|
||||
c.consumer.Close()
|
||||
<-c.consumer.Done()
|
||||
assert.NoError(t, c.consumer.Error())
|
||||
}
|
||||
|
||||
func TestConsumerWithCancellation(t *testing.T) {
|
||||
resultCh := make(message.ChanMessageHandler, 1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := newMockedConsumerImpl(t, ctx, resultCh)
|
||||
|
||||
mmsg, _ := message.NewInsertMessageBuilderV1().
|
||||
WithHeader(&message.InsertMessageHeader{}).
|
||||
WithBody(&msgpb.InsertRequest{}).
|
||||
WithVChannel("test-1").
|
||||
BuildMutable()
|
||||
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)
|
||||
// The recv goroutinue will be blocked until the context is canceled.
|
||||
mmsg, _ = message.NewInsertMessageBuilderV1().
|
||||
WithHeader(&message.InsertMessageHeader{}).
|
||||
WithBody(&msgpb.InsertRequest{}).
|
||||
WithVChannel("test-1").
|
||||
BuildMutable()
|
||||
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)
|
||||
|
||||
// The background recv loop should be started.
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
go func() {
|
||||
c.consumer.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-c.consumer.Done():
|
||||
panic("should not reach here")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-c.consumer.Done():
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
panic("should not reach here")
|
||||
}
|
||||
assert.ErrorIs(t, c.consumer.Error(), context.Canceled)
|
||||
}
|
||||
|
||||
type mockedConsumer struct {
|
||||
consumer Consumer
|
||||
recvCh chan *streamingpb.ConsumeResponse
|
||||
}
|
||||
|
||||
func newMockedConsumerImpl(t *testing.T, ctx context.Context, h message.Handler) *mockedConsumer {
|
||||
c := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t)
|
||||
cc := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeClient(t)
|
||||
recvCh := make(chan *streamingpb.ConsumeResponse, 10)
|
||||
@ -43,8 +138,6 @@ func TestConsumer(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
resultCh := make(message.ChanMessageHandler, 1)
|
||||
opts := &ConsumerOptions{
|
||||
Assignment: &types.PChannelInfoAssigned{
|
||||
Channel: types.PChannelInfo{Name: "test", Term: 1},
|
||||
@ -55,7 +148,7 @@ func TestConsumer(t *testing.T) {
|
||||
options.DeliverFilterVChannel("test-1"),
|
||||
options.DeliverFilterTimeTickGT(100),
|
||||
},
|
||||
MessageHandler: resultCh,
|
||||
MessageHandler: h,
|
||||
}
|
||||
|
||||
recvCh <- &streamingpb.ConsumeResponse{
|
||||
@ -65,53 +158,15 @@ func TestConsumer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mmsg, _ := message.NewInsertMessageBuilderV1().
|
||||
WithHeader(&message.InsertMessageHeader{}).
|
||||
WithBody(&msgpb.InsertRequest{}).
|
||||
WithVChannel("test-1").
|
||||
BuildMutable()
|
||||
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)
|
||||
|
||||
consumer, err := CreateConsumer(ctx, opts, c)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, consumer)
|
||||
msg := <-resultCh
|
||||
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1)))
|
||||
|
||||
txnCtx := message.TxnContext{
|
||||
TxnID: 1,
|
||||
Keepalive: time.Second,
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mmsg, _ = message.NewBeginTxnMessageBuilderV2().
|
||||
WithVChannel("test-1").
|
||||
WithHeader(&message.BeginTxnMessageHeader{}).
|
||||
WithBody(&message.BeginTxnMessageBody{}).
|
||||
BuildMutable()
|
||||
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx))
|
||||
|
||||
mmsg, _ = message.NewInsertMessageBuilderV1().
|
||||
WithVChannel("test-1").
|
||||
WithHeader(&message.InsertMessageHeader{}).
|
||||
WithBody(&msgpb.InsertRequest{}).
|
||||
BuildMutable()
|
||||
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx))
|
||||
|
||||
mmsg, _ = message.NewCommitTxnMessageBuilderV2().
|
||||
WithVChannel("test-1").
|
||||
WithHeader(&message.CommitTxnMessageHeader{}).
|
||||
WithBody(&message.CommitTxnMessageBody{}).
|
||||
BuildMutable()
|
||||
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx))
|
||||
|
||||
msg = <-resultCh
|
||||
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4)))
|
||||
assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID)
|
||||
assert.Equal(t, message.MessageTypeTxn, msg.MessageType())
|
||||
|
||||
consumer.Close()
|
||||
<-consumer.Done()
|
||||
assert.NoError(t, consumer.Error())
|
||||
return &mockedConsumer{
|
||||
consumer: consumer,
|
||||
recvCh: recvCh,
|
||||
}
|
||||
}
|
||||
|
||||
func newConsumeResponse(id message.MessageID, msg message.MutableMessage) *streamingpb.ConsumeResponse {
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
@ -27,12 +29,17 @@ func (m *MsgPackAdaptorHandler) Chan() <-chan *msgstream.MsgPack {
|
||||
}
|
||||
|
||||
// Handle is the callback for handling message.
|
||||
func (m *MsgPackAdaptorHandler) Handle(msg message.ImmutableMessage) {
|
||||
func (m *MsgPackAdaptorHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) {
|
||||
m.base.GenerateMsgPack(msg)
|
||||
for m.base.PendingMsgPack.Len() > 0 {
|
||||
m.base.Channel <- m.base.PendingMsgPack.Next()
|
||||
m.base.PendingMsgPack.UnsafeAdvance()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true, ctx.Err()
|
||||
case m.base.Channel <- m.base.PendingMsgPack.Next():
|
||||
m.base.PendingMsgPack.UnsafeAdvance()
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Close is the callback for closing message.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -26,7 +27,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) {
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
h.Handle(insertImmutableMessage)
|
||||
ok, err := h.Handle(context.Background(), insertImmutableMessage)
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
msgPack := <-ch
|
||||
|
||||
assert.Equal(t, uint64(10), msgPack.BeginTs)
|
||||
@ -60,7 +63,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) {
|
||||
WithLastConfirmedUseMessageID().
|
||||
IntoImmutableMessage(id)
|
||||
|
||||
h.Handle(deleteImmutableMsg)
|
||||
ok, err = h.Handle(context.Background(), deleteImmutableMsg)
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
msgPack = <-ch
|
||||
assert.Equal(t, uint64(11), msgPack.BeginTs)
|
||||
assert.Equal(t, uint64(11), msgPack.EndTs)
|
||||
@ -114,7 +119,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) {
|
||||
Build(commitImmutableMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
h.Handle(txn)
|
||||
ok, err = h.Handle(context.Background(), txn)
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
msgPack = <-ch
|
||||
|
||||
assert.Equal(t, uint64(12), msgPack.BeginTs)
|
||||
@ -133,7 +140,9 @@ func TestMsgPackAdaptorHandler(t *testing.T) {
|
||||
WithLastConfirmedUseMessageID().
|
||||
IntoImmutableMessage(rmq.NewRmqID(4))
|
||||
|
||||
h.Handle(flushMsg)
|
||||
ok, err = h.Handle(context.Background(), flushMsg)
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msgPack = <-ch
|
||||
|
||||
@ -143,3 +152,18 @@ func TestMsgPackAdaptorHandler(t *testing.T) {
|
||||
h.Close()
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestMsgPackAdaptorHandlerTimeout(t *testing.T) {
|
||||
id := rmq.NewRmqID(1)
|
||||
|
||||
insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id)
|
||||
insertImmutableMessage := insertMsg.IntoImmutableMessage(id)
|
||||
|
||||
h := NewMsgPackAdaptorHandler()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
ok, err := h.Handle(ctx, insertImmutableMessage)
|
||||
assert.True(t, ok)
|
||||
assert.ErrorIs(t, err, ctx.Err())
|
||||
}
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
package message
|
||||
|
||||
import "context"
|
||||
|
||||
// Handler is used to handle message read from log.
|
||||
type Handler interface {
|
||||
// Handle is the callback for handling message.
|
||||
Handle(msg ImmutableMessage)
|
||||
// Return true if the message is consumed, false if the message is not consumed.
|
||||
// Should return error if and only if ctx is done.
|
||||
// !!! It's a bad implementation for compatibility for msgstream,
|
||||
// should be removed in the future.
|
||||
Handle(ctx context.Context, msg ImmutableMessage) (bool, error)
|
||||
|
||||
// Close is called after all messages are handled or handling is interrupted.
|
||||
Close()
|
||||
@ -15,8 +21,13 @@ var _ Handler = ChanMessageHandler(nil)
|
||||
type ChanMessageHandler chan ImmutableMessage
|
||||
|
||||
// Handle is the callback for handling message.
|
||||
func (cmh ChanMessageHandler) Handle(msg ImmutableMessage) {
|
||||
cmh <- msg
|
||||
func (cmh ChanMessageHandler) Handle(ctx context.Context, msg ImmutableMessage) (bool, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
case cmh <- msg:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close is called after all messages are handled or handling is interrupted.
|
||||
|
||||
@ -1,17 +1,27 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMessageHandler(t *testing.T) {
|
||||
ch := make(chan ImmutableMessage, 100)
|
||||
ch := make(chan ImmutableMessage, 1)
|
||||
h := ChanMessageHandler(ch)
|
||||
h.Handle(nil)
|
||||
ok, err := h.Handle(context.Background(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
ok, err = h.Handle(ctx, nil)
|
||||
assert.ErrorIs(t, err, ctx.Err())
|
||||
assert.False(t, ok)
|
||||
|
||||
assert.Nil(t, <-ch)
|
||||
h.Close()
|
||||
_, ok := <-ch
|
||||
_, ok = <-ch
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user