From 821750d902faa7a28189fc13377168d5c621e8f0 Mon Sep 17 00:00:00 2001 From: SimFG Date: Thu, 17 Oct 2024 12:53:29 +0800 Subject: [PATCH] enhance: force to stop buffer message when receiving the drop collection message (#36916) /kind improvement Signed-off-by: SimFG --- pkg/mq/msgstream/mq_msgstream.go | 7 ++- pkg/mq/msgstream/mq_msgstream_test.go | 62 +++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index bf10bb3b92..09d7121985 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -679,8 +679,9 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { startBufTime := time.Now() var endTs uint64 var size uint64 + var containsDropCollectionMsg bool - for ms.continueBuffering(endTs, size, startBufTime) { + for ms.continueBuffering(endTs, size, startBufTime) && !containsDropCollectionMsg { ms.consumerLock.Lock() // wait all channels get ttMsg for _, consumer := range ms.consumers { @@ -722,6 +723,10 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { } else { tempBuffer = append(tempBuffer, v) } + // when drop collection, force to exit the buffer loop + if v.Type() == commonpb.MsgType_DropCollection { + containsDropCollectionMsg = true + } } ms.chanMsgBuf[consumer] = tempBuffer diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index 968d455b77..9870a14422 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -725,6 +725,46 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { outputStream.Close() } +func TestStream_PulsarTtMsgStream_DropCollection(t *testing.T) { + pulsarAddress := getPulsarAddress() + c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) + producerChannels := []string{c1, c2} + consumerChannels := []string{c1, c2} + consumerSubName := funcutil.RandomString(8) + + msgPack0 := MsgPack{} + msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0)) + + msgPack1 := MsgPack{} + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) + + msgPack2 := MsgPack{} + msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_DropCollection, 3)) + + msgPack3 := MsgPack{} + msgPack3.Msgs = append(msgPack3.Msgs, getTimeTickMsg(5)) + + ctx := context.Background() + inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) + outputStream := getPulsarTtOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) + + _, err := inputStream.Broadcast(&msgPack0) + require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) + + err = inputStream.Produce(&msgPack1) + require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + + _, err = inputStream.Broadcast(&msgPack2) + require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) + + _, err = inputStream.Broadcast(&msgPack3) + require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) + + receiveMsg(ctx, outputStream, 2) + inputStream.Close() + outputStream.Close() +} + func createRandMsgPacks(msgsInPack int, numOfMsgPack int, deltaTs int) []*MsgPack { msgPacks := make([]*MsgPack, numOfMsgPack) @@ -1325,6 +1365,28 @@ func getTsMsg(msgType MsgType, reqID UniqueID) TsMsg { CreateCollectionRequest: createCollectionRequest, } return createCollectionMsg + case commonpb.MsgType_DropCollection: + dropCollectionRequest := &msgpb.DropCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropCollection, + MsgID: reqID, + Timestamp: time, + SourceID: reqID, + }, + DbName: "test_db", + CollectionName: "test_collection", + DbID: 4, + CollectionID: 5, + } + dropCollectionMsg := &DropCollectionMsg{ + BaseMsg: BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{hashValue}, + }, + DropCollectionRequest: dropCollectionRequest, + } + return dropCollectionMsg case commonpb.MsgType_TimeTick: timeTickResult := &msgpb.TimeTickMsg{ Base: &commonpb.MsgBase{