From 7af0f7d90cee4a15eac1947bfac1f23c0cb0295e Mon Sep 17 00:00:00 2001 From: wei liu Date: Wed, 23 Aug 2023 10:10:22 +0800 Subject: [PATCH] avoid concurrent sub/unsub on same channel (#26454) Signed-off-by: Wei Liu --- internal/querynodev2/delegator/delegator.go | 4 ++++ internal/querynodev2/delegator/delegator_data.go | 4 ++++ internal/querynodev2/delegator/delegator_test.go | 4 ++++ internal/querynodev2/server.go | 14 ++++++++------ internal/querynodev2/services.go | 10 +++++++++- internal/querynodev2/services_test.go | 9 ++++++++- pkg/util/merr/errors.go | 9 +++++---- pkg/util/merr/errors_test.go | 1 + pkg/util/merr/utils.go | 8 ++++++++ 9 files changed, 51 insertions(+), 12 deletions(-) diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 8d68309c6a..dc7cd1d6aa 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -150,6 +150,10 @@ func (sd *shardDelegator) Serviceable() bool { return sd.lifetime.GetState() == working } +func (sd *shardDelegator) Stopped() bool { + return sd.lifetime.GetState() == stopped +} + // Start sets delegator to working state. func (sd *shardDelegator) Start() { sd.lifetime.SetState(working) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index a439017e4b..28282c8b40 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -250,6 +250,10 @@ func (sd *shardDelegator) applyDelete(ctx context.Context, nodeID int64, worker if ok { log.Debug("delegator plan to applyDelete via worker") err := retry.Do(ctx, func() error { + if sd.Stopped() { + return retry.Unrecoverable(merr.WrapErrChannelUnsubscribing(sd.vchannelName)) + } + err := worker.Delete(ctx, &querypb.DeleteRequest{ Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(nodeID)), CollectionId: sd.collectionID, diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index b69b3aab96..40af0b3314 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -930,4 +930,8 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { case <-time.After(time.Second): assert.FailNow(t, "watchTsafe still working after listener closed") } + + sd.Close() + assert.Equal(t, sd.Serviceable(), false) + assert.Equal(t, sd.Stopped(), true) } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 3acf272bc6..ffca576277 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -95,12 +95,13 @@ type QueryNode struct { stopOnce sync.Once // internal components - manager *segments.Manager - clusterManager cluster.Manager - tSafeManager tsafe.Manager - pipelineManager pipeline.Manager - subscribingChannels *typeutil.ConcurrentSet[string] - delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator] + manager *segments.Manager + clusterManager cluster.Manager + tSafeManager tsafe.Manager + pipelineManager pipeline.Manager + subscribingChannels *typeutil.ConcurrentSet[string] + unsubscribingChannels *typeutil.ConcurrentSet[string] + delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator] // segment loader loader segments.Loader @@ -324,6 +325,7 @@ func (node *QueryNode) Init() error { }) node.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]() node.subscribingChannels = typeutil.NewConcurrentSet[string]() + node.unsubscribingChannels = typeutil.NewConcurrentSet[string]() node.manager = segments.NewManager() node.loader = segments.NewLoader(node.manager, node.vectorStorage) node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID()) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index bdf2b9eeb4..6abad18375 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -253,6 +253,13 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm } defer node.subscribingChannels.Remove(channel.GetChannelName()) + // to avoid concurrent watch/unwatch + if node.unsubscribingChannels.Contain(channel.GetChannelName()) { + err := merr.WrapErrChannelUnsubscribing(channel.GetChannelName()) + log.Warn(err.Error()) + return merr.Status(err), nil + } + _, exist := node.delegators.Get(channel.GetChannelName()) if exist { log.Info("channel already subscribed") @@ -375,6 +382,8 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC return status, nil } + node.unsubscribingChannels.Insert(req.GetChannelName()) + defer node.unsubscribingChannels.Remove(req.GetChannelName()) delegator, ok := node.delegators.GetAndRemove(req.GetChannelName()) if ok { // close the delegator first to block all coming query/search requests @@ -386,7 +395,6 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC node.manager.Collection.Unref(req.GetCollectionID(), 1) } - log.Info("unsubscribed channel") return util.SuccessStatus(), nil diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 4dcb92593a..5763cc6aba 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -373,13 +373,20 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { }, } + // test channel is unsubscribing + suite.node.unsubscribingChannels.Insert(suite.vchannel) + status, err := suite.node.WatchDmChannels(ctx, req) + suite.NoError(err) + suite.Equal(status.GetReason(), merr.WrapErrChannelUnsubscribing(suite.vchannel).Error()) + suite.node.unsubscribingChannels.Remove(suite.vchannel) + // init msgstream failed suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return() suite.msgStream.EXPECT().Close().Return() suite.msgStream.EXPECT().Seek(mock.Anything).Return(errors.New("mock error")) - status, err := suite.node.WatchDmChannels(ctx, req) + status, err = suite.node.WatchDmChannels(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index c0f15825fa..3d9bfbb66d 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -76,10 +76,11 @@ var ( ErrReplicaNotAvailable = newMilvusError("replica not available", 401, false) // Channel related - ErrChannelNotFound = newMilvusError("channel not found", 500, false) - ErrChannelLack = newMilvusError("channel lacks", 501, false) - ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false) - ErrChannelNotAvailable = newMilvusError("channel not available", 503, false) + ErrChannelNotFound = newMilvusError("channel not found", 500, false) + ErrChannelLack = newMilvusError("channel lacks", 501, false) + ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false) + ErrChannelNotAvailable = newMilvusError("channel not available", 503, false) + ErrChannelUnsubscribing = newMilvusError("chanel is unsubscribing", 504, true) // Segment related ErrSegmentNotFound = newMilvusError("segment not found", 600, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index e4befaaed7..f10354045c 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -100,6 +100,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrChannelNotFound("test_Channel", "failed to get Channel"), ErrChannelNotFound) s.ErrorIs(WrapErrChannelLack("test_Channel", "failed to get Channel"), ErrChannelLack) s.ErrorIs(WrapErrChannelReduplicate("test_Channel", "failed to get Channel"), ErrChannelReduplicate) + s.ErrorIs(WrapErrChannelUnsubscribing("test_channel"), ErrChannelUnsubscribing) // Segment related s.ErrorIs(WrapErrSegmentNotFound(1, "failed to get Segment"), ErrSegmentNotFound) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 2168dbf852..748a157f49 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -345,6 +345,14 @@ func WrapErrChannelNotAvailable(name string, msg ...string) error { return err } +func WrapErrChannelUnsubscribing(name string, msg ...string) error { + err := wrapWithField(ErrChannelUnsubscribing, "channel", name) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + // Segment related func WrapErrSegmentNotFound(id int64, msg ...string) error { err := wrapWithField(ErrSegmentNotFound, "segment", id)