mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
avoid concurrent sub/unsub on same channel (#26454)
Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
0bb68cac36
commit
7af0f7d90c
@ -150,6 +150,10 @@ func (sd *shardDelegator) Serviceable() bool {
|
|||||||
return sd.lifetime.GetState() == working
|
return sd.lifetime.GetState() == working
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sd *shardDelegator) Stopped() bool {
|
||||||
|
return sd.lifetime.GetState() == stopped
|
||||||
|
}
|
||||||
|
|
||||||
// Start sets delegator to working state.
|
// Start sets delegator to working state.
|
||||||
func (sd *shardDelegator) Start() {
|
func (sd *shardDelegator) Start() {
|
||||||
sd.lifetime.SetState(working)
|
sd.lifetime.SetState(working)
|
||||||
|
|||||||
@ -250,6 +250,10 @@ func (sd *shardDelegator) applyDelete(ctx context.Context, nodeID int64, worker
|
|||||||
if ok {
|
if ok {
|
||||||
log.Debug("delegator plan to applyDelete via worker")
|
log.Debug("delegator plan to applyDelete via worker")
|
||||||
err := retry.Do(ctx, func() error {
|
err := retry.Do(ctx, func() error {
|
||||||
|
if sd.Stopped() {
|
||||||
|
return retry.Unrecoverable(merr.WrapErrChannelUnsubscribing(sd.vchannelName))
|
||||||
|
}
|
||||||
|
|
||||||
err := worker.Delete(ctx, &querypb.DeleteRequest{
|
err := worker.Delete(ctx, &querypb.DeleteRequest{
|
||||||
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(nodeID)),
|
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(nodeID)),
|
||||||
CollectionId: sd.collectionID,
|
CollectionId: sd.collectionID,
|
||||||
|
|||||||
@ -930,4 +930,8 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
|
|||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
assert.FailNow(t, "watchTsafe still working after listener closed")
|
assert.FailNow(t, "watchTsafe still working after listener closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sd.Close()
|
||||||
|
assert.Equal(t, sd.Serviceable(), false)
|
||||||
|
assert.Equal(t, sd.Stopped(), true)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -100,6 +100,7 @@ type QueryNode struct {
|
|||||||
tSafeManager tsafe.Manager
|
tSafeManager tsafe.Manager
|
||||||
pipelineManager pipeline.Manager
|
pipelineManager pipeline.Manager
|
||||||
subscribingChannels *typeutil.ConcurrentSet[string]
|
subscribingChannels *typeutil.ConcurrentSet[string]
|
||||||
|
unsubscribingChannels *typeutil.ConcurrentSet[string]
|
||||||
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
|
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
|
||||||
|
|
||||||
// segment loader
|
// segment loader
|
||||||
@ -324,6 +325,7 @@ func (node *QueryNode) Init() error {
|
|||||||
})
|
})
|
||||||
node.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]()
|
node.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]()
|
||||||
node.subscribingChannels = typeutil.NewConcurrentSet[string]()
|
node.subscribingChannels = typeutil.NewConcurrentSet[string]()
|
||||||
|
node.unsubscribingChannels = typeutil.NewConcurrentSet[string]()
|
||||||
node.manager = segments.NewManager()
|
node.manager = segments.NewManager()
|
||||||
node.loader = segments.NewLoader(node.manager, node.vectorStorage)
|
node.loader = segments.NewLoader(node.manager, node.vectorStorage)
|
||||||
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
|
||||||
|
|||||||
@ -253,6 +253,13 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
|||||||
}
|
}
|
||||||
defer node.subscribingChannels.Remove(channel.GetChannelName())
|
defer node.subscribingChannels.Remove(channel.GetChannelName())
|
||||||
|
|
||||||
|
// to avoid concurrent watch/unwatch
|
||||||
|
if node.unsubscribingChannels.Contain(channel.GetChannelName()) {
|
||||||
|
err := merr.WrapErrChannelUnsubscribing(channel.GetChannelName())
|
||||||
|
log.Warn(err.Error())
|
||||||
|
return merr.Status(err), nil
|
||||||
|
}
|
||||||
|
|
||||||
_, exist := node.delegators.Get(channel.GetChannelName())
|
_, exist := node.delegators.Get(channel.GetChannelName())
|
||||||
if exist {
|
if exist {
|
||||||
log.Info("channel already subscribed")
|
log.Info("channel already subscribed")
|
||||||
@ -375,6 +382,8 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||||||
return status, nil
|
return status, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node.unsubscribingChannels.Insert(req.GetChannelName())
|
||||||
|
defer node.unsubscribingChannels.Remove(req.GetChannelName())
|
||||||
delegator, ok := node.delegators.GetAndRemove(req.GetChannelName())
|
delegator, ok := node.delegators.GetAndRemove(req.GetChannelName())
|
||||||
if ok {
|
if ok {
|
||||||
// close the delegator first to block all coming query/search requests
|
// close the delegator first to block all coming query/search requests
|
||||||
@ -386,7 +395,6 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||||||
|
|
||||||
node.manager.Collection.Unref(req.GetCollectionID(), 1)
|
node.manager.Collection.Unref(req.GetCollectionID(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("unsubscribed channel")
|
log.Info("unsubscribed channel")
|
||||||
|
|
||||||
return util.SuccessStatus(), nil
|
return util.SuccessStatus(), nil
|
||||||
|
|||||||
@ -373,13 +373,20 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test channel is unsubscribing
|
||||||
|
suite.node.unsubscribingChannels.Insert(suite.vchannel)
|
||||||
|
status, err := suite.node.WatchDmChannels(ctx, req)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(status.GetReason(), merr.WrapErrChannelUnsubscribing(suite.vchannel).Error())
|
||||||
|
suite.node.unsubscribingChannels.Remove(suite.vchannel)
|
||||||
|
|
||||||
// init msgstream failed
|
// init msgstream failed
|
||||||
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
|
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
|
||||||
suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return()
|
suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return()
|
||||||
suite.msgStream.EXPECT().Close().Return()
|
suite.msgStream.EXPECT().Close().Return()
|
||||||
suite.msgStream.EXPECT().Seek(mock.Anything).Return(errors.New("mock error"))
|
suite.msgStream.EXPECT().Seek(mock.Anything).Return(errors.New("mock error"))
|
||||||
|
|
||||||
status, err := suite.node.WatchDmChannels(ctx, req)
|
status, err = suite.node.WatchDmChannels(ctx, req)
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
|
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())
|
||||||
|
|
||||||
|
|||||||
@ -80,6 +80,7 @@ var (
|
|||||||
ErrChannelLack = newMilvusError("channel lacks", 501, false)
|
ErrChannelLack = newMilvusError("channel lacks", 501, false)
|
||||||
ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false)
|
ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false)
|
||||||
ErrChannelNotAvailable = newMilvusError("channel not available", 503, false)
|
ErrChannelNotAvailable = newMilvusError("channel not available", 503, false)
|
||||||
|
ErrChannelUnsubscribing = newMilvusError("chanel is unsubscribing", 504, true)
|
||||||
|
|
||||||
// Segment related
|
// Segment related
|
||||||
ErrSegmentNotFound = newMilvusError("segment not found", 600, false)
|
ErrSegmentNotFound = newMilvusError("segment not found", 600, false)
|
||||||
|
|||||||
@ -100,6 +100,7 @@ func (s *ErrSuite) TestWrap() {
|
|||||||
s.ErrorIs(WrapErrChannelNotFound("test_Channel", "failed to get Channel"), ErrChannelNotFound)
|
s.ErrorIs(WrapErrChannelNotFound("test_Channel", "failed to get Channel"), ErrChannelNotFound)
|
||||||
s.ErrorIs(WrapErrChannelLack("test_Channel", "failed to get Channel"), ErrChannelLack)
|
s.ErrorIs(WrapErrChannelLack("test_Channel", "failed to get Channel"), ErrChannelLack)
|
||||||
s.ErrorIs(WrapErrChannelReduplicate("test_Channel", "failed to get Channel"), ErrChannelReduplicate)
|
s.ErrorIs(WrapErrChannelReduplicate("test_Channel", "failed to get Channel"), ErrChannelReduplicate)
|
||||||
|
s.ErrorIs(WrapErrChannelUnsubscribing("test_channel"), ErrChannelUnsubscribing)
|
||||||
|
|
||||||
// Segment related
|
// Segment related
|
||||||
s.ErrorIs(WrapErrSegmentNotFound(1, "failed to get Segment"), ErrSegmentNotFound)
|
s.ErrorIs(WrapErrSegmentNotFound(1, "failed to get Segment"), ErrSegmentNotFound)
|
||||||
|
|||||||
@ -345,6 +345,14 @@ func WrapErrChannelNotAvailable(name string, msg ...string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WrapErrChannelUnsubscribing(name string, msg ...string) error {
|
||||||
|
err := wrapWithField(ErrChannelUnsubscribing, "channel", name)
|
||||||
|
if len(msg) > 0 {
|
||||||
|
err = errors.Wrap(err, strings.Join(msg, "; "))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Segment related
|
// Segment related
|
||||||
func WrapErrSegmentNotFound(id int64, msg ...string) error {
|
func WrapErrSegmentNotFound(id int64, msg ...string) error {
|
||||||
err := wrapWithField(ErrSegmentNotFound, "segment", id)
|
err := wrapWithField(ErrSegmentNotFound, "segment", id)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user