diff --git a/internal/proxy/channels_mgr.go b/internal/proxy/channels_mgr.go index 28e56ebabe..69d3fbc9ef 100644 --- a/internal/proxy/channels_mgr.go +++ b/internal/proxy/channels_mgr.go @@ -235,6 +235,8 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr zap.Strings("physical_channels", channelInfos.pchans)) mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream} incPChansMetrics(channelInfos.pchans) + } else { + stream.Close() } return mgr.infos[collectionID].stream, nil diff --git a/internal/proxy/channels_mgr_test.go b/internal/proxy/channels_mgr_test.go index a35c4a3e45..555fd18a95 100644 --- a/internal/proxy/channels_mgr_test.go +++ b/internal/proxy/channels_mgr_test.go @@ -18,6 +18,7 @@ package proxy import ( "context" + "sync" "testing" "github.com/cockroachdb/errors" @@ -251,6 +252,43 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) { assert.NotNil(t, stream) }) + t.Run("concurrent create", func(t *testing.T) { + factory := newMockMsgStreamFactory() + factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { + return newMockMsgStream(), nil + } + stopCh := make(chan struct{}) + readyCh := make(chan struct{}) + m := &singleTypeChannelsMgr{ + infos: make(map[UniqueID]streamInfos), + getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) { + close(readyCh) + <-stopCh + return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil + }, + msgStreamFactory: factory, + repackFunc: nil, + } + + firstStream := streamInfos{stream: newMockMsgStream()} + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + stream, err := m.createMsgStream(100) + assert.NoError(t, err) + assert.NotNil(t, stream) + }() + // make sure create msg stream has run at getchannels + <-readyCh + // mock create stream for same collection in same time. + m.mu.Lock() + m.infos[100] = firstStream + m.mu.Unlock() + + close(stopCh) + wg.Wait() + }) t.Run("failed to get channels", func(t *testing.T) { m := &singleTypeChannelsMgr{ getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {