Add msgDispatcher to support sharing msgs for different vChannel (#21917)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2023-02-13 16:38:33 +08:00 committed by GitHub
parent a2435cfc4f
commit d2667064bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1601 additions and 351 deletions

View File

@ -39,13 +39,13 @@ import (
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
allocator2 "github.com/milvus-io/milvus/internal/allocator" allocator2 "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/commonpbutil"
@ -127,6 +127,7 @@ type DataNode struct {
closer io.Closer closer io.Closer
dispClient msgdispatcher.Client
factory dependency.Factory factory dependency.Factory
} }
@ -249,6 +250,9 @@ func (node *DataNode) Init() error {
} }
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID())) log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, paramtable.GetNodeID())
log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", paramtable.GetNodeID()))
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID()) idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID())
if err != nil { if err != nil {
log.Error("failed to create id allocator", log.Error("failed to create id allocator",

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -49,6 +50,7 @@ type dataSyncService struct {
resendTTCh chan resendTTMsg // chan to ask for resending DataNode time tick message. resendTTCh chan resendTTMsg // chan to ask for resending DataNode time tick message.
channel Channel // channel stores meta of channel channel Channel // channel stores meta of channel
idAllocator allocatorInterface // id/timestamp allocator idAllocator allocatorInterface // id/timestamp allocator
dispClient msgdispatcher.Client
msFactory msgstream.Factory msFactory msgstream.Factory
collectionID UniqueID // collection id of vchan for which this data sync service serves collectionID UniqueID // collection id of vchan for which this data sync service serves
vchannelName string vchannelName string
@ -71,6 +73,7 @@ func newDataSyncService(ctx context.Context,
resendTTCh chan resendTTMsg, resendTTCh chan resendTTMsg,
channel Channel, channel Channel,
alloc allocatorInterface, alloc allocatorInterface,
dispClient msgdispatcher.Client,
factory msgstream.Factory, factory msgstream.Factory,
vchan *datapb.VchannelInfo, vchan *datapb.VchannelInfo,
clearSignal chan<- string, clearSignal chan<- string,
@ -101,6 +104,7 @@ func newDataSyncService(ctx context.Context,
resendTTCh: resendTTCh, resendTTCh: resendTTCh,
channel: channel, channel: channel,
idAllocator: alloc, idAllocator: alloc,
dispClient: dispClient,
msFactory: factory, msFactory: factory,
collectionID: vchan.GetCollectionID(), collectionID: vchan.GetCollectionID(),
vchannelName: vchan.GetChannelName(), vchannelName: vchan.GetChannelName(),
@ -156,6 +160,7 @@ func (dsService *dataSyncService) close() {
if dsService.fg != nil { if dsService.fg != nil {
log.Info("dataSyncService closing flowgraph", zap.Int64("collectionID", dsService.collectionID), log.Info("dataSyncService closing flowgraph", zap.Int64("collectionID", dsService.collectionID),
zap.String("vChanName", dsService.vchannelName)) zap.String("vChanName", dsService.vchannelName))
dsService.dispClient.Deregister(dsService.vchannelName)
dsService.fg.Close() dsService.fg.Close()
metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(2) // timeTickChannel + deltaChannel metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(2) // timeTickChannel + deltaChannel
@ -287,7 +292,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro
} }
var dmStreamNode Node var dmStreamNode Node
dmStreamNode, err = newDmInputNode(dsService.ctx, vchanInfo.GetSeekPosition(), c) dmStreamNode, err = newDmInputNode(dsService.dispClient, vchanInfo.GetSeekPosition(), c)
if err != nil { if err != nil {
return err return err
} }

View File

@ -33,6 +33,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
@ -40,10 +41,15 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service" var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service"
func init() {
Params.Init()
}
func getVchanInfo(info *testInfo) *datapb.VchannelInfo { func getVchanInfo(info *testInfo) *datapb.VchannelInfo {
var ufs []*datapb.SegmentInfo var ufs []*datapb.SegmentInfo
var fs []*datapb.SegmentInfo var fs []*datapb.SegmentInfo
@ -160,12 +166,14 @@ func TestDataSyncService_newDataSyncService(te *testing.T) {
if test.channelNil { if test.channelNil {
channel = nil channel = nil
} }
dispClient := msgdispatcher.NewClient(test.inMsgFactory, typeutil.DataNodeRole, paramtable.GetNodeID())
ds, err := newDataSyncService(ctx, ds, err := newDataSyncService(ctx,
make(chan flushMsg), make(chan flushMsg),
make(chan resendTTMsg), make(chan resendTTMsg),
channel, channel,
NewAllocatorFactory(), NewAllocatorFactory(),
dispClient,
test.inMsgFactory, test.inMsgFactory,
getVchanInfo(test), getVchanInfo(test),
make(chan string), make(chan string),
@ -217,6 +225,7 @@ func TestDataSyncService_Start(t *testing.T) {
allocFactory := NewAllocatorFactory(1) allocFactory := NewAllocatorFactory(1)
factory := dependency.NewDefaultFactory(true) factory := dependency.NewDefaultFactory(true)
dispClient := msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
defer os.RemoveAll("/tmp/milvus") defer os.RemoveAll("/tmp/milvus")
paramtable.Get().Save(Params.DataNodeCfg.FlushInsertBufferSize.Key, "1") paramtable.Get().Save(Params.DataNodeCfg.FlushInsertBufferSize.Key, "1")
@ -270,7 +279,7 @@ func TestDataSyncService_Start(t *testing.T) {
}, },
} }
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0) sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, dispClient, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err) assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack) sync.flushListener = make(chan *segmentFlushPack)
@ -399,6 +408,7 @@ func TestDataSyncService_Close(t *testing.T) {
allocFactory = NewAllocatorFactory(1) allocFactory = NewAllocatorFactory(1)
factory = dependency.NewDefaultFactory(true) factory = dependency.NewDefaultFactory(true)
dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
mockDataCoord = &DataCoordFactory{} mockDataCoord = &DataCoordFactory{}
) )
mockDataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{ mockDataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{
@ -421,7 +431,7 @@ func TestDataSyncService_Close(t *testing.T) {
paramtable.Get().Reset(Params.DataNodeCfg.FlushInsertBufferSize.Key) paramtable.Get().Reset(Params.DataNodeCfg.FlushInsertBufferSize.Key)
channel := newChannel(insertChannelName, collMeta.ID, collMeta.GetSchema(), mockRootCoord, cm) channel := newChannel(insertChannelName, collMeta.ID, collMeta.GetSchema(), mockRootCoord, cm)
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), 0) sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, dispClient, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err) assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack, 10) sync.flushListener = make(chan *segmentFlushPack, 10)

View File

@ -220,6 +220,12 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
for i := int64(0); i < dmsg.NumRows; i++ { for i := int64(0); i < dmsg.NumRows; i++ {
dmsg.HashValues = append(dmsg.HashValues, uint32(0)) dmsg.HashValues = append(dmsg.HashValues, uint32(0))
} }
deltaVChannel, err := funcutil.ConvertChannelName(dmsg.ShardName, Params.CommonCfg.RootCoordDml.GetValue(), Params.CommonCfg.RootCoordDelta.GetValue())
if err != nil {
log.Error("convert dmlVChannel to deltaVChannel failed", zap.String("vchannel", ddn.vChannelName), zap.Error(err))
panic(err)
}
dmsg.ShardName = deltaVChannel
forwardMsgs = append(forwardMsgs, dmsg) forwardMsgs = append(forwardMsgs, dmsg)
if dmsg.CollectionID != ddn.collectionID { if dmsg.CollectionID != ddn.collectionID {
log.Warn("filter invalid DeleteMsg, collection mis-match", log.Warn("filter invalid DeleteMsg, collection mis-match",

View File

@ -278,6 +278,7 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) {
}, },
DeleteRequest: internalpb.DeleteRequest{ DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete}, Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete},
ShardName: "by-dev-rootcoord-dml-mock-0",
CollectionID: test.inMsgCollID, CollectionID: test.inMsgCollID,
}, },
} }

View File

@ -17,7 +17,6 @@
package datanode package datanode
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@ -25,10 +24,11 @@ import (
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/internal/util/flowgraph"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -37,50 +37,32 @@ import (
// DmInputNode receives messages from message streams, packs messages between two timeticks, and passes all // DmInputNode receives messages from message streams, packs messages between two timeticks, and passes all
// messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is // messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is
// flowgraph ddNode. // flowgraph ddNode.
func newDmInputNode(ctx context.Context, seekPos *internalpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) { func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *internalpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) {
// subName should be unique, since pchannelName is shared among several collections log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
// use vchannel in case of reuse pchannel for same collection
consumeSubName := fmt.Sprintf("%s-%d-%s", Params.CommonCfg.DataNodeSubName.GetValue(), paramtable.GetNodeID(), dmNodeConfig.vChannelName)
insertStream, err := dmNodeConfig.msFactory.NewTtMsgStream(ctx)
if err != nil {
return nil, err
}
// MsgStream needs a physical channel name, but the channel name in seek position from DataCoord
// is virtual channel name, so we need to convert vchannel name into pchannel neme here.
pchannelName := funcutil.ToPhysicalChannel(dmNodeConfig.vChannelName)
if seekPos != nil {
insertStream.AsConsumer([]string{pchannelName}, consumeSubName, mqwrapper.SubscriptionPositionUnknown)
seekPos.ChannelName = pchannelName
cpTs, _ := tsoutil.ParseTS(seekPos.Timestamp)
start := time.Now()
log.Info("datanode begin to seek",
zap.ByteString("seek msgID", seekPos.GetMsgID()),
zap.String("pchannel", seekPos.GetChannelName()),
zap.String("vchannel", dmNodeConfig.vChannelName),
zap.Time("position", cpTs),
zap.Duration("tsLag", time.Since(cpTs)),
zap.Int64("collection ID", dmNodeConfig.collectionID))
err = insertStream.Seek([]*internalpb.MsgPosition{seekPos})
if err != nil {
return nil, err
}
log.Info("datanode seek successfully",
zap.ByteString("seek msgID", seekPos.GetMsgID()),
zap.String("pchannel", seekPos.GetChannelName()),
zap.String("vchannel", dmNodeConfig.vChannelName),
zap.Time("position", cpTs),
zap.Duration("tsLag", time.Since(cpTs)),
zap.Int64("collection ID", dmNodeConfig.collectionID), zap.Int64("collection ID", dmNodeConfig.collectionID),
zap.Duration("elapse", time.Since(start))) zap.String("vchannel", dmNodeConfig.vChannelName))
var err error
var input <-chan *msgstream.MsgPack
if seekPos != nil && len(seekPos.MsgID) != 0 {
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown)
if err != nil {
return nil, err
}
log.Info("datanode seek successfully when register to msgDispatcher",
zap.ByteString("msgID", seekPos.GetMsgID()),
zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())),
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp()))))
} else { } else {
insertStream.AsConsumer([]string{pchannelName}, consumeSubName, mqwrapper.SubscriptionPositionEarliest) input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest)
if err != nil {
return nil, err
}
log.Info("datanode consume successfully when register to msgDispatcher")
} }
metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() metrics.DataNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
log.Info("datanode AsConsumer", zap.String("physical channel", pchannelName), zap.String("subName", consumeSubName), zap.Int64("collection ID", dmNodeConfig.collectionID))
name := fmt.Sprintf("dmInputNode-data-%d-%s", dmNodeConfig.collectionID, dmNodeConfig.vChannelName) name := fmt.Sprintf("dmInputNode-data-%d-%s", dmNodeConfig.collectionID, dmNodeConfig.vChannelName)
node := flowgraph.NewInputNode(insertStream, name, dmNodeConfig.maxQueueLength, dmNodeConfig.maxParallelism, node := flowgraph.NewInputNode(input, name, dmNodeConfig.maxQueueLength, dmNodeConfig.maxParallelism,
typeutil.DataNodeRole, paramtable.GetNodeID(), dmNodeConfig.collectionID, metrics.AllLabel) typeutil.DataNodeRole, paramtable.GetNodeID(), dmNodeConfig.collectionID, metrics.AllLabel)
return node, nil return node, nil
} }

View File

@ -21,12 +21,14 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
type mockMsgStreamFactory struct { type mockMsgStreamFactory struct {
@ -93,7 +95,10 @@ func (mtm *mockTtMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID,
} }
func TestNewDmInputNode(t *testing.T) { func TestNewDmInputNode(t *testing.T) {
ctx := context.Background() client := msgdispatcher.NewClient(&mockMsgStreamFactory{}, typeutil.DataNodeRole, paramtable.GetNodeID())
_, err := newDmInputNode(ctx, new(internalpb.MsgPosition), &nodeConfig{msFactory: &mockMsgStreamFactory{}}) _, err := newDmInputNode(client, new(internalpb.MsgPosition), &nodeConfig{
msFactory: &mockMsgStreamFactory{},
vChannelName: "mock_vchannel_0",
})
assert.Nil(t, err) assert.Nil(t, err)
} }

View File

@ -48,7 +48,7 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo
var alloc allocatorInterface = newAllocator(dn.rootCoord) var alloc allocatorInterface = newAllocator(dn.rootCoord)
dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel, dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel,
alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID) alloc, dn.dispClient, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
if err != nil { if err != nil {
log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err)) log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err))
return err return err

View File

@ -27,29 +27,29 @@ import (
"sync" "sync"
"time" "time"
"github.com/milvus-io/milvus/internal/util/metautil"
"go.uber.org/zap" "go.uber.org/zap"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
s "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
s "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/metautil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
const ctxTimeInMillisecond = 5000 const ctxTimeInMillisecond = 5000
@ -81,6 +81,7 @@ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNod
factory := dependency.NewDefaultFactory(true) factory := dependency.NewDefaultFactory(true)
node := NewDataNode(ctx, factory) node := NewDataNode(ctx, factory)
node.SetSession(&sessionutil.Session{ServerID: 1}) node.SetSession(&sessionutil.Session{ServerID: 1})
node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())
rc := &RootCoordFactory{ rc := &RootCoordFactory{
ID: 0, ID: 0,

View File

@ -0,0 +1,93 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type (
Pos = internalpb.MsgPosition
MsgPack = msgstream.MsgPack
SubPos = mqwrapper.SubscriptionInitialPosition
)
type Client interface {
Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
Deregister(vchannel string)
}
var _ Client = (*client)(nil)
type client struct {
role string
nodeID int64
managers *typeutil.ConcurrentMap[string, DispatcherManager] // pchannel -> DispatcherManager
factory msgstream.Factory
}
func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
return &client{
role: role,
nodeID: nodeID,
managers: typeutil.NewConcurrentMap[string, DispatcherManager](),
factory: factory,
}
}
func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
pchannel := funcutil.ToPhysicalChannel(vchannel)
managers, ok := c.managers.Get(pchannel)
if !ok {
managers = NewDispatcherManager(pchannel, c.role, c.nodeID, c.factory)
go managers.Run()
old, exist := c.managers.GetOrInsert(pchannel, managers)
if exist {
managers.Close()
managers = old
}
}
ch, err := managers.Add(vchannel, pos, subPos)
if err != nil {
log.Error("register failed", zap.Error(err))
return nil, err
}
log.Info("register done")
return ch, nil
}
func (c *client) Deregister(vchannel string) {
pchannel := funcutil.ToPhysicalChannel(vchannel)
if managers, ok := c.managers.Get(pchannel); ok {
managers.Remove(vchannel)
if managers.Num() == 0 {
managers.Close()
c.managers.GetAndRemove(pchannel)
}
log.Info("deregister done", zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
}
}

View File

@ -0,0 +1,58 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"fmt"
"math/rand"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestClient(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
assert.NotNil(t, client)
_, err := client.Register("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.NotPanics(t, func() {
client.Deregister("mock_vchannel_0")
})
}
func TestClient_Concurrency(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
assert.NotNil(t, client)
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
wg.Add(1)
go func() {
for j := 0; j < 10; j++ {
_, err := client.Register(vchannel, nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
client.Deregister(vchannel)
}
wg.Done()
}()
}
wg.Wait()
}

View File

@ -0,0 +1,244 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type signal int32
const (
start signal = 0
pause signal = 1
resume signal = 2
terminate signal = 3
)
var signalString = map[int32]string{
0: "start",
1: "pause",
2: "resume",
3: "terminate",
}
func (s signal) String() string {
return signalString[int32(s)]
}
type Dispatcher struct {
done chan struct{}
wg sync.WaitGroup
once sync.Once
isMain bool // indicates if it's a main dispatcher
pchannel string
curTs atomic.Uint64
lagNotifyChan chan struct{}
lagTargets *sync.Map // vchannel -> *target
// vchannel -> *target, lock free since we guarantee that
// it's modified only after dispatcher paused or terminated
targets map[string]*target
stream msgstream.MsgStream
}
func NewDispatcher(factory msgstream.Factory,
isMain bool,
pchannel string,
position *Pos,
subName string,
subPos SubPos,
lagNotifyChan chan struct{},
lagTargets *sync.Map,
) (*Dispatcher, error) {
log := log.With(zap.String("pchannel", pchannel),
zap.String("subName", subName), zap.Bool("isMain", isMain))
log.Info("creating dispatcher...")
stream, err := factory.NewTtMsgStream(context.Background())
if err != nil {
return nil, err
}
if position != nil && len(position.MsgID) != 0 {
position.ChannelName = funcutil.ToPhysicalChannel(position.ChannelName)
stream.AsConsumer([]string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown)
err = stream.Seek([]*Pos{position})
if err != nil {
log.Error("seek failed", zap.Error(err))
return nil, err
}
posTime := tsoutil.PhysicalTime(position.GetTimestamp())
log.Info("seek successfully", zap.Time("posTime", posTime),
zap.Duration("tsLag", time.Since(posTime)))
} else {
stream.AsConsumer([]string{pchannel}, subName, subPos)
log.Info("asConsumer successfully")
}
d := &Dispatcher{
done: make(chan struct{}, 1),
isMain: isMain,
pchannel: pchannel,
lagNotifyChan: lagNotifyChan,
lagTargets: lagTargets,
targets: make(map[string]*target),
stream: stream,
}
return d, nil
}
func (d *Dispatcher) CurTs() typeutil.Timestamp {
return d.curTs.Load()
}
func (d *Dispatcher) AddTarget(t *target) {
log := log.With(zap.String("vchannel", t.vchannel), zap.Bool("isMain", d.isMain))
if _, ok := d.targets[t.vchannel]; ok {
log.Warn("target exists")
return
}
d.targets[t.vchannel] = t
log.Info("add new target")
}
func (d *Dispatcher) GetTarget(vchannel string) (*target, error) {
if t, ok := d.targets[vchannel]; ok {
return t, nil
}
return nil, fmt.Errorf("cannot find target, vchannel=%s, isMain=%t", vchannel, d.isMain)
}
func (d *Dispatcher) CloseTarget(vchannel string) {
log := log.With(zap.String("vchannel", vchannel), zap.Bool("isMain", d.isMain))
if t, ok := d.targets[vchannel]; ok {
t.close()
delete(d.targets, vchannel)
log.Info("closed target")
} else {
log.Warn("target not exist")
}
}
func (d *Dispatcher) TargetNum() int {
return len(d.targets)
}
func (d *Dispatcher) Handle(signal signal) {
log := log.With(zap.String("pchannel", d.pchannel),
zap.String("signal", signal.String()), zap.Bool("isMain", d.isMain))
log.Info("get signal")
switch signal {
case start:
d.wg.Add(1)
go d.work()
case pause:
d.done <- struct{}{}
d.wg.Wait()
case resume:
d.wg.Add(1)
go d.work()
case terminate:
d.done <- struct{}{}
d.wg.Wait()
d.once.Do(func() {
d.stream.Close()
})
}
log.Info("handle signal done")
}
func (d *Dispatcher) work() {
log := log.With(zap.String("pchannel", d.pchannel), zap.Bool("isMain", d.isMain))
log.Info("begin to work")
defer d.wg.Done()
for {
select {
case <-d.done:
log.Info("stop working")
return
case pack := <-d.stream.Chan():
if pack == nil || len(pack.EndPositions) != 1 {
log.Error("consumed invalid msgPack")
continue
}
d.curTs.Store(pack.EndPositions[0].GetTimestamp())
// init packs for all targets, even though there's no msg in pack,
// but we still need to dispatch time ticks to the targets.
targetPacks := make(map[string]*MsgPack)
for vchannel := range d.targets {
targetPacks[vchannel] = &MsgPack{
BeginTs: pack.BeginTs,
EndTs: pack.EndTs,
Msgs: make([]msgstream.TsMsg, 0),
StartPositions: pack.StartPositions,
EndPositions: pack.EndPositions,
}
}
// group messages by vchannel
for _, msg := range pack.Msgs {
if msg.VChannel() == "" {
// for non-dml msg, such as CreateCollection, DropCollection, ...
// we need to dispatch it to all the vchannels.
for k := range targetPacks {
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
}
continue
}
if _, ok := targetPacks[msg.VChannel()]; !ok {
continue
}
targetPacks[msg.VChannel()].Msgs = append(targetPacks[msg.VChannel()].Msgs, msg)
}
// dispatch messages, split target if block
for vchannel, p := range targetPacks {
t := d.targets[vchannel]
if err := t.send(p); err != nil {
t.pos = pack.StartPositions[0]
d.lagTargets.LoadOrStore(t.vchannel, t)
d.nonBlockingNotify()
delete(d.targets, vchannel)
log.Warn("lag target notified", zap.Error(err))
}
}
}
}
}
func (d *Dispatcher) nonBlockingNotify() {
select {
case d.lagNotifyChan <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,128 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
)
func TestDispatcher(t *testing.T) {
t.Run("test base", func(t *testing.T) {
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(t, err)
assert.NotPanics(t, func() {
d.Handle(start)
d.Handle(pause)
d.Handle(resume)
d.Handle(terminate)
})
pos := &msgstream.MsgPosition{
ChannelName: "mock_vchannel_0",
MsgGroup: "mock_msg_group",
Timestamp: 100,
}
d.curTs.Store(pos.GetTimestamp())
curTs := d.CurTs()
assert.Equal(t, pos.Timestamp, curTs)
})
t.Run("test target", func(t *testing.T) {
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(t, err)
output := make(chan *msgstream.MsgPack, 1024)
d.AddTarget(&target{
vchannel: "mock_vchannel_0",
pos: nil,
ch: output,
})
d.AddTarget(&target{
vchannel: "mock_vchannel_1",
pos: nil,
ch: nil,
})
num := d.TargetNum()
assert.Equal(t, 2, num)
target, err := d.GetTarget("mock_vchannel_0")
assert.NoError(t, err)
assert.Equal(t, cap(output), cap(target.ch))
d.CloseTarget("mock_vchannel_0")
select {
case <-time.After(1 * time.Second):
assert.Fail(t, "timeout, didn't receive close message")
case _, ok := <-target.ch:
assert.False(t, ok)
}
num = d.TargetNum()
assert.Equal(t, 1, num)
})
t.Run("test concurrent send and close", func(t *testing.T) {
for i := 0; i < 100; i++ {
output := make(chan *msgstream.MsgPack, 1024)
target := &target{
vchannel: "mock_vchannel_0",
pos: nil,
ch: output,
}
assert.Equal(t, cap(output), cap(target.ch))
wg := &sync.WaitGroup{}
for j := 0; j < 100; j++ {
wg.Add(1)
go func() {
err := target.send(&MsgPack{})
assert.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
target.close()
wg.Done()
}()
}
wg.Wait()
}
})
}
func BenchmarkDispatcher_handle(b *testing.B) {
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
d.Handle(start)
d.Handle(pause)
d.Handle(resume)
d.Handle(terminate)
}
// BenchmarkDispatcher_handle-12 9568 122123 ns/op
// PASS
}

View File

@ -0,0 +1,240 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/retry"
)
var (
CheckPeriod = 1 * time.Second // TODO: dyh, move to config
)
type DispatcherManager interface {
Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
Remove(vchannel string)
Num() int
Run()
Close()
}
var _ DispatcherManager = (*dispatcherManager)(nil)
type dispatcherManager struct {
role string
nodeID int64
pchannel string
lagNotifyChan chan struct{}
lagTargets *sync.Map // vchannel -> *target
mu sync.RWMutex // guards mainDispatcher and soloDispatchers
mainDispatcher *Dispatcher
soloDispatchers map[string]*Dispatcher // vchannel -> *Dispatcher
factory msgstream.Factory
closeChan chan struct{}
closeOnce sync.Once
}
func NewDispatcherManager(pchannel string, role string, nodeID int64, factory msgstream.Factory) DispatcherManager {
log.Info("create new dispatcherManager", zap.String("role", role),
zap.Int64("nodeID", nodeID), zap.String("pchannel", pchannel))
c := &dispatcherManager{
role: role,
nodeID: nodeID,
pchannel: pchannel,
lagNotifyChan: make(chan struct{}, 1),
lagTargets: &sync.Map{},
soloDispatchers: make(map[string]*Dispatcher),
factory: factory,
closeChan: make(chan struct{}),
}
return c
}
func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) string {
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
}
func (c *dispatcherManager) Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.mu.Lock()
defer c.mu.Unlock()
isMain := c.mainDispatcher == nil
d, err := NewDispatcher(c.factory, isMain, c.pchannel, pos,
c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets)
if err != nil {
return nil, err
}
t := newTarget(vchannel, pos)
d.AddTarget(t)
if isMain {
c.mainDispatcher = d
log.Info("add main dispatcher")
} else {
c.soloDispatchers[vchannel] = d
log.Info("add solo dispatcher")
}
d.Handle(start)
return t.ch, nil
}
func (c *dispatcherManager) Remove(vchannel string) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
c.mu.Lock()
defer c.mu.Unlock()
if c.mainDispatcher != nil {
c.mainDispatcher.Handle(pause)
c.mainDispatcher.CloseTarget(vchannel)
if c.mainDispatcher.TargetNum() == 0 && len(c.soloDispatchers) == 0 {
c.mainDispatcher.Handle(terminate)
c.mainDispatcher = nil
} else {
c.mainDispatcher.Handle(resume)
}
}
if _, ok := c.soloDispatchers[vchannel]; ok {
c.soloDispatchers[vchannel].Handle(terminate)
c.soloDispatchers[vchannel].CloseTarget(vchannel)
delete(c.soloDispatchers, vchannel)
log.Info("remove soloDispatcher done")
}
c.lagTargets.Delete(vchannel)
}
func (c *dispatcherManager) Num() int {
c.mu.RLock()
defer c.mu.RUnlock()
var res int
if c.mainDispatcher != nil {
res++
}
return res + len(c.soloDispatchers)
}
func (c *dispatcherManager) Close() {
c.closeOnce.Do(func() {
c.closeChan <- struct{}{}
})
}
func (c *dispatcherManager) Run() {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
log.Info("dispatcherManager is running...")
ticker := time.NewTicker(CheckPeriod)
defer ticker.Stop()
for {
select {
case <-c.closeChan:
log.Info("dispatcherManager exited")
return
case <-ticker.C:
c.tryMerge()
case <-c.lagNotifyChan:
c.mu.Lock()
c.lagTargets.Range(func(vchannel, t any) bool {
c.split(t.(*target))
c.lagTargets.Delete(vchannel)
return true
})
c.mu.Unlock()
}
}
}
func (c *dispatcherManager) tryMerge() {
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID))
c.mu.Lock()
defer c.mu.Unlock()
if c.mainDispatcher == nil {
return
}
candidates := make(map[string]struct{})
for vchannel, sd := range c.soloDispatchers {
if sd.CurTs() == c.mainDispatcher.CurTs() {
candidates[vchannel] = struct{}{}
}
}
if len(candidates) == 0 {
return
}
log.Info("start merging...", zap.Any("vchannel", candidates))
c.mainDispatcher.Handle(pause)
for vchannel := range candidates {
c.soloDispatchers[vchannel].Handle(pause)
// after pause, check alignment again, if not, evict it and try to merge next time
if c.mainDispatcher.CurTs() != c.soloDispatchers[vchannel].CurTs() {
c.soloDispatchers[vchannel].Handle(resume)
delete(candidates, vchannel)
}
}
for vchannel := range candidates {
t, err := c.soloDispatchers[vchannel].GetTarget(vchannel)
if err == nil {
c.mainDispatcher.AddTarget(t)
}
c.soloDispatchers[vchannel].Handle(terminate)
delete(c.soloDispatchers, vchannel)
}
c.mainDispatcher.Handle(resume)
log.Info("merge done", zap.Any("vchannel", candidates))
}
func (c *dispatcherManager) split(t *target) {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", t.vchannel))
log.Info("start splitting...")
// remove stale soloDispatcher if it existed
if _, ok := c.soloDispatchers[t.vchannel]; ok {
c.soloDispatchers[t.vchannel].Handle(terminate)
delete(c.soloDispatchers, t.vchannel)
}
var newSolo *Dispatcher
err := retry.Do(context.Background(), func() error {
var err error
newSolo, err = NewDispatcher(c.factory, false, c.pchannel, t.pos,
c.constructSubName(t.vchannel, false), mqwrapper.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets)
return err
}, retry.Attempts(10))
if err != nil {
log.Error("split failed", zap.Error(err))
panic(err)
}
newSolo.AddTarget(t)
c.soloDispatchers[t.vchannel] = newSolo
newSolo.Handle(start)
log.Info("split done")
}

View File

@ -0,0 +1,354 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"math/rand"
"reflect"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestManager(t *testing.T) {
t.Run("test add and remove dispatcher", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
assert.Equal(t, 0, c.Num())
var offset int
for i := 0; i < 100; i++ {
r := rand.Intn(10) + 1
for j := 0; j < r; j++ {
offset++
_, err := c.Add(fmt.Sprintf("mock_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.Equal(t, offset, c.Num())
}
for j := 0; j < rand.Intn(r); j++ {
c.Remove(fmt.Sprintf("mock_vchannel_%d", offset))
offset--
assert.Equal(t, offset, c.Num())
}
}
})
t.Run("test merge and split", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.Equal(t, 3, c.Num())
c.(*dispatcherManager).tryMerge()
assert.Equal(t, 1, c.Num())
info := &target{
vchannel: "mock_vchannel_2",
pos: nil,
ch: nil,
}
c.(*dispatcherManager).split(info)
assert.Equal(t, 2, c.Num())
})
t.Run("test run and close", func(t *testing.T) {
c := NewDispatcherManager("mock_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
assert.NoError(t, err)
assert.Equal(t, 3, c.Num())
CheckPeriod = 10 * time.Millisecond
go c.Run()
time.Sleep(15 * time.Millisecond)
assert.Equal(t, 1, c.Num()) // expected merged
assert.NotPanics(t, func() {
c.Close()
})
})
}
type vchannelHelper struct {
output <-chan *msgstream.MsgPack
pubInsMsgNum int
pubDelMsgNum int
pubDDLMsgNum int
pubPackNum int
subInsMsgNum int
subDelMsgNum int
subDDLMsgNum int
subPackNum int
}
type SimulationSuite struct {
suite.Suite
testVchannelNum int
manager DispatcherManager
pchannel string
vchannels map[string]*vchannelHelper
producer msgstream.MsgStream
factory msgstream.Factory
}
func (suite *SimulationSuite) SetupSuite() {
suite.factory = newMockFactory()
}
func (suite *SimulationSuite) SetupTest() {
suite.pchannel = fmt.Sprintf("by-dev-rootcoord-dispatcher-simulation-dml-%d-%d", rand.Int(), time.Now().UnixNano())
producer, err := newMockProducer(suite.factory, suite.pchannel)
assert.NoError(suite.T(), err)
suite.producer = producer
suite.manager = NewDispatcherManager(suite.pchannel, typeutil.DataNodeRole, 0, suite.factory)
CheckPeriod = 10 * time.Millisecond
go suite.manager.Run()
}
func (suite *SimulationSuite) produceMsg(wg *sync.WaitGroup) {
defer wg.Done()
const timeTickCount = 200
var uniqueMsgID int64
vchannelKeys := reflect.ValueOf(suite.vchannels).MapKeys()
for i := 1; i <= timeTickCount; i++ {
// produce random insert
insNum := rand.Intn(10)
for j := 0; j < insNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genInsertMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
uniqueMsgID++
suite.vchannels[vchannel].pubInsMsgNum++
}
// produce random delete
delNum := rand.Intn(2)
for j := 0; j < delNum; j++ {
vchannel := vchannelKeys[rand.Intn(len(vchannelKeys))].Interface().(string)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDeleteMsg(rand.Intn(20)+1, vchannel, uniqueMsgID)},
})
assert.NoError(suite.T(), err)
uniqueMsgID++
suite.vchannels[vchannel].pubDelMsgNum++
}
// produce random ddl
ddlNum := rand.Intn(2)
for j := 0; j < ddlNum; j++ {
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genDDLMsg(commonpb.MsgType_DropCollection)},
})
assert.NoError(suite.T(), err)
for k := range suite.vchannels {
suite.vchannels[k].pubDDLMsgNum++
}
}
// produce time tick
ts := uint64(i * 100)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
for k := range suite.vchannels {
suite.vchannels[k].pubPackNum++
}
}
suite.T().Logf("[%s] produce %d msgPack for %s done", time.Now(), timeTickCount, suite.pchannel)
}
func (suite *SimulationSuite) consumeMsg(ctx context.Context, wg *sync.WaitGroup, vchannel string) {
defer wg.Done()
var lastTs typeutil.Timestamp
for {
select {
case <-ctx.Done():
return
case <-time.After(2000 * time.Millisecond): // no message to consume
return
case pack := <-suite.vchannels[vchannel].output:
assert.Greater(suite.T(), pack.EndTs, lastTs)
lastTs = pack.EndTs
helper := suite.vchannels[vchannel]
helper.subPackNum++
for _, msg := range pack.Msgs {
switch msg.Type() {
case commonpb.MsgType_Insert:
helper.subInsMsgNum++
case commonpb.MsgType_Delete:
helper.subDelMsgNum++
case commonpb.MsgType_CreateCollection, commonpb.MsgType_DropCollection,
commonpb.MsgType_CreatePartition, commonpb.MsgType_DropPartition:
helper.subDDLMsgNum++
}
}
}
}
}
func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) {
var tt = 1
for {
select {
case <-ctx.Done():
return
case <-time.After(10 * time.Millisecond):
ts := uint64(tt * 1000)
err := suite.producer.Produce(&msgstream.MsgPack{
Msgs: []msgstream.TsMsg{genTimeTickMsg(ts)},
})
assert.NoError(suite.T(), err)
tt++
}
}
}
func (suite *SimulationSuite) TestDispatchToVchannels() {
const vchannelNum = 20
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}
wg.Add(1)
go suite.produceMsg(wg)
wg.Wait()
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(context.Background(), wg, vchannel)
}
wg.Wait()
for _, helper := range suite.vchannels {
assert.Equal(suite.T(), helper.pubInsMsgNum, helper.subInsMsgNum)
assert.Equal(suite.T(), helper.pubDelMsgNum, helper.subDelMsgNum)
assert.Equal(suite.T(), helper.pubDDLMsgNum, helper.subDDLMsgNum)
assert.Equal(suite.T(), helper.pubPackNum, helper.subPackNum)
}
}
func (suite *SimulationSuite) TestMerge() {
ctx, cancel := context.WithCancel(context.Background())
go suite.produceTimeTickOnly(ctx)
const vchannelNum = 20
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
positions, err := getSeekPositions(suite.factory, suite.pchannel, 200)
assert.NoError(suite.T(), err)
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(vchannel, positions[rand.Intn(len(positions))],
mqwrapper.SubscriptionPositionUnknown) // seek from random position
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
wg := &sync.WaitGroup{}
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(ctx, wg, vchannel)
}
suite.Eventually(func() bool {
suite.T().Logf("dispatcherManager.dispatcherNum = %d", suite.manager.Num())
return suite.manager.Num() == 1 // expected all merged, only mainDispatcher exist
}, 10*time.Second, 100*time.Millisecond)
cancel()
wg.Wait()
}
func (suite *SimulationSuite) TestSplit() {
ctx, cancel := context.WithCancel(context.Background())
go suite.produceTimeTickOnly(ctx)
const vchannelNum = 10
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
DefaultTargetChanSize = 10
MaxTolerantLag = 500 * time.Millisecond
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
const splitNum = 3
wg := &sync.WaitGroup{}
counter := 0
for vchannel := range suite.vchannels {
wg.Add(1)
go suite.consumeMsg(ctx, wg, vchannel)
counter++
if counter >= len(suite.vchannels)-splitNum {
break
}
}
suite.Eventually(func() bool {
suite.T().Logf("dispatcherManager.dispatcherNum = %d, splitNum+1 = %d", suite.manager.Num(), splitNum+1)
return suite.manager.Num() == splitNum+1 // expected 1 mainDispatcher and `splitNum` soloDispatchers
}, 10*time.Second, 100*time.Millisecond)
cancel()
wg.Wait()
}
func (suite *SimulationSuite) TearDownTest() {
for vchannel := range suite.vchannels {
suite.manager.Remove(vchannel)
}
suite.manager.Close()
}
func (suite *SimulationSuite) TearDownSuite() {
}
func TestSimulation(t *testing.T) {
suite.Run(t, new(SimulationSuite))
}

View File

@ -0,0 +1,213 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
dim = 128
)
func newMockFactory() msgstream.Factory {
paramtable.Init()
return msgstream.NewRmsFactory("/tmp/milvus/rocksmq/")
}
func newMockProducer(factory msgstream.Factory, pchannel string) (msgstream.MsgStream, error) {
stream, err := factory.NewMsgStream(context.Background())
if err != nil {
return nil, err
}
stream.AsProducer([]string{pchannel})
stream.SetRepackFunc(defaultInsertRepackFunc)
return stream, nil
}
func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([]*msgstream.MsgPosition, error) {
stream, err := factory.NewTtMsgStream(context.Background())
if err != nil {
return nil, err
}
defer stream.Close()
stream.AsConsumer([]string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest)
positions := make([]*msgstream.MsgPosition, 0)
for {
select {
case <-time.After(100 * time.Millisecond): // no message to consume
return positions, nil
case pack := <-stream.Chan():
positions = append(positions, pack.EndPositions[0])
if len(positions) >= maxNum {
return positions, nil
}
}
}
}
func genPKs(numRows int) []typeutil.IntPrimaryKey {
ids := make([]typeutil.IntPrimaryKey, numRows)
for i := 0; i < numRows; i++ {
ids[i] = typeutil.IntPrimaryKey(i)
}
return ids
}
func genTimestamps(numRows int) []typeutil.Timestamp {
ts := make([]typeutil.Timestamp, numRows)
for i := 0; i < numRows; i++ {
ts[i] = typeutil.Timestamp(i + 1)
}
return ts
}
func genInsertMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.InsertMsg {
floatVec := make([]float32, numRows*dim)
for i := 0; i < numRows*dim; i++ {
floatVec[i] = rand.Float32()
}
hashValues := make([]uint32, numRows)
for i := 0; i < numRows; i++ {
hashValues[i] = uint32(1)
}
return &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{HashValues: hashValues},
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Insert, MsgID: msgID},
ShardName: vchannel,
Timestamps: genTimestamps(numRows),
RowIDs: genPKs(numRows),
FieldsData: []*schemapb.FieldData{{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: floatVec}},
},
},
}},
NumRows: uint64(numRows),
Version: internalpb.InsertDataVersion_ColumnBased,
},
}
}
func genDeleteMsg(numRows int, vchannel string, msgID typeutil.UniqueID) *msgstream.DeleteMsg {
return &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{HashValues: make([]uint32, numRows)},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete, MsgID: msgID},
ShardName: vchannel,
PrimaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: genPKs(numRows),
},
},
},
Timestamps: genTimestamps(numRows),
NumRows: int64(numRows),
},
}
}
func genDDLMsg(msgType commonpb.MsgType) msgstream.TsMsg {
switch msgType {
case commonpb.MsgType_CreateCollection:
return &msgstream.CreateCollectionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
CreateCollectionRequest: internalpb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
},
}
case commonpb.MsgType_DropCollection:
return &msgstream.DropCollectionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
DropCollectionRequest: internalpb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
},
}
case commonpb.MsgType_CreatePartition:
return &msgstream.CreatePartitionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
CreatePartitionRequest: internalpb.CreatePartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition},
},
}
case commonpb.MsgType_DropPartition:
return &msgstream.DropPartitionMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
DropPartitionRequest: internalpb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
},
}
}
return nil
}
func genTimeTickMsg(ts typeutil.Timestamp) *msgstream.TimeTickMsg {
return &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
TimeTickMsg: internalpb.TimeTickMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_TimeTick,
Timestamp: ts,
},
},
}
}
// defaultInsertRepackFunc repacks the dml messages.
func defaultInsertRepackFunc(
tsMsgs []msgstream.TsMsg,
hashKeys [][]int32,
) (map[int32]*msgstream.MsgPack, error) {
if len(hashKeys) < len(tsMsgs) {
return nil, fmt.Errorf(
"the length of hash keys (%d) is less than the length of messages (%d)",
len(hashKeys),
len(tsMsgs),
)
}
// after assigning segment id to msg, tsMsgs was already re-bucketed
pack := make(map[int32]*msgstream.MsgPack)
for idx, msg := range tsMsgs {
if len(hashKeys[idx]) <= 0 {
return nil, fmt.Errorf("no hash key for %dth message", idx)
}
key := hashKeys[idx][0]
_, ok := pack[key]
if !ok {
pack[key] = &msgstream.MsgPack{}
}
pack[key].Msgs = append(pack[key].Msgs, msg)
}
return pack, nil
}

View File

@ -0,0 +1,72 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgdispatcher
import (
"fmt"
"sync"
"time"
)
// TODO: dyh, move to config
var (
MaxTolerantLag = 3 * time.Second
DefaultTargetChanSize = 1024
)
type target struct {
vchannel string
ch chan *MsgPack
pos *Pos
closeMu sync.Mutex
closeOnce sync.Once
closed bool
}
func newTarget(vchannel string, pos *Pos) *target {
t := &target{
vchannel: vchannel,
ch: make(chan *MsgPack, DefaultTargetChanSize),
pos: pos,
}
t.closed = false
return t
}
func (t *target) close() {
t.closeMu.Lock()
defer t.closeMu.Unlock()
t.closeOnce.Do(func() {
t.closed = true
close(t.ch)
})
}
func (t *target) send(pack *MsgPack) error {
t.closeMu.Lock()
defer t.closeMu.Unlock()
if t.closed {
return nil
}
select {
case <-time.After(MaxTolerantLag):
return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s", t.vchannel, MaxTolerantLag)
case t.ch <- pack:
return nil
}
}

View File

@ -53,6 +53,7 @@ type TsMsg interface {
Unmarshal(MarshalType) (TsMsg, error) Unmarshal(MarshalType) (TsMsg, error)
Position() *MsgPosition Position() *MsgPosition
SetPosition(*MsgPosition) SetPosition(*MsgPosition)
VChannel() string
} }
// BaseMsg is a basic structure that contains begin timestamp, end timestamp and the position of msgstream // BaseMsg is a basic structure that contains begin timestamp, end timestamp and the position of msgstream
@ -62,6 +63,7 @@ type BaseMsg struct {
EndTimestamp Timestamp EndTimestamp Timestamp
HashValues []uint32 HashValues []uint32
MsgPosition *MsgPosition MsgPosition *MsgPosition
Vchannel string
} }
// TraceCtx returns the context of opentracing // TraceCtx returns the context of opentracing
@ -99,6 +101,10 @@ func (bm *BaseMsg) SetPosition(position *MsgPosition) {
bm.MsgPosition = position bm.MsgPosition = position
} }
func (bm *BaseMsg) VChannel() string {
return bm.Vchannel
}
func convertToByteArray(input interface{}) ([]byte, error) { func convertToByteArray(input interface{}) ([]byte, error) {
switch output := input.(type) { switch output := input.(type) {
case []byte: case []byte:
@ -170,6 +176,7 @@ func (it *InsertMsg) Unmarshal(input MarshalType) (TsMsg, error) {
insertMsg.BeginTimestamp = timestamp insertMsg.BeginTimestamp = timestamp
} }
} }
insertMsg.Vchannel = insertMsg.ShardName
return insertMsg, nil return insertMsg, nil
} }
@ -278,6 +285,7 @@ func (it *InsertMsg) IndexMsg(index int) *InsertMsg {
Ctx: it.TraceCtx(), Ctx: it.TraceCtx(),
BeginTimestamp: it.BeginTimestamp, BeginTimestamp: it.BeginTimestamp,
EndTimestamp: it.EndTimestamp, EndTimestamp: it.EndTimestamp,
Vchannel: it.Vchannel,
HashValues: it.HashValues, HashValues: it.HashValues,
MsgPosition: it.MsgPosition, MsgPosition: it.MsgPosition,
}, },
@ -361,7 +369,7 @@ func (dt *DeleteMsg) Unmarshal(input MarshalType) (TsMsg, error) {
deleteMsg.BeginTimestamp = timestamp deleteMsg.BeginTimestamp = timestamp
} }
} }
deleteMsg.Vchannel = deleteRequest.ShardName
return deleteMsg, nil return deleteMsg, nil
} }

View File

@ -148,6 +148,7 @@ func assignSegmentID(ctx context.Context, insertMsg *msgstream.InsertMsg, result
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset]) msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset]) msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset]) msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])
msg.BaseMsg.Vchannel = channelName
msg.NumRows++ msg.NumRows++
requestSize += curRowMessageSize requestSize += curRowMessageSize
} }

View File

@ -268,6 +268,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
partitionName := dt.deleteMsg.PartitionName partitionName := dt.deleteMsg.PartitionName
proxyID := dt.deleteMsg.Base.SourceID proxyID := dt.deleteMsg.Base.SourceID
for index, key := range dt.deleteMsg.HashValues { for index, key := range dt.deleteMsg.HashValues {
vchannel := channelNames[key]
ts := dt.deleteMsg.Timestamps[index] ts := dt.deleteMsg.Timestamps[index]
_, ok := result[key] _, ok := result[key]
if !ok { if !ok {
@ -297,6 +298,8 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index]) curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index) typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index)
curMsg.NumRows++ curMsg.NumRows++
curMsg.ShardName = vchannel
curMsg.Vchannel = vchannel
} }
// send delete request to log broker // send delete request to log broker

View File

@ -439,6 +439,7 @@ func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgP
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index]) curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index) typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
curMsg.NumRows++ curMsg.NumRows++
curMsg.ShardName = channelNames[key]
} }
// send delete request to log broker // send delete request to log broker

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
@ -40,6 +41,7 @@ type dataSyncService struct {
metaReplica ReplicaInterface metaReplica ReplicaInterface
tSafeReplica TSafeReplicaInterface tSafeReplica TSafeReplicaInterface
dispClient msgdispatcher.Client
msFactory msgstream.Factory msFactory msgstream.Factory
} }
@ -51,7 +53,7 @@ func (dsService *dataSyncService) getFlowGraphNum() int {
} }
// addFlowGraphsForDMLChannels add flowGraphs to dmlChannel2FlowGraph // addFlowGraphsForDMLChannels add flowGraphs to dmlChannel2FlowGraph
func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels []string) (map[string]*queryNodeFlowGraph, error) { func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels map[string]*msgstream.MsgPosition) (map[string]*queryNodeFlowGraph, error) {
dsService.mu.Lock() dsService.mu.Lock()
defer dsService.mu.Unlock() defer dsService.mu.Unlock()
@ -61,7 +63,7 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu
} }
results := make(map[string]*queryNodeFlowGraph) results := make(map[string]*queryNodeFlowGraph)
for _, channel := range dmlChannels { for channel, position := range dmlChannels {
if _, ok := dsService.dmlChannel2FlowGraph[channel]; ok { if _, ok := dsService.dmlChannel2FlowGraph[channel]; ok {
log.Warn("dml flow graph has been existed", log.Warn("dml flow graph has been existed",
zap.Any("collectionID", collectionID), zap.Any("collectionID", collectionID),
@ -74,7 +76,8 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu
dsService.metaReplica, dsService.metaReplica,
dsService.tSafeReplica, dsService.tSafeReplica,
channel, channel,
dsService.msFactory) position,
dsService.dispClient)
if err != nil { if err != nil {
for _, fg := range results { for _, fg := range results {
fg.flowGraph.Close() fg.flowGraph.Close()
@ -128,7 +131,7 @@ func (dsService *dataSyncService) addFlowGraphsForDeltaChannels(collectionID Uni
dsService.metaReplica, dsService.metaReplica,
dsService.tSafeReplica, dsService.tSafeReplica,
channel, channel,
dsService.msFactory) dsService.dispClient)
if err != nil { if err != nil {
for channel, fg := range results { for channel, fg := range results {
fg.flowGraph.Close() fg.flowGraph.Close()
@ -291,6 +294,7 @@ func (dsService *dataSyncService) removeEmptyFlowGraphByChannel(collectionID int
func newDataSyncService(ctx context.Context, func newDataSyncService(ctx context.Context,
metaReplica ReplicaInterface, metaReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface, tSafeReplica TSafeReplicaInterface,
dispClient msgdispatcher.Client,
factory msgstream.Factory) *dataSyncService { factory msgstream.Factory) *dataSyncService {
return &dataSyncService{ return &dataSyncService{
@ -299,6 +303,7 @@ func newDataSyncService(ctx context.Context,
deltaChannel2FlowGraph: make(map[Channel]*queryNodeFlowGraph), deltaChannel2FlowGraph: make(map[Channel]*queryNodeFlowGraph),
metaReplica: metaReplica, metaReplica: metaReplica,
tSafeReplica: tSafeReplica, tSafeReplica: tSafeReplica,
dispClient: dispClient,
msFactory: factory, msFactory: factory,
} }
} }
@ -308,6 +313,7 @@ func (dsService *dataSyncService) close() {
// close DML flow graphs // close DML flow graphs
for channel, nodeFG := range dsService.dmlChannel2FlowGraph { for channel, nodeFG := range dsService.dmlChannel2FlowGraph {
if nodeFG != nil { if nodeFG != nil {
dsService.dispClient.Deregister(channel)
nodeFG.flowGraph.Close() nodeFG.flowGraph.Close()
} }
delete(dsService.dmlChannel2FlowGraph, channel) delete(dsService.dmlChannel2FlowGraph, channel)
@ -315,6 +321,7 @@ func (dsService *dataSyncService) close() {
// close delta flow graphs // close delta flow graphs
for channel, nodeFG := range dsService.deltaChannel2FlowGraph { for channel, nodeFG := range dsService.deltaChannel2FlowGraph {
if nodeFG != nil { if nodeFG != nil {
dsService.dispClient.Deregister(channel)
nodeFG.flowGraph.Close() nodeFG.flowGraph.Close()
} }
delete(dsService.deltaChannel2FlowGraph, channel) delete(dsService.deltaChannel2FlowGraph, channel)

View File

@ -21,14 +21,20 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
func init() { func init() {
rateCol, _ = newRateCollector() rateCol, _ = newRateCollector()
Params.Init()
} }
func TestDataSyncService_DMLFlowGraphs(t *testing.T) { func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
@ -40,17 +46,18 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
fac := genFactory() fac := genFactory()
assert.NoError(t, err) assert.NoError(t, err)
dispClient := msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
tSafe := newTSafeReplica() tSafe := newTSafeReplica()
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac) dataSyncService := newDataSyncService(ctx, replica, tSafe, dispClient, fac)
assert.NotNil(t, dataSyncService) assert.NotNil(t, dataSyncService)
t.Run("test DMLFlowGraphs", func(t *testing.T) { t.Run("test DMLFlowGraphs", func(t *testing.T) {
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) _, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1) assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) _, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1) assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
@ -68,7 +75,7 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
assert.Nil(t, fg) assert.Nil(t, fg)
assert.Error(t, err) assert.Error(t, err)
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) _, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1) assert.Len(t, dataSyncService.dmlChannel2FlowGraph, 1)
@ -88,7 +95,7 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) {
t.Run("test addFlowGraphsForDMLChannels checkReplica Failed", func(t *testing.T) { t.Run("test addFlowGraphsForDMLChannels checkReplica Failed", func(t *testing.T) {
err = dataSyncService.metaReplica.removeCollection(defaultCollectionID) err = dataSyncService.metaReplica.removeCollection(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
_, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) _, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.Error(t, err) assert.Error(t, err)
dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema()) dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema())
}) })
@ -103,9 +110,10 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
fac := genFactory() fac := genFactory()
assert.NoError(t, err) assert.NoError(t, err)
dispClient := msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
tSafe := newTSafeReplica() tSafe := newTSafeReplica()
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac) dataSyncService := newDataSyncService(ctx, replica, tSafe, dispClient, fac)
assert.NotNil(t, dataSyncService) assert.NotNil(t, dataSyncService)
t.Run("test DeltaFlowGraphs", func(t *testing.T) { t.Run("test DeltaFlowGraphs", func(t *testing.T) {
@ -160,12 +168,14 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
type DataSyncServiceSuite struct { type DataSyncServiceSuite struct {
suite.Suite suite.Suite
dispClient msgdispatcher.Client
factory dependency.Factory factory dependency.Factory
dsService *dataSyncService dsService *dataSyncService
} }
func (s *DataSyncServiceSuite) SetupSuite() { func (s *DataSyncServiceSuite) SetupSuite() {
s.factory = genFactory() s.factory = genFactory()
s.dispClient = msgdispatcher.NewClient(s.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
} }
func (s *DataSyncServiceSuite) SetupTest() { func (s *DataSyncServiceSuite) SetupTest() {
@ -176,7 +186,7 @@ func (s *DataSyncServiceSuite) SetupTest() {
s.Require().NoError(err) s.Require().NoError(err)
tSafe := newTSafeReplica() tSafe := newTSafeReplica()
s.dsService = newDataSyncService(ctx, replica, tSafe, s.factory) s.dsService = newDataSyncService(ctx, replica, tSafe, s.dispClient, s.factory)
s.Require().NoError(err) s.Require().NoError(err)
} }

View File

@ -18,22 +18,20 @@ package querynode
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/internal/util/flowgraph"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
type ( type (
@ -49,9 +47,9 @@ type queryNodeFlowGraph struct {
collectionID UniqueID collectionID UniqueID
vchannel Channel vchannel Channel
flowGraph *flowgraph.TimeTickedFlowGraph flowGraph *flowgraph.TimeTickedFlowGraph
dmlStream msgstream.MsgStream
tSafeReplica TSafeReplicaInterface tSafeReplica TSafeReplicaInterface
consumerCnt int consumerCnt int
dispClient msgdispatcher.Client
} }
// newQueryNodeFlowGraph returns a new queryNodeFlowGraph // newQueryNodeFlowGraph returns a new queryNodeFlowGraph
@ -60,16 +58,18 @@ func newQueryNodeFlowGraph(ctx context.Context,
metaReplica ReplicaInterface, metaReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface, tSafeReplica TSafeReplicaInterface,
vchannel Channel, vchannel Channel,
factory msgstream.Factory) (*queryNodeFlowGraph, error) { pos *msgstream.MsgPosition,
dispClient msgdispatcher.Client) (*queryNodeFlowGraph, error) {
q := &queryNodeFlowGraph{ q := &queryNodeFlowGraph{
collectionID: collectionID, collectionID: collectionID,
vchannel: vchannel, vchannel: vchannel,
tSafeReplica: tSafeReplica, tSafeReplica: tSafeReplica,
flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx), flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx),
dispClient: dispClient,
} }
dmStreamNode, err := q.newDmInputNode(ctx, factory, collectionID, vchannel, metrics.InsertLabel) dmStreamNode, err := q.newDmInputNode(collectionID, vchannel, pos, metrics.InsertLabel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,16 +123,18 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context,
metaReplica ReplicaInterface, metaReplica ReplicaInterface,
tSafeReplica TSafeReplicaInterface, tSafeReplica TSafeReplicaInterface,
vchannel Channel, vchannel Channel,
factory msgstream.Factory) (*queryNodeFlowGraph, error) { dispClient msgdispatcher.Client) (*queryNodeFlowGraph, error) {
q := &queryNodeFlowGraph{ q := &queryNodeFlowGraph{
collectionID: collectionID, collectionID: collectionID,
vchannel: vchannel, vchannel: vchannel,
tSafeReplica: tSafeReplica, tSafeReplica: tSafeReplica,
flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx), flowGraph: flowgraph.NewTimeTickedFlowGraph(ctx),
dispClient: dispClient,
} }
dmStreamNode, err := q.newDmInputNode(ctx, factory, collectionID, vchannel, metrics.DeleteLabel) // use nil position, let deltaFlowGraph consume from latest.
dmStreamNode, err := q.newDmInputNode(collectionID, vchannel, nil, metrics.DeleteLabel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,84 +186,45 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context,
} }
// newDmInputNode returns a new inputNode // newDmInputNode returns a new inputNode
func (q *queryNodeFlowGraph) newDmInputNode(collectionID UniqueID, vchannel Channel, pos *msgstream.MsgPosition, dataType string) (*flowgraph.InputNode, error) {
func (q *queryNodeFlowGraph) newDmInputNode(ctx context.Context, factory msgstream.Factory, collectionID UniqueID, vchannel Channel, dataType string) (*flowgraph.InputNode, error) { log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
insertStream, err := factory.NewTtMsgStream(ctx) zap.Int64("collection ID", collectionID),
zap.String("vchannel", vchannel))
var err error
var input <-chan *msgstream.MsgPack
tsBegin := time.Now()
if pos != nil && len(pos.MsgID) != 0 {
input, err = q.dispClient.Register(vchannel, pos, mqwrapper.SubscriptionPositionUnknown)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Info("QueryNode seek successfully when register to msgDispatcher",
q.dmlStream = insertStream zap.ByteString("msgID", pos.GetMsgID()),
zap.Time("tsTime", tsoutil.PhysicalTime(pos.GetTimestamp())),
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(pos.GetTimestamp()))),
zap.Duration("timeTaken", time.Since(tsBegin)))
} else {
input, err = q.dispClient.Register(vchannel, nil, mqwrapper.SubscriptionPositionLatest)
if err != nil {
return nil, err
}
log.Info("QueryNode consume successfully when register to msgDispatcher",
zap.Duration("timeTaken", time.Since(tsBegin)))
}
maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32() maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism.GetAsInt32() maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism.GetAsInt32()
name := fmt.Sprintf("dmInputNode-query-%d-%s", collectionID, vchannel) name := fmt.Sprintf("dmInputNode-query-%d-%s", collectionID, vchannel)
node := flowgraph.NewInputNode(insertStream, name, maxQueueLength, maxParallelism, typeutil.QueryNodeRole, node := flowgraph.NewInputNode(input, name, maxQueueLength, maxParallelism, typeutil.QueryNodeRole,
paramtable.GetNodeID(), collectionID, dataType) paramtable.GetNodeID(), collectionID, dataType)
return node, nil return node, nil
} }
// consumeFlowGraph would consume by channel and subName
func (q *queryNodeFlowGraph) consumeFlowGraph(channel Channel, subName ConsumeSubName) error {
if q.dmlStream == nil {
return errors.New("null dml message stream in flow graph")
}
q.dmlStream.AsConsumer([]string{channel}, subName, mqwrapper.SubscriptionPositionUnknown)
log.Info("query node flow graph consumes from PositionUnknown",
zap.Int64("collectionID", q.collectionID),
zap.String("pchannel", channel),
zap.String("vchannel", q.vchannel),
zap.String("subName", subName),
)
q.consumerCnt++
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return nil
}
// consumeFlowGraphFromLatest would consume from latest by channel and subName
func (q *queryNodeFlowGraph) consumeFlowGraphFromLatest(channel Channel, subName ConsumeSubName) error {
if q.dmlStream == nil {
return errors.New("null dml message stream in flow graph")
}
q.dmlStream.AsConsumer([]string{channel}, subName, mqwrapper.SubscriptionPositionLatest)
log.Info("query node flow graph consumes from latest",
zap.Int64("collectionID", q.collectionID),
zap.String("pchannel", channel),
zap.String("vchannel", q.vchannel),
zap.String("subName", subName),
)
q.consumerCnt++
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return nil
}
// seekQueryNodeFlowGraph would seek by position
func (q *queryNodeFlowGraph) consumeFlowGraphFromPosition(position *internalpb.MsgPosition) error {
q.dmlStream.AsConsumer([]string{position.ChannelName}, position.MsgGroup, mqwrapper.SubscriptionPositionUnknown)
start := time.Now()
err := q.dmlStream.Seek([]*internalpb.MsgPosition{position})
// setup first ts
q.tSafeReplica.setTSafe(q.vchannel, position.GetTimestamp())
ts, _ := tsoutil.ParseTS(position.GetTimestamp())
log.Info("query node flow graph seeks from position",
zap.Int64("collectionID", q.collectionID),
zap.String("pchannel", position.ChannelName),
zap.String("vchannel", q.vchannel),
zap.Time("checkpointTs", ts),
zap.Duration("tsLag", time.Since(ts)),
zap.Duration("elapse", time.Since(start)),
)
q.consumerCnt++
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return err
}
// close would close queryNodeFlowGraph // close would close queryNodeFlowGraph
func (q *queryNodeFlowGraph) close() { func (q *queryNodeFlowGraph) close() {
q.dispClient.Deregister(q.vchannel)
q.flowGraph.Close() q.flowGraph.Close()
if q.dmlStream != nil && q.consumerCnt > 0 { if q.consumerCnt > 0 {
metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(float64(q.consumerCnt)) metrics.QueryNodeNumConsumers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Sub(float64(q.consumerCnt))
} }
log.Info("stop query node flow graph", log.Info("stop query node flow graph",

View File

@ -1,82 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tSafe := newTSafeReplica()
streamingReplica, err := genSimpleReplica()
assert.NoError(t, err)
fac := genFactory()
fg, err := newQueryNodeFlowGraph(ctx,
defaultCollectionID,
streamingReplica,
tSafe,
defaultDMLChannel,
fac)
assert.NoError(t, err)
err = fg.consumeFlowGraph(defaultDMLChannel, defaultSubName)
assert.NoError(t, err)
fg.close()
}
func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streamingReplica, err := genSimpleReplica()
assert.NoError(t, err)
fac := genFactory()
tSafe := newTSafeReplica()
fg, err := newQueryNodeFlowGraph(ctx,
defaultCollectionID,
streamingReplica,
tSafe,
defaultDMLChannel,
fac)
assert.NoError(t, err)
position := &internalpb.MsgPosition{
ChannelName: defaultDMLChannel,
MsgID: []byte{},
MsgGroup: defaultSubName,
Timestamp: 0,
}
err = fg.consumeFlowGraphFromPosition(position)
assert.Error(t, err)
fg.close()
}

View File

@ -26,7 +26,6 @@ import (
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/samber/lo" "github.com/samber/lo"
) )
@ -189,31 +188,6 @@ func (l *loadSegmentsTask) watchDeltaChannel(deltaChannels []string) error {
} }
} }
}() }()
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID())
// channels as consumer
for channel, fg := range channel2FlowGraph {
pchannel := VPDeltaChannels[channel]
// use pChannel to consume
err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName)
if err != nil {
log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
break
}
}
if err != nil {
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
for _, fg := range channel2FlowGraph {
fg.flowGraph.Close()
}
gcChannels := make([]Channel, 0)
for channel := range channel2FlowGraph {
gcChannels = append(gcChannels, channel)
}
l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels)
return err
}
log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels)) log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/etcdpb"
@ -1702,6 +1703,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue())
node.etcdKV = etcdKV node.etcdKV = etcdKV
node.dispClient = msgdispatcher.NewClient(fac, typeutil.QueryNodeRole, paramtable.GetNodeID())
node.tSafeReplica = newTSafeReplica() node.tSafeReplica = newTSafeReplica()
replica, err := genSimpleReplicaWithSealSegment(ctx) replica, err := genSimpleReplicaWithSealSegment(ctx)
@ -1711,7 +1713,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
node.tSafeReplica.addTSafe(defaultDMLChannel) node.tSafeReplica.addTSafe(defaultDMLChannel)
node.tSafeReplica.addTSafe(defaultDeltaChannel) node.tSafeReplica.addTSafe(defaultDeltaChannel)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.factory) node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.dispClient, node.factory)
node.metaReplica = replica node.metaReplica = replica

View File

@ -30,6 +30,7 @@ import "C"
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"os" "os"
"path" "path"
"runtime" "runtime"
@ -106,6 +107,7 @@ type QueryNode struct {
etcdCli *clientv3.Client etcdCli *clientv3.Client
address string address string
dispClient msgdispatcher.Client
factory dependency.Factory factory dependency.Factory
scheduler *taskScheduler scheduler *taskScheduler
@ -256,6 +258,9 @@ func (node *QueryNode) Init() error {
} }
log.Info("QueryNode init rateCollector done", zap.Int64("nodeID", paramtable.GetNodeID())) log.Info("QueryNode init rateCollector done", zap.Int64("nodeID", paramtable.GetNodeID()))
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
log.Info("QueryNode init dispatcher client done", zap.Int64("nodeID", paramtable.GetNodeID()))
node.vectorStorage, err = node.factory.NewPersistentStorageChunkManager(node.queryNodeLoopCtx) node.vectorStorage, err = node.factory.NewPersistentStorageChunkManager(node.queryNodeLoopCtx)
if err != nil { if err != nil {
log.Error("QueryNode init vector storage failed", zap.Error(err)) log.Error("QueryNode init vector storage failed", zap.Error(err))
@ -283,7 +288,7 @@ func (node *QueryNode) Init() error {
node.vectorStorage, node.vectorStorage,
node.factory) node.factory)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.factory) node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.dispClient, node.factory)
node.InitSegcore() node.InitSegcore()

View File

@ -27,13 +27,14 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed" "go.etcd.io/etcd/server/v3/embed"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
var embedetcdServer *embed.Etcd var embedetcdServer *embed.Etcd
@ -98,10 +99,9 @@ func newQueryNodeMock() *QueryNode {
factory := newMessageStreamFactory() factory := newMessageStreamFactory()
svr := NewQueryNode(ctx, factory) svr := NewQueryNode(ctx, factory)
tsReplica := newTSafeReplica() tsReplica := newTSafeReplica()
svr.dispClient = msgdispatcher.NewClient(factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
replica := newCollectionReplica() svr.metaReplica = newCollectionReplica()
svr.metaReplica = replica svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, svr.dispClient, factory)
svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, factory)
svr.vectorStorage, err = factory.NewPersistentStorageChunkManager(ctx) svr.vectorStorage, err = factory.NewPersistentStorageChunkManager(ctx)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -458,7 +458,7 @@ func TestTask_releasePartitionTask(t *testing.T) {
req: genReleasePartitionsRequest(), req: genReleasePartitionsRequest(),
node: node, node: node,
} }
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) _, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err) assert.NoError(t, err)
err = task.Execute(ctx) err = task.Execute(ctx)
assert.NoError(t, err) assert.NoError(t, err)
@ -534,7 +534,7 @@ func TestTask_releasePartitionTask(t *testing.T) {
req: genReleasePartitionsRequest(), req: genReleasePartitionsRequest(),
node: node, node: node,
} }
_, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) _, err = task.node.dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, map[Channel]*msgstream.MsgPosition{defaultDMLChannel: nil})
assert.NoError(t, err) assert.NoError(t, err)
err = task.Execute(ctx) err = task.Execute(ctx)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -122,7 +122,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
} }
}() }()
channel2FlowGraph, err := w.initFlowGraph(ctx, collectionID, vChannels, VPChannels) channel2FlowGraph, err := w.initFlowGraph(collectionID, vChannels)
if err != nil { if err != nil {
return err return err
} }
@ -243,21 +243,16 @@ func (w *watchDmChannelsTask) LoadGrowingSegments(ctx context.Context, collectio
return unFlushedSegmentIDs, nil return unFlushedSegmentIDs, nil
} }
func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID UniqueID, vChannels []Channel, VPChannels map[string]string) (map[string]*queryNodeFlowGraph, error) { func (w *watchDmChannelsTask) initFlowGraph(collectionID UniqueID, vChannels []Channel) (map[string]*queryNodeFlowGraph, error) {
// So far, we don't support to enable each node with two different channel // So far, we don't support to enable each node with two different channel
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID()) consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName.GetValue(), collectionID, paramtable.GetNodeID())
// group channels by to seeking or consuming // group channels by to seeking
channel2SeekPosition := make(map[string]*internalpb.MsgPosition) channel2SeekPosition := make(map[string]*internalpb.MsgPosition)
// for channel with no position
channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition)
for _, info := range w.req.Infos { for _, info := range w.req.Infos {
if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 { if info.SeekPosition != nil && len(info.SeekPosition.MsgID) != 0 {
channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition
continue
}
info.SeekPosition.MsgGroup = consumeSubName info.SeekPosition.MsgGroup = consumeSubName
}
channel2SeekPosition[info.ChannelName] = info.SeekPosition channel2SeekPosition[info.ChannelName] = info.SeekPosition
} }
log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID)) log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID))
@ -333,49 +328,11 @@ func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID Un
) )
// add flow graph // add flow graph
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels) channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, channel2SeekPosition)
if err != nil { if err != nil {
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err)) log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
return nil, err return nil, err
} }
log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels))
// channels as consumer
for channel, fg := range channel2FlowGraph {
if _, ok := channel2AsConsumerPosition[channel]; ok {
// use pChannel to consume
err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName)
if err != nil {
log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
break
}
}
if pos, ok := channel2SeekPosition[channel]; ok {
pos.MsgGroup = consumeSubName
// use pChannel to seek
pos.ChannelName = VPChannels[channel]
err = fg.consumeFlowGraphFromPosition(pos)
if err != nil {
log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
break
}
}
}
if err != nil {
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
for _, fg := range channel2FlowGraph {
fg.flowGraph.Close()
}
gcChannels := make([]Channel, 0)
for channel := range channel2FlowGraph {
gcChannels = append(gcChannels, channel)
}
w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels)
return nil, err
}
log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
return channel2FlowGraph, nil return channel2FlowGraph, nil
} }

View File

@ -36,7 +36,7 @@ import (
// InputNode is the entry point of flowgragh // InputNode is the entry point of flowgragh
type InputNode struct { type InputNode struct {
BaseNode BaseNode
inStream msgstream.MsgStream input <-chan *msgstream.MsgPack
lastMsg *msgstream.MsgPack lastMsg *msgstream.MsgPack
name string name string
role string role string
@ -51,17 +51,6 @@ func (inNode *InputNode) IsInputNode() bool {
return true return true
} }
// Start is used to start input msgstream
func (inNode *InputNode) Start() {
}
// Close implements node
func (inNode *InputNode) Close() {
inNode.closeOnce.Do(func() {
inNode.inStream.Close()
})
}
func (inNode *InputNode) IsValidInMsg(in []Msg) bool { func (inNode *InputNode) IsValidInMsg(in []Msg) bool {
return true return true
} }
@ -71,16 +60,11 @@ func (inNode *InputNode) Name() string {
return inNode.name return inNode.name
} }
// InStream returns the internal MsgStream
func (inNode *InputNode) InStream() msgstream.MsgStream {
return inNode.inStream
}
// Operate consume a message pack from msgstream and return // Operate consume a message pack from msgstream and return
func (inNode *InputNode) Operate(in []Msg) []Msg { func (inNode *InputNode) Operate(in []Msg) []Msg {
msgPack, ok := <-inNode.inStream.Chan() msgPack, ok := <-inNode.input
if !ok { if !ok {
log.Warn("MsgStream closed", zap.Any("input node", inNode.Name())) log.Warn("input closed", zap.Any("input node", inNode.Name()))
if inNode.lastMsg != nil { if inNode.lastMsg != nil {
log.Info("trigger force sync", zap.Int64("collection", inNode.collectionID), zap.Any("position", inNode.lastMsg)) log.Info("trigger force sync", zap.Int64("collection", inNode.collectionID), zap.Any("position", inNode.lastMsg))
return []Msg{&MsgStreamMsg{ return []Msg{&MsgStreamMsg{
@ -151,15 +135,15 @@ func (inNode *InputNode) Operate(in []Msg) []Msg {
return []Msg{msgStreamMsg} return []Msg{msgStreamMsg}
} }
// NewInputNode composes an InputNode with provided MsgStream, name and parameters // NewInputNode composes an InputNode with provided input channel, name and parameters
func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32, role string, nodeID int64, collectionID int64, dataType string) *InputNode { func NewInputNode(input <-chan *msgstream.MsgPack, nodeName string, maxQueueLength int32, maxParallelism int32, role string, nodeID int64, collectionID int64, dataType string) *InputNode {
baseNode := BaseNode{} baseNode := BaseNode{}
baseNode.SetMaxQueueLength(maxQueueLength) baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism) baseNode.SetMaxParallelism(maxParallelism)
return &InputNode{ return &InputNode{
BaseNode: baseNode, BaseNode: baseNode,
inStream: inStream, input: input,
name: nodeName, name: nodeName,
role: role, role: role,
nodeID: nodeID, nodeID: nodeID,

View File

@ -40,7 +40,7 @@ func TestInputNode(t *testing.T) {
produceStream.Produce(&msgPack) produceStream.Produce(&msgPack)
nodeName := "input_node" nodeName := "input_node"
inputNode := NewInputNode(msgStream, nodeName, 100, 100, "", 0, 0, "") inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
defer inputNode.Close() defer inputNode.Close()
isInputNode := inputNode.IsInputNode() isInputNode := inputNode.IsInputNode()
@ -49,9 +49,6 @@ func TestInputNode(t *testing.T) {
name := inputNode.Name() name := inputNode.Name()
assert.Equal(t, name, nodeName) assert.Equal(t, name, nodeName)
stream := inputNode.InStream()
assert.NotNil(t, stream)
output := inputNode.Operate(nil) output := inputNode.Operate(nil)
assert.NotNil(t, output) assert.NotNil(t, output)
msg, ok := output[0].(*MsgStreamMsg) msg, ok := output[0].(*MsgStreamMsg)

View File

@ -76,6 +76,10 @@ func (bm *MockMsg) SetPosition(position *MsgPosition) {
} }
func (bm *MockMsg) VChannel() string {
return ""
}
func Test_GenerateMsgStreamMsg(t *testing.T) { func Test_GenerateMsgStreamMsg(t *testing.T) {
messages := make([]msgstream.TsMsg, 1) messages := make([]msgstream.TsMsg, 1)
messages[0] = &MockMsg{ messages[0] = &MockMsg{

View File

@ -74,7 +74,7 @@ func TestNodeCtx_Start(t *testing.T) {
produceStream.Produce(&msgPack) produceStream.Produce(&msgPack)
nodeName := "input_node" nodeName := "input_node"
inputNode := NewInputNode(msgStream, nodeName, 100, 100, "", 0, 0, "") inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "")
node := &nodeCtx{ node := &nodeCtx{
node: inputNode, node: inputNode,