enhance: unmashall ts msg in dispatcher instead in msgstream (#38656)

relate: https://github.com/milvus-io/milvus/issues/38655

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-02-14 12:04:13 +08:00 committed by GitHub
parent 58045a3396
commit 24d2bbc441
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 778 additions and 232 deletions

View File

@ -58,7 +58,7 @@ Analyze(CAnalyze* res_analyze,
auto analyze_info =
std::make_unique<milvus::proto::clustering::AnalyzeInfo>();
auto res = analyze_info->ParseFromArray(serialized_analyze_info, len);
AssertInfo(res, "Unmarshall analyze info failed");
AssertInfo(res, "Unmarshal analyze info failed");
auto field_type =
static_cast<DataType>(analyze_info->field_schema().data_type());
auto field_id = analyze_info->field_schema().fieldid();

View File

@ -161,7 +161,7 @@ CreateIndex(CIndex* res_index,
std::make_unique<milvus::proto::indexcgo::BuildIndexInfo>();
auto res =
build_index_info->ParseFromArray(serialized_build_index_info, len);
AssertInfo(res, "Unmarshall build index info failed");
AssertInfo(res, "Unmarshal build index info failed");
auto field_type =
static_cast<DataType>(build_index_info->field_schema().data_type());
@ -233,7 +233,7 @@ BuildTextIndex(ProtoLayoutInterface result,
std::make_unique<milvus::proto::indexcgo::BuildIndexInfo>();
auto res =
build_index_info->ParseFromArray(serialized_build_index_info, len);
AssertInfo(res, "Unmarshall build index info failed");
AssertInfo(res, "Unmarshal build index info failed");
auto field_type =
static_cast<DataType>(build_index_info->field_schema().data_type());
@ -606,7 +606,7 @@ AppendBuildIndexParam(CBuildIndexInfo c_build_index_info,
auto index_params =
std::make_unique<milvus::proto::indexcgo::IndexParams>();
auto res = index_params->ParseFromArray(serialized_index_params, len);
AssertInfo(res, "Unmarshall index params failed");
AssertInfo(res, "Unmarshal index params failed");
for (auto i = 0; i < index_params->params_size(); ++i) {
const auto& param = index_params->params(i);
build_index_info->config[param.key()] = param.value();
@ -633,7 +633,7 @@ AppendBuildTypeParam(CBuildIndexInfo c_build_index_info,
auto type_params =
std::make_unique<milvus::proto::indexcgo::TypeParams>();
auto res = type_params->ParseFromArray(serialized_type_params, len);
AssertInfo(res, "Unmarshall index build type params failed");
AssertInfo(res, "Unmarshal index build type params failed");
for (auto i = 0; i < type_params->params_size(); ++i) {
const auto& param = type_params->params(i);
build_index_info->config[param.key()] = param.value();

View File

@ -30,7 +30,7 @@ ValidateIndexParams(const char* index_type,
std::make_unique<milvus::proto::indexcgo::IndexParams>();
auto res =
index_params->ParseFromArray(serialized_index_params, length);
AssertInfo(res, "Unmarshall index params failed");
AssertInfo(res, "Unmarshal index params failed");
knowhere::Json json;

View File

@ -308,7 +308,7 @@ type DataSyncServiceSuite struct {
channelCheckpointUpdater *util2.ChannelCheckpointUpdater
factory *dependency.MockFactory
ms *msgstream.MockMsgStream
msChan chan *msgstream.MsgPack
msChan chan *msgstream.ConsumeMsgPack
}
func (s *DataSyncServiceSuite) SetupSuite() {
@ -330,7 +330,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
s.channelCheckpointUpdater = util2.NewChannelCheckpointUpdater(s.broker)
go s.channelCheckpointUpdater.Start()
s.msChan = make(chan *msgstream.MsgPack, 1)
s.msChan = make(chan *msgstream.ConsumeMsgPack, 1)
s.factory = dependency.NewMockFactory(s.T())
s.ms = msgstream.NewMockMsgStream(s.T())
@ -338,6 +338,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
s.ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.ms.EXPECT().Chan().Return(s.msChan)
s.ms.EXPECT().Close().Return()
s.ms.EXPECT().GetUnmarshalDispatcher().Return(nil)
s.pipelineParams = &util2.PipelineParams{
Ctx: context.TODO(),
@ -487,8 +488,8 @@ func (s *DataSyncServiceSuite) TestStartStop() {
close(ch)
return nil
})
s.msChan <- &msgPack
s.msChan <- &timeTickMsgPack
s.msChan <- msgstream.BuildConsumeMsgPack(&msgPack)
s.msChan <- msgstream.BuildConsumeMsgPack(&msgPack)
<-ch
}

View File

@ -67,8 +67,8 @@ func (mtm *mockTtMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
func (mtm *mockTtMsgStream) Close() {}
func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
return make(chan *msgstream.MsgPack, 100)
func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack {
return make(chan *msgstream.ConsumeMsgPack, 100)
}
func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}
@ -77,6 +77,10 @@ func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, s
return nil
}
func (mtm *mockTtMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher {
return nil
}
func (mtm *mockTtMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
func (mtm *mockTtMsgStream) GetProduceChannels() []string {

View File

@ -235,7 +235,7 @@ func newDefaultMockDqlTask() *mockDqlTask {
}
type simpleMockMsgStream struct {
msgChan chan *msgstream.MsgPack
msgChan chan *msgstream.ConsumeMsgPack
msgCount int
msgCountMtx sync.RWMutex
@ -244,7 +244,7 @@ type simpleMockMsgStream struct {
func (ms *simpleMockMsgStream) Close() {
}
func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack {
if ms.getMsgCount() <= 0 {
ms.msgChan <- nil
return ms.msgChan
@ -255,6 +255,10 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
return ms.msgChan
}
func (ms *simpleMockMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher {
return nil
}
func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
}
@ -286,8 +290,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
defer ms.increaseMsgCount(1)
ms.msgChan <- pack
ms.msgChan <- msgstream.BuildConsumeMsgPack(pack)
return nil
}
@ -319,7 +322,7 @@ func (ms *simpleMockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
func newSimpleMockMsgStream() *simpleMockMsgStream {
return &simpleMockMsgStream{
msgChan: make(chan *msgstream.MsgPack, 1024),
msgChan: make(chan *msgstream.ConsumeMsgPack, 1024),
msgCount: 0,
}
}

View File

@ -728,7 +728,15 @@ func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, positio
if err != nil {
return nil, stream.Close, err
}
return stream.Chan(), stream.Close, nil
dispatcher := msgstream.NewSimpleMsgDispatcher(stream, func(pm msgstream.ConsumeMsg) bool {
if pm.GetType() != commonpb.MsgType_Delete || pm.GetVChannel() != vchannelName {
return false
}
return true
})
return dispatcher.Chan(), dispatcher.Close, nil
}
func (sd *shardDelegator) createDeleteStreamFromStreamingService(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) {

View File

@ -207,6 +207,8 @@ func (s *DelegatorDataSuite) SetupTest() {
// init schema
s.genNormalCollection()
s.mq = &msgstream.MockMsgStream{}
s.mq.EXPECT().GetUnmarshalDispatcher().Return(nil)
s.rootPath = s.Suite.T().Name()
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
@ -916,8 +918,9 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().GetUnmarshalDispatcher().Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.MsgPack, 10)
ch := make(chan *msgstream.ConsumeMsgPack, 10)
close(ch)
s.mq.EXPECT().Chan().Return(ch)
@ -1585,7 +1588,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.MsgPack, 10)
ch := make(chan *msgstream.ConsumeMsgPack, 10)
s.mq.EXPECT().Chan().Return(ch)
oracle := pkoracle.NewBloomFilterSet(1, 1, commonpb.SegmentState_Sealed)
@ -1603,7 +1606,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
}
for _, data := range datas {
ch <- data
ch <- msgstream.BuildConsumeMsgPack(data)
}
result, err := s.delegator.readDeleteFromMsgstream(ctx, &msgpb.MsgPosition{Timestamp: 0}, 10, oracle)

View File

@ -64,7 +64,7 @@ import (
type ServiceSuite struct {
suite.Suite
// Data
msgChan chan *msgstream.MsgPack
msgChan chan *msgstream.ConsumeMsgPack
collectionID int64
collectionName string
schema *schemapb.CollectionSchema

View File

@ -24,6 +24,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -250,7 +251,7 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
packChan := make(chan *msgstream.MsgPack, 10)
packChan := make(chan *msgstream.ConsumeMsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
@ -268,13 +269,18 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
},
}
unmarshalFactory := &msgstream.ProtoUDFactory{}
unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher()
err := task.Execute(context.Background())
assert.NoError(t, err)
time.Sleep(time.Second)
select {
case pack := <-packChan:
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType())
tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher)
require.NoError(t, err)
replicateMsg := tsMsg.(*msgstream.ReplicateMsg)
assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase())
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection())
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())

View File

@ -24,6 +24,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/metastore/model"
@ -236,7 +237,7 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
mock.Anything,
).Return(nil)
// the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step
packChan := make(chan *msgstream.MsgPack, 10)
packChan := make(chan *msgstream.ConsumeMsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
@ -252,13 +253,19 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
},
}
unmarshalFactory := &msgstream.ProtoUDFactory{}
unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher()
err := task.Execute(context.Background())
assert.NoError(t, err)
time.Sleep(time.Second)
select {
case pack := <-packChan:
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType())
tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher)
require.NoError(t, err)
replicateMsg := tsMsg.(*msgstream.ReplicateMsg)
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase())
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
default:

View File

@ -277,10 +277,11 @@ type FailMsgStream struct {
errBroadcast bool
}
func (ms *FailMsgStream) Close() {}
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {}
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
func (ms *FailMsgStream) Close() {}
func (ms *FailMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack { return nil }
func (ms *FailMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher { return nil }
func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {}
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil
}

View File

@ -1058,23 +1058,23 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync {
return ticker
}
func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync {
func newChanTimeTickSync(packChan chan *msgstream.ConsumeMsgPack) *timetickSync {
f := msgstream.NewMockMqFactory()
f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) {
stream := msgstream.NewWastedMockMsgStream()
stream.BroadcastFunc = func(pack *msgstream.MsgPack) error {
log.Info("mock Broadcast")
packChan <- pack
packChan <- msgstream.BuildConsumeMsgPack(pack)
return nil
}
stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
log.Info("mock BroadcastMark")
packChan <- pack
packChan <- msgstream.BuildConsumeMsgPack(pack)
return map[string][]msgstream.MessageID{}, nil
}
stream.AsProducerFunc = func(channels []string) {
}
stream.ChanFunc = func() <-chan *msgstream.MsgPack {
stream.ChanFunc = func() <-chan *msgstream.ConsumeMsgPack {
return packChan
}
return stream, nil

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -45,8 +46,9 @@ func TestInputNode(t *testing.T) {
produceStream.AsProducer(context.TODO(), channels)
produceStream.Produce(context.TODO(), &msgPack)
dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.ConsumeMsg) bool { return true })
nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, "", 0, 0, "")
defer inputNode.Close()
isInputNode := inputNode.IsInputNode()
@ -89,7 +91,8 @@ func Test_InputNodeSkipMode(t *testing.T) {
outputCh := make(chan bool)
nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, typeutil.DataNodeRole, 0, 0, "")
dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.ConsumeMsg) bool { return true })
inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, typeutil.DataNodeRole, 0, 0, "")
defer inputNode.Close()
outputCount := 0

View File

@ -26,6 +26,7 @@ import (
)
type MockMsg struct {
*msgstream.BaseMsg
Ctx context.Context
}

View File

@ -89,7 +89,8 @@ func TestNodeManager_Start(t *testing.T) {
produceStream.Produce(context.TODO(), &msgPack)
nodeName := "input_node"
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.ConsumeMsg) bool { return true })
inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, "", 0, 0, "")
ddNode := BaseNode{}

View File

@ -74,7 +74,14 @@ const (
SubscriptionPositionUnknown
)
const MsgTypeKey = "msg_type"
const (
MsgTypeKey = "msg_type"
MsgIdTypeKey = "msg_id"
TimestampTypeKey = "timestamp"
ChannelTypeKey = "vchannel"
CollectionIDTypeKey = "collection_id"
ReplicateIDTypeKey = "replicate_id"
)
func GetMsgType(msg Message) (commonpb.MsgType, error) {
msgType := commonpb.MsgType_Undefined

View File

@ -19,7 +19,6 @@ package msgdispatcher
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"time"
@ -229,7 +228,7 @@ func (d *Dispatcher) work() {
}
d.curTs.Store(pack.EndPositions[0].GetTimestamp())
targetPacks := d.groupingMsgs(pack)
targetPacks := d.groupAndParseMsgs(pack, d.stream.GetUnmarshalDispatcher())
for vchannel, p := range targetPacks {
var err error
t := d.targets[vchannel]
@ -260,7 +259,7 @@ func (d *Dispatcher) work() {
}
}
func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
func (d *Dispatcher) groupAndParseMsgs(pack *msgstream.ConsumeMsgPack, unmarshalDispatcher msgstream.UnmarshalDispatcher) map[string]*MsgPack {
// init packs for all targets, even though there's no msg in pack,
// but we still need to dispatch time ticks to the targets.
targetPacks := make(map[string]*MsgPack)
@ -280,27 +279,24 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
// group messages by vchannel
for _, msg := range pack.Msgs {
var vchannel, collectionID string
switch msg.Type() {
case commonpb.MsgType_Insert:
vchannel = msg.(*msgstream.InsertMsg).GetShardName()
case commonpb.MsgType_Delete:
vchannel = msg.(*msgstream.DeleteMsg).GetShardName()
case commonpb.MsgType_CreateCollection:
collectionID = strconv.FormatInt(msg.(*msgstream.CreateCollectionMsg).GetCollectionID(), 10)
case commonpb.MsgType_DropCollection:
collectionID = strconv.FormatInt(msg.(*msgstream.DropCollectionMsg).GetCollectionID(), 10)
case commonpb.MsgType_CreatePartition:
collectionID = strconv.FormatInt(msg.(*msgstream.CreatePartitionMsg).GetCollectionID(), 10)
case commonpb.MsgType_DropPartition:
collectionID = strconv.FormatInt(msg.(*msgstream.DropPartitionMsg).GetCollectionID(), 10)
if msg.GetType() == commonpb.MsgType_Insert || msg.GetType() == commonpb.MsgType_Delete {
vchannel = msg.GetVChannel()
} else if msg.GetType() == commonpb.MsgType_CreateCollection ||
msg.GetType() == commonpb.MsgType_DropCollection ||
msg.GetType() == commonpb.MsgType_CreatePartition ||
msg.GetType() == commonpb.MsgType_DropPartition {
collectionID = msg.GetCollectionID()
}
if vchannel == "" {
// we need to dispatch it to the vchannel of this collection
targets := []string{}
for k := range targetPacks {
if msg.Type() == commonpb.MsgType_Replicate {
if msg.GetType() == commonpb.MsgType_Replicate {
config := replicateConfigs[k]
if config != nil && msgstream.MatchReplicateID(msg, config.ReplicateID) {
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
if config != nil && msg.GetReplicateID() == config.ReplicateID {
targets = append(targets, k)
}
continue
}
@ -308,14 +304,29 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
if !strings.Contains(k, collectionID) {
continue
}
targets = append(targets, k)
}
if len(targets) > 0 {
tsMsg, err := msg.Unmarshal(unmarshalDispatcher)
if err != nil {
log.Warn("unmarshl message failed", zap.Error(err))
continue
}
// TODO: There's data race when non-dml msg is sent to different flow graph.
// Wrong open-trancing information is generated, Fix in future.
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
for _, target := range targets {
targetPacks[target].Msgs = append(targetPacks[target].Msgs, tsMsg)
}
}
continue
}
if _, ok := targetPacks[vchannel]; ok {
targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, msg)
tsMsg, err := msg.Unmarshal(unmarshalDispatcher)
if err != nil {
log.Warn("unmarshl message failed", zap.Error(err))
continue
}
targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, tsMsg)
}
}
replicateEndChannels := make(map[string]struct{})

View File

@ -150,7 +150,7 @@ func TestGroupMessage(t *testing.T) {
d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo")))
{
// no replicate msg
packs := d.groupingMsgs(&MsgPack{
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
@ -182,13 +182,13 @@ func TestGroupMessage(t *testing.T) {
},
},
},
})
}), nil)
assert.Len(t, packs, 1)
}
{
// equal to replicateID
packs := d.groupingMsgs(&MsgPack{
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
@ -222,7 +222,7 @@ func TestGroupMessage(t *testing.T) {
},
},
},
})
}), nil)
assert.Len(t, packs, 2)
{
replicatePack := packs["mock_pchannel_0_2v0"]
@ -244,7 +244,7 @@ func TestGroupMessage(t *testing.T) {
{
// not equal to replicateID
packs := d.groupingMsgs(&MsgPack{
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
@ -278,7 +278,7 @@ func TestGroupMessage(t *testing.T) {
},
},
},
})
}), nil)
assert.Len(t, packs, 1)
replicatePack := packs["mock_pchannel_0_2v0"]
assert.Nil(t, replicatePack)
@ -288,7 +288,7 @@ func TestGroupMessage(t *testing.T) {
// replicate end
replicateTarget := d.targets["mock_pchannel_0_2v0"]
assert.NotNil(t, replicateTarget.replicateConfig)
packs := d.groupingMsgs(&MsgPack{
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
@ -324,7 +324,7 @@ func TestGroupMessage(t *testing.T) {
},
},
},
})
}), nil)
assert.Len(t, packs, 2)
replicatePack := packs["mock_pchannel_0_2v0"]
assert.EqualValues(t, 100, replicatePack.BeginTs)

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -372,7 +373,7 @@ func (suite *SimulationSuite) TestMerge() {
vchannel, positions[rand.Intn(len(positions))],
common.SubscriptionPositionUnknown,
)) // seek from random position
assert.NoError(suite.T(), err)
require.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}

View File

@ -266,7 +266,7 @@ func testSeekToLast(t *testing.T, f []Factory) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consume(ctx, consumer)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
if i == 5 {
seekPosition = result.EndPositions[0]
}
@ -295,11 +295,11 @@ func testSeekToLast(t *testing.T, f []Factory) {
assert.Equal(t, 1, len(msgPack.Msgs))
for _, tsMsg := range msgPack.Msgs {
assert.Equal(t, value, tsMsg.ID())
assert.Equal(t, value, tsMsg.GetID())
value++
cnt++
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID)
assert.NoError(t, err)
if ret {
hasMore = false
@ -398,13 +398,17 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), result[i])
}
seekMsg2 := consume(ctx, consumer)
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
consumer.Close()
@ -412,7 +416,9 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
seekMsg = consume(ctx, consumer)
assert.Equal(t, len(seekMsg.Msgs), 1)
for _, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
consumer.Close()
}
@ -473,9 +479,11 @@ func testTimeTickerStream1(t *testing.T, f []Factory) {
rcvMsg += len(msgPack.Msgs)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
require.NoError(t, err)
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
}
log.Println("================")
}
@ -525,7 +533,7 @@ func testTimeTickerStream2(t *testing.T, f []Factory) {
// consume msg
log.Println("=============receive msg===================")
rcvMsgPacks := make([]*MsgPack, 0)
rcvMsgPacks := make([]*ConsumeMsgPack, 0)
resumeMsgPack := func(t *testing.T) int {
var consumer MsgStream
@ -539,9 +547,11 @@ func testTimeTickerStream2(t *testing.T, f []Factory) {
rcvMsgPacks = append(rcvMsgPacks, msgPack)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
require.NoError(t, err)
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
}
log.Println("================")
}
@ -576,7 +586,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consume(ctx, consumer)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
if i == 5 {
seekPosition = result.EndPositions[0]
}
@ -586,7 +596,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
consumer = createAndSeekConsumer(ctx, t, f[0].NewMsgStream, channels, []*msgpb.MsgPosition{seekPosition})
for i := 6; i < 10; i++ {
result := consume(ctx, consumer)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
}
consumer.Close()
}
@ -610,7 +620,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consume(ctx, consumer)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
seekPosition = result.EndPositions[0]
}
@ -625,7 +635,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
err = producer.Produce(ctx, msgPack)
assert.NoError(t, err)
result := consume(ctx, consumer2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
assert.Equal(t, result.Msgs[0].GetID(), int64(1))
}
func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
@ -658,7 +668,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
for i := 10; i < 20; i++ {
result := consume(ctx, consumer2)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
}
}
@ -748,7 +758,7 @@ func applyProduceAndConsume(
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
}
func consume(ctx context.Context, mq MsgStream) *MsgPack {
func consume(ctx context.Context, mq MsgStream) *ConsumeMsgPack {
for {
select {
case msgPack, ok := <-mq.Chan():
@ -829,7 +839,7 @@ func receiveAndValidateMsg(ctx context.Context, outputStream MsgStream, msgCount
msgs := result.Msgs
for _, v := range msgs {
receiveCount++
log.Println("msg type: ", v.Type(), ", msg value: ", v)
log.Println("msg type: ", v.GetType(), ", msg value: ", v)
}
log.Println("================")
}

View File

@ -168,19 +168,19 @@ func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, *
}
// Chan provides a mock function with given fields:
func (_m *MockMsgStream) Chan() <-chan *MsgPack {
func (_m *MockMsgStream) Chan() <-chan *ConsumeMsgPack {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Chan")
}
var r0 <-chan *MsgPack
if rf, ok := ret.Get(0).(func() <-chan *MsgPack); ok {
var r0 <-chan *ConsumeMsgPack
if rf, ok := ret.Get(0).(func() <-chan *ConsumeMsgPack); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *MsgPack)
r0 = ret.Get(0).(<-chan *ConsumeMsgPack)
}
}
@ -204,12 +204,12 @@ func (_c *MockMsgStream_Chan_Call) Run(run func()) *MockMsgStream_Chan_Call {
return _c
}
func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *MsgPack) *MockMsgStream_Chan_Call {
func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *ConsumeMsgPack) *MockMsgStream_Chan_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *MsgPack) *MockMsgStream_Chan_Call {
func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *ConsumeMsgPack) *MockMsgStream_Chan_Call {
_c.Call.Return(run)
return _c
}
@ -430,6 +430,53 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin
return _c
}
// GetUnmarshalDispatcher provides a mock function with given fields:
func (_m *MockMsgStream) GetUnmarshalDispatcher() UnmarshalDispatcher {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetUnmarshalDispatcher")
}
var r0 UnmarshalDispatcher
if rf, ok := ret.Get(0).(func() UnmarshalDispatcher); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(UnmarshalDispatcher)
}
}
return r0
}
// MockMsgStream_GetUnmarshalDispatcher_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUnmarshalDispatcher'
type MockMsgStream_GetUnmarshalDispatcher_Call struct {
*mock.Call
}
// GetUnmarshalDispatcher is a helper method to define mock.On call
func (_e *MockMsgStream_Expecter) GetUnmarshalDispatcher() *MockMsgStream_GetUnmarshalDispatcher_Call {
return &MockMsgStream_GetUnmarshalDispatcher_Call{Call: _e.mock.On("GetUnmarshalDispatcher")}
}
func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) Run(run func()) *MockMsgStream_GetUnmarshalDispatcher_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) Return(_a0 UnmarshalDispatcher) *MockMsgStream_GetUnmarshalDispatcher_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) RunAndReturn(run func() UnmarshalDispatcher) *MockMsgStream_GetUnmarshalDispatcher_Call {
_c.Call.Return(run)
return _c
}
// Produce provides a mock function with given fields: _a0, _a1
func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error {
ret := _m.Called(_a0, _a1)

View File

@ -24,6 +24,7 @@ import (
"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
@ -131,7 +132,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
outputStream := getKafkaOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
for i := 0; i < 10; i++ {
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
if i == 5 {
seekPosition = result.EndPositions[0]
break
@ -162,11 +163,11 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
assert.Equal(t, 1, len(msgPack.Msgs))
for _, tsMsg := range msgPack.Msgs {
assert.Equal(t, value, tsMsg.ID())
assert.Equal(t, value, tsMsg.GetID())
value++
cnt++
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID)
assert.NoError(t, err)
if ret {
hasMore = false
@ -272,20 +273,26 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) {
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), result[i])
}
seekMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
outputStream2 := getKafkaTtOutputStreamAndSeek(ctx, kafkaAddress, receivedMsg3.EndPositions)
seekMsg = consumer(ctx, outputStream2)
assert.Equal(t, len(seekMsg.Msgs), 1)
for _, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
inputStream.Close()
@ -320,9 +327,11 @@ func TestStream_KafkaTtMsgStream_1(t *testing.T) {
rcvMsg += len(msgPack.Msgs)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
}
}
}
@ -361,7 +370,7 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) {
// consume msg
log.Println("=============receive msg===================")
rcvMsgPacks := make([]*MsgPack, 0)
rcvMsgPacks := make([]*ConsumeMsgPack, 0)
resumeMsgPack := func(t *testing.T) int {
var outputStream MsgStream
@ -376,9 +385,11 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) {
rcvMsgPacks = append(rcvMsgPacks, msgPack)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
log.Println("TestStream_KafkaTtMsgStream_2 msg type: ", msg.Type(), ", msg value: ", msg)
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
log.Println("TestStream_KafkaTtMsgStream_2 msg type: ", tsMsg.Type(), ", msg value: ", msg)
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
}
log.Println("================")
}

View File

@ -61,7 +61,7 @@ type mqMsgStream struct {
repackFunc RepackFunc
unmarshal UnmarshalDispatcher
receiveBuf chan *MsgPack
receiveBuf chan *ConsumeMsgPack
closeRWMutex *sync.RWMutex
streamCancel func()
bufSize int64
@ -89,7 +89,7 @@ func NewMqMsgStream(initCtx context.Context,
consumers := make(map[string]mqwrapper.Consumer)
producerChannels := make([]string, 0)
consumerChannels := make([]string, 0)
receiveBuf := make(chan *MsgPack, receiveBufSize)
receiveBuf := make(chan *ConsumeMsgPack, receiveBufSize)
stream := &mqMsgStream{
ctx: streamCtx,
@ -355,9 +355,7 @@ func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
return err
}
msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{
common.MsgTypeKey: v.Msgs[i].Type().String(),
}}
msg := &common.ProducerMessage{Payload: m, Properties: GetPorperties(v.Msgs[i])}
InjectCtx(spanCtx, msg.Properties)
if _, err := producer.Send(spanCtx, msg); err != nil {
@ -399,7 +397,7 @@ func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[str
return ids, err
}
msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{}}
msg := &common.ProducerMessage{Payload: m, Properties: GetPorperties(v)}
InjectCtx(spanCtx, msg.Properties)
ms.producerLock.RLock()
@ -421,10 +419,6 @@ func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[str
return ids, nil
}
func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg common.Message) (TsMsg, error) {
return GetTsMsgFromConsumerMsg(ms.unmarshal, msg)
}
// GetTsMsgFromConsumerMsg get TsMsg from consumer message
func GetTsMsgFromConsumerMsg(unmarshalDispatcher UnmarshalDispatcher, msg common.Message) (TsMsg, error) {
msgType, err := common.GetMsgType(msg)
@ -464,31 +458,35 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) {
log.Ctx(ms.ctx).Warn("MqMsgStream get msg whose payload is nil")
continue
}
// not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295
// if the message not belong to the topic, will skip it
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
var err error
var packMsg ConsumeMsg
packMsg, err = NewMarshaledMsg(msg, consumer.Subscription())
if err != nil {
log.Ctx(ms.ctx).Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
continue
packMsg, err = UnmarshalMsg(msg, ms.unmarshal)
if err != nil {
log.Ctx(ms.ctx).Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
continue
}
}
pos := tsMsg.Position()
tsMsg.SetPosition(&MsgPosition{
ChannelName: pos.ChannelName,
MsgID: pos.MsgID,
pos := &msgpb.MsgPosition{
ChannelName: filepath.Base(msg.Topic()),
MsgID: packMsg.GetMessageID(),
MsgGroup: consumer.Subscription(),
Timestamp: tsMsg.BeginTs(),
})
ctx, _ := ExtractCtx(tsMsg, msg.Properties())
tsMsg.SetTraceCtx(ctx)
msgPack := MsgPack{
Msgs: []TsMsg{tsMsg},
StartPositions: []*msgpb.MsgPosition{tsMsg.Position()},
EndPositions: []*msgpb.MsgPosition{tsMsg.Position()},
BeginTs: tsMsg.BeginTs(),
EndTs: tsMsg.EndTs(),
Timestamp: packMsg.GetTimestamp(),
}
packMsg.SetPosition(pos)
msgPack := ConsumeMsgPack{
Msgs: []ConsumeMsg{packMsg},
StartPositions: []*msgpb.MsgPosition{pos},
EndPositions: []*msgpb.MsgPosition{pos},
BeginTs: packMsg.GetTimestamp(),
EndTs: packMsg.GetTimestamp(),
}
select {
case ms.receiveBuf <- &msgPack:
case <-ms.ctx.Done():
@ -498,7 +496,11 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) {
}
}
func (ms *mqMsgStream) Chan() <-chan *MsgPack {
func (ms *mqMsgStream) GetUnmarshalDispatcher() UnmarshalDispatcher {
return ms.unmarshal
}
func (ms *mqMsgStream) Chan() <-chan *ConsumeMsgPack {
ms.onceChan.Do(func() {
for _, c := range ms.consumers {
go ms.receiveMsg(c)
@ -546,7 +548,7 @@ var _ MsgStream = (*MqTtMsgStream)(nil)
// MqTtMsgStream is a msgstream that contains timeticks
type MqTtMsgStream struct {
*mqMsgStream
chanMsgBuf map[mqwrapper.Consumer][]TsMsg
chanMsgBuf map[mqwrapper.Consumer][]ConsumeMsg
chanMsgPos map[mqwrapper.Consumer]*msgpb.MsgPosition
chanStopChan map[mqwrapper.Consumer]chan bool
chanTtMsgTime map[mqwrapper.Consumer]Timestamp
@ -568,7 +570,7 @@ func NewMqTtMsgStream(ctx context.Context,
if err != nil {
return nil, err
}
chanMsgBuf := make(map[mqwrapper.Consumer][]TsMsg)
chanMsgBuf := make(map[mqwrapper.Consumer][]ConsumeMsg)
chanMsgPos := make(map[mqwrapper.Consumer]*msgpb.MsgPosition)
chanStopChan := make(map[mqwrapper.Consumer]chan bool)
chanTtMsgTime := make(map[mqwrapper.Consumer]Timestamp)
@ -593,7 +595,7 @@ func (ms *MqTtMsgStream) addConsumer(consumer mqwrapper.Consumer, channel string
}
ms.consumers[channel] = consumer
ms.consumerChannels = append(ms.consumerChannels, channel)
ms.chanMsgBuf[consumer] = make([]TsMsg, 0)
ms.chanMsgBuf[consumer] = make([]ConsumeMsg, 0)
ms.chanMsgPos[consumer] = &msgpb.MsgPosition{
ChannelName: channel,
MsgID: make([]byte, 0),
@ -649,8 +651,8 @@ func (ms *MqTtMsgStream) Close() {
ms.mqMsgStream.Close()
}
func isDMLMsg(msg TsMsg) bool {
return msg.Type() == commonpb.MsgType_Insert || msg.Type() == commonpb.MsgType_Delete
func isDMLMsg(msg ConsumeMsg) bool {
return msg.GetType() == commonpb.MsgType_Insert || msg.GetType() == commonpb.MsgType_Delete
}
func (ms *MqTtMsgStream) continueBuffering(endTs, size uint64, startTime time.Time) bool {
@ -700,7 +702,7 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
case <-ms.ctx.Done():
return
default:
timeTickBuf := make([]TsMsg, 0)
timeTickBuf := make([]ConsumeMsg, 0)
// startMsgPosition := make([]*msgpb.MsgPosition, 0)
// endMsgPositions := make([]*msgpb.MsgPosition, 0)
startPositions := make(map[string]*msgpb.MsgPosition)
@ -739,22 +741,22 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
if _, ok := startPositions[channelName]; !ok {
startPositions[channelName] = startPos
}
tempBuffer := make([]TsMsg, 0)
var timeTickMsg TsMsg
tempBuffer := make([]ConsumeMsg, 0)
var timeTickMsg ConsumeMsg
for _, v := range msgs {
if v.Type() == commonpb.MsgType_TimeTick {
if v.GetType() == commonpb.MsgType_TimeTick {
timeTickMsg = v
continue
}
if v.EndTs() <= currTs ||
GetReplicateID(v) != "" {
size += uint64(v.Size())
if v.GetTimestamp() <= currTs ||
v.GetReplicateID() != "" {
size += uint64(v.GetSize())
timeTickBuf = append(timeTickBuf, v)
} else {
tempBuffer = append(tempBuffer, v)
}
// when drop collection, force to exit the buffer loop
if v.Type() == commonpb.MsgType_DropCollection || v.Type() == commonpb.MsgType_Replicate {
if v.GetType() == commonpb.MsgType_DropCollection || v.GetType() == commonpb.MsgType_Replicate {
containsEndBufferMsg = true
}
}
@ -765,8 +767,8 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
if len(tempBuffer) > 0 {
// if tempBuffer is not empty, use tempBuffer[0] to seek
newPos = &msgpb.MsgPosition{
ChannelName: tempBuffer[0].Position().ChannelName,
MsgID: tempBuffer[0].Position().MsgID,
ChannelName: tempBuffer[0].GetPChannel(),
MsgID: tempBuffer[0].GetMessageID(),
Timestamp: currTs,
MsgGroup: consumer.Subscription(),
}
@ -774,8 +776,8 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
} else if timeTickMsg != nil {
// if tempBuffer is empty, use timeTickMsg to seek
newPos = &msgpb.MsgPosition{
ChannelName: timeTickMsg.Position().ChannelName,
MsgID: timeTickMsg.Position().MsgID,
ChannelName: timeTickMsg.GetPChannel(),
MsgID: timeTickMsg.GetMessageID(),
Timestamp: currTs,
MsgGroup: consumer.Subscription(),
}
@ -787,20 +789,20 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
ms.consumerLock.Unlock()
}
idset := make(typeutil.UniqueSet)
uniqueMsgs := make([]TsMsg, 0, len(timeTickBuf))
idset := make(typeutil.Set[int64])
uniqueMsgs := make([]ConsumeMsg, 0, len(timeTickBuf))
for _, msg := range timeTickBuf {
if isDMLMsg(msg) && idset.Contain(msg.ID()) {
log.Ctx(ms.ctx).Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.ID()))
if isDMLMsg(msg) && idset.Contain(msg.GetID()) {
log.Ctx(ms.ctx).Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.GetID()))
continue
}
idset.Insert(msg.ID())
idset.Insert(msg.GetID())
uniqueMsgs = append(uniqueMsgs, msg)
}
// skip endTs = 0 (no run for ctx error)
if endTs > 0 {
msgPack := MsgPack{
msgPack := ConsumeMsgPack{
BeginTs: ms.lastTimeStamp,
EndTs: endTs,
Msgs: uniqueMsgs,
@ -840,21 +842,26 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqwrapper.Consumer) {
log.Warn("MqTtMsgStream get msg whose payload is nil")
continue
}
// not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295
// if the message not belong to the topic, will skip it
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
var err error
var packMsg ConsumeMsg
packMsg, err = NewMarshaledMsg(msg, consumer.Subscription())
if err != nil {
log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
continue
packMsg, err = UnmarshalMsg(msg, ms.unmarshal)
if err != nil {
log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
continue
}
}
ms.chanMsgBufMutex.Lock()
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg)
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], packMsg)
ms.chanMsgBufMutex.Unlock()
if tsMsg.Type() == commonpb.MsgType_TimeTick {
if packMsg.GetType() == commonpb.MsgType_TimeTick {
ms.chanTtMsgTimeMutex.Lock()
ms.chanTtMsgTime[consumer] = tsMsg.(*TimeTickMsg).Base.Timestamp
ms.chanTtMsgTime[consumer] = packMsg.GetTimestamp()
ms.chanTtMsgTimeMutex.Unlock()
return
}
@ -972,20 +979,23 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
loopMsgCnt++
consumer.Ack(msg)
headerMsg := commonpb.MsgHeader{}
err := proto.Unmarshal(msg.Payload(), &headerMsg)
var err error
var packMsg ConsumeMsg
packMsg, err = NewMarshaledMsg(msg, consumer.Subscription())
if err != nil {
return fmt.Errorf("failed to unmarshal message header, err %s", err.Error())
}
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType)
if err != nil {
return fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error())
packMsg, err = UnmarshalMsg(msg, ms.unmarshal)
if err != nil {
log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
continue
}
}
// skip the replicate msg because it must have been consumed
if GetReplicateID(tsMsg) != "" {
if packMsg.GetReplicateID() != "" {
continue
}
if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp {
if packMsg.GetType() == commonpb.MsgType_TimeTick && packMsg.GetTimestamp() >= mp.Timestamp {
runLoop = false
if time.Since(loopStarTime) > 30*time.Second {
log.Info("seek loop finished long time",
@ -993,21 +1003,21 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
zap.String("channel", mp.ChannelName),
zap.Duration("cost", time.Since(loopStarTime)))
}
} else if tsMsg.BeginTs() > mp.Timestamp {
ctx, _ := ExtractCtx(tsMsg, msg.Properties())
tsMsg.SetTraceCtx(ctx)
} else if packMsg.GetTimestamp() > mp.Timestamp {
ctx, _ := ExtractCtx(packMsg, msg.Properties())
packMsg.SetTraceCtx(ctx)
tsMsg.SetPosition(&MsgPosition{
packMsg.SetPosition(&MsgPosition{
ChannelName: filepath.Base(msg.Topic()),
MsgID: msg.ID().Serialize(),
})
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg)
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], packMsg)
} else {
log.Info("skip msg",
zap.Int64("source", tsMsg.SourceID()),
zap.String("type", tsMsg.Type().String()),
zap.Int("size", tsMsg.Size()),
zap.Uint64("msgTs", tsMsg.BeginTs()),
// zap.Int64("source", tsMsg.SourceID()), // TODO SOURCE ID ?
zap.String("type", packMsg.GetType().String()),
zap.Int("size", packMsg.GetSize()),
zap.Uint64("msgTs", packMsg.GetTimestamp()),
zap.Uint64("posTs", mp.GetTimestamp()),
)
}
@ -1017,7 +1027,7 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
return nil
}
func (ms *MqTtMsgStream) Chan() <-chan *MsgPack {
func (ms *MqTtMsgStream) Chan() <-chan *ConsumeMsgPack {
ms.onceChan.Do(func() {
if ms.consumers != nil {
go ms.bufMsgPackToChannel()

View File

@ -85,7 +85,7 @@ func getKafkaBrokerList() string {
return brokerList
}
func consumer(ctx context.Context, mq MsgStream) *MsgPack {
func consumer(ctx context.Context, mq MsgStream) *ConsumeMsgPack {
for {
select {
case msgPack, ok := <-mq.Chan():
@ -506,7 +506,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
if i == 5 {
seekPosition = result.EndPositions[0]
}
@ -539,11 +539,11 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
assert.Equal(t, 1, len(msgPack.Msgs))
for _, tsMsg := range msgPack.Msgs {
assert.Equal(t, value, tsMsg.ID())
assert.Equal(t, value, tsMsg.GetID())
value++
cnt++
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID)
assert.NoError(t, err)
if ret {
hasMore = false
@ -674,20 +674,26 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), result[i])
}
seekMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
outputStream2 := getPulsarTtOutputStreamAndSeek(ctx, pulsarAddress, receivedMsg3.EndPositions)
seekMsg = consumer(ctx, outputStream2)
assert.Equal(t, len(seekMsg.Msgs), 1)
for _, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
inputStream.Close()
@ -882,9 +888,11 @@ func TestStream_PulsarTtMsgStream_1(t *testing.T) {
rcvMsg += len(msgPack.Msgs)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
}
log.Println("================")
}
@ -940,7 +948,7 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) {
// consume msg
log.Println("=============receive msg===================")
rcvMsgPacks := make([]*MsgPack, 0)
rcvMsgPacks := make([]*ConsumeMsgPack, 0)
resumeMsgPack := func(t *testing.T) int {
var outputStream MsgStream
@ -954,9 +962,11 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) {
rcvMsgPacks = append(rcvMsgPacks, msgPack)
if len(msgPack.Msgs) > 0 {
for _, msg := range msgPack.Msgs {
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
}
log.Println("================")
}
@ -998,7 +1008,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
if i == 5 {
seekPosition = result.EndPositions[0]
}
@ -1013,7 +1023,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
for i := 6; i < 10; i++ {
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
}
outputStream2.Close()
}
@ -1042,7 +1052,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
seekPosition = result.EndPositions[0]
}
@ -1074,7 +1084,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
err = inputStream.Produce(ctx, msgPack)
assert.NoError(t, err)
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
assert.Equal(t, result.Msgs[0].GetID(), int64(1))
}
func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
@ -1101,7 +1111,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
seekPosition = result.EndPositions[0]
}
@ -1179,7 +1189,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
for i := 10; i < 20; i++ {
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
}
inputStream.Close()
@ -1570,7 +1580,7 @@ func receiveMsg(ctx context.Context, outputStream MsgStream, msgCount int) {
msgs := result.Msgs
for _, v := range msgs {
receiveCount++
log.Println("msg type: ", v.Type(), ", msg value: ", v)
log.Println("msg type: ", v.GetType(), ", msg value: ", v)
}
log.Println("================")
}

View File

@ -380,9 +380,9 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
outputStream.Seek(ctx, receivedMsg.StartPositions, false)
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 1+2)
assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1)
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type())
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type())
assert.EqualValues(t, seekMsg.Msgs[0].GetTimestamp(), 1)
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].GetType())
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].GetType())
inputStream.Close()
outputStream.Close()
@ -485,13 +485,17 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
for i, msg := range seekMsg.Msgs {
assert.Equal(t, msg.BeginTs(), result[i])
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), result[i])
}
seekMsg2 := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg2.Msgs), 1)
for _, msg := range seekMsg2.Msgs {
assert.Equal(t, msg.BeginTs(), uint64(19))
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
require.NoError(t, err)
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
}
inputStream.Close()
@ -517,7 +521,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
var seekPosition *msgpb.MsgPosition
for i := 0; i < 10; i++ {
result := consumer(ctx, outputStream)
assert.Equal(t, result.Msgs[0].ID(), int64(i))
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
seekPosition = result.EndPositions[0]
}
outputStream.Close()
@ -550,7 +554,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
assert.NoError(t, err)
result := consumer(ctx, outputStream2)
assert.Equal(t, result.Msgs[0].ID(), int64(1))
assert.Equal(t, result.Msgs[0].GetID(), int64(1))
inputStream.Close()
outputStream2.Close()
@ -585,7 +589,7 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
pack := <-outputStream.Chan()
assert.NotNil(t, pack)
assert.Equal(t, 1, len(pack.Msgs))
assert.EqualValues(t, 1000, pack.Msgs[0].BeginTs())
assert.EqualValues(t, 1000, pack.Msgs[0].GetTimestamp())
inputStream.Close()
outputStream.Close()

View File

@ -47,6 +47,10 @@ type TsMsg interface {
BeginTs() Timestamp
EndTs() Timestamp
Type() MsgType
VChannel() string
// CollID return msg collection id
// return 0 if not exist
CollID() int64
SourceID() int64
HashKeys() []uint32
Marshal(TsMsg) (MarshalType, error)
@ -117,6 +121,14 @@ func (bm *BaseMsg) SetTs(ts uint64) {
bm.EndTimestamp = ts
}
func (it *BaseMsg) VChannel() string {
return ""
}
func (it *BaseMsg) CollID() int64 {
return 0
}
func convertToByteArray(input interface{}) ([]byte, error) {
switch output := input.(type) {
case []byte:
@ -157,6 +169,14 @@ func (it *InsertMsg) SourceID() int64 {
return it.Base.SourceID
}
func (it *InsertMsg) VChannel() string {
return it.ShardName
}
func (it *InsertMsg) CollID() int64 {
return it.GetCollectionID()
}
// Marshal is used to serialize a message pack to byte array
func (it *InsertMsg) Marshal(input TsMsg) (MarshalType, error) {
insertMsg := input.(*InsertMsg)
@ -343,6 +363,14 @@ func (dt *DeleteMsg) SourceID() int64 {
return dt.Base.SourceID
}
func (it *DeleteMsg) VChannel() string {
return it.ShardName
}
func (it *DeleteMsg) CollID() int64 {
return it.GetCollectionID()
}
// Marshal is used to serializing a message pack to byte array
func (dt *DeleteMsg) Marshal(input TsMsg) (MarshalType, error) {
deleteMsg := input.(*DeleteMsg)
@ -516,6 +544,10 @@ func (cc *CreateCollectionMsg) SourceID() int64 {
return cc.Base.SourceID
}
func (it *CreateCollectionMsg) CollID() int64 {
return it.GetCollectionID()
}
// Marshal is used to serializing a message pack to byte array
func (cc *CreateCollectionMsg) Marshal(input TsMsg) (MarshalType, error) {
createCollectionMsg := input.(*CreateCollectionMsg)
@ -580,6 +612,10 @@ func (dc *DropCollectionMsg) SourceID() int64 {
return dc.Base.SourceID
}
func (it *DropCollectionMsg) CollID() int64 {
return it.GetCollectionID()
}
// Marshal is used to serializing a message pack to byte array
func (dc *DropCollectionMsg) Marshal(input TsMsg) (MarshalType, error) {
dropCollectionMsg := input.(*DropCollectionMsg)
@ -644,6 +680,10 @@ func (cp *CreatePartitionMsg) SourceID() int64 {
return cp.Base.SourceID
}
func (it *CreatePartitionMsg) CollID() int64 {
return it.GetCollectionID()
}
// Marshal is used to serializing a message pack to byte array
func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
createPartitionMsg := input.(*CreatePartitionMsg)
@ -708,6 +748,10 @@ func (dp *DropPartitionMsg) SourceID() int64 {
return dp.Base.SourceID
}
func (it *DropPartitionMsg) CollID() int64 {
return it.GetCollectionID()
}
// Marshal is used to serializing a message pack to byte array
func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
dropPartitionMsg := input.(*DropPartitionMsg)

View File

@ -18,6 +18,10 @@ package msgstream
import (
"context"
"fmt"
"path/filepath"
"strconv"
"sync"
"go.uber.org/zap"
@ -25,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -52,6 +57,238 @@ type MsgPack struct {
EndPositions []*MsgPosition
}
// ConsumeMsgPack represents a batch of msg in consumer
type ConsumeMsgPack struct {
BeginTs Timestamp
EndTs Timestamp
Msgs []ConsumeMsg
StartPositions []*MsgPosition
EndPositions []*MsgPosition
}
// ConsumeMsg used for ConumserMsgPack
// support fetch some properties metric
type ConsumeMsg interface {
GetPosition() *msgpb.MsgPosition
SetPosition(*msgpb.MsgPosition)
GetSize() int
GetTimestamp() uint64
GetVChannel() string
GetPChannel() string
GetMessageID() []byte
GetID() int64
GetCollectionID() string
GetType() commonpb.MsgType
GetReplicateID() string
SetTraceCtx(ctx context.Context)
Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error)
}
// UnmarshaledMsg pack unmarshalled tsMsg as ConsumeMsg
// For Compatibility or Test
type UnmarshaledMsg struct {
msg TsMsg
}
func (m *UnmarshaledMsg) GetTimestamp() uint64 {
return m.msg.BeginTs()
}
func (m *UnmarshaledMsg) GetVChannel() string {
return m.msg.VChannel()
}
func (m *UnmarshaledMsg) GetPChannel() string {
return m.msg.Position().GetChannelName()
}
func (m *UnmarshaledMsg) GetMessageID() []byte {
return m.msg.Position().GetMsgID()
}
func (m *UnmarshaledMsg) GetID() int64 {
return m.msg.ID()
}
func (m *UnmarshaledMsg) GetCollectionID() string {
return strconv.FormatInt(m.msg.CollID(), 10)
}
func (m *UnmarshaledMsg) GetType() commonpb.MsgType {
return m.msg.Type()
}
func (m *UnmarshaledMsg) GetSize() int {
return m.msg.Size()
}
func (m *UnmarshaledMsg) GetReplicateID() string {
msgBase, ok := m.msg.(interface{ GetBase() *commonpb.MsgBase })
if !ok {
log.Warn("fail to get msg base, please check it", zap.Any("type", m.msg.Type()))
return ""
}
return msgBase.GetBase().GetReplicateInfo().GetReplicateID()
}
func (m *UnmarshaledMsg) SetPosition(pos *msgpb.MsgPosition) {
m.msg.SetPosition(pos)
}
func (m *UnmarshaledMsg) GetPosition() *msgpb.MsgPosition {
return m.msg.Position()
}
func (m *UnmarshaledMsg) SetTraceCtx(ctx context.Context) {
m.msg.SetTraceCtx(ctx)
}
func (m *UnmarshaledMsg) Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) {
return m.msg, nil
}
// MarshaledMsg pack marshaled tsMsg
// and parse properties
type MarshaledMsg struct {
msg common.Message
pos *MsgPosition
msgType MsgType
msgID int64
timestamp uint64
vchannel string
collectionID string
replicateID string
traceCtx context.Context
}
func (m *MarshaledMsg) GetTimestamp() uint64 {
return m.timestamp
}
func (m *MarshaledMsg) GetVChannel() string {
return m.vchannel
}
func (m *MarshaledMsg) GetPChannel() string {
return filepath.Base(m.msg.Topic())
}
func (m *MarshaledMsg) GetMessageID() []byte {
return m.msg.ID().Serialize()
}
func (m *MarshaledMsg) GetID() int64 {
return m.msgID
}
func (m *MarshaledMsg) GetCollectionID() string {
return m.collectionID
}
func (m *MarshaledMsg) GetType() commonpb.MsgType {
return m.msgType
}
func (m *MarshaledMsg) GetSize() int {
return len(m.msg.Payload())
}
func (m *MarshaledMsg) GetReplicateID() string {
return m.replicateID
}
func (m *MarshaledMsg) SetPosition(pos *msgpb.MsgPosition) {
m.pos = pos
}
func (m *MarshaledMsg) GetPosition() *msgpb.MsgPosition {
return m.pos
}
func (m *MarshaledMsg) SetTraceCtx(ctx context.Context) {
m.traceCtx = ctx
}
func (m *MarshaledMsg) Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) {
tsMsg, err := GetTsMsgFromConsumerMsg(unmarshalDispatcher, m.msg)
if err != nil {
return nil, err
}
tsMsg.SetTraceCtx(m.traceCtx)
tsMsg.SetPosition(m.pos)
return tsMsg, nil
}
func NewMarshaledMsg(msg common.Message, group string) (ConsumeMsg, error) {
properties := msg.Properties()
vchannel, ok := properties[common.ChannelTypeKey]
if !ok {
return nil, fmt.Errorf("get channel name from msg properties failed")
}
collID, ok := properties[common.CollectionIDTypeKey]
if !ok {
return nil, fmt.Errorf("get collection ID from msg properties failed")
}
tsStr, ok := properties[common.TimestampTypeKey]
if !ok {
return nil, fmt.Errorf("get minTs from msg properties failed")
}
timestamp, err := strconv.ParseUint(tsStr, 10, 64)
if err != nil {
log.Warn("parse message properties minTs failed, unknown message", zap.Error(err))
return nil, fmt.Errorf("parse minTs from msg properties failed")
}
idStr, ok := properties[common.MsgIdTypeKey]
if !ok {
return nil, fmt.Errorf("get msgType from msg properties failed")
}
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
log.Warn("parse message properties minTs failed, unknown message", zap.Error(err))
return nil, fmt.Errorf("parse minTs from msg properties failed")
}
val, ok := properties[common.MsgTypeKey]
if !ok {
return nil, fmt.Errorf("get msgType from msg properties failed")
}
msgType := commonpb.MsgType(commonpb.MsgType_value[val])
result := &MarshaledMsg{
msg: msg,
msgID: id,
collectionID: collID,
timestamp: timestamp,
msgType: msgType,
vchannel: vchannel,
}
replicateID, ok := properties[common.ReplicateIDTypeKey]
if ok {
result.replicateID = replicateID
}
return result, nil
}
// unmarshal common message to UnmarshaledMsg
func UnmarshalMsg(msg common.Message, unmarshalDispatcher UnmarshalDispatcher) (ConsumeMsg, error) {
tsMsg, err := GetTsMsgFromConsumerMsg(unmarshalDispatcher, msg)
if err != nil {
return nil, err
}
return &UnmarshaledMsg{
msg: tsMsg,
}, nil
}
// RepackFunc is a function type which used to repack message after hash by primary key
type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error)
@ -66,7 +303,8 @@ type MsgStream interface {
Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error)
AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error
Chan() <-chan *MsgPack
Chan() <-chan *ConsumeMsgPack
GetUnmarshalDispatcher() UnmarshalDispatcher
// Seek consume message from the specified position
// includeCurrentMsg indicates whether to consume the current message, and in the milvus system, it should be always false
Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error
@ -117,6 +355,15 @@ func GetReplicateID(msg TsMsg) string {
return msgBase.GetBase().GetReplicateInfo().GetReplicateID()
}
func GetTimestamp(msg TsMsg) uint64 {
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
if !ok {
log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type()))
return 0
}
return msgBase.GetBase().GetTimestamp()
}
func MatchReplicateID(msg TsMsg, replicateID string) bool {
return GetReplicateID(msg) == replicateID
}
@ -126,3 +373,81 @@ type Factory interface {
NewTtMsgStream(ctx context.Context) (MsgStream, error)
NewMsgStreamDisposer(ctx context.Context) func([]string, string) error
}
// Filter and parse ts message for temporary stream
type SimpleMsgDispatcher struct {
stream MsgStream
unmarshalDispatcher UnmarshalDispatcher
filter func(ConsumeMsg) bool
ch chan *MsgPack
chOnce sync.Once
closeCh chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
}
func NewSimpleMsgDispatcher(stream MsgStream, filter func(ConsumeMsg) bool) *SimpleMsgDispatcher {
return &SimpleMsgDispatcher{
stream: stream,
filter: filter,
unmarshalDispatcher: stream.GetUnmarshalDispatcher(),
closeCh: make(chan struct{}),
}
}
func (p *SimpleMsgDispatcher) filterAndParase() {
defer func() {
close(p.ch)
p.wg.Done()
}()
for {
select {
case <-p.closeCh:
return
case marshalPack, ok := <-p.stream.Chan():
if !ok {
log.Warn("dispatcher fail to read delta msg")
return
}
msgPack := &MsgPack{
BeginTs: marshalPack.BeginTs,
EndTs: marshalPack.EndTs,
Msgs: make([]TsMsg, 0),
StartPositions: marshalPack.StartPositions,
EndPositions: marshalPack.EndPositions,
}
for _, marshalMsg := range marshalPack.Msgs {
if !p.filter(marshalMsg) {
continue
}
// unmarshal message
msg, err := marshalMsg.Unmarshal(p.unmarshalDispatcher)
if err != nil {
log.Warn("unmarshal message failed, invalid message", zap.Error(err))
continue
}
msgPack.Msgs = append(msgPack.Msgs, msg)
}
p.ch <- msgPack
}
}
}
func (p *SimpleMsgDispatcher) Chan() chan *MsgPack {
p.chOnce.Do(func() {
p.ch = make(chan *MsgPack, paramtable.Get().MQCfg.ReceiveBufSize.GetAsInt64())
p.wg.Add(1)
go p.filterAndParase()
})
return p.ch
}
func (p *SimpleMsgDispatcher) Close() {
p.closeOnce.Do(func() {
p.stream.Close()
close(p.closeCh)
p.wg.Wait()
})
}

View File

@ -20,10 +20,13 @@ import (
"context"
"fmt"
"math/rand"
"strconv"
"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
pcommon "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/common"
@ -141,3 +144,32 @@ func KafkaHealthCheck(clusterStatus *pcommon.MQClusterStatus) {
clusterStatus.Health = true
clusterStatus.Members = healthList
}
func GetPorperties(msg TsMsg) map[string]string {
properties := map[string]string{}
properties[common.ChannelTypeKey] = msg.VChannel()
properties[common.MsgTypeKey] = msg.Type().String()
properties[common.MsgIdTypeKey] = strconv.FormatInt(msg.ID(), 10)
properties[common.CollectionIDTypeKey] = strconv.FormatInt(msg.CollID(), 10)
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
if ok {
properties[common.TimestampTypeKey] = strconv.FormatUint(msgBase.GetBase().GetTimestamp(), 10)
if msgBase.GetBase().GetReplicateInfo() != nil {
properties[common.ReplicateIDTypeKey] = msgBase.GetBase().GetReplicateInfo().GetReplicateID()
}
}
return properties
}
func BuildConsumeMsgPack(pack *MsgPack) *ConsumeMsgPack {
return &ConsumeMsgPack{
BeginTs: pack.BeginTs,
EndTs: pack.EndTs,
Msgs: lo.Map(pack.Msgs, func(msg TsMsg, _ int) ConsumeMsg {
return &UnmarshaledMsg{msg: msg}
}),
StartPositions: pack.StartPositions,
EndPositions: pack.EndPositions,
}
}

View File

@ -29,21 +29,17 @@ import (
// ExtractCtx extracts trace span from msg.properties.
// And it will attach some default tags to the span.
func ExtractCtx(msg TsMsg, properties map[string]string) (context.Context, trace.Span) {
ctx := msg.TraceCtx()
if ctx == nil {
ctx = context.Background()
}
func ExtractCtx(msg ConsumeMsg, properties map[string]string) (context.Context, trace.Span) {
ctx := context.Background()
if !allowTrace(msg) {
return ctx, trace.SpanFromContext(ctx)
}
ctx = otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(properties))
name := "ReceieveMsg"
return otel.Tracer(name).Start(ctx, name, trace.WithAttributes(
attribute.Int64("ID", msg.ID()),
attribute.String("Type", msg.Type().String()),
// attribute.Int64Value("HashKeys", msg.HashKeys()),
attribute.String("Position", msg.Position().String()),
attribute.Int64("ID", msg.GetID()),
attribute.String("Type", msg.GetType().String()),
attribute.String("Position", msg.GetPosition().String()),
))
}

View File

@ -7,7 +7,7 @@ type WastedMockMsgStream struct {
AsProducerFunc func(channels []string)
BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error)
BroadcastFunc func(*MsgPack) error
ChanFunc func() <-chan *MsgPack
ChanFunc func() <-chan *ConsumeMsgPack
}
func NewWastedMockMsgStream() *WastedMockMsgStream {
@ -22,6 +22,6 @@ func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[
return m.BroadcastMarkFunc(pack)
}
func (m WastedMockMsgStream) Chan() <-chan *MsgPack {
func (m WastedMockMsgStream) Chan() <-chan *ConsumeMsgPack {
return m.ChanFunc()
}