From 2ca81620eaf0090d177bb30aa824d8563f789a41 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Thu, 9 Jun 2022 17:34:09 +0800 Subject: [PATCH] Reduce lock operations when get dml stream (#17468) Signed-off-by: longjiquan --- internal/proxy/channels_mgr.go | 74 +++++++++++------------------ internal/proxy/channels_mgr_test.go | 43 ++++------------- internal/proxy/task.go | 4 +- internal/proxy/task_test.go | 4 +- 4 files changed, 43 insertions(+), 82 deletions(-) diff --git a/internal/proxy/channels_mgr.go b/internal/proxy/channels_mgr.go index 6a3a4e6ffa..1434981dd6 100644 --- a/internal/proxy/channels_mgr.go +++ b/internal/proxy/channels_mgr.go @@ -38,8 +38,7 @@ import ( type channelsMgr interface { getChannels(collectionID UniqueID) ([]pChan, error) getVChannels(collectionID UniqueID) ([]vChan, error) - createDMLMsgStream(collectionID UniqueID) error - getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error) + getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) removeDMLStream(collectionID UniqueID) error removeAllDMLStream() error } @@ -182,12 +181,6 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool return ok && streamInfos.stream != nil } -func (mgr *singleTypeChannelsMgr) streamExist(collectionID UniqueID) bool { - mgr.mu.RLock() - defer mgr.mu.RUnlock() - return mgr.streamExistPrivate(collectionID) -} - func createStream(factory msgstream.Factory, streamType streamType, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) { var stream msgstream.MsgStream var err error @@ -213,14 +206,6 @@ func createStream(factory msgstream.Factory, streamType streamType, pchans []pCh return stream, nil } -func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, channelInfos channelInfos, stream msgstream.MsgStream) { - mgr.mu.Lock() - defer mgr.mu.Unlock() - if !mgr.streamExistPrivate(collectionID) { - mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream} - } -} - func incPChansMetrics(pchans []pChan) { for _, pc := range pchans { metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Inc() @@ -234,35 +219,42 @@ func decPChanMetrics(pchans []pChan) { } // createMsgStream create message stream for specified collection. Idempotent. -// If stream already exists, directly return nil. -func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error { - if mgr.streamExist(collectionID) { - log.Info("stream already exist, no need to re-create", zap.Int64("collection_id", collectionID)) - return nil +// If stream already exists, directly return it and no error will be returned. +func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) { + mgr.mu.RLock() + infos, ok := mgr.infos[collectionID] + if ok && infos.stream != nil { + // already exist. + mgr.mu.RUnlock() + return infos.stream, nil } + mgr.mu.RUnlock() channelInfos, err := mgr.getChannelsFunc(collectionID) if err != nil { + // What if stream created by other goroutines? log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID)) - return err + return nil, err } stream, err := createStream(mgr.msgStreamFactory, mgr.singleStreamType, channelInfos.pchans, mgr.repackFunc) if err != nil { + // What if stream created by other goroutines? log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID)) - return err + return nil, err } - mgr.updateCollection(collectionID, channelInfos, stream) + mgr.mu.Lock() + defer mgr.mu.Unlock() + if !mgr.streamExistPrivate(collectionID) { + log.Info("create message stream", zap.Int64("collection", collectionID), + zap.Strings("virtual_channels", channelInfos.vchans), + zap.Strings("physical_channels", channelInfos.pchans)) + mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream} + incPChansMetrics(channelInfos.pchans) + } - log.Info("create message stream", - zap.Int64("collection_id", collectionID), - zap.Strings("virtual_channels", channelInfos.vchans), - zap.Strings("physical_channels", channelInfos.pchans)) - - incPChansMetrics(channelInfos.pchans) - - return nil + return mgr.infos[collectionID].stream, nil } func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) { @@ -275,18 +267,14 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea return nil, fmt.Errorf("collection not found: %d", collectionID) } -// getStream get message stream of specified collection. +// getOrCreateStream get message stream of specified collection. // If stream don't exists, call createMsgStream to create for it. -func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) { +func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) { if stream, err := mgr.lockGetStream(collectionID); err == nil { return stream, nil } - if err := mgr.createMsgStream(collectionID); err != nil { - return nil, err - } - - return mgr.lockGetStream(collectionID) + return mgr.createMsgStream(collectionID) } // removeStream remove the corresponding stream of the specified collection. Idempotent. @@ -343,12 +331,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error) return mgr.dmlChannelsMgr.getVChannels(collectionID) } -func (mgr *channelsMgrImpl) createDMLMsgStream(collectionID UniqueID) error { - return mgr.dmlChannelsMgr.createMsgStream(collectionID) -} - -func (mgr *channelsMgrImpl) getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error) { - return mgr.dmlChannelsMgr.getStream(collectionID) +func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) { + return mgr.dmlChannelsMgr.getOrCreateStream(collectionID) } func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) error { diff --git a/internal/proxy/channels_mgr_test.go b/internal/proxy/channels_mgr_test.go index 33fcae90b1..180eb509d8 100644 --- a/internal/proxy/channels_mgr_test.go +++ b/internal/proxy/channels_mgr_test.go @@ -205,31 +205,6 @@ func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) { }) } -func Test_singleTypeChannelsMgr_streamExist(t *testing.T) { - t.Run("exist", func(t *testing.T) { - m := &singleTypeChannelsMgr{ - infos: map[UniqueID]streamInfos{ - 100: {stream: newSimpleMockMsgStream()}, - }, - } - exist := m.streamExist(100) - assert.True(t, exist) - }) - - t.Run("not exist", func(t *testing.T) { - m := &singleTypeChannelsMgr{ - infos: map[UniqueID]streamInfos{ - 100: {stream: nil}, - }, - } - exist := m.streamExist(100) - assert.False(t, exist) - m.infos = make(map[UniqueID]streamInfos) - exist = m.streamExist(100) - assert.False(t, exist) - }) -} - func Test_createStream(t *testing.T) { t.Run("failed to create msgstream", func(t *testing.T) { factory := newMockMsgStreamFactory() @@ -268,8 +243,9 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { 100: {stream: newMockMsgStream()}, }, } - err := m.createMsgStream(100) + stream, err := m.createMsgStream(100) assert.NoError(t, err) + assert.NotNil(t, stream) }) t.Run("failed to get channels", func(t *testing.T) { @@ -278,7 +254,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { return channelInfos{}, errors.New("mock") }, } - err := m.createMsgStream(100) + _, err := m.createMsgStream(100) assert.Error(t, err) }) @@ -295,7 +271,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { singleStreamType: dmlStreamType, repackFunc: nil, } - err := m.createMsgStream(100) + _, err := m.createMsgStream(100) assert.Error(t, err) }) @@ -313,9 +289,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { singleStreamType: dmlStreamType, repackFunc: nil, } - err := m.createMsgStream(100) + stream, err := m.createMsgStream(100) assert.NoError(t, err) - stream, err := m.getStream(100) + assert.NotNil(t, stream) + stream, err = m.getOrCreateStream(100) assert.NoError(t, err) assert.NotNil(t, stream) }) @@ -349,7 +326,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) { 100: {stream: newMockMsgStream()}, }, } - stream, err := m.getStream(100) + stream, err := m.getOrCreateStream(100) assert.NoError(t, err) assert.NotNil(t, stream) }) @@ -361,7 +338,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) { return channelInfos{}, errors.New("mock") }, } - _, err := m.getStream(100) + _, err := m.getOrCreateStream(100) assert.Error(t, err) }) @@ -379,7 +356,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) { singleStreamType: dmlStreamType, repackFunc: nil, } - stream, err := m.getStream(100) + stream, err := m.getOrCreateStream(100) assert.NoError(t, err) assert.NotNil(t, stream) }) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index f4eb316a51..5800f740f3 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -497,7 +497,7 @@ func (it *insertTask) Execute(ctx context.Context) error { it.PartitionID = partitionID tr.Record("get collection id & partition id from cache") - stream, err := it.chMgr.getDMLStream(collID) + stream, err := it.chMgr.getOrCreateDmlStream(collID) if err != nil { return err } @@ -3260,7 +3260,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) collID := dt.DeleteRequest.CollectionID - stream, err := dt.chMgr.getDMLStream(collID) + stream, err := dt.chMgr.getOrCreateDmlStream(collID) if err != nil { return err } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 17bf8cb2c9..ba71848d06 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1716,7 +1716,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) defer chMgr.removeAllDMLStream() - err = chMgr.createDMLMsgStream(collectionID) + _, err = chMgr.getOrCreateDmlStream(collectionID) assert.NoError(t, err) pchans, err := chMgr.getChannels(collectionID) assert.NoError(t, err) @@ -1971,7 +1971,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) defer chMgr.removeAllDMLStream() - err = chMgr.createDMLMsgStream(collectionID) + _, err = chMgr.getOrCreateDmlStream(collectionID) assert.NoError(t, err) pchans, err := chMgr.getChannels(collectionID) assert.NoError(t, err)