From 2880a75dfc5caa2256647b3639da37236bb52e81 Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 15 Nov 2022 20:03:06 +0800 Subject: [PATCH] Fix querynode panics when watch/unsub runs concurrently (#20606) (#20619) Signed-off-by: Congqi Xia Signed-off-by: Congqi Xia --- internal/querynode/impl.go | 24 +++++++----- internal/querynode/impl_test.go | 38 +++++++++++++++++++ internal/querynode/watch_dm_channels_task.go | 40 ++++++++++++++++---- 3 files changed, 85 insertions(+), 17 deletions(-) diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 566510101e..eda02c4fb0 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -26,6 +26,7 @@ import ( "time" "github.com/golang/protobuf/proto" + "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -34,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/metricsinfo" @@ -301,6 +303,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC return status, nil } + log := log.With( + zap.Int64("collectionID", in.GetCollectionID()), + zap.Int64("nodeID", node.session.ServerID), + zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string { + return info.GetChannelName() + })), + ) + task := &watchDmChannelsTask{ baseTask: baseTask{ ctx: ctx, @@ -311,13 +321,10 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC } startTs := time.Now() - log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID), - zap.String("channelName", in.Infos[0].GetChannelName()), - zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) + log.Info("watchDmChannels init") // currently we only support load one channel as a time future := node.taskPool.Submit(func() (interface{}, error) { - log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID), - zap.String("channelName", in.Infos[0].GetChannelName()), + log.Info("watchDmChannels start ", zap.Duration("timeInQueue", time.Since(startTs))) err := task.PreExecute(ctx) if err != nil { @@ -335,7 +342,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), } - log.Warn("failed to subscribe channel ", zap.Error(err)) + log.Warn("failed to subscribe channel", zap.Error(err)) return status, nil } @@ -349,10 +356,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC return status, nil } - sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName()) - sc.SetupFirstVersion() - log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID), - zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) + log.Info("successfully watchDmChannelsTask") return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index fc68bae2b2..8f5b7cae17 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/commonpb" @@ -137,10 +138,47 @@ func TestImpl_WatchDmChannels(t *testing.T) { }, } node.UpdateStateCode(commonpb.StateCode_Abnormal) + defer node.UpdateStateCode(commonpb.StateCode_Healthy) status, err := node.WatchDmChannels(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) }) + + t.Run("mock release after loaded", func(t *testing.T) { + + mockTSReplica := &MockTSafeReplicaInterface{} + + oldTSReplica := node.tSafeReplica + defer func() { + node.tSafeReplica = oldTSReplica + }() + node.tSafeReplica = mockTSReplica + mockTSReplica.On("addTSafe", mock.Anything).Run(func(_ mock.Arguments) { + node.ShardClusterService.releaseShardCluster("1001-dmc0") + }) + schema := genTestCollectionSchema() + req := &queryPb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + MsgID: rand.Int63(), + TargetID: node.session.ServerID, + }, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + Schema: schema, + Infos: []*datapb.VchannelInfo{ + { + CollectionID: 1001, + ChannelName: "1001-dmc0", + }, + }, + } + + status, err := node.WatchDmChannels(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + }) + } func TestImpl_UnsubDmChannel(t *testing.T) { diff --git a/internal/querynode/watch_dm_channels_task.go b/internal/querynode/watch_dm_channels_task.go index 48b4239c0e..a9c0c1d33d 100644 --- a/internal/querynode/watch_dm_channels_task.go +++ b/internal/querynode/watch_dm_channels_task.go @@ -64,16 +64,19 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { VPChannels[v] = p } + log := log.With( + zap.Int64("collectionID", w.req.GetCollectionID()), + zap.Strings("vChannels", vChannels), + zap.Int64("replicaID", w.req.GetReplicaID()), + ) + if len(VPChannels) != len(vChannels) { return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID)) } log.Info("Starting WatchDmChannels ...", - zap.String("collectionName", w.req.Schema.Name), - zap.Int64("collectionID", collectionID), - zap.Int64("replicaID", w.req.GetReplicaID()), - zap.String("load type", lType.String()), - zap.Strings("vChannels", vChannels), + zap.String("loadType", lType.String()), + zap.String("collectionName", w.req.GetSchema().GetName()), ) // init collection meta @@ -125,7 +128,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { coll.setLoadType(lType) - log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) + log.Info("watchDMChannel, init replica done") // create tSafe for _, channel := range vChannels { @@ -142,7 +145,30 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { fg.flowGraph.Start() } - log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) + log.Info("WatchDmChannels done") + return nil +} + +// PostExecute setup ShardCluster first version and without do gc if failed. +func (w *watchDmChannelsTask) PostExecute(ctx context.Context) error { + // setup shard cluster version + var releasedChannels []string + for _, info := range w.req.GetInfos() { + sc, ok := w.node.ShardClusterService.getShardCluster(info.GetChannelName()) + // shard cluster may be released by a release task + if !ok { + releasedChannels = append(releasedChannels, info.GetChannelName()) + continue + } + sc.SetupFirstVersion() + } + if len(releasedChannels) > 0 { + // no clean up needed, release shall do the job + log.Warn("WatchDmChannels failed, shard cluster may be released", + zap.Strings("releasedChannels", releasedChannels), + ) + return fmt.Errorf("failed to watch %v, shard cluster may be released", releasedChannels) + } return nil }