mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: unmashall ts msg in dispatcher instead in msgstream (#38656)
relate: https://github.com/milvus-io/milvus/issues/38655 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
58045a3396
commit
24d2bbc441
@ -58,7 +58,7 @@ Analyze(CAnalyze* res_analyze,
|
||||
auto analyze_info =
|
||||
std::make_unique<milvus::proto::clustering::AnalyzeInfo>();
|
||||
auto res = analyze_info->ParseFromArray(serialized_analyze_info, len);
|
||||
AssertInfo(res, "Unmarshall analyze info failed");
|
||||
AssertInfo(res, "Unmarshal analyze info failed");
|
||||
auto field_type =
|
||||
static_cast<DataType>(analyze_info->field_schema().data_type());
|
||||
auto field_id = analyze_info->field_schema().fieldid();
|
||||
|
||||
@ -161,7 +161,7 @@ CreateIndex(CIndex* res_index,
|
||||
std::make_unique<milvus::proto::indexcgo::BuildIndexInfo>();
|
||||
auto res =
|
||||
build_index_info->ParseFromArray(serialized_build_index_info, len);
|
||||
AssertInfo(res, "Unmarshall build index info failed");
|
||||
AssertInfo(res, "Unmarshal build index info failed");
|
||||
|
||||
auto field_type =
|
||||
static_cast<DataType>(build_index_info->field_schema().data_type());
|
||||
@ -233,7 +233,7 @@ BuildTextIndex(ProtoLayoutInterface result,
|
||||
std::make_unique<milvus::proto::indexcgo::BuildIndexInfo>();
|
||||
auto res =
|
||||
build_index_info->ParseFromArray(serialized_build_index_info, len);
|
||||
AssertInfo(res, "Unmarshall build index info failed");
|
||||
AssertInfo(res, "Unmarshal build index info failed");
|
||||
|
||||
auto field_type =
|
||||
static_cast<DataType>(build_index_info->field_schema().data_type());
|
||||
@ -606,7 +606,7 @@ AppendBuildIndexParam(CBuildIndexInfo c_build_index_info,
|
||||
auto index_params =
|
||||
std::make_unique<milvus::proto::indexcgo::IndexParams>();
|
||||
auto res = index_params->ParseFromArray(serialized_index_params, len);
|
||||
AssertInfo(res, "Unmarshall index params failed");
|
||||
AssertInfo(res, "Unmarshal index params failed");
|
||||
for (auto i = 0; i < index_params->params_size(); ++i) {
|
||||
const auto& param = index_params->params(i);
|
||||
build_index_info->config[param.key()] = param.value();
|
||||
@ -633,7 +633,7 @@ AppendBuildTypeParam(CBuildIndexInfo c_build_index_info,
|
||||
auto type_params =
|
||||
std::make_unique<milvus::proto::indexcgo::TypeParams>();
|
||||
auto res = type_params->ParseFromArray(serialized_type_params, len);
|
||||
AssertInfo(res, "Unmarshall index build type params failed");
|
||||
AssertInfo(res, "Unmarshal index build type params failed");
|
||||
for (auto i = 0; i < type_params->params_size(); ++i) {
|
||||
const auto& param = type_params->params(i);
|
||||
build_index_info->config[param.key()] = param.value();
|
||||
|
||||
@ -30,7 +30,7 @@ ValidateIndexParams(const char* index_type,
|
||||
std::make_unique<milvus::proto::indexcgo::IndexParams>();
|
||||
auto res =
|
||||
index_params->ParseFromArray(serialized_index_params, length);
|
||||
AssertInfo(res, "Unmarshall index params failed");
|
||||
AssertInfo(res, "Unmarshal index params failed");
|
||||
|
||||
knowhere::Json json;
|
||||
|
||||
|
||||
@ -308,7 +308,7 @@ type DataSyncServiceSuite struct {
|
||||
channelCheckpointUpdater *util2.ChannelCheckpointUpdater
|
||||
factory *dependency.MockFactory
|
||||
ms *msgstream.MockMsgStream
|
||||
msChan chan *msgstream.MsgPack
|
||||
msChan chan *msgstream.ConsumeMsgPack
|
||||
}
|
||||
|
||||
func (s *DataSyncServiceSuite) SetupSuite() {
|
||||
@ -330,7 +330,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
|
||||
s.channelCheckpointUpdater = util2.NewChannelCheckpointUpdater(s.broker)
|
||||
|
||||
go s.channelCheckpointUpdater.Start()
|
||||
s.msChan = make(chan *msgstream.MsgPack, 1)
|
||||
s.msChan = make(chan *msgstream.ConsumeMsgPack, 1)
|
||||
|
||||
s.factory = dependency.NewMockFactory(s.T())
|
||||
s.ms = msgstream.NewMockMsgStream(s.T())
|
||||
@ -338,6 +338,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
|
||||
s.ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.ms.EXPECT().Chan().Return(s.msChan)
|
||||
s.ms.EXPECT().Close().Return()
|
||||
s.ms.EXPECT().GetUnmarshalDispatcher().Return(nil)
|
||||
|
||||
s.pipelineParams = &util2.PipelineParams{
|
||||
Ctx: context.TODO(),
|
||||
@ -487,8 +488,8 @@ func (s *DataSyncServiceSuite) TestStartStop() {
|
||||
close(ch)
|
||||
return nil
|
||||
})
|
||||
s.msChan <- &msgPack
|
||||
s.msChan <- &timeTickMsgPack
|
||||
s.msChan <- msgstream.BuildConsumeMsgPack(&msgPack)
|
||||
s.msChan <- msgstream.BuildConsumeMsgPack(&msgPack)
|
||||
<-ch
|
||||
}
|
||||
|
||||
|
||||
@ -67,8 +67,8 @@ func (mtm *mockTtMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
|
||||
|
||||
func (mtm *mockTtMsgStream) Close() {}
|
||||
|
||||
func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
|
||||
return make(chan *msgstream.MsgPack, 100)
|
||||
func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack {
|
||||
return make(chan *msgstream.ConsumeMsgPack, 100)
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}
|
||||
@ -77,6 +77,10 @@ func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
|
||||
|
||||
func (mtm *mockTtMsgStream) GetProduceChannels() []string {
|
||||
|
||||
@ -235,7 +235,7 @@ func newDefaultMockDqlTask() *mockDqlTask {
|
||||
}
|
||||
|
||||
type simpleMockMsgStream struct {
|
||||
msgChan chan *msgstream.MsgPack
|
||||
msgChan chan *msgstream.ConsumeMsgPack
|
||||
|
||||
msgCount int
|
||||
msgCountMtx sync.RWMutex
|
||||
@ -244,7 +244,7 @@ type simpleMockMsgStream struct {
|
||||
func (ms *simpleMockMsgStream) Close() {
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
|
||||
func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack {
|
||||
if ms.getMsgCount() <= 0 {
|
||||
ms.msgChan <- nil
|
||||
return ms.msgChan
|
||||
@ -255,6 +255,10 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
|
||||
return ms.msgChan
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
|
||||
}
|
||||
|
||||
@ -286,8 +290,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
|
||||
func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
|
||||
defer ms.increaseMsgCount(1)
|
||||
|
||||
ms.msgChan <- pack
|
||||
|
||||
ms.msgChan <- msgstream.BuildConsumeMsgPack(pack)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -319,7 +322,7 @@ func (ms *simpleMockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
|
||||
|
||||
func newSimpleMockMsgStream() *simpleMockMsgStream {
|
||||
return &simpleMockMsgStream{
|
||||
msgChan: make(chan *msgstream.MsgPack, 1024),
|
||||
msgChan: make(chan *msgstream.ConsumeMsgPack, 1024),
|
||||
msgCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -728,7 +728,15 @@ func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, positio
|
||||
if err != nil {
|
||||
return nil, stream.Close, err
|
||||
}
|
||||
return stream.Chan(), stream.Close, nil
|
||||
|
||||
dispatcher := msgstream.NewSimpleMsgDispatcher(stream, func(pm msgstream.ConsumeMsg) bool {
|
||||
if pm.GetType() != commonpb.MsgType_Delete || pm.GetVChannel() != vchannelName {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return dispatcher.Chan(), dispatcher.Close, nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) createDeleteStreamFromStreamingService(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) {
|
||||
|
||||
@ -207,6 +207,8 @@ func (s *DelegatorDataSuite) SetupTest() {
|
||||
// init schema
|
||||
s.genNormalCollection()
|
||||
s.mq = &msgstream.MockMsgStream{}
|
||||
s.mq.EXPECT().GetUnmarshalDispatcher().Return(nil)
|
||||
|
||||
s.rootPath = s.Suite.T().Name()
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
|
||||
@ -916,8 +918,9 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
|
||||
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().GetUnmarshalDispatcher().Return(nil)
|
||||
s.mq.EXPECT().Close()
|
||||
ch := make(chan *msgstream.MsgPack, 10)
|
||||
ch := make(chan *msgstream.ConsumeMsgPack, 10)
|
||||
close(ch)
|
||||
|
||||
s.mq.EXPECT().Chan().Return(ch)
|
||||
@ -1585,7 +1588,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
|
||||
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Close()
|
||||
ch := make(chan *msgstream.MsgPack, 10)
|
||||
ch := make(chan *msgstream.ConsumeMsgPack, 10)
|
||||
s.mq.EXPECT().Chan().Return(ch)
|
||||
|
||||
oracle := pkoracle.NewBloomFilterSet(1, 1, commonpb.SegmentState_Sealed)
|
||||
@ -1603,7 +1606,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
|
||||
}
|
||||
|
||||
for _, data := range datas {
|
||||
ch <- data
|
||||
ch <- msgstream.BuildConsumeMsgPack(data)
|
||||
}
|
||||
|
||||
result, err := s.delegator.readDeleteFromMsgstream(ctx, &msgpb.MsgPosition{Timestamp: 0}, 10, oracle)
|
||||
|
||||
@ -64,7 +64,7 @@ import (
|
||||
type ServiceSuite struct {
|
||||
suite.Suite
|
||||
// Data
|
||||
msgChan chan *msgstream.MsgPack
|
||||
msgChan chan *msgstream.ConsumeMsgPack
|
||||
collectionID int64
|
||||
collectionName string
|
||||
schema *schemapb.CollectionSchema
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
@ -250,7 +251,7 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
|
||||
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
|
||||
return nil
|
||||
}
|
||||
packChan := make(chan *msgstream.MsgPack, 10)
|
||||
packChan := make(chan *msgstream.ConsumeMsgPack, 10)
|
||||
ticker := newChanTimeTickSync(packChan)
|
||||
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
|
||||
|
||||
@ -268,13 +269,18 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
unmarshalFactory := &msgstream.ProtoUDFactory{}
|
||||
unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher()
|
||||
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case pack := <-packChan:
|
||||
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
|
||||
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
|
||||
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType())
|
||||
tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher)
|
||||
require.NoError(t, err)
|
||||
replicateMsg := tsMsg.(*msgstream.ReplicateMsg)
|
||||
assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase())
|
||||
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection())
|
||||
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
@ -236,7 +237,7 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
|
||||
mock.Anything,
|
||||
).Return(nil)
|
||||
// the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step
|
||||
packChan := make(chan *msgstream.MsgPack, 10)
|
||||
packChan := make(chan *msgstream.ConsumeMsgPack, 10)
|
||||
ticker := newChanTimeTickSync(packChan)
|
||||
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
|
||||
|
||||
@ -252,13 +253,19 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
unmarshalFactory := &msgstream.ProtoUDFactory{}
|
||||
unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher()
|
||||
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case pack := <-packChan:
|
||||
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
|
||||
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
|
||||
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType())
|
||||
|
||||
tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher)
|
||||
require.NoError(t, err)
|
||||
replicateMsg := tsMsg.(*msgstream.ReplicateMsg)
|
||||
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase())
|
||||
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
|
||||
default:
|
||||
|
||||
@ -278,7 +278,8 @@ type FailMsgStream struct {
|
||||
}
|
||||
|
||||
func (ms *FailMsgStream) Close() {}
|
||||
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
|
||||
func (ms *FailMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack { return nil }
|
||||
func (ms *FailMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher { return nil }
|
||||
func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {}
|
||||
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
|
||||
func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
|
||||
|
||||
@ -1058,23 +1058,23 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync {
|
||||
return ticker
|
||||
}
|
||||
|
||||
func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync {
|
||||
func newChanTimeTickSync(packChan chan *msgstream.ConsumeMsgPack) *timetickSync {
|
||||
f := msgstream.NewMockMqFactory()
|
||||
f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
stream := msgstream.NewWastedMockMsgStream()
|
||||
stream.BroadcastFunc = func(pack *msgstream.MsgPack) error {
|
||||
log.Info("mock Broadcast")
|
||||
packChan <- pack
|
||||
packChan <- msgstream.BuildConsumeMsgPack(pack)
|
||||
return nil
|
||||
}
|
||||
stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
log.Info("mock BroadcastMark")
|
||||
packChan <- pack
|
||||
packChan <- msgstream.BuildConsumeMsgPack(pack)
|
||||
return map[string][]msgstream.MessageID{}, nil
|
||||
}
|
||||
stream.AsProducerFunc = func(channels []string) {
|
||||
}
|
||||
stream.ChanFunc = func() <-chan *msgstream.MsgPack {
|
||||
stream.ChanFunc = func() <-chan *msgstream.ConsumeMsgPack {
|
||||
return packChan
|
||||
}
|
||||
return stream, nil
|
||||
|
||||
@ -28,6 +28,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
@ -45,8 +46,9 @@ func TestInputNode(t *testing.T) {
|
||||
produceStream.AsProducer(context.TODO(), channels)
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
|
||||
dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.ConsumeMsg) bool { return true })
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
defer inputNode.Close()
|
||||
|
||||
isInputNode := inputNode.IsInputNode()
|
||||
@ -89,7 +91,8 @@ func Test_InputNodeSkipMode(t *testing.T) {
|
||||
outputCh := make(chan bool)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, typeutil.DataNodeRole, 0, 0, "")
|
||||
dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.ConsumeMsg) bool { return true })
|
||||
inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, typeutil.DataNodeRole, 0, 0, "")
|
||||
defer inputNode.Close()
|
||||
|
||||
outputCount := 0
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
)
|
||||
|
||||
type MockMsg struct {
|
||||
*msgstream.BaseMsg
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
|
||||
@ -89,7 +89,8 @@ func TestNodeManager_Start(t *testing.T) {
|
||||
produceStream.Produce(context.TODO(), &msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.ConsumeMsg) bool { return true })
|
||||
inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, "", 0, 0, "")
|
||||
|
||||
ddNode := BaseNode{}
|
||||
|
||||
|
||||
@ -74,7 +74,14 @@ const (
|
||||
SubscriptionPositionUnknown
|
||||
)
|
||||
|
||||
const MsgTypeKey = "msg_type"
|
||||
const (
|
||||
MsgTypeKey = "msg_type"
|
||||
MsgIdTypeKey = "msg_id"
|
||||
TimestampTypeKey = "timestamp"
|
||||
ChannelTypeKey = "vchannel"
|
||||
CollectionIDTypeKey = "collection_id"
|
||||
ReplicateIDTypeKey = "replicate_id"
|
||||
)
|
||||
|
||||
func GetMsgType(msg Message) (commonpb.MsgType, error) {
|
||||
msgType := commonpb.MsgType_Undefined
|
||||
|
||||
@ -19,7 +19,6 @@ package msgdispatcher
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -229,7 +228,7 @@ func (d *Dispatcher) work() {
|
||||
}
|
||||
d.curTs.Store(pack.EndPositions[0].GetTimestamp())
|
||||
|
||||
targetPacks := d.groupingMsgs(pack)
|
||||
targetPacks := d.groupAndParseMsgs(pack, d.stream.GetUnmarshalDispatcher())
|
||||
for vchannel, p := range targetPacks {
|
||||
var err error
|
||||
t := d.targets[vchannel]
|
||||
@ -260,7 +259,7 @@ func (d *Dispatcher) work() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
||||
func (d *Dispatcher) groupAndParseMsgs(pack *msgstream.ConsumeMsgPack, unmarshalDispatcher msgstream.UnmarshalDispatcher) map[string]*MsgPack {
|
||||
// init packs for all targets, even though there's no msg in pack,
|
||||
// but we still need to dispatch time ticks to the targets.
|
||||
targetPacks := make(map[string]*MsgPack)
|
||||
@ -280,27 +279,24 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
||||
// group messages by vchannel
|
||||
for _, msg := range pack.Msgs {
|
||||
var vchannel, collectionID string
|
||||
switch msg.Type() {
|
||||
case commonpb.MsgType_Insert:
|
||||
vchannel = msg.(*msgstream.InsertMsg).GetShardName()
|
||||
case commonpb.MsgType_Delete:
|
||||
vchannel = msg.(*msgstream.DeleteMsg).GetShardName()
|
||||
case commonpb.MsgType_CreateCollection:
|
||||
collectionID = strconv.FormatInt(msg.(*msgstream.CreateCollectionMsg).GetCollectionID(), 10)
|
||||
case commonpb.MsgType_DropCollection:
|
||||
collectionID = strconv.FormatInt(msg.(*msgstream.DropCollectionMsg).GetCollectionID(), 10)
|
||||
case commonpb.MsgType_CreatePartition:
|
||||
collectionID = strconv.FormatInt(msg.(*msgstream.CreatePartitionMsg).GetCollectionID(), 10)
|
||||
case commonpb.MsgType_DropPartition:
|
||||
collectionID = strconv.FormatInt(msg.(*msgstream.DropPartitionMsg).GetCollectionID(), 10)
|
||||
|
||||
if msg.GetType() == commonpb.MsgType_Insert || msg.GetType() == commonpb.MsgType_Delete {
|
||||
vchannel = msg.GetVChannel()
|
||||
} else if msg.GetType() == commonpb.MsgType_CreateCollection ||
|
||||
msg.GetType() == commonpb.MsgType_DropCollection ||
|
||||
msg.GetType() == commonpb.MsgType_CreatePartition ||
|
||||
msg.GetType() == commonpb.MsgType_DropPartition {
|
||||
collectionID = msg.GetCollectionID()
|
||||
}
|
||||
|
||||
if vchannel == "" {
|
||||
// we need to dispatch it to the vchannel of this collection
|
||||
targets := []string{}
|
||||
for k := range targetPacks {
|
||||
if msg.Type() == commonpb.MsgType_Replicate {
|
||||
if msg.GetType() == commonpb.MsgType_Replicate {
|
||||
config := replicateConfigs[k]
|
||||
if config != nil && msgstream.MatchReplicateID(msg, config.ReplicateID) {
|
||||
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
|
||||
if config != nil && msg.GetReplicateID() == config.ReplicateID {
|
||||
targets = append(targets, k)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -308,14 +304,29 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
||||
if !strings.Contains(k, collectionID) {
|
||||
continue
|
||||
}
|
||||
targets = append(targets, k)
|
||||
}
|
||||
if len(targets) > 0 {
|
||||
tsMsg, err := msg.Unmarshal(unmarshalDispatcher)
|
||||
if err != nil {
|
||||
log.Warn("unmarshl message failed", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
// TODO: There's data race when non-dml msg is sent to different flow graph.
|
||||
// Wrong open-trancing information is generated, Fix in future.
|
||||
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
|
||||
for _, target := range targets {
|
||||
targetPacks[target].Msgs = append(targetPacks[target].Msgs, tsMsg)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if _, ok := targetPacks[vchannel]; ok {
|
||||
targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, msg)
|
||||
tsMsg, err := msg.Unmarshal(unmarshalDispatcher)
|
||||
if err != nil {
|
||||
log.Warn("unmarshl message failed", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, tsMsg)
|
||||
}
|
||||
}
|
||||
replicateEndChannels := make(map[string]struct{})
|
||||
|
||||
@ -150,7 +150,7 @@ func TestGroupMessage(t *testing.T) {
|
||||
d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo")))
|
||||
{
|
||||
// no replicate msg
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
@ -182,13 +182,13 @@ func TestGroupMessage(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}), nil)
|
||||
assert.Len(t, packs, 1)
|
||||
}
|
||||
|
||||
{
|
||||
// equal to replicateID
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
@ -222,7 +222,7 @@ func TestGroupMessage(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}), nil)
|
||||
assert.Len(t, packs, 2)
|
||||
{
|
||||
replicatePack := packs["mock_pchannel_0_2v0"]
|
||||
@ -244,7 +244,7 @@ func TestGroupMessage(t *testing.T) {
|
||||
|
||||
{
|
||||
// not equal to replicateID
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
@ -278,7 +278,7 @@ func TestGroupMessage(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}), nil)
|
||||
assert.Len(t, packs, 1)
|
||||
replicatePack := packs["mock_pchannel_0_2v0"]
|
||||
assert.Nil(t, replicatePack)
|
||||
@ -288,7 +288,7 @@ func TestGroupMessage(t *testing.T) {
|
||||
// replicate end
|
||||
replicateTarget := d.targets["mock_pchannel_0_2v0"]
|
||||
assert.NotNil(t, replicateTarget.replicateConfig)
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
@ -324,7 +324,7 @@ func TestGroupMessage(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}), nil)
|
||||
assert.Len(t, packs, 2)
|
||||
replicatePack := packs["mock_pchannel_0_2v0"]
|
||||
assert.EqualValues(t, 100, replicatePack.BeginTs)
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
@ -372,7 +373,7 @@ func (suite *SimulationSuite) TestMerge() {
|
||||
vchannel, positions[rand.Intn(len(positions))],
|
||||
common.SubscriptionPositionUnknown,
|
||||
)) // seek from random position
|
||||
assert.NoError(suite.T(), err)
|
||||
require.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@ -266,7 +266,7 @@ func testSeekToLast(t *testing.T, f []Factory) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consume(ctx, consumer)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
if i == 5 {
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
@ -295,11 +295,11 @@ func testSeekToLast(t *testing.T, f []Factory) {
|
||||
|
||||
assert.Equal(t, 1, len(msgPack.Msgs))
|
||||
for _, tsMsg := range msgPack.Msgs {
|
||||
assert.Equal(t, value, tsMsg.ID())
|
||||
assert.Equal(t, value, tsMsg.GetID())
|
||||
value++
|
||||
cnt++
|
||||
|
||||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
|
||||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID)
|
||||
assert.NoError(t, err)
|
||||
if ret {
|
||||
hasMore = false
|
||||
@ -398,13 +398,17 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
|
||||
assert.Equal(t, len(seekMsg.Msgs), 3)
|
||||
result := []uint64{14, 12, 13}
|
||||
for i, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), result[i])
|
||||
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), result[i])
|
||||
}
|
||||
|
||||
seekMsg2 := consume(ctx, consumer)
|
||||
assert.Equal(t, len(seekMsg2.Msgs), 1)
|
||||
for _, msg := range seekMsg2.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
consumer.Close()
|
||||
|
||||
@ -412,7 +416,9 @@ func testTimeTickerSeek(t *testing.T, f []Factory) {
|
||||
seekMsg = consume(ctx, consumer)
|
||||
assert.Equal(t, len(seekMsg.Msgs), 1)
|
||||
for _, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
consumer.Close()
|
||||
}
|
||||
@ -473,9 +479,11 @@ func testTimeTickerStream1(t *testing.T, f []Factory) {
|
||||
rcvMsg += len(msgPack.Msgs)
|
||||
if len(msgPack.Msgs) > 0 {
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
|
||||
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
@ -525,7 +533,7 @@ func testTimeTickerStream2(t *testing.T, f []Factory) {
|
||||
|
||||
// consume msg
|
||||
log.Println("=============receive msg===================")
|
||||
rcvMsgPacks := make([]*MsgPack, 0)
|
||||
rcvMsgPacks := make([]*ConsumeMsgPack, 0)
|
||||
|
||||
resumeMsgPack := func(t *testing.T) int {
|
||||
var consumer MsgStream
|
||||
@ -539,9 +547,11 @@ func testTimeTickerStream2(t *testing.T, f []Factory) {
|
||||
rcvMsgPacks = append(rcvMsgPacks, msgPack)
|
||||
if len(msgPack.Msgs) > 0 {
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
|
||||
tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
@ -576,7 +586,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consume(ctx, consumer)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
if i == 5 {
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
@ -586,7 +596,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) {
|
||||
consumer = createAndSeekConsumer(ctx, t, f[0].NewMsgStream, channels, []*msgpb.MsgPosition{seekPosition})
|
||||
for i := 6; i < 10; i++ {
|
||||
result := consume(ctx, consumer)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
}
|
||||
consumer.Close()
|
||||
}
|
||||
@ -610,7 +620,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consume(ctx, consumer)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
|
||||
@ -625,7 +635,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen
|
||||
err = producer.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
result := consume(ctx, consumer2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(1))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(1))
|
||||
}
|
||||
|
||||
func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
|
||||
@ -658,7 +668,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) {
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
result := consume(ctx, consumer2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
@ -748,7 +758,7 @@ func applyProduceAndConsume(
|
||||
receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs))
|
||||
}
|
||||
|
||||
func consume(ctx context.Context, mq MsgStream) *MsgPack {
|
||||
func consume(ctx context.Context, mq MsgStream) *ConsumeMsgPack {
|
||||
for {
|
||||
select {
|
||||
case msgPack, ok := <-mq.Chan():
|
||||
@ -829,7 +839,7 @@ func receiveAndValidateMsg(ctx context.Context, outputStream MsgStream, msgCount
|
||||
msgs := result.Msgs
|
||||
for _, v := range msgs {
|
||||
receiveCount++
|
||||
log.Println("msg type: ", v.Type(), ", msg value: ", v)
|
||||
log.Println("msg type: ", v.GetType(), ", msg value: ", v)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
|
||||
@ -168,19 +168,19 @@ func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, *
|
||||
}
|
||||
|
||||
// Chan provides a mock function with given fields:
|
||||
func (_m *MockMsgStream) Chan() <-chan *MsgPack {
|
||||
func (_m *MockMsgStream) Chan() <-chan *ConsumeMsgPack {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Chan")
|
||||
}
|
||||
|
||||
var r0 <-chan *MsgPack
|
||||
if rf, ok := ret.Get(0).(func() <-chan *MsgPack); ok {
|
||||
var r0 <-chan *ConsumeMsgPack
|
||||
if rf, ok := ret.Get(0).(func() <-chan *ConsumeMsgPack); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan *MsgPack)
|
||||
r0 = ret.Get(0).(<-chan *ConsumeMsgPack)
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,12 +204,12 @@ func (_c *MockMsgStream_Chan_Call) Run(run func()) *MockMsgStream_Chan_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *MsgPack) *MockMsgStream_Chan_Call {
|
||||
func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *ConsumeMsgPack) *MockMsgStream_Chan_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *MsgPack) *MockMsgStream_Chan_Call {
|
||||
func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *ConsumeMsgPack) *MockMsgStream_Chan_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@ -430,6 +430,53 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetUnmarshalDispatcher provides a mock function with given fields:
|
||||
func (_m *MockMsgStream) GetUnmarshalDispatcher() UnmarshalDispatcher {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetUnmarshalDispatcher")
|
||||
}
|
||||
|
||||
var r0 UnmarshalDispatcher
|
||||
if rf, ok := ret.Get(0).(func() UnmarshalDispatcher); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(UnmarshalDispatcher)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMsgStream_GetUnmarshalDispatcher_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUnmarshalDispatcher'
|
||||
type MockMsgStream_GetUnmarshalDispatcher_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetUnmarshalDispatcher is a helper method to define mock.On call
|
||||
func (_e *MockMsgStream_Expecter) GetUnmarshalDispatcher() *MockMsgStream_GetUnmarshalDispatcher_Call {
|
||||
return &MockMsgStream_GetUnmarshalDispatcher_Call{Call: _e.mock.On("GetUnmarshalDispatcher")}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) Run(run func()) *MockMsgStream_GetUnmarshalDispatcher_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) Return(_a0 UnmarshalDispatcher) *MockMsgStream_GetUnmarshalDispatcher_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) RunAndReturn(run func() UnmarshalDispatcher) *MockMsgStream_GetUnmarshalDispatcher_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Produce provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/confluentinc/confluent-kafka-go/kafka"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
@ -131,7 +132,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
|
||||
outputStream := getKafkaOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest)
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
if i == 5 {
|
||||
seekPosition = result.EndPositions[0]
|
||||
break
|
||||
@ -162,11 +163,11 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 1, len(msgPack.Msgs))
|
||||
for _, tsMsg := range msgPack.Msgs {
|
||||
assert.Equal(t, value, tsMsg.ID())
|
||||
assert.Equal(t, value, tsMsg.GetID())
|
||||
value++
|
||||
cnt++
|
||||
|
||||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
|
||||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID)
|
||||
assert.NoError(t, err)
|
||||
if ret {
|
||||
hasMore = false
|
||||
@ -272,20 +273,26 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) {
|
||||
assert.Equal(t, len(seekMsg.Msgs), 3)
|
||||
result := []uint64{14, 12, 13}
|
||||
for i, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), result[i])
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), result[i])
|
||||
}
|
||||
|
||||
seekMsg2 := consumer(ctx, outputStream)
|
||||
assert.Equal(t, len(seekMsg2.Msgs), 1)
|
||||
for _, msg := range seekMsg2.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
|
||||
outputStream2 := getKafkaTtOutputStreamAndSeek(ctx, kafkaAddress, receivedMsg3.EndPositions)
|
||||
seekMsg = consumer(ctx, outputStream2)
|
||||
assert.Equal(t, len(seekMsg.Msgs), 1)
|
||||
for _, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
|
||||
inputStream.Close()
|
||||
@ -320,9 +327,11 @@ func TestStream_KafkaTtMsgStream_1(t *testing.T) {
|
||||
rcvMsg += len(msgPack.Msgs)
|
||||
if len(msgPack.Msgs) > 0 {
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -361,7 +370,7 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) {
|
||||
|
||||
// consume msg
|
||||
log.Println("=============receive msg===================")
|
||||
rcvMsgPacks := make([]*MsgPack, 0)
|
||||
rcvMsgPacks := make([]*ConsumeMsgPack, 0)
|
||||
|
||||
resumeMsgPack := func(t *testing.T) int {
|
||||
var outputStream MsgStream
|
||||
@ -376,9 +385,11 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) {
|
||||
rcvMsgPacks = append(rcvMsgPacks, msgPack)
|
||||
if len(msgPack.Msgs) > 0 {
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("TestStream_KafkaTtMsgStream_2 msg type: ", msg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
log.Println("TestStream_KafkaTtMsgStream_2 msg type: ", tsMsg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ type mqMsgStream struct {
|
||||
|
||||
repackFunc RepackFunc
|
||||
unmarshal UnmarshalDispatcher
|
||||
receiveBuf chan *MsgPack
|
||||
receiveBuf chan *ConsumeMsgPack
|
||||
closeRWMutex *sync.RWMutex
|
||||
streamCancel func()
|
||||
bufSize int64
|
||||
@ -89,7 +89,7 @@ func NewMqMsgStream(initCtx context.Context,
|
||||
consumers := make(map[string]mqwrapper.Consumer)
|
||||
producerChannels := make([]string, 0)
|
||||
consumerChannels := make([]string, 0)
|
||||
receiveBuf := make(chan *MsgPack, receiveBufSize)
|
||||
receiveBuf := make(chan *ConsumeMsgPack, receiveBufSize)
|
||||
|
||||
stream := &mqMsgStream{
|
||||
ctx: streamCtx,
|
||||
@ -355,9 +355,7 @@ func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{
|
||||
common.MsgTypeKey: v.Msgs[i].Type().String(),
|
||||
}}
|
||||
msg := &common.ProducerMessage{Payload: m, Properties: GetPorperties(v.Msgs[i])}
|
||||
InjectCtx(spanCtx, msg.Properties)
|
||||
|
||||
if _, err := producer.Send(spanCtx, msg); err != nil {
|
||||
@ -399,7 +397,7 @@ func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[str
|
||||
return ids, err
|
||||
}
|
||||
|
||||
msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{}}
|
||||
msg := &common.ProducerMessage{Payload: m, Properties: GetPorperties(v)}
|
||||
InjectCtx(spanCtx, msg.Properties)
|
||||
|
||||
ms.producerLock.RLock()
|
||||
@ -421,10 +419,6 @@ func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[str
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg common.Message) (TsMsg, error) {
|
||||
return GetTsMsgFromConsumerMsg(ms.unmarshal, msg)
|
||||
}
|
||||
|
||||
// GetTsMsgFromConsumerMsg get TsMsg from consumer message
|
||||
func GetTsMsgFromConsumerMsg(unmarshalDispatcher UnmarshalDispatcher, msg common.Message) (TsMsg, error) {
|
||||
msgType, err := common.GetMsgType(msg)
|
||||
@ -464,31 +458,35 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) {
|
||||
log.Ctx(ms.ctx).Warn("MqMsgStream get msg whose payload is nil")
|
||||
continue
|
||||
}
|
||||
// not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295
|
||||
// if the message not belong to the topic, will skip it
|
||||
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
|
||||
|
||||
var err error
|
||||
var packMsg ConsumeMsg
|
||||
|
||||
packMsg, err = NewMarshaledMsg(msg, consumer.Subscription())
|
||||
if err != nil {
|
||||
packMsg, err = UnmarshalMsg(msg, ms.unmarshal)
|
||||
if err != nil {
|
||||
log.Ctx(ms.ctx).Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
pos := tsMsg.Position()
|
||||
tsMsg.SetPosition(&MsgPosition{
|
||||
ChannelName: pos.ChannelName,
|
||||
MsgID: pos.MsgID,
|
||||
MsgGroup: consumer.Subscription(),
|
||||
Timestamp: tsMsg.BeginTs(),
|
||||
})
|
||||
|
||||
ctx, _ := ExtractCtx(tsMsg, msg.Properties())
|
||||
tsMsg.SetTraceCtx(ctx)
|
||||
|
||||
msgPack := MsgPack{
|
||||
Msgs: []TsMsg{tsMsg},
|
||||
StartPositions: []*msgpb.MsgPosition{tsMsg.Position()},
|
||||
EndPositions: []*msgpb.MsgPosition{tsMsg.Position()},
|
||||
BeginTs: tsMsg.BeginTs(),
|
||||
EndTs: tsMsg.EndTs(),
|
||||
}
|
||||
|
||||
pos := &msgpb.MsgPosition{
|
||||
ChannelName: filepath.Base(msg.Topic()),
|
||||
MsgID: packMsg.GetMessageID(),
|
||||
MsgGroup: consumer.Subscription(),
|
||||
Timestamp: packMsg.GetTimestamp(),
|
||||
}
|
||||
|
||||
packMsg.SetPosition(pos)
|
||||
msgPack := ConsumeMsgPack{
|
||||
Msgs: []ConsumeMsg{packMsg},
|
||||
StartPositions: []*msgpb.MsgPosition{pos},
|
||||
EndPositions: []*msgpb.MsgPosition{pos},
|
||||
BeginTs: packMsg.GetTimestamp(),
|
||||
EndTs: packMsg.GetTimestamp(),
|
||||
}
|
||||
|
||||
select {
|
||||
case ms.receiveBuf <- &msgPack:
|
||||
case <-ms.ctx.Done():
|
||||
@ -498,7 +496,11 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) Chan() <-chan *MsgPack {
|
||||
func (ms *mqMsgStream) GetUnmarshalDispatcher() UnmarshalDispatcher {
|
||||
return ms.unmarshal
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) Chan() <-chan *ConsumeMsgPack {
|
||||
ms.onceChan.Do(func() {
|
||||
for _, c := range ms.consumers {
|
||||
go ms.receiveMsg(c)
|
||||
@ -546,7 +548,7 @@ var _ MsgStream = (*MqTtMsgStream)(nil)
|
||||
// MqTtMsgStream is a msgstream that contains timeticks
|
||||
type MqTtMsgStream struct {
|
||||
*mqMsgStream
|
||||
chanMsgBuf map[mqwrapper.Consumer][]TsMsg
|
||||
chanMsgBuf map[mqwrapper.Consumer][]ConsumeMsg
|
||||
chanMsgPos map[mqwrapper.Consumer]*msgpb.MsgPosition
|
||||
chanStopChan map[mqwrapper.Consumer]chan bool
|
||||
chanTtMsgTime map[mqwrapper.Consumer]Timestamp
|
||||
@ -568,7 +570,7 @@ func NewMqTtMsgStream(ctx context.Context,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chanMsgBuf := make(map[mqwrapper.Consumer][]TsMsg)
|
||||
chanMsgBuf := make(map[mqwrapper.Consumer][]ConsumeMsg)
|
||||
chanMsgPos := make(map[mqwrapper.Consumer]*msgpb.MsgPosition)
|
||||
chanStopChan := make(map[mqwrapper.Consumer]chan bool)
|
||||
chanTtMsgTime := make(map[mqwrapper.Consumer]Timestamp)
|
||||
@ -593,7 +595,7 @@ func (ms *MqTtMsgStream) addConsumer(consumer mqwrapper.Consumer, channel string
|
||||
}
|
||||
ms.consumers[channel] = consumer
|
||||
ms.consumerChannels = append(ms.consumerChannels, channel)
|
||||
ms.chanMsgBuf[consumer] = make([]TsMsg, 0)
|
||||
ms.chanMsgBuf[consumer] = make([]ConsumeMsg, 0)
|
||||
ms.chanMsgPos[consumer] = &msgpb.MsgPosition{
|
||||
ChannelName: channel,
|
||||
MsgID: make([]byte, 0),
|
||||
@ -649,8 +651,8 @@ func (ms *MqTtMsgStream) Close() {
|
||||
ms.mqMsgStream.Close()
|
||||
}
|
||||
|
||||
func isDMLMsg(msg TsMsg) bool {
|
||||
return msg.Type() == commonpb.MsgType_Insert || msg.Type() == commonpb.MsgType_Delete
|
||||
func isDMLMsg(msg ConsumeMsg) bool {
|
||||
return msg.GetType() == commonpb.MsgType_Insert || msg.GetType() == commonpb.MsgType_Delete
|
||||
}
|
||||
|
||||
func (ms *MqTtMsgStream) continueBuffering(endTs, size uint64, startTime time.Time) bool {
|
||||
@ -700,7 +702,7 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
||||
case <-ms.ctx.Done():
|
||||
return
|
||||
default:
|
||||
timeTickBuf := make([]TsMsg, 0)
|
||||
timeTickBuf := make([]ConsumeMsg, 0)
|
||||
// startMsgPosition := make([]*msgpb.MsgPosition, 0)
|
||||
// endMsgPositions := make([]*msgpb.MsgPosition, 0)
|
||||
startPositions := make(map[string]*msgpb.MsgPosition)
|
||||
@ -739,22 +741,22 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
||||
if _, ok := startPositions[channelName]; !ok {
|
||||
startPositions[channelName] = startPos
|
||||
}
|
||||
tempBuffer := make([]TsMsg, 0)
|
||||
var timeTickMsg TsMsg
|
||||
tempBuffer := make([]ConsumeMsg, 0)
|
||||
var timeTickMsg ConsumeMsg
|
||||
for _, v := range msgs {
|
||||
if v.Type() == commonpb.MsgType_TimeTick {
|
||||
if v.GetType() == commonpb.MsgType_TimeTick {
|
||||
timeTickMsg = v
|
||||
continue
|
||||
}
|
||||
if v.EndTs() <= currTs ||
|
||||
GetReplicateID(v) != "" {
|
||||
size += uint64(v.Size())
|
||||
if v.GetTimestamp() <= currTs ||
|
||||
v.GetReplicateID() != "" {
|
||||
size += uint64(v.GetSize())
|
||||
timeTickBuf = append(timeTickBuf, v)
|
||||
} else {
|
||||
tempBuffer = append(tempBuffer, v)
|
||||
}
|
||||
// when drop collection, force to exit the buffer loop
|
||||
if v.Type() == commonpb.MsgType_DropCollection || v.Type() == commonpb.MsgType_Replicate {
|
||||
if v.GetType() == commonpb.MsgType_DropCollection || v.GetType() == commonpb.MsgType_Replicate {
|
||||
containsEndBufferMsg = true
|
||||
}
|
||||
}
|
||||
@ -765,8 +767,8 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
||||
if len(tempBuffer) > 0 {
|
||||
// if tempBuffer is not empty, use tempBuffer[0] to seek
|
||||
newPos = &msgpb.MsgPosition{
|
||||
ChannelName: tempBuffer[0].Position().ChannelName,
|
||||
MsgID: tempBuffer[0].Position().MsgID,
|
||||
ChannelName: tempBuffer[0].GetPChannel(),
|
||||
MsgID: tempBuffer[0].GetMessageID(),
|
||||
Timestamp: currTs,
|
||||
MsgGroup: consumer.Subscription(),
|
||||
}
|
||||
@ -774,8 +776,8 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
||||
} else if timeTickMsg != nil {
|
||||
// if tempBuffer is empty, use timeTickMsg to seek
|
||||
newPos = &msgpb.MsgPosition{
|
||||
ChannelName: timeTickMsg.Position().ChannelName,
|
||||
MsgID: timeTickMsg.Position().MsgID,
|
||||
ChannelName: timeTickMsg.GetPChannel(),
|
||||
MsgID: timeTickMsg.GetMessageID(),
|
||||
Timestamp: currTs,
|
||||
MsgGroup: consumer.Subscription(),
|
||||
}
|
||||
@ -787,20 +789,20 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
||||
ms.consumerLock.Unlock()
|
||||
}
|
||||
|
||||
idset := make(typeutil.UniqueSet)
|
||||
uniqueMsgs := make([]TsMsg, 0, len(timeTickBuf))
|
||||
idset := make(typeutil.Set[int64])
|
||||
uniqueMsgs := make([]ConsumeMsg, 0, len(timeTickBuf))
|
||||
for _, msg := range timeTickBuf {
|
||||
if isDMLMsg(msg) && idset.Contain(msg.ID()) {
|
||||
log.Ctx(ms.ctx).Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.ID()))
|
||||
if isDMLMsg(msg) && idset.Contain(msg.GetID()) {
|
||||
log.Ctx(ms.ctx).Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.GetID()))
|
||||
continue
|
||||
}
|
||||
idset.Insert(msg.ID())
|
||||
idset.Insert(msg.GetID())
|
||||
uniqueMsgs = append(uniqueMsgs, msg)
|
||||
}
|
||||
|
||||
// skip endTs = 0 (no run for ctx error)
|
||||
if endTs > 0 {
|
||||
msgPack := MsgPack{
|
||||
msgPack := ConsumeMsgPack{
|
||||
BeginTs: ms.lastTimeStamp,
|
||||
EndTs: endTs,
|
||||
Msgs: uniqueMsgs,
|
||||
@ -840,21 +842,26 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqwrapper.Consumer) {
|
||||
log.Warn("MqTtMsgStream get msg whose payload is nil")
|
||||
continue
|
||||
}
|
||||
// not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295
|
||||
// if the message not belong to the topic, will skip it
|
||||
tsMsg, err := ms.getTsMsgFromConsumerMsg(msg)
|
||||
|
||||
var err error
|
||||
var packMsg ConsumeMsg
|
||||
|
||||
packMsg, err = NewMarshaledMsg(msg, consumer.Subscription())
|
||||
if err != nil {
|
||||
packMsg, err = UnmarshalMsg(msg, ms.unmarshal)
|
||||
if err != nil {
|
||||
log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
ms.chanMsgBufMutex.Lock()
|
||||
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg)
|
||||
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], packMsg)
|
||||
ms.chanMsgBufMutex.Unlock()
|
||||
|
||||
if tsMsg.Type() == commonpb.MsgType_TimeTick {
|
||||
if packMsg.GetType() == commonpb.MsgType_TimeTick {
|
||||
ms.chanTtMsgTimeMutex.Lock()
|
||||
ms.chanTtMsgTime[consumer] = tsMsg.(*TimeTickMsg).Base.Timestamp
|
||||
ms.chanTtMsgTime[consumer] = packMsg.GetTimestamp()
|
||||
ms.chanTtMsgTimeMutex.Unlock()
|
||||
return
|
||||
}
|
||||
@ -972,20 +979,23 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
|
||||
loopMsgCnt++
|
||||
consumer.Ack(msg)
|
||||
|
||||
headerMsg := commonpb.MsgHeader{}
|
||||
err := proto.Unmarshal(msg.Payload(), &headerMsg)
|
||||
var err error
|
||||
var packMsg ConsumeMsg
|
||||
|
||||
packMsg, err = NewMarshaledMsg(msg, consumer.Subscription())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal message header, err %s", err.Error())
|
||||
}
|
||||
tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType)
|
||||
packMsg, err = UnmarshalMsg(msg, ms.unmarshal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error())
|
||||
}
|
||||
// skip the replicate msg because it must have been consumed
|
||||
if GetReplicateID(tsMsg) != "" {
|
||||
log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp {
|
||||
}
|
||||
|
||||
// skip the replicate msg because it must have been consumed
|
||||
if packMsg.GetReplicateID() != "" {
|
||||
continue
|
||||
}
|
||||
if packMsg.GetType() == commonpb.MsgType_TimeTick && packMsg.GetTimestamp() >= mp.Timestamp {
|
||||
runLoop = false
|
||||
if time.Since(loopStarTime) > 30*time.Second {
|
||||
log.Info("seek loop finished long time",
|
||||
@ -993,21 +1003,21 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
|
||||
zap.String("channel", mp.ChannelName),
|
||||
zap.Duration("cost", time.Since(loopStarTime)))
|
||||
}
|
||||
} else if tsMsg.BeginTs() > mp.Timestamp {
|
||||
ctx, _ := ExtractCtx(tsMsg, msg.Properties())
|
||||
tsMsg.SetTraceCtx(ctx)
|
||||
} else if packMsg.GetTimestamp() > mp.Timestamp {
|
||||
ctx, _ := ExtractCtx(packMsg, msg.Properties())
|
||||
packMsg.SetTraceCtx(ctx)
|
||||
|
||||
tsMsg.SetPosition(&MsgPosition{
|
||||
packMsg.SetPosition(&MsgPosition{
|
||||
ChannelName: filepath.Base(msg.Topic()),
|
||||
MsgID: msg.ID().Serialize(),
|
||||
})
|
||||
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg)
|
||||
ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], packMsg)
|
||||
} else {
|
||||
log.Info("skip msg",
|
||||
zap.Int64("source", tsMsg.SourceID()),
|
||||
zap.String("type", tsMsg.Type().String()),
|
||||
zap.Int("size", tsMsg.Size()),
|
||||
zap.Uint64("msgTs", tsMsg.BeginTs()),
|
||||
// zap.Int64("source", tsMsg.SourceID()), // TODO SOURCE ID ?
|
||||
zap.String("type", packMsg.GetType().String()),
|
||||
zap.Int("size", packMsg.GetSize()),
|
||||
zap.Uint64("msgTs", packMsg.GetTimestamp()),
|
||||
zap.Uint64("posTs", mp.GetTimestamp()),
|
||||
)
|
||||
}
|
||||
@ -1017,7 +1027,7 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *MqTtMsgStream) Chan() <-chan *MsgPack {
|
||||
func (ms *MqTtMsgStream) Chan() <-chan *ConsumeMsgPack {
|
||||
ms.onceChan.Do(func() {
|
||||
if ms.consumers != nil {
|
||||
go ms.bufMsgPackToChannel()
|
||||
|
||||
@ -85,7 +85,7 @@ func getKafkaBrokerList() string {
|
||||
return brokerList
|
||||
}
|
||||
|
||||
func consumer(ctx context.Context, mq MsgStream) *MsgPack {
|
||||
func consumer(ctx context.Context, mq MsgStream) *ConsumeMsgPack {
|
||||
for {
|
||||
select {
|
||||
case msgPack, ok := <-mq.Chan():
|
||||
@ -506,7 +506,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
if i == 5 {
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
@ -539,11 +539,11 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 1, len(msgPack.Msgs))
|
||||
for _, tsMsg := range msgPack.Msgs {
|
||||
assert.Equal(t, value, tsMsg.ID())
|
||||
assert.Equal(t, value, tsMsg.GetID())
|
||||
value++
|
||||
cnt++
|
||||
|
||||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
|
||||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID)
|
||||
assert.NoError(t, err)
|
||||
if ret {
|
||||
hasMore = false
|
||||
@ -674,20 +674,26 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
|
||||
assert.Equal(t, len(seekMsg.Msgs), 3)
|
||||
result := []uint64{14, 12, 13}
|
||||
for i, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), result[i])
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), result[i])
|
||||
}
|
||||
|
||||
seekMsg2 := consumer(ctx, outputStream)
|
||||
assert.Equal(t, len(seekMsg2.Msgs), 1)
|
||||
for _, msg := range seekMsg2.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
|
||||
outputStream2 := getPulsarTtOutputStreamAndSeek(ctx, pulsarAddress, receivedMsg3.EndPositions)
|
||||
seekMsg = consumer(ctx, outputStream2)
|
||||
assert.Equal(t, len(seekMsg.Msgs), 1)
|
||||
for _, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
|
||||
inputStream.Close()
|
||||
@ -882,9 +888,11 @@ func TestStream_PulsarTtMsgStream_1(t *testing.T) {
|
||||
rcvMsg += len(msgPack.Msgs)
|
||||
if len(msgPack.Msgs) > 0 {
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
@ -940,7 +948,7 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) {
|
||||
|
||||
// consume msg
|
||||
log.Println("=============receive msg===================")
|
||||
rcvMsgPacks := make([]*MsgPack, 0)
|
||||
rcvMsgPacks := make([]*ConsumeMsgPack, 0)
|
||||
|
||||
resumeMsgPack := func(t *testing.T) int {
|
||||
var outputStream MsgStream
|
||||
@ -954,9 +962,11 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) {
|
||||
rcvMsgPacks = append(rcvMsgPacks, msgPack)
|
||||
if len(msgPack.Msgs) > 0 {
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("msg type: ", msg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, msg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs)
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg)
|
||||
assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs)
|
||||
assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
@ -998,7 +1008,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
if i == 5 {
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
@ -1013,7 +1023,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
|
||||
|
||||
for i := 6; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
}
|
||||
outputStream2.Close()
|
||||
}
|
||||
@ -1042,7 +1052,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
|
||||
@ -1074,7 +1084,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
err = inputStream.Produce(ctx, msgPack)
|
||||
assert.NoError(t, err)
|
||||
result := consumer(ctx, outputStream2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(1))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(1))
|
||||
}
|
||||
|
||||
func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
|
||||
@ -1101,7 +1111,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
|
||||
@ -1179,7 +1189,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
result := consumer(ctx, outputStream2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
}
|
||||
|
||||
inputStream.Close()
|
||||
@ -1570,7 +1580,7 @@ func receiveMsg(ctx context.Context, outputStream MsgStream, msgCount int) {
|
||||
msgs := result.Msgs
|
||||
for _, v := range msgs {
|
||||
receiveCount++
|
||||
log.Println("msg type: ", v.Type(), ", msg value: ", v)
|
||||
log.Println("msg type: ", v.GetType(), ", msg value: ", v)
|
||||
}
|
||||
log.Println("================")
|
||||
}
|
||||
|
||||
@ -380,9 +380,9 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
|
||||
outputStream.Seek(ctx, receivedMsg.StartPositions, false)
|
||||
seekMsg := consumer(ctx, outputStream)
|
||||
assert.Equal(t, len(seekMsg.Msgs), 1+2)
|
||||
assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1)
|
||||
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type())
|
||||
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type())
|
||||
assert.EqualValues(t, seekMsg.Msgs[0].GetTimestamp(), 1)
|
||||
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].GetType())
|
||||
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].GetType())
|
||||
|
||||
inputStream.Close()
|
||||
outputStream.Close()
|
||||
@ -485,13 +485,17 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
|
||||
assert.Equal(t, len(seekMsg.Msgs), 3)
|
||||
result := []uint64{14, 12, 13}
|
||||
for i, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), result[i])
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), result[i])
|
||||
}
|
||||
|
||||
seekMsg2 := consumer(ctx, outputStream)
|
||||
assert.Equal(t, len(seekMsg2.Msgs), 1)
|
||||
for _, msg := range seekMsg2.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(19))
|
||||
tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tsMsg.BeginTs(), uint64(19))
|
||||
}
|
||||
|
||||
inputStream.Close()
|
||||
@ -517,7 +521,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
var seekPosition *msgpb.MsgPosition
|
||||
for i := 0; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(i))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(i))
|
||||
seekPosition = result.EndPositions[0]
|
||||
}
|
||||
outputStream.Close()
|
||||
@ -550,7 +554,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
result := consumer(ctx, outputStream2)
|
||||
assert.Equal(t, result.Msgs[0].ID(), int64(1))
|
||||
assert.Equal(t, result.Msgs[0].GetID(), int64(1))
|
||||
|
||||
inputStream.Close()
|
||||
outputStream2.Close()
|
||||
@ -585,7 +589,7 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
|
||||
pack := <-outputStream.Chan()
|
||||
assert.NotNil(t, pack)
|
||||
assert.Equal(t, 1, len(pack.Msgs))
|
||||
assert.EqualValues(t, 1000, pack.Msgs[0].BeginTs())
|
||||
assert.EqualValues(t, 1000, pack.Msgs[0].GetTimestamp())
|
||||
|
||||
inputStream.Close()
|
||||
outputStream.Close()
|
||||
|
||||
@ -47,6 +47,10 @@ type TsMsg interface {
|
||||
BeginTs() Timestamp
|
||||
EndTs() Timestamp
|
||||
Type() MsgType
|
||||
VChannel() string
|
||||
// CollID return msg collection id
|
||||
// return 0 if not exist
|
||||
CollID() int64
|
||||
SourceID() int64
|
||||
HashKeys() []uint32
|
||||
Marshal(TsMsg) (MarshalType, error)
|
||||
@ -117,6 +121,14 @@ func (bm *BaseMsg) SetTs(ts uint64) {
|
||||
bm.EndTimestamp = ts
|
||||
}
|
||||
|
||||
func (it *BaseMsg) VChannel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (it *BaseMsg) CollID() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func convertToByteArray(input interface{}) ([]byte, error) {
|
||||
switch output := input.(type) {
|
||||
case []byte:
|
||||
@ -157,6 +169,14 @@ func (it *InsertMsg) SourceID() int64 {
|
||||
return it.Base.SourceID
|
||||
}
|
||||
|
||||
func (it *InsertMsg) VChannel() string {
|
||||
return it.ShardName
|
||||
}
|
||||
|
||||
func (it *InsertMsg) CollID() int64 {
|
||||
return it.GetCollectionID()
|
||||
}
|
||||
|
||||
// Marshal is used to serialize a message pack to byte array
|
||||
func (it *InsertMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
insertMsg := input.(*InsertMsg)
|
||||
@ -343,6 +363,14 @@ func (dt *DeleteMsg) SourceID() int64 {
|
||||
return dt.Base.SourceID
|
||||
}
|
||||
|
||||
func (it *DeleteMsg) VChannel() string {
|
||||
return it.ShardName
|
||||
}
|
||||
|
||||
func (it *DeleteMsg) CollID() int64 {
|
||||
return it.GetCollectionID()
|
||||
}
|
||||
|
||||
// Marshal is used to serializing a message pack to byte array
|
||||
func (dt *DeleteMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
deleteMsg := input.(*DeleteMsg)
|
||||
@ -516,6 +544,10 @@ func (cc *CreateCollectionMsg) SourceID() int64 {
|
||||
return cc.Base.SourceID
|
||||
}
|
||||
|
||||
func (it *CreateCollectionMsg) CollID() int64 {
|
||||
return it.GetCollectionID()
|
||||
}
|
||||
|
||||
// Marshal is used to serializing a message pack to byte array
|
||||
func (cc *CreateCollectionMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
createCollectionMsg := input.(*CreateCollectionMsg)
|
||||
@ -580,6 +612,10 @@ func (dc *DropCollectionMsg) SourceID() int64 {
|
||||
return dc.Base.SourceID
|
||||
}
|
||||
|
||||
func (it *DropCollectionMsg) CollID() int64 {
|
||||
return it.GetCollectionID()
|
||||
}
|
||||
|
||||
// Marshal is used to serializing a message pack to byte array
|
||||
func (dc *DropCollectionMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
dropCollectionMsg := input.(*DropCollectionMsg)
|
||||
@ -644,6 +680,10 @@ func (cp *CreatePartitionMsg) SourceID() int64 {
|
||||
return cp.Base.SourceID
|
||||
}
|
||||
|
||||
func (it *CreatePartitionMsg) CollID() int64 {
|
||||
return it.GetCollectionID()
|
||||
}
|
||||
|
||||
// Marshal is used to serializing a message pack to byte array
|
||||
func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
createPartitionMsg := input.(*CreatePartitionMsg)
|
||||
@ -708,6 +748,10 @@ func (dp *DropPartitionMsg) SourceID() int64 {
|
||||
return dp.Base.SourceID
|
||||
}
|
||||
|
||||
func (it *DropPartitionMsg) CollID() int64 {
|
||||
return it.GetCollectionID()
|
||||
}
|
||||
|
||||
// Marshal is used to serializing a message pack to byte array
|
||||
func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
dropPartitionMsg := input.(*DropPartitionMsg)
|
||||
|
||||
@ -18,6 +18,10 @@ package msgstream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -25,6 +29,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
@ -52,6 +57,238 @@ type MsgPack struct {
|
||||
EndPositions []*MsgPosition
|
||||
}
|
||||
|
||||
// ConsumeMsgPack represents a batch of msg in consumer
|
||||
type ConsumeMsgPack struct {
|
||||
BeginTs Timestamp
|
||||
EndTs Timestamp
|
||||
Msgs []ConsumeMsg
|
||||
StartPositions []*MsgPosition
|
||||
EndPositions []*MsgPosition
|
||||
}
|
||||
|
||||
// ConsumeMsg used for ConumserMsgPack
|
||||
// support fetch some properties metric
|
||||
type ConsumeMsg interface {
|
||||
GetPosition() *msgpb.MsgPosition
|
||||
SetPosition(*msgpb.MsgPosition)
|
||||
|
||||
GetSize() int
|
||||
GetTimestamp() uint64
|
||||
GetVChannel() string
|
||||
GetPChannel() string
|
||||
GetMessageID() []byte
|
||||
GetID() int64
|
||||
GetCollectionID() string
|
||||
GetType() commonpb.MsgType
|
||||
GetReplicateID() string
|
||||
|
||||
SetTraceCtx(ctx context.Context)
|
||||
Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error)
|
||||
}
|
||||
|
||||
// UnmarshaledMsg pack unmarshalled tsMsg as ConsumeMsg
|
||||
// For Compatibility or Test
|
||||
type UnmarshaledMsg struct {
|
||||
msg TsMsg
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetTimestamp() uint64 {
|
||||
return m.msg.BeginTs()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetVChannel() string {
|
||||
return m.msg.VChannel()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetPChannel() string {
|
||||
return m.msg.Position().GetChannelName()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetMessageID() []byte {
|
||||
return m.msg.Position().GetMsgID()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetID() int64 {
|
||||
return m.msg.ID()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetCollectionID() string {
|
||||
return strconv.FormatInt(m.msg.CollID(), 10)
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetType() commonpb.MsgType {
|
||||
return m.msg.Type()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetSize() int {
|
||||
return m.msg.Size()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetReplicateID() string {
|
||||
msgBase, ok := m.msg.(interface{ GetBase() *commonpb.MsgBase })
|
||||
if !ok {
|
||||
log.Warn("fail to get msg base, please check it", zap.Any("type", m.msg.Type()))
|
||||
return ""
|
||||
}
|
||||
return msgBase.GetBase().GetReplicateInfo().GetReplicateID()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) SetPosition(pos *msgpb.MsgPosition) {
|
||||
m.msg.SetPosition(pos)
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) GetPosition() *msgpb.MsgPosition {
|
||||
return m.msg.Position()
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) SetTraceCtx(ctx context.Context) {
|
||||
m.msg.SetTraceCtx(ctx)
|
||||
}
|
||||
|
||||
func (m *UnmarshaledMsg) Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) {
|
||||
return m.msg, nil
|
||||
}
|
||||
|
||||
// MarshaledMsg pack marshaled tsMsg
|
||||
// and parse properties
|
||||
type MarshaledMsg struct {
|
||||
msg common.Message
|
||||
pos *MsgPosition
|
||||
msgType MsgType
|
||||
msgID int64
|
||||
timestamp uint64
|
||||
vchannel string
|
||||
collectionID string
|
||||
replicateID string
|
||||
traceCtx context.Context
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetTimestamp() uint64 {
|
||||
return m.timestamp
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetVChannel() string {
|
||||
return m.vchannel
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetPChannel() string {
|
||||
return filepath.Base(m.msg.Topic())
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetMessageID() []byte {
|
||||
return m.msg.ID().Serialize()
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetID() int64 {
|
||||
return m.msgID
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetCollectionID() string {
|
||||
return m.collectionID
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetType() commonpb.MsgType {
|
||||
return m.msgType
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetSize() int {
|
||||
return len(m.msg.Payload())
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetReplicateID() string {
|
||||
return m.replicateID
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) SetPosition(pos *msgpb.MsgPosition) {
|
||||
m.pos = pos
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) GetPosition() *msgpb.MsgPosition {
|
||||
return m.pos
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) SetTraceCtx(ctx context.Context) {
|
||||
m.traceCtx = ctx
|
||||
}
|
||||
|
||||
func (m *MarshaledMsg) Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) {
|
||||
tsMsg, err := GetTsMsgFromConsumerMsg(unmarshalDispatcher, m.msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tsMsg.SetTraceCtx(m.traceCtx)
|
||||
tsMsg.SetPosition(m.pos)
|
||||
return tsMsg, nil
|
||||
}
|
||||
|
||||
func NewMarshaledMsg(msg common.Message, group string) (ConsumeMsg, error) {
|
||||
properties := msg.Properties()
|
||||
vchannel, ok := properties[common.ChannelTypeKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("get channel name from msg properties failed")
|
||||
}
|
||||
|
||||
collID, ok := properties[common.CollectionIDTypeKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("get collection ID from msg properties failed")
|
||||
}
|
||||
|
||||
tsStr, ok := properties[common.TimestampTypeKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("get minTs from msg properties failed")
|
||||
}
|
||||
|
||||
timestamp, err := strconv.ParseUint(tsStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Warn("parse message properties minTs failed, unknown message", zap.Error(err))
|
||||
return nil, fmt.Errorf("parse minTs from msg properties failed")
|
||||
}
|
||||
|
||||
idStr, ok := properties[common.MsgIdTypeKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("get msgType from msg properties failed")
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Warn("parse message properties minTs failed, unknown message", zap.Error(err))
|
||||
return nil, fmt.Errorf("parse minTs from msg properties failed")
|
||||
}
|
||||
|
||||
val, ok := properties[common.MsgTypeKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("get msgType from msg properties failed")
|
||||
}
|
||||
msgType := commonpb.MsgType(commonpb.MsgType_value[val])
|
||||
|
||||
result := &MarshaledMsg{
|
||||
msg: msg,
|
||||
msgID: id,
|
||||
collectionID: collID,
|
||||
timestamp: timestamp,
|
||||
msgType: msgType,
|
||||
vchannel: vchannel,
|
||||
}
|
||||
|
||||
replicateID, ok := properties[common.ReplicateIDTypeKey]
|
||||
if ok {
|
||||
result.replicateID = replicateID
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// unmarshal common message to UnmarshaledMsg
|
||||
func UnmarshalMsg(msg common.Message, unmarshalDispatcher UnmarshalDispatcher) (ConsumeMsg, error) {
|
||||
tsMsg, err := GetTsMsgFromConsumerMsg(unmarshalDispatcher, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UnmarshaledMsg{
|
||||
msg: tsMsg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RepackFunc is a function type which used to repack message after hash by primary key
|
||||
type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error)
|
||||
|
||||
@ -66,7 +303,8 @@ type MsgStream interface {
|
||||
Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error)
|
||||
|
||||
AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error
|
||||
Chan() <-chan *MsgPack
|
||||
Chan() <-chan *ConsumeMsgPack
|
||||
GetUnmarshalDispatcher() UnmarshalDispatcher
|
||||
// Seek consume message from the specified position
|
||||
// includeCurrentMsg indicates whether to consume the current message, and in the milvus system, it should be always false
|
||||
Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error
|
||||
@ -117,6 +355,15 @@ func GetReplicateID(msg TsMsg) string {
|
||||
return msgBase.GetBase().GetReplicateInfo().GetReplicateID()
|
||||
}
|
||||
|
||||
func GetTimestamp(msg TsMsg) uint64 {
|
||||
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
|
||||
if !ok {
|
||||
log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type()))
|
||||
return 0
|
||||
}
|
||||
return msgBase.GetBase().GetTimestamp()
|
||||
}
|
||||
|
||||
func MatchReplicateID(msg TsMsg, replicateID string) bool {
|
||||
return GetReplicateID(msg) == replicateID
|
||||
}
|
||||
@ -126,3 +373,81 @@ type Factory interface {
|
||||
NewTtMsgStream(ctx context.Context) (MsgStream, error)
|
||||
NewMsgStreamDisposer(ctx context.Context) func([]string, string) error
|
||||
}
|
||||
|
||||
// Filter and parse ts message for temporary stream
|
||||
type SimpleMsgDispatcher struct {
|
||||
stream MsgStream
|
||||
unmarshalDispatcher UnmarshalDispatcher
|
||||
filter func(ConsumeMsg) bool
|
||||
ch chan *MsgPack
|
||||
chOnce sync.Once
|
||||
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewSimpleMsgDispatcher(stream MsgStream, filter func(ConsumeMsg) bool) *SimpleMsgDispatcher {
|
||||
return &SimpleMsgDispatcher{
|
||||
stream: stream,
|
||||
filter: filter,
|
||||
unmarshalDispatcher: stream.GetUnmarshalDispatcher(),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SimpleMsgDispatcher) filterAndParase() {
|
||||
defer func() {
|
||||
close(p.ch)
|
||||
p.wg.Done()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case marshalPack, ok := <-p.stream.Chan():
|
||||
if !ok {
|
||||
log.Warn("dispatcher fail to read delta msg")
|
||||
return
|
||||
}
|
||||
|
||||
msgPack := &MsgPack{
|
||||
BeginTs: marshalPack.BeginTs,
|
||||
EndTs: marshalPack.EndTs,
|
||||
Msgs: make([]TsMsg, 0),
|
||||
StartPositions: marshalPack.StartPositions,
|
||||
EndPositions: marshalPack.EndPositions,
|
||||
}
|
||||
for _, marshalMsg := range marshalPack.Msgs {
|
||||
if !p.filter(marshalMsg) {
|
||||
continue
|
||||
}
|
||||
// unmarshal message
|
||||
msg, err := marshalMsg.Unmarshal(p.unmarshalDispatcher)
|
||||
if err != nil {
|
||||
log.Warn("unmarshal message failed, invalid message", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
}
|
||||
p.ch <- msgPack
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SimpleMsgDispatcher) Chan() chan *MsgPack {
|
||||
p.chOnce.Do(func() {
|
||||
p.ch = make(chan *MsgPack, paramtable.Get().MQCfg.ReceiveBufSize.GetAsInt64())
|
||||
p.wg.Add(1)
|
||||
go p.filterAndParase()
|
||||
})
|
||||
return p.ch
|
||||
}
|
||||
|
||||
func (p *SimpleMsgDispatcher) Close() {
|
||||
p.closeOnce.Do(func() {
|
||||
p.stream.Close()
|
||||
close(p.closeCh)
|
||||
p.wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
@ -20,10 +20,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
|
||||
"github.com/confluentinc/confluent-kafka-go/kafka"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
pcommon "github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
@ -141,3 +144,32 @@ func KafkaHealthCheck(clusterStatus *pcommon.MQClusterStatus) {
|
||||
clusterStatus.Health = true
|
||||
clusterStatus.Members = healthList
|
||||
}
|
||||
|
||||
func GetPorperties(msg TsMsg) map[string]string {
|
||||
properties := map[string]string{}
|
||||
properties[common.ChannelTypeKey] = msg.VChannel()
|
||||
properties[common.MsgTypeKey] = msg.Type().String()
|
||||
properties[common.MsgIdTypeKey] = strconv.FormatInt(msg.ID(), 10)
|
||||
properties[common.CollectionIDTypeKey] = strconv.FormatInt(msg.CollID(), 10)
|
||||
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
|
||||
if ok {
|
||||
properties[common.TimestampTypeKey] = strconv.FormatUint(msgBase.GetBase().GetTimestamp(), 10)
|
||||
if msgBase.GetBase().GetReplicateInfo() != nil {
|
||||
properties[common.ReplicateIDTypeKey] = msgBase.GetBase().GetReplicateInfo().GetReplicateID()
|
||||
}
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func BuildConsumeMsgPack(pack *MsgPack) *ConsumeMsgPack {
|
||||
return &ConsumeMsgPack{
|
||||
BeginTs: pack.BeginTs,
|
||||
EndTs: pack.EndTs,
|
||||
Msgs: lo.Map(pack.Msgs, func(msg TsMsg, _ int) ConsumeMsg {
|
||||
return &UnmarshaledMsg{msg: msg}
|
||||
}),
|
||||
StartPositions: pack.StartPositions,
|
||||
EndPositions: pack.EndPositions,
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,21 +29,17 @@ import (
|
||||
|
||||
// ExtractCtx extracts trace span from msg.properties.
|
||||
// And it will attach some default tags to the span.
|
||||
func ExtractCtx(msg TsMsg, properties map[string]string) (context.Context, trace.Span) {
|
||||
ctx := msg.TraceCtx()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
func ExtractCtx(msg ConsumeMsg, properties map[string]string) (context.Context, trace.Span) {
|
||||
ctx := context.Background()
|
||||
if !allowTrace(msg) {
|
||||
return ctx, trace.SpanFromContext(ctx)
|
||||
}
|
||||
ctx = otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(properties))
|
||||
name := "ReceieveMsg"
|
||||
return otel.Tracer(name).Start(ctx, name, trace.WithAttributes(
|
||||
attribute.Int64("ID", msg.ID()),
|
||||
attribute.String("Type", msg.Type().String()),
|
||||
// attribute.Int64Value("HashKeys", msg.HashKeys()),
|
||||
attribute.String("Position", msg.Position().String()),
|
||||
attribute.Int64("ID", msg.GetID()),
|
||||
attribute.String("Type", msg.GetType().String()),
|
||||
attribute.String("Position", msg.GetPosition().String()),
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ type WastedMockMsgStream struct {
|
||||
AsProducerFunc func(channels []string)
|
||||
BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error)
|
||||
BroadcastFunc func(*MsgPack) error
|
||||
ChanFunc func() <-chan *MsgPack
|
||||
ChanFunc func() <-chan *ConsumeMsgPack
|
||||
}
|
||||
|
||||
func NewWastedMockMsgStream() *WastedMockMsgStream {
|
||||
@ -22,6 +22,6 @@ func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[
|
||||
return m.BroadcastMarkFunc(pack)
|
||||
}
|
||||
|
||||
func (m WastedMockMsgStream) Chan() <-chan *MsgPack {
|
||||
func (m WastedMockMsgStream) Chan() <-chan *ConsumeMsgPack {
|
||||
return m.ChanFunc()
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user