From 56c94cdfa76c4cdf9fd5832e5e07e8c65ded21ef Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Sun, 8 Oct 2023 21:37:33 +0800 Subject: [PATCH] Add channel manager in DataNode (#27308) Signed-off-by: yangxuan --- internal/datanode/channel_manager.go | 505 ++++++++++++++++++ internal/datanode/channel_manager_test.go | 192 +++++++ internal/datanode/data_node.go | 3 +- internal/datanode/data_sync_service.go | 135 ++++- internal/datanode/data_sync_service_test.go | 50 +- .../datanode/flow_graph_insert_buffer_node.go | 2 +- internal/datanode/flow_graph_manager.go | 13 +- pkg/util/paramtable/component_param.go | 10 + 8 files changed, 890 insertions(+), 20 deletions(-) create mode 100644 internal/datanode/channel_manager.go create mode 100644 internal/datanode/channel_manager_test.go diff --git a/internal/datanode/channel_manager.go b/internal/datanode/channel_manager.go new file mode 100644 index 0000000000..9b5a0fac4c --- /dev/null +++ b/internal/datanode/channel_manager.go @@ -0,0 +1,505 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datanode + +import ( + "context" + "sync" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/atomic" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type releaseFunc func(channel string) + +type ChannelManager struct { + mu sync.RWMutex + dn *DataNode + + communicateCh chan *opState + runningFlowgraphs *flowgraphManager + opRunners *typeutil.ConcurrentMap[string, *opRunner] // channel -> runner + abnormals *typeutil.ConcurrentMap[int64, string] // OpID -> Channel + + releaseFunc releaseFunc + + closeCh chan struct{} + closeOnce sync.Once + closeWaiter sync.WaitGroup +} + +func NewChannelManager(dn *DataNode) *ChannelManager { + fm := newFlowgraphManager() + cm := ChannelManager{ + dn: dn, + + communicateCh: make(chan *opState, 100), + runningFlowgraphs: fm, + opRunners: typeutil.NewConcurrentMap[string, *opRunner](), + abnormals: typeutil.NewConcurrentMap[int64, string](), + + releaseFunc: fm.release, + + closeCh: make(chan struct{}), + } + + return &cm +} + +func (m *ChannelManager) Submit(info *datapb.ChannelWatchInfo) error { + channel := info.GetVchan().GetChannelName() + runner := m.getOrCreateRunner(channel) + return runner.Enqueue(info) +} + +func (m *ChannelManager) GetProgress(info *datapb.ChannelWatchInfo) *datapb.ChannelOperationProgressResponse { + m.mu.RLock() + defer m.mu.RUnlock() + resp := &datapb.ChannelOperationProgressResponse{ + Status: merr.Status(nil), + OpID: info.GetOpID(), + } + + channel := info.GetVchan().GetChannelName() + switch info.GetState() { + case datapb.ChannelWatchState_ToWatch: + if m.runningFlowgraphs.existWithOpID(channel, info.GetOpID()) { + resp.State = datapb.ChannelWatchState_WatchSuccess + return resp + } + + if runner, ok := m.opRunners.Get(channel); ok { + if runner.Exist(info.GetOpID()) { + resp.State = datapb.ChannelWatchState_ToWatch + } else { + resp.State = datapb.ChannelWatchState_WatchFailure + } + return resp + } + resp.State = datapb.ChannelWatchState_WatchFailure + return resp + + case datapb.ChannelWatchState_ToRelease: + if !m.runningFlowgraphs.exist(channel) { + resp.State = datapb.ChannelWatchState_ReleaseSuccess + return resp + } + if runner, ok := m.opRunners.Get(channel); ok && runner.Exist(info.GetOpID()) { + resp.State = datapb.ChannelWatchState_ToRelease + return resp + } + + resp.State = datapb.ChannelWatchState_ReleaseFailure + return resp + default: + err := merr.WrapErrParameterInvalid("ToWatch or ToRelease", info.GetState().String()) + log.Warn("fail to get progress", zap.Error(err)) + resp.Status = merr.Status(err) + return resp + } +} + +func (m *ChannelManager) Close() { + m.closeOnce.Do(func() { + m.opRunners.Range(func(channel string, runner *opRunner) bool { + runner.Close() + return true + }) + m.runningFlowgraphs.close() + close(m.closeCh) + m.closeWaiter.Wait() + }) +} + +func (m *ChannelManager) Start() { + m.closeWaiter.Add(2) + + go m.runningFlowgraphs.start(&m.closeWaiter) + go func() { + defer m.closeWaiter.Done() + log.Info("DataNode ChannelManager start") + for { + select { + case opState := <-m.communicateCh: + m.handleOpState(opState) + case <-m.closeCh: + log.Info("DataNode ChannelManager exit") + return + } + } + }() +} + +func (m *ChannelManager) handleOpState(opState *opState) { + m.mu.Lock() + defer m.mu.Unlock() + log := log.With( + zap.Int64("opID", opState.opID), + zap.String("channel", opState.channel), + zap.String("State", opState.state.String()), + ) + switch opState.state { + case datapb.ChannelWatchState_WatchSuccess: + log.Info("Success to watch") + m.runningFlowgraphs.Add(opState.fg) + m.finishOp(opState.opID, opState.channel) + + case datapb.ChannelWatchState_WatchFailure: + log.Info("Fail to watch") + m.finishOp(opState.opID, opState.channel) + + case datapb.ChannelWatchState_ReleaseSuccess: + log.Info("Success to release") + m.finishOp(opState.opID, opState.channel) + m.destoryRunner(opState.channel) + + case datapb.ChannelWatchState_ReleaseFailure: + log.Info("Fail to release, add channel to abnormal lists") + m.abnormals.Insert(opState.opID, opState.channel) + m.finishOp(opState.opID, opState.channel) + m.destoryRunner(opState.channel) + } +} + +func (m *ChannelManager) getOrCreateRunner(channel string) *opRunner { + runner, loaded := m.opRunners.GetOrInsert(channel, NewOpRunner(channel, m.dn, m.releaseFunc, m.communicateCh)) + if !loaded { + runner.Start() + } + return runner +} + +func (m *ChannelManager) destoryRunner(channel string) { + if runner, loaded := m.opRunners.GetAndRemove(channel); loaded { + runner.Close() + } +} + +func (m *ChannelManager) finishOp(opID int64, channel string) { + if runner, loaded := m.opRunners.Get(channel); loaded { + runner.FinishOp(opID) + } +} + +type opInfo struct { + tickler *tickler +} + +type opRunner struct { + channel string + dn *DataNode + releaseFunc releaseFunc + + guard sync.RWMutex + allOps map[UniqueID]*opInfo // opID -> tickler + opsInQueue chan *datapb.ChannelWatchInfo + resultCh chan *opState + + closeWg sync.WaitGroup + closeOnce sync.Once + closeCh chan struct{} +} + +func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opState) *opRunner { + return &opRunner{ + channel: channel, + dn: dn, + releaseFunc: f, + opsInQueue: make(chan *datapb.ChannelWatchInfo, 10), + allOps: make(map[UniqueID]*opInfo), + resultCh: resultCh, + closeCh: make(chan struct{}), + } +} + +func (r *opRunner) Start() { + r.closeWg.Add(1) + go func() { + defer r.closeWg.Done() + for { + select { + case info := <-r.opsInQueue: + r.NotifyState(r.Execute(info)) + case <-r.closeCh: + return + } + } + }() +} + +func (r *opRunner) FinishOp(opID UniqueID) { + r.guard.Lock() + defer r.guard.Unlock() + delete(r.allOps, opID) +} + +func (r *opRunner) Exist(opID UniqueID) bool { + r.guard.RLock() + defer r.guard.RUnlock() + _, ok := r.allOps[opID] + return ok +} + +func (r *opRunner) Enqueue(info *datapb.ChannelWatchInfo) error { + if info.GetState() != datapb.ChannelWatchState_ToWatch && + info.GetState() != datapb.ChannelWatchState_ToRelease { + return errors.New("Invalid channel watch state") + } + + r.guard.Lock() + defer r.guard.Unlock() + if _, ok := r.allOps[info.GetOpID()]; !ok { + r.opsInQueue <- info + r.allOps[info.GetOpID()] = &opInfo{} + } + return nil +} + +func (r *opRunner) UnfinishedOpSize() int { + r.guard.RLock() + defer r.guard.RUnlock() + return len(r.allOps) +} + +// Execute excutes channel operations, channel state is validated during enqueue +func (r *opRunner) Execute(info *datapb.ChannelWatchInfo) *opState { + log.Info("Start to execute channel operation", + zap.String("channel", info.GetVchan().GetChannelName()), + zap.Int64("opID", info.GetOpID()), + zap.String("state", info.GetState().String()), + ) + if info.GetState() == datapb.ChannelWatchState_ToWatch { + return r.watchWithTimer(info) + } + + // ToRelease state + return releaseWithTimer(r.releaseFunc, info.GetVchan().GetChannelName(), info.GetOpID()) +} + +// watchWithTimer will return WatchFailure after WatchTimeoutInterval +func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { + opState := &opState{ + channel: info.GetVchan().GetChannelName(), + opID: info.GetOpID(), + } + log := log.With(zap.String("channel", opState.channel), zap.Int64("opID", opState.opID)) + + r.guard.Lock() + opInfo, ok := r.allOps[info.GetOpID()] + if !ok { + opState.state = datapb.ChannelWatchState_WatchFailure + return opState + } + tickler := newTickler() + opInfo.tickler = tickler + r.guard.Unlock() + + var ( + successSig = make(chan struct{}, 1) + waiter sync.WaitGroup + ) + + watchTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) + ctx, cancel := context.WithTimeout(context.Background(), watchTimeout) + defer cancel() + + startTimer := func(wg *sync.WaitGroup) { + defer wg.Done() + + timer := time.NewTimer(watchTimeout) + defer timer.Stop() + + log.Info("Start timer for ToWatch operation", zap.Duration("timeout", watchTimeout)) + for { + select { + case <-timer.C: + // watch timeout + tickler.close() + cancel() + log.Info("Stop timer for ToWatch operation timeout", zap.Duration("timeout", watchTimeout)) + return + + case <-tickler.progressSig: + timer.Reset(watchTimeout) + + case <-successSig: + // watch success + log.Info("Stop timer for ToWatch operation succeeded", zap.Duration("timeout", watchTimeout)) + return + } + } + } + + waiter.Add(2) + go startTimer(&waiter) + go func() { + defer waiter.Done() + fg, err := executeWatch(ctx, r.dn, info, tickler) + if err != nil { + opState.state = datapb.ChannelWatchState_WatchFailure + } else { + opState.state = datapb.ChannelWatchState_WatchSuccess + opState.fg = fg + successSig <- struct{}{} + } + }() + + waiter.Wait() + return opState +} + +// releaseWithTimer will return ReleaseFailure after WatchTimeoutInterval +func releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *opState { + opState := &opState{ + channel: channel, + opID: opID, + } + var ( + successSig = make(chan struct{}, 1) + waiter sync.WaitGroup + ) + + log := log.With(zap.String("channel", channel)) + startTimer := func(wg *sync.WaitGroup) { + defer wg.Done() + releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) + timer := time.NewTimer(releaseTimeout) + defer timer.Stop() + + log.Info("Start timer for ToRelease operation", zap.Duration("timeout", releaseTimeout)) + for { + select { + case <-timer.C: + log.Info("Stop timer for ToRelease operation timeout", zap.Duration("timeout", releaseTimeout)) + opState.state = datapb.ChannelWatchState_ReleaseFailure + return + + case <-successSig: + log.Info("Stop timer for ToRelease operation succeeded", zap.Duration("timeout", releaseTimeout)) + opState.state = datapb.ChannelWatchState_ReleaseSuccess + return + } + } + } + + waiter.Add(1) + go startTimer(&waiter) + go func() { + // TODO: failure should panic this DN, but we're not sure how + // to recover when releaseFunc stuck. + // Whenever we see a stuck, it's a bug need to be fixed. + // In case of the unknown behavior after the stuck of release, + // we'll mark this channel abnormal in this DN. This goroutine might never return. + // + // The channel can still be balanced into other DNs, but not on this one. + // ExclusiveConsumer error happens when the same DN subscribes the same pchannel twice. + releaseFunc(opState.channel) + successSig <- struct{}{} + }() + + waiter.Wait() + return opState +} + +func (r *opRunner) NotifyState(state *opState) { + r.resultCh <- state +} + +func (r *opRunner) Close() { + r.guard.Lock() + for _, info := range r.allOps { + if info.tickler != nil { + info.tickler.close() + } + } + r.guard.Unlock() + + r.closeOnce.Do(func() { + close(r.closeCh) + r.closeWg.Wait() + }) +} + +type opState struct { + channel string + opID int64 + state datapb.ChannelWatchState + fg *dataSyncService +} + +// executeWatch will always return, won't be stuck, either success or fail. +func executeWatch(ctx context.Context, dn *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) { + dataSyncService, err := newDataSyncService(ctx, dn, info, tickler) + if err != nil { + return nil, err + } + + dataSyncService.start() + + return dataSyncService, nil +} + +// tickler counts every time when called inc(), +type tickler struct { + count *atomic.Int32 + total *atomic.Int32 + closedSig *atomic.Bool + + progressSig chan struct{} +} + +func (t *tickler) inc() { + t.count.Inc() + t.progressSig <- struct{}{} +} + +func (t *tickler) setTotal(total int32) { + t.total.Store(total) +} + +// progress returns the count over total if total is set +// else just return the count number. +func (t *tickler) progress() int32 { + if t.total.Load() == 0 { + return t.count.Load() + } + return (t.count.Load() / t.total.Load()) * 100 +} + +func (t *tickler) close() { + t.closedSig.CompareAndSwap(false, true) +} + +func (t *tickler) closed() bool { + return t.closedSig.Load() +} + +func newTickler() *tickler { + return &tickler{ + count: atomic.NewInt32(0), + total: atomic.NewInt32(0), + closedSig: atomic.NewBool(false), + progressSig: make(chan struct{}, 200), + } +} diff --git a/internal/datanode/channel_manager_test.go b/internal/datanode/channel_manager_test.go new file mode 100644 index 0000000000..e7a8eb4a99 --- /dev/null +++ b/internal/datanode/channel_manager_test.go @@ -0,0 +1,192 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datanode + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestChannelManagerSuite(t *testing.T) { + suite.Run(t, new(ChannelManagerSuite)) +} + +type ChannelManagerSuite struct { + suite.Suite + + node *DataNode + manager *ChannelManager +} + +func (s *ChannelManagerSuite) SetupTest() { + ctx := context.Background() + s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) + s.manager = NewChannelManager(s.node) +} + +func getWatchInfoByOpID(opID UniqueID, channel string, state datapb.ChannelWatchState) *datapb.ChannelWatchInfo { + return &datapb.ChannelWatchInfo{ + OpID: opID, + State: state, + Vchan: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: channel, + }, + } +} + +func (s *ChannelManagerSuite) TearDownTest() { + s.manager.Close() +} + +func (s *ChannelManagerSuite) TestWatchFail() { + channel := "by-dev-rootcoord-dml-2" + paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001") + defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key) + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + s.Require().Equal(0, s.manager.opRunners.Len()) + err := s.manager.Submit(info) + s.Require().NoError(err) + + opState := <-s.manager.communicateCh + s.Require().NotNil(opState) + s.Equal(info.GetOpID(), opState.opID) + s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state) + + s.manager.handleOpState(opState) + + resp := s.manager.GetProgress(info) + s.Equal(datapb.ChannelWatchState_WatchFailure, resp.GetState()) +} + +func (s *ChannelManagerSuite) TestReleaseStuck() { + var ( + channel = "by-dev-rootcoord-dml-2" + stuckSig = make(chan struct{}) + ) + s.manager.releaseFunc = func(channel string) { + stuckSig <- struct{}{} + } + + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + s.Require().Equal(0, s.manager.opRunners.Len()) + err := s.manager.Submit(info) + s.Require().NoError(err) + + opState := <-s.manager.communicateCh + s.Require().NotNil(opState) + + s.manager.handleOpState(opState) + + releaseInfo := getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease) + paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.1") + defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key) + + err = s.manager.Submit(releaseInfo) + s.NoError(err) + + opState = <-s.manager.communicateCh + s.Require().NotNil(opState) + s.Equal(datapb.ChannelWatchState_ReleaseFailure, opState.state) + s.manager.handleOpState(opState) + + s.Equal(1, s.manager.abnormals.Len()) + abchannel, ok := s.manager.abnormals.Get(releaseInfo.GetOpID()) + s.True(ok) + s.Equal(channel, abchannel) + + <-stuckSig + + resp := s.manager.GetProgress(releaseInfo) + s.Equal(datapb.ChannelWatchState_ReleaseFailure, resp.GetState()) +} + +func (s *ChannelManagerSuite) TestSubmitIdempotent() { + channel := "by-dev-rootcoord-dml-1" + + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + s.Require().Equal(0, s.manager.opRunners.Len()) + + for i := 0; i < 10; i++ { + err := s.manager.Submit(info) + s.NoError(err) + } + + s.Equal(1, s.manager.opRunners.Len()) + s.True(s.manager.opRunners.Contain(channel)) + + runner, ok := s.manager.opRunners.Get(channel) + s.True(ok) + s.Equal(1, runner.UnfinishedOpSize()) +} + +func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() { + channel := "by-dev-rootcoord-dml-0" + + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + + err := s.manager.Submit(info) + s.NoError(err) + + opState := <-s.manager.communicateCh + s.NotNil(opState) + s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state) + s.NotNil(opState.fg) + s.Equal(info.GetOpID(), opState.fg.opID) + + resp := s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_ToWatch, resp.GetState()) + + s.manager.handleOpState(opState) + s.Equal(1, s.manager.runningFlowgraphs.getFlowGraphNum()) + s.True(s.manager.opRunners.Contain(info.GetVchan().GetChannelName())) + s.Equal(1, s.manager.opRunners.Len()) + + resp = s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_WatchSuccess, resp.GetState()) + + // release + info = getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease) + + err = s.manager.Submit(info) + s.NoError(err) + + resp = s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_ToRelease, resp.GetState()) + + opState = <-s.manager.communicateCh + s.NotNil(opState) + s.Equal(datapb.ChannelWatchState_ReleaseSuccess, opState.state) + s.manager.handleOpState(opState) + + resp = s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_ReleaseSuccess, resp.GetState()) + + s.Equal(0, s.manager.runningFlowgraphs.getFlowGraphNum()) + s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName())) + s.Equal(0, s.manager.opRunners.Len()) +} diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 80e3e26320..9156b6cc41 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -359,7 +359,8 @@ func (node *DataNode) Start() error { // Start node watch node go node.StartWatchChannels(node.ctx) - go node.flowgraphManager.start() + node.stopWaiter.Add(1) + go node.flowgraphManager.start(&node.stopWaiter) node.UpdateStateCode(commonpb.StateCode_Healthy) }) diff --git a/internal/datanode/data_sync_service.go b/internal/datanode/data_sync_service.go index d36827bd03..ae0c649cf7 100644 --- a/internal/datanode/data_sync_service.go +++ b/internal/datanode/data_sync_service.go @@ -43,7 +43,8 @@ import ( type dataSyncService struct { ctx context.Context cancelFn context.CancelFunc - channel Channel // channel stores meta of channel + channel Channel // channel stores meta of channel + opID int64 collectionID UniqueID // collection id of vchan for which this data sync service serves vchannelName string @@ -137,24 +138,81 @@ func (dsService *dataSyncService) clearGlobalFlushingCache() { dsService.flushingSegCache.Remove(segments...) } -// getSegmentInfos return the SegmentInfo details according to the given ids through RPC to datacoord -// TODO: add a broker for the rpc -func getSegmentInfos(ctx context.Context, datacoord types.DataCoordClient, segmentIDs []int64) ([]*datapb.SegmentInfo, error) { - infoResp, err := datacoord.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo), - commonpbutil.WithMsgID(0), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - SegmentIDs: segmentIDs, - IncludeUnHealthy: true, - }) - if err := merr.CheckRPCCall(infoResp, err); err != nil { - log.Error("Fail to get SegmentInfo by ids from datacoord", zap.Error(err)) +func getChannelWithTickler(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler, unflushed, flushed []*datapb.SegmentInfo) (Channel, error) { + var ( + channelName = info.GetVchan().GetChannelName() + collectionID = info.GetVchan().GetCollectionID() + recoverTs = info.GetVchan().GetSeekPosition().GetTimestamp() + ) + + // init channel meta + channel := newChannel(channelName, collectionID, info.GetSchema(), node.rootCoord, node.chunkManager) + + // tickler will update addSegment progress to watchInfo + futures := make([]*conc.Future[any], 0, len(unflushed)+len(flushed)) + tickler.setTotal(int32(len(unflushed) + len(flushed))) + + for _, us := range unflushed { + log.Info("recover growing segments from checkpoints", + zap.String("vChannelName", us.GetInsertChannel()), + zap.Int64("segmentID", us.GetID()), + zap.Int64("numRows", us.GetNumOfRows()), + ) + + // avoid closure capture iteration variable + segment := us + future := getOrCreateIOPool().Submit(func() (interface{}, error) { + if err := channel.addSegment(initCtx, addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: segment.GetID(), + collID: segment.CollectionID, + partitionID: segment.PartitionID, + numOfRows: segment.GetNumOfRows(), + statsBinLogs: segment.Statslogs, + binLogs: segment.GetBinlogs(), + endPos: segment.GetDmlPosition(), + recoverTs: recoverTs, + }); err != nil { + return nil, err + } + tickler.inc() + return nil, nil + }) + futures = append(futures, future) + } + + for _, fs := range flushed { + log.Info("recover sealed segments form checkpoints", + zap.String("vChannelName", fs.GetInsertChannel()), + zap.Int64("segmentID", fs.GetID()), + zap.Int64("numRows", fs.GetNumOfRows()), + ) + // avoid closure capture iteration variable + segment := fs + future := getOrCreateIOPool().Submit(func() (interface{}, error) { + if err := channel.addSegment(initCtx, addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: segment.GetID(), + collID: segment.GetCollectionID(), + partitionID: segment.GetPartitionID(), + numOfRows: segment.GetNumOfRows(), + statsBinLogs: segment.GetStatslogs(), + binLogs: segment.GetBinlogs(), + recoverTs: recoverTs, + }); err != nil { + return nil, err + } + tickler.inc() + return nil, nil + }) + futures = append(futures, future) + } + + if err := conc.AwaitAll(futures...); err != nil { return nil, err } - return infoResp.Infos, nil + return channel, nil } // getChannelWithEtcdTickler updates progress into etcd when a new segment is added into channel. @@ -271,6 +329,7 @@ func getServiceWithChannel(initCtx context.Context, node *DataNode, info *datapb flushCh: flushCh, resendTTCh: resendTTCh, delBufferManager: delBufferManager, + opID: info.GetOpID(), dispClient: node.dispClient, msFactory: node.factory, @@ -375,3 +434,47 @@ func newServiceWithEtcdTickler(initCtx context.Context, node *DataNode, info *da return getServiceWithChannel(initCtx, node, info, channel, unflushedSegmentInfos, flushedSegmentInfos) } + +// newDataSyncService gets a dataSyncService, but flowgraphs are not running +// initCtx is used to init the dataSyncService only, if initCtx.Canceled or initCtx.Timeout +// newDataSyncService stops and returns the initCtx.Err() +// NOTE: compactiable for event manager +func newDataSyncService(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) { + // recover segment checkpoints + unflushedSegmentInfos, err := getSegmentInfos(initCtx, node.dataCoord, info.GetVchan().GetUnflushedSegmentIds()) + if err != nil { + return nil, err + } + flushedSegmentInfos, err := getSegmentInfos(initCtx, node.dataCoord, info.GetVchan().GetFlushedSegmentIds()) + if err != nil { + return nil, err + } + + // init channel meta + channel, err := getChannelWithTickler(initCtx, node, info, tickler, unflushedSegmentInfos, flushedSegmentInfos) + if err != nil { + return nil, err + } + + return getServiceWithChannel(initCtx, node, info, channel, unflushedSegmentInfos, flushedSegmentInfos) +} + +// getSegmentInfos return the SegmentInfo details according to the given ids through RPC to datacoord +// TODO: add a broker for the rpc +func getSegmentInfos(ctx context.Context, datacoord types.DataCoordClient, segmentIDs []int64) ([]*datapb.SegmentInfo, error) { + infoResp, err := datacoord.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo), + commonpbutil.WithMsgID(0), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + SegmentIDs: segmentIDs, + IncludeUnHealthy: true, + }) + if err := merr.CheckRPCCall(infoResp, err); err != nil { + log.Error("Fail to get SegmentInfo by ids from datacoord", zap.Error(err)) + return nil, err + } + + return infoResp.Infos, nil +} diff --git a/internal/datanode/data_sync_service_test.go b/internal/datanode/data_sync_service_test.go index 4ba62077e0..4c8e3c6056 100644 --- a/internal/datanode/data_sync_service_test.go +++ b/internal/datanode/data_sync_service_test.go @@ -121,7 +121,7 @@ type testInfo struct { description string } -func TestDataSyncService_getDataSyncService(t *testing.T) { +func TestDataSyncService_newDataSyncService(t *testing.T) { ctx := context.Background() tests := []*testInfo{ @@ -715,3 +715,51 @@ func TestGetChannelLatestMsgID(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, id) } + +func TestGetChannelWithTickler(t *testing.T) { + channelName := "by-dev-rootcoord-dml-0" + info := getWatchInfoByOpID(100, channelName, datapb.ChannelWatchState_ToWatch) + node := newIDLEDataNodeMock(context.Background(), schemapb.DataType_Int64) + node.chunkManager = storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) + defer node.chunkManager.RemoveWithPrefix(context.Background(), node.chunkManager.RootPath()) + + unflushed := []*datapb.SegmentInfo{ + { + ID: 100, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + { + ID: 101, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + } + + flushed := []*datapb.SegmentInfo{ + { + ID: 200, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + { + ID: 201, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + } + + channel, err := getChannelWithTickler(context.TODO(), node, info, newTickler(), unflushed, flushed) + assert.NoError(t, err) + assert.NotNil(t, channel) + assert.Equal(t, channelName, channel.getChannelName(100)) + assert.Equal(t, int64(1), channel.getCollectionID()) + assert.True(t, channel.hasSegment(100, true)) + assert.True(t, channel.hasSegment(101, true)) + assert.True(t, channel.hasSegment(200, true)) + assert.True(t, channel.hasSegment(201, true)) +} diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index 69b655dd17..b541d7cb06 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -763,7 +763,7 @@ func newInsertBufferNode( commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), commonpbutil.WithMsgID(0), commonpbutil.WithTimeStamp(ts), - commonpbutil.WithSourceID(config.serverID), + commonpbutil.WithSourceID(paramtable.GetNodeID()), ), ChannelName: config.vChannelName, Timestamp: ts, diff --git a/internal/datanode/flow_graph_manager.go b/internal/datanode/flow_graph_manager.go index 5d175135b2..ac14cbb11c 100644 --- a/internal/datanode/flow_graph_manager.go +++ b/internal/datanode/flow_graph_manager.go @@ -49,7 +49,8 @@ func newFlowgraphManager() *flowgraphManager { } } -func (fm *flowgraphManager) start() { +func (fm *flowgraphManager) start(waiter *sync.WaitGroup) { + defer waiter.Done() ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { @@ -115,6 +116,11 @@ func (fm *flowgraphManager) execute(totalMemory uint64) { } } +func (fm *flowgraphManager) Add(ds *dataSyncService) { + fm.flowgraphs.Insert(ds.vchannelName, ds) + metrics.DataNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() +} + func (fm *flowgraphManager) addAndStartWithEtcdTickler(dn *DataNode, vchan *datapb.VchannelInfo, schema *schemapb.CollectionSchema, tickler *etcdTickler) error { log := log.With(zap.String("channel", vchan.GetChannelName())) if fm.flowgraphs.Contain(vchan.GetChannelName()) { @@ -215,6 +221,11 @@ func (fm *flowgraphManager) exist(vchan string) bool { return exist } +func (fm *flowgraphManager) existWithOpID(vchan string, opID UniqueID) bool { + ds, exist := fm.getFlowgraphService(vchan) + return exist && ds.opID == opID +} + // getFlowGraphNum returns number of flow graphs. func (fm *flowgraphManager) getFlowGraphNum() int { return fm.flowgraphs.Len() diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 475b5f1929..8bd4c81bdd 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1950,6 +1950,7 @@ type dataCoordConfig struct { WatchTimeoutInterval ParamItem `refreshable:"false"` ChannelBalanceSilentDuration ParamItem `refreshable:"true"` ChannelBalanceInterval ParamItem `refreshable:"true"` + ChannelOperationRPCTimeout ParamItem `refreshable:"true"` // --- SEGMENTS --- SegmentMaxSize ParamItem `refreshable:"false"` @@ -2027,6 +2028,15 @@ func (p *dataCoordConfig) init(base *BaseTable) { } p.ChannelBalanceInterval.Init(base.mgr) + p.ChannelOperationRPCTimeout = ParamItem{ + Key: "dataCoord.channel.notifyChannelOperationTimeout", + Version: "2.2.3", + DefaultValue: "5", + Doc: "Timeout notifing channel operations (in seconds).", + Export: true, + } + p.ChannelOperationRPCTimeout.Init(base.mgr) + p.SegmentMaxSize = ParamItem{ Key: "dataCoord.segment.maxSize", Version: "2.0.0",