diff --git a/internal/datanode/flow_graph_manager.go b/internal/datanode/flow_graph_manager.go index 22d95d7bc0..43abf6d7ab 100644 --- a/internal/datanode/flow_graph_manager.go +++ b/internal/datanode/flow_graph_manager.go @@ -31,10 +31,11 @@ import ( "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type flowgraphManager struct { - flowgraphs sync.Map // vChannelName -> dataSyncService + flowgraphs *typeutil.ConcurrentMap[string, *dataSyncService] closeCh chan struct{} closeOnce sync.Once @@ -42,7 +43,8 @@ type flowgraphManager struct { func newFlowgraphManager() *flowgraphManager { return &flowgraphManager{ - closeCh: make(chan struct{}), + flowgraphs: typeutil.NewConcurrentMap[string, *dataSyncService](), + closeCh: make(chan struct{}), } } @@ -75,12 +77,12 @@ func (fm *flowgraphManager) execute(totalMemory uint64) { channel string bufferSize int64 }, 0) - fm.flowgraphs.Range(func(key, value interface{}) bool { - size := value.(*dataSyncService).channel.getTotalMemorySize() + fm.flowgraphs.Range(func(key string, value *dataSyncService) bool { + size := value.channel.getTotalMemorySize() channels = append(channels, struct { channel string bufferSize int64 - }{key.(string), size}) + }{key, size}) total += size return true }) @@ -95,8 +97,8 @@ func (fm *flowgraphManager) execute(totalMemory uint64) { sort.Slice(channels, func(i, j int) bool { return channels[i].bufferSize > channels[j].bufferSize }) - if fg, ok := fm.flowgraphs.Load(channels[0].channel); ok { // sync the first channel with the largest memory usage - fg.(*dataSyncService).channel.forceToSync() + if fg, ok := fm.flowgraphs.Get(channels[0].channel); ok { // sync the first channel with the largest memory usage + fg.channel.forceToSync() log.Info("notify flowgraph to sync", zap.String("channel", channels[0].channel), zap.Int64("bufferSize", channels[0].bufferSize)) } @@ -104,7 +106,7 @@ func (fm *flowgraphManager) execute(totalMemory uint64) { func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo, schema *schemapb.CollectionSchema, tickler *tickler) error { log := log.With(zap.String("channel", vchan.GetChannelName())) - if _, ok := fm.flowgraphs.Load(vchan.GetChannelName()); ok { + if fm.flowgraphs.Contain(vchan.GetChannelName()) { log.Warn("try to add an existed DataSyncService") return nil } @@ -118,15 +120,15 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo return err } dataSyncService.start() - fm.flowgraphs.Store(vchan.GetChannelName(), dataSyncService) + fm.flowgraphs.Insert(vchan.GetChannelName(), dataSyncService) metrics.DataNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() return nil } func (fm *flowgraphManager) release(vchanName string) { - if fg, loaded := fm.flowgraphs.LoadAndDelete(vchanName); loaded { - fg.(*dataSyncService).close() + if fg, loaded := fm.flowgraphs.GetAndRemove(vchanName); loaded { + fg.close() metrics.DataNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() } rateCol.removeFlowGraphChannel(vchanName) @@ -135,8 +137,7 @@ func (fm *flowgraphManager) release(vchanName string) { func (fm *flowgraphManager) getFlushCh(segID UniqueID) (chan<- flushMsg, error) { var flushCh chan flushMsg - fm.flowgraphs.Range(func(key, value interface{}) bool { - fg := value.(*dataSyncService) + fm.flowgraphs.Range(func(key string, fg *dataSyncService) bool { if fg.channel.hasSegment(segID, true) { flushCh = fg.flushCh return false @@ -156,8 +157,7 @@ func (fm *flowgraphManager) getChannel(segID UniqueID) (Channel, error) { rep Channel exists = false ) - fm.flowgraphs.Range(func(key, value interface{}) bool { - fg := value.(*dataSyncService) + fm.flowgraphs.Range(func(key string, fg *dataSyncService) bool { if fg.channel.hasSegment(segID, true) { exists = true rep = fg.channel @@ -178,8 +178,7 @@ func (fm *flowgraphManager) getChannel(segID UniqueID) (Channel, error) { // these segments will be resent. func (fm *flowgraphManager) resendTT() []UniqueID { var unFlushedSegments []UniqueID - fm.flowgraphs.Range(func(key, value interface{}) bool { - fg := value.(*dataSyncService) + fm.flowgraphs.Range(func(key string, fg *dataSyncService) bool { segIDs := fg.channel.listNotFlushedSegmentIDs() if len(segIDs) > 0 { log.Info("un-flushed segments found, stats will be resend", @@ -195,12 +194,7 @@ func (fm *flowgraphManager) resendTT() []UniqueID { } func (fm *flowgraphManager) getFlowgraphService(vchan string) (*dataSyncService, bool) { - fg, ok := fm.flowgraphs.Load(vchan) - if ok { - return fg.(*dataSyncService), ok - } - - return nil, ok + return fm.flowgraphs.Get(vchan) } func (fm *flowgraphManager) exist(vchan string) bool { @@ -210,21 +204,16 @@ func (fm *flowgraphManager) exist(vchan string) bool { // getFlowGraphNum returns number of flow graphs. func (fm *flowgraphManager) getFlowGraphNum() int { - length := 0 - fm.flowgraphs.Range(func(_, _ interface{}) bool { - length++ - return true - }) - return length + return fm.flowgraphs.Len() } func (fm *flowgraphManager) dropAll() { log.Info("start drop all flowgraph resources in DataNode") - fm.flowgraphs.Range(func(key, value interface{}) bool { - value.(*dataSyncService).close() - fm.flowgraphs.Delete(key.(string)) + fm.flowgraphs.Range(func(key string, value *dataSyncService) bool { + value.close() + fm.flowgraphs.GetAndRemove(key) - log.Info("successfully dropped flowgraph", zap.String("vChannelName", key.(string))) + log.Info("successfully dropped flowgraph", zap.String("vChannelName", key)) return true }) } diff --git a/internal/datanode/flow_graph_manager_test.go b/internal/datanode/flow_graph_manager_test.go index 8f4a6a662b..b485aa1bec 100644 --- a/internal/datanode/flow_graph_manager_test.go +++ b/internal/datanode/flow_graph_manager_test.go @@ -218,19 +218,19 @@ func TestFlowGraphManager(t *testing.T) { } err = fm.addAndStart(node, vchan, nil, genTestTickler()) assert.NoError(t, err) - fg, ok := fm.flowgraphs.Load(vchannel) + fg, ok := fm.flowgraphs.Get(vchannel) assert.True(t, ok) - err = fg.(*dataSyncService).channel.addSegment(addSegmentReq{segID: 0}) + err = fg.channel.addSegment(addSegmentReq{segID: 0}) assert.NoError(t, err) - fg.(*dataSyncService).channel.updateSegmentMemorySize(0, memorySize) - fg.(*dataSyncService).channel.(*ChannelMeta).needToSync.Store(false) + fg.channel.updateSegmentMemorySize(0, memorySize) + fg.channel.(*ChannelMeta).needToSync.Store(false) } fm.execute(test.totalMemory) for i, needToSync := range test.expectNeedToSync { vchannel := fmt.Sprintf("%s%d", channelPrefix, i) - fg, ok := fm.flowgraphs.Load(vchannel) + fg, ok := fm.flowgraphs.Get(vchannel) assert.True(t, ok) - assert.Equal(t, needToSync, fg.(*dataSyncService).channel.(*ChannelMeta).needToSync.Load()) + assert.Equal(t, needToSync, fg.channel.(*ChannelMeta).needToSync.Load()) } } }) diff --git a/internal/datanode/metrics_info.go b/internal/datanode/metrics_info.go index c619953329..9726b899c9 100644 --- a/internal/datanode/metrics_info.go +++ b/internal/datanode/metrics_info.go @@ -51,8 +51,7 @@ func (node *DataNode) getQuotaMetrics() (*metricsinfo.DataNodeQuotaMetrics, erro getAllCollections := func() []int64 { collectionSet := typeutil.UniqueSet{} - node.flowgraphManager.flowgraphs.Range(func(key, value any) bool { - fg := value.(*dataSyncService) + node.flowgraphManager.flowgraphs.Range(func(key string, fg *dataSyncService) bool { collectionSet.Insert(fg.channel.getCollectionID()) return true })