diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 931a490b6e..b554958b8d 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -407,20 +407,17 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC return status, nil } - dct := &releaseCollectionTask{ + unsubTask := &unsubDmChannelTask{ baseTask: baseTask{ ctx: ctx, done: make(chan error), }, - req: &querypb.ReleaseCollectionRequest{ - Base: req.GetBase(), - CollectionID: req.GetCollectionID(), - NodeID: req.GetNodeID(), - }, - node: node, + node: node, + collectionID: req.GetCollectionID(), + channel: req.GetChannelName(), } - err := node.scheduler.queue.Enqueue(dct) + err := node.scheduler.queue.Enqueue(unsubTask) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -429,21 +426,21 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC log.Warn("failed to enqueue subscribe channel task", zap.Error(err)) return status, nil } - log.Info("unsubDmChannel(ReleaseCollection) enqueue done", zap.Int64("collectionID", req.GetCollectionID())) + log.Info("unsubDmChannelTask enqueue done", zap.Int64("collectionID", req.GetCollectionID())) - func() { - err = dct.WaitToFinish() - if err != nil { - log.Warn("failed to do subscribe channel task successfully", zap.Error(err)) - return - } - log.Info("unsubDmChannel(ReleaseCollection) WaitToFinish done", zap.Int64("collectionID", req.GetCollectionID())) - }() - - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, + err = unsubTask.WaitToFinish() + if err != nil { + log.Warn("failed to do subscribe channel task successfully", zap.Error(err)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, nil } - return status, nil + + log.Info("unsubDmChannelTask WaitToFinish done", zap.Int64("collectionID", req.GetCollectionID())) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil } // LoadSegments load historical data into query node, historical data can be vector data or index diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 69f20ac89e..4fd199f2d0 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -212,6 +212,54 @@ func TestImpl_UnsubDmChannel(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) + t.Run("normal run", func(t *testing.T) { + schema := genTestCollectionSchema() + req := &queryPb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + MsgID: rand.Int63(), + TargetID: node.session.ServerID, + }, + NodeID: 0, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + Schema: schema, + Infos: []*datapb.VchannelInfo{ + { + CollectionID: 1000, + ChannelName: Params.CommonCfg.RootCoordDml + "-dmc0", + }, + }, + } + + status, err := node.WatchDmChannels(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) + + { + req := &queryPb.UnsubDmChannelRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_UnsubDmChannel, + MsgID: rand.Int63(), + TargetID: node.session.ServerID, + }, + NodeID: 0, + CollectionID: defaultCollectionID, + ChannelName: Params.CommonCfg.RootCoordDml + "-dmc0", + } + originMetaReplica := node.metaReplica + node.metaReplica = newMockReplicaInterface() + status, err := node.UnsubDmChannel(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + + node.metaReplica = originMetaReplica + status, err = node.UnsubDmChannel(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) + } + }) + t.Run("target not match", func(t *testing.T) { req := &queryPb.UnsubDmChannelRequest{ Base: &commonpb.MsgBase{ diff --git a/internal/querynode/query_shard_service.go b/internal/querynode/query_shard_service.go index dd06f388fb..c6d6450db6 100644 --- a/internal/querynode/query_shard_service.go +++ b/internal/querynode/query_shard_service.go @@ -148,3 +148,15 @@ func (q *queryShardService) releaseCollection(collectionID int64) { q.queryShardsMu.Unlock() log.Info("release collection in query shard service", zap.Int64("collectionId", collectionID)) } + +func (q *queryShardService) releaseQueryShard(channel string) { + q.queryShardsMu.Lock() + defer q.queryShardsMu.Unlock() + for ch, queryShard := range q.queryShards { + if ch == channel { + queryShard.Close() + delete(q.queryShards, ch) + break + } + } +} diff --git a/internal/querynode/task.go b/internal/querynode/task.go index 1aad4eba87..afd2e4ab41 100644 --- a/internal/querynode/task.go +++ b/internal/querynode/task.go @@ -28,6 +28,8 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) +var ErrChannelNotFound = errors.New("channel not found") + type task interface { ID() UniqueID // return ReqID Ctx() context.Context @@ -226,3 +228,71 @@ func (r *releasePartitionsTask) isAllPartitionsReleased(coll *Collection) bool { return parts.Contain(coll.partitionIDs...) } + +type unsubDmChannelTask struct { + baseTask + node *QueryNode + collectionID int64 + channel string +} + +func (t *unsubDmChannelTask) Execute(ctx context.Context) error { + log.Info("start to execute unsubscribe dmchannel task", zap.Int64("collectionID", t.collectionID), zap.String("channel", t.channel)) + collection, err := t.node.metaReplica.getCollectionByID(t.collectionID) + if err != nil { + if errors.Is(err, ErrCollectionNotFound) { + log.Info("collection has been released", + zap.Int64("collectionID", t.collectionID), + zap.Error(err), + ) + return nil + } + return err + } + + channels := collection.getVChannels() + var find bool + for _, c := range channels { + if c == t.channel { + find = true + break + } + } + + if !find { + return ErrChannelNotFound + } + + if err := t.releaseChannelResources(collection); err != nil { + return err + } + debug.FreeOSMemory() + return nil +} + +func (t *unsubDmChannelTask) releaseChannelResources(collection *Collection) error { + log := log.With(zap.Int64("collectionID", t.collectionID), zap.String("channel", t.channel)) + log.Info("start to release channel resources") + + collection.removeVChannel(t.channel) + // release flowgraph resources + t.node.dataSyncService.removeFlowGraphsByDMLChannels([]string{t.channel}) + t.node.queryShardService.releaseQueryShard(t.channel) + t.node.ShardClusterService.releaseShardCluster(t.channel) + + t.node.tSafeReplica.removeTSafe(t.channel) + log.Info("release channel related resources successfully") + + // release segment resources + segmentIDs, err := t.node.metaReplica.getSegmentIDsByVChannel(nil, t.channel, segmentTypeGrowing) + if err != nil { + return err + } + for _, segmentID := range segmentIDs { + t.node.metaReplica.removeSegment(segmentID, segmentTypeGrowing) + } + + t.node.dataSyncService.removeEmptyFlowGraphByChannel(t.collectionID, t.channel) + log.Info("release segment resources successfully") + return nil +}