From 7b91fa3db8dbbfd76866f4d697653ce0e45203df Mon Sep 17 00:00:00 2001 From: congqixia Date: Wed, 21 Feb 2024 11:08:51 +0800 Subject: [PATCH] fix: Make leader checker generate leader task instead of segment task (#30258) See also #30150 For leader view distribution with offline nodes, a release task can never be sent to querynode due to targetNode online check logic. Even the request is dispatched, normal release task does not have "force" flag when calling `delegator.ReleaseSegment`. This PR adds a new type of querycoord task: LeaderTask, the responsibility of which is to rectify leader view distribtion. --------- Signed-off-by: Congqi Xia --- .../querycoordv2/checkers/leader_checker.go | 27 +- .../checkers/leader_checker_test.go | 14 +- internal/querycoordv2/task/action.go | 56 +++- internal/querycoordv2/task/executor.go | 279 ++++++++++++++---- internal/querycoordv2/task/scheduler.go | 89 ++++++ internal/querycoordv2/task/task.go | 46 ++- internal/querycoordv2/task/task_test.go | 181 ++++++++++++ 7 files changed, 592 insertions(+), 100 deletions(-) diff --git a/internal/querycoordv2/checkers/leader_checker.go b/internal/querycoordv2/checkers/leader_checker.go index eacee155b9..002cd2df49 100644 --- a/internal/querycoordv2/checkers/leader_checker.go +++ b/internal/querycoordv2/checkers/leader_checker.go @@ -22,7 +22,6 @@ import ( "go.uber.org/zap" - "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" @@ -139,23 +138,16 @@ func (c *LeaderChecker) findNeedLoadedSegments(ctx context.Context, replica int6 log.RatedDebug(10, "leader checker append a segment to set", zap.Int64("segmentID", s.GetID()), zap.Int64("nodeID", s.Node)) - action := task.NewSegmentActionWithScope(s.Node, task.ActionTypeGrow, s.GetInsertChannel(), s.GetID(), querypb.DataScope_Historical) - t, err := task.NewSegmentTask( + action := task.NewLeaderAction(leaderView.ID, s.Node, task.ActionTypeGrow, s.GetInsertChannel(), s.GetID()) + t := task.NewLeaderTask( ctx, params.Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), c.ID(), s.GetCollectionID(), replica, + leaderView.ID, action, ) - if err != nil { - log.Warn("create segment update task failed", - zap.Int64("segmentID", s.GetID()), - zap.Int64("node", s.Node), - zap.Error(err), - ) - continue - } // index task shall have lower or equal priority than balance task t.SetPriority(task.TaskPriorityHigh) t.SetReason("add segment to leader view") @@ -189,23 +181,16 @@ func (c *LeaderChecker) findNeedRemovedSegments(ctx context.Context, replica int log.Debug("leader checker append a segment to remove", zap.Int64("segmentID", sid), zap.Int64("nodeID", s.NodeID)) - - action := task.NewSegmentActionWithScope(s.NodeID, task.ActionTypeReduce, leaderView.Channel, sid, querypb.DataScope_Historical) - t, err := task.NewSegmentTask( + action := task.NewLeaderAction(leaderView.ID, s.NodeID, task.ActionTypeReduce, leaderView.Channel, sid) + t := task.NewLeaderTask( ctx, paramtable.Get().QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), c.ID(), leaderView.CollectionID, replica, + leaderView.ID, action, ) - if err != nil { - log.Warn("create segment reduce task failed", - zap.Int64("segmentID", sid), - zap.Int64("nodeID", s.NodeID), - zap.Error(err)) - continue - } t.SetPriority(task.TaskPriorityHigh) t.SetReason("remove segment from leader view") diff --git a/internal/querycoordv2/checkers/leader_checker_test.go b/internal/querycoordv2/checkers/leader_checker_test.go index d28c0ccafa..fca0bf03ec 100644 --- a/internal/querycoordv2/checkers/leader_checker_test.go +++ b/internal/querycoordv2/checkers/leader_checker_test.go @@ -119,7 +119,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow) suite.Equal(tasks[0].Actions()[0].Node(), int64(1)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } @@ -161,7 +161,7 @@ func (suite *LeaderCheckerTestSuite) TestActivation() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow) suite.Equal(tasks[0].Actions()[0].Node(), int64(1)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } @@ -236,7 +236,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow) suite.Equal(tasks[0].Actions()[0].Node(), int64(1)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } @@ -289,7 +289,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreBalancedSegment() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow) suite.Equal(tasks[0].Actions()[0].Node(), int64(1)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } @@ -334,7 +334,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeGrow) suite.Equal(tasks[0].Actions()[0].Node(), int64(1)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(1)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(1)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } @@ -368,7 +368,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeReduce) suite.Equal(tasks[0].Actions()[0].Node(), int64(1)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(3)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(3)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } @@ -405,7 +405,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() { suite.Len(tasks[0].Actions(), 1) suite.Equal(tasks[0].Actions()[0].Type(), task.ActionTypeReduce) suite.Equal(tasks[0].Actions()[0].Node(), int64(2)) - suite.Equal(tasks[0].Actions()[0].(*task.SegmentAction).SegmentID(), int64(3)) + suite.Equal(tasks[0].Actions()[0].(*task.LeaderAction).SegmentID(), int64(3)) suite.Equal(tasks[0].Priority(), task.TaskPriorityHigh) } diff --git a/internal/querycoordv2/task/action.go b/internal/querycoordv2/task/action.go index 897bb0237e..2e72fb3a8a 100644 --- a/internal/querycoordv2/task/action.go +++ b/internal/querycoordv2/task/action.go @@ -23,7 +23,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/util/funcutil" - . "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type ActionType int32 @@ -51,12 +51,12 @@ type Action interface { } type BaseAction struct { - nodeID UniqueID + nodeID typeutil.UniqueID typ ActionType shard string } -func NewBaseAction(nodeID UniqueID, typ ActionType, shard string) *BaseAction { +func NewBaseAction(nodeID typeutil.UniqueID, typ ActionType, shard string) *BaseAction { return &BaseAction{ nodeID: nodeID, typ: typ, @@ -79,17 +79,17 @@ func (action *BaseAction) Shard() string { type SegmentAction struct { *BaseAction - segmentID UniqueID + segmentID typeutil.UniqueID scope querypb.DataScope rpcReturned atomic.Bool } -func NewSegmentAction(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID) *SegmentAction { +func NewSegmentAction(nodeID typeutil.UniqueID, typ ActionType, shard string, segmentID typeutil.UniqueID) *SegmentAction { return NewSegmentActionWithScope(nodeID, typ, shard, segmentID, querypb.DataScope_All) } -func NewSegmentActionWithScope(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID, scope querypb.DataScope) *SegmentAction { +func NewSegmentActionWithScope(nodeID typeutil.UniqueID, typ ActionType, shard string, segmentID typeutil.UniqueID, scope querypb.DataScope) *SegmentAction { base := NewBaseAction(nodeID, typ, shard) return &SegmentAction{ BaseAction: base, @@ -99,7 +99,7 @@ func NewSegmentActionWithScope(nodeID UniqueID, typ ActionType, shard string, se } } -func (action *SegmentAction) SegmentID() UniqueID { +func (action *SegmentAction) SegmentID() typeutil.UniqueID { return action.segmentID } @@ -143,7 +143,7 @@ type ChannelAction struct { *BaseAction } -func NewChannelAction(nodeID UniqueID, typ ActionType, channelName string) *ChannelAction { +func NewChannelAction(nodeID typeutil.UniqueID, typ ActionType, channelName string) *ChannelAction { return &ChannelAction{ BaseAction: NewBaseAction(nodeID, typ, channelName), } @@ -160,3 +160,43 @@ func (action *ChannelAction) IsFinished(distMgr *meta.DistributionManager) bool return hasNode == isGrow } + +type LeaderAction struct { + *BaseAction + + leaderID typeutil.UniqueID + segmentID typeutil.UniqueID + + rpcReturned atomic.Bool +} + +func NewLeaderAction(leaderID, workerID typeutil.UniqueID, typ ActionType, shard string, segmentID typeutil.UniqueID) *LeaderAction { + action := &LeaderAction{ + BaseAction: NewBaseAction(workerID, typ, shard), + + leaderID: leaderID, + segmentID: segmentID, + } + action.rpcReturned.Store(false) + return action +} + +func (action *LeaderAction) SegmentID() typeutil.UniqueID { + return action.segmentID +} + +func (action *LeaderAction) IsFinished(distMgr *meta.DistributionManager) bool { + views := distMgr.LeaderViewManager.GetLeaderView(action.leaderID) + view := views[action.Shard()] + if view == nil { + return false + } + dist := view.Segments[action.SegmentID()] + switch action.Type() { + case ActionTypeGrow: + return action.rpcReturned.Load() && dist != nil && dist.NodeID == action.Node() + case ActionTypeReduce: + return action.rpcReturned.Load() && (dist == nil || dist.NodeID != action.Node()) + } + return false +} diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index cef2527bf8..993e8120c1 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -26,6 +26,8 @@ import ( "go.uber.org/atomic" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -33,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" @@ -111,6 +114,9 @@ func (ex *Executor) Execute(task Task, step int) bool { case *ChannelAction: ex.executeDmChannelAction(task.(*ChannelTask), step) + + case *LeaderAction: + ex.executeLeaderAction(task.(*LeaderTask), step) } }() @@ -162,70 +168,15 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { ex.removeTask(task, step) }() - collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID()) + collectionInfo, loadMeta, channel, err := ex.getMetaInfo(ctx, task) if err != nil { - log.Warn("failed to get collection info", zap.Error(err)) - return err - } - partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID()) - if err != nil { - log.Warn("failed to get partitions of collection", zap.Error(err)) return err } - loadMeta := packLoadMeta( - ex.meta.GetLoadType(task.CollectionID()), - task.CollectionID(), - partitions..., - ) - // get channel first, in case of target updated after segment info fetched - channel := ex.targetMgr.GetDmChannel(task.CollectionID(), task.shard, meta.NextTargetFirst) - if channel == nil { - return merr.WrapErrChannelNotAvailable(task.shard) - } - - resp, err := ex.broker.GetSegmentInfo(ctx, task.SegmentID()) - if err != nil || len(resp.GetInfos()) == 0 { - log.Warn("failed to get segment info from DataCoord", zap.Error(err)) + loadInfo, indexInfos, err := ex.getLoadInfo(ctx, task.CollectionID(), action.SegmentID(), channel) + if err != nil { return err } - segment := resp.GetInfos()[0] - log = log.With(zap.String("level", segment.GetLevel().String())) - - indexes, err := ex.broker.GetIndexInfo(ctx, task.CollectionID(), segment.GetID()) - if err != nil { - if !errors.Is(err, merr.ErrIndexNotFound) { - log.Warn("failed to get index of segment", zap.Error(err)) - return err - } - indexes = nil - } - - // Get collection index info - indexInfos, err := ex.broker.DescribeIndex(ctx, task.CollectionID()) - if err != nil { - log.Warn("fail to get index meta of collection") - return err - } - // update the field index params - for _, segmentIndex := range indexes { - index, found := lo.Find(indexInfos, func(indexInfo *indexpb.IndexInfo) bool { - return indexInfo.IndexID == segmentIndex.IndexID - }) - if !found { - log.Warn("no collection index info for the given segment index", zap.String("indexName", segmentIndex.GetIndexName())) - } - - params := funcutil.KeyValuePair2Map(segmentIndex.GetIndexParams()) - for _, kv := range index.GetUserIndexParams() { - if indexparams.IsConfigableIndexParam(kv.GetKey()) { - params[kv.GetKey()] = kv.GetValue() - } - } - segmentIndex.IndexParams = funcutil.Map2KeyValuePair(params) - } - - loadInfo := utils.PackSegmentLoadInfo(segment, channel.GetSeekPosition(), indexes) req := packLoadSegmentRequest( task, @@ -238,10 +189,10 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { ) // Get shard leader for the given replica and segment - leaderID, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), segment.GetInsertChannel()) + leaderID, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), task.Shard()) if !ok { msg := "no shard leader for the segment to execute loading" - err = merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "shard delegator not found") + err = merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found") log.Warn(msg, zap.Error(err)) return err } @@ -444,3 +395,211 @@ func (ex *Executor) unsubscribeChannel(task *ChannelTask, step int) error { log.Info("unsubscribe channel done", zap.Int64("taskID", task.ID()), zap.Duration("time taken", elapsed)) return nil } + +func (ex *Executor) executeLeaderAction(task *LeaderTask, step int) { + switch task.Actions()[step].Type() { + case ActionTypeGrow, ActionTypeUpdate: + ex.setDistribution(task, step) + + case ActionTypeReduce: + ex.removeDistribution(task, step) + } +} + +func (ex *Executor) setDistribution(task *LeaderTask, step int) error { + action := task.Actions()[step].(*LeaderAction) + defer action.rpcReturned.Store(true) + ctx := task.Context() + log := log.Ctx(ctx).With( + zap.Int64("taskID", task.ID()), + zap.Int64("collectionID", task.CollectionID()), + zap.Int64("replicaID", task.ReplicaID()), + zap.Int64("segmentID", task.segmentID), + zap.Int64("leader", action.leaderID), + zap.Int64("node", action.Node()), + zap.String("source", task.Source().String()), + ) + + var err error + defer func() { + if err != nil { + task.Fail(err) + } + ex.removeTask(task, step) + }() + + collectionInfo, loadMeta, channel, err := ex.getMetaInfo(ctx, task) + if err != nil { + return err + } + + loadInfo, _, err := ex.getLoadInfo(ctx, task.CollectionID(), action.SegmentID(), channel) + if err != nil { + return err + } + + req := &querypb.SyncDistributionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), + commonpbutil.WithMsgID(task.ID()), + ), + CollectionID: task.collectionID, + Channel: task.Shard(), + Schema: collectionInfo.GetSchema(), + LoadMeta: loadMeta, + ReplicaID: task.ReplicaID(), + Actions: []*querypb.SyncAction{ + { + Type: querypb.SyncType_Set, + PartitionID: loadInfo.GetPartitionID(), + SegmentID: action.SegmentID(), + NodeID: action.Node(), + Info: loadInfo, + }, + }, + } + + startTs := time.Now() + log.Info("Sync Distribution...") + status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req) + err = merr.CheckRPCCall(status, err) + if err != nil { + log.Warn("failed to sync distribution", zap.Error(err)) + return err + } + + elapsed := time.Since(startTs) + log.Info("sync distribution done", zap.Duration("elapsed", elapsed)) + + return nil +} + +func (ex *Executor) removeDistribution(task *LeaderTask, step int) error { + action := task.Actions()[step].(*LeaderAction) + defer action.rpcReturned.Store(true) + ctx := task.Context() + log := log.Ctx(ctx).With( + zap.Int64("taskID", task.ID()), + zap.Int64("collectionID", task.CollectionID()), + zap.Int64("replicaID", task.ReplicaID()), + zap.Int64("segmentID", task.segmentID), + zap.Int64("leader", action.leaderID), + zap.Int64("node", action.Node()), + zap.String("source", task.Source().String()), + ) + + var err error + defer func() { + if err != nil { + task.Fail(err) + } + ex.removeTask(task, step) + }() + + req := &querypb.SyncDistributionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), + commonpbutil.WithMsgID(task.ID()), + ), + CollectionID: task.collectionID, + Channel: task.Shard(), + ReplicaID: task.ReplicaID(), + Actions: []*querypb.SyncAction{ + { + Type: querypb.SyncType_Set, + SegmentID: action.SegmentID(), + }, + }, + } + + startTs := time.Now() + log.Info("Sync Distribution...") + status, err := ex.cluster.SyncDistribution(task.Context(), task.leaderID, req) + // status, err := ex.cluster.LoadSegments(task.Context(), leaderID, req) + err = merr.CheckRPCCall(status, err) + if err != nil { + log.Warn("failed to sync distribution", zap.Error(err)) + return err + } + + elapsed := time.Since(startTs) + log.Info("sync distribution done", zap.Duration("elapsed", elapsed)) + + return nil +} + +func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.DescribeCollectionResponse, *querypb.LoadMetaInfo, *meta.DmChannel, error) { + collectionID := task.CollectionID() + shard := task.Shard() + log := log.Ctx(ctx) + collectionInfo, err := ex.broker.DescribeCollection(ctx, collectionID) + if err != nil { + log.Warn("failed to get collection info", zap.Error(err)) + return nil, nil, nil, err + } + partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID) + if err != nil { + log.Warn("failed to get partitions of collection", zap.Error(err)) + return nil, nil, nil, err + } + + loadMeta := packLoadMeta( + ex.meta.GetLoadType(collectionID), + collectionID, + partitions..., + ) + // get channel first, in case of target updated after segment info fetched + channel := ex.targetMgr.GetDmChannel(collectionID, shard, meta.NextTargetFirst) + if channel == nil { + return nil, nil, nil, merr.WrapErrChannelNotAvailable(shard) + } + + return collectionInfo, loadMeta, channel, nil +} + +func (ex *Executor) getLoadInfo(ctx context.Context, collectionID, segmentID int64, channel *meta.DmChannel) (*querypb.SegmentLoadInfo, []*indexpb.IndexInfo, error) { + log := log.Ctx(ctx) + resp, err := ex.broker.GetSegmentInfo(ctx, segmentID) + if err != nil || len(resp.GetInfos()) == 0 { + log.Warn("failed to get segment info from DataCoord", zap.Error(err)) + return nil, nil, err + } + segment := resp.GetInfos()[0] + log = log.With(zap.String("level", segment.GetLevel().String())) + + indexes, err := ex.broker.GetIndexInfo(ctx, collectionID, segment.GetID()) + if err != nil { + if !errors.Is(err, merr.ErrIndexNotFound) { + log.Warn("failed to get index of segment", zap.Error(err)) + return nil, nil, err + } + indexes = nil + } + + // Get collection index info + indexInfos, err := ex.broker.DescribeIndex(ctx, collectionID) + if err != nil { + log.Warn("fail to get index meta of collection", zap.Error(err)) + return nil, nil, err + } + // update the field index params + for _, segmentIndex := range indexes { + index, found := lo.Find(indexInfos, func(indexInfo *indexpb.IndexInfo) bool { + return indexInfo.IndexID == segmentIndex.IndexID + }) + if !found { + log.Warn("no collection index info for the given segment index", zap.String("indexName", segmentIndex.GetIndexName())) + } + + params := funcutil.KeyValuePair2Map(segmentIndex.GetIndexParams()) + for _, kv := range index.GetUserIndexParams() { + if indexparams.IsConfigableIndexParam(kv.GetKey()) { + params[kv.GetKey()] = kv.GetValue() + } + } + segmentIndex.IndexParams = funcutil.Map2KeyValuePair(params) + } + + loadInfo := utils.PackSegmentLoadInfo(segment, channel.GetSeekPosition(), indexes) + return loadInfo, indexInfos, nil +} diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index a35be5ad3e..994d7efc17 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -75,6 +75,14 @@ func NewReplicaSegmentIndex(task *SegmentTask) replicaSegmentIndex { } } +func NewReplicaLeaderIndex(task *LeaderTask) replicaSegmentIndex { + return replicaSegmentIndex{ + ReplicaID: task.ReplicaID(), + SegmentID: task.SegmentID(), + IsGrowing: false, + } +} + type replicaChannelIndex struct { ReplicaID int64 Channel string @@ -263,6 +271,10 @@ func (scheduler *taskScheduler) Add(task Task) error { case *ChannelTask: index := replicaChannelIndex{task.ReplicaID(), task.Channel()} scheduler.channelTasks[index] = task + + case *LeaderTask: + index := NewReplicaLeaderIndex(task) + scheduler.segmentTasks[index] = task } scheduler.updateTaskMetrics() @@ -369,6 +381,23 @@ func (scheduler *taskScheduler) preAdd(task Task) error { return merr.WrapErrServiceInternal("source channel unsubscribed, stop balancing") } } + case *LeaderTask: + index := NewReplicaLeaderIndex(task) + if old, ok := scheduler.segmentTasks[index]; ok { + if task.Priority() > old.Priority() { + log.Info("replace old task, the new one with higher priority", + zap.Int64("oldID", old.ID()), + zap.String("oldPriority", old.Priority().String()), + zap.Int64("newID", task.ID()), + zap.String("newPriority", task.Priority().String()), + ) + old.Cancel(merr.WrapErrServiceInternal("replaced with the other one with higher priority")) + scheduler.remove(old) + return nil + } + + return merr.WrapErrServiceInternal("task with the same segment exists") + } default: panic(fmt.Sprintf("preAdd: forget to process task type: %+v", task)) } @@ -755,6 +784,11 @@ func (scheduler *taskScheduler) remove(task Task) { index := replicaChannelIndex{task.ReplicaID(), task.Channel()} delete(scheduler.channelTasks, index) log = log.With(zap.String("channel", task.Channel())) + + case *LeaderTask: + index := NewReplicaLeaderIndex(task) + delete(scheduler.segmentTasks, index) + log = log.With(zap.Int64("segmentID", task.SegmentID())) } scheduler.updateTaskMetrics() @@ -780,6 +814,11 @@ func (scheduler *taskScheduler) checkStale(task Task) error { return err } + case *LeaderTask: + if err := scheduler.checkLeaderTaskStale(task); err != nil { + return err + } + default: panic(fmt.Sprintf("checkStale: forget to check task type: %+v", task)) } @@ -865,3 +904,53 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error { } return nil } + +func (scheduler *taskScheduler) checkLeaderTaskStale(task *LeaderTask) error { + log := log.With( + zap.Int64("taskID", task.ID()), + zap.Int64("collectionID", task.CollectionID()), + zap.Int64("replicaID", task.ReplicaID()), + zap.String("source", task.Source().String()), + zap.Int64("leaderID", task.leaderID), + ) + + for _, action := range task.Actions() { + switch action.Type() { + case ActionTypeGrow: + taskType := GetTaskType(task) + var segment *datapb.SegmentInfo + if taskType == TaskTypeMove || taskType == TaskTypeUpdate { + segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget) + } else { + segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget) + } + if segment == nil { + log.Warn("task stale due to the segment to load not exists in targets", + zap.Int64("segment", task.segmentID), + zap.String("taskType", taskType.String()), + ) + return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment") + } + + replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) + if replica == nil { + log.Warn("task stale due to replica not found") + return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID") + } + + view := scheduler.distMgr.GetLeaderShardView(task.leaderID, task.Shard()) + if view == nil { + log.Warn("task stale due to leader not found") + return merr.WrapErrChannelNotFound(task.Shard(), "failed to get shard delegator") + } + + case ActionTypeReduce: + view := scheduler.distMgr.GetLeaderShardView(task.leaderID, task.Shard()) + if view == nil { + log.Warn("task stale due to leader not found") + return merr.WrapErrChannelNotFound(task.Shard(), "failed to get shard delegator") + } + } + } + return nil +} diff --git a/internal/querycoordv2/task/task.go b/internal/querycoordv2/task/task.go index 7e2fe2422e..956fe2f272 100644 --- a/internal/querycoordv2/task/task.go +++ b/internal/querycoordv2/task/task.go @@ -72,6 +72,7 @@ type Task interface { ID() typeutil.UniqueID CollectionID() typeutil.UniqueID ReplicaID() typeutil.UniqueID + Shard() string SetID(id typeutil.UniqueID) Status() Status SetStatus(status Status) @@ -162,6 +163,10 @@ func (task *baseTask) ReplicaID() typeutil.UniqueID { return task.replicaID } +func (task *baseTask) Shard() string { + return task.shard +} + func (task *baseTask) LoadType() querypb.LoadType { return task.loadType } @@ -318,10 +323,6 @@ func NewSegmentTask(ctx context.Context, }, nil } -func (task *SegmentTask) Shard() string { - return task.shard -} - func (task *SegmentTask) SegmentID() typeutil.UniqueID { return task.segmentID } @@ -383,3 +384,40 @@ func (task *ChannelTask) Index() string { func (task *ChannelTask) String() string { return fmt.Sprintf("%s [channel=%s]", task.baseTask.String(), task.Channel()) } + +type LeaderTask struct { + *baseTask + + segmentID typeutil.UniqueID + leaderID int64 +} + +func NewLeaderTask(ctx context.Context, + timeout time.Duration, + source Source, + collectionID, + replicaID typeutil.UniqueID, + leaderID int64, + action *LeaderAction, +) *LeaderTask { + segmentID := action.SegmentID() + base := newBaseTask(ctx, source, collectionID, replicaID, action.Shard(), fmt.Sprintf("LeaderTask-%s-%d", action.Type().String(), segmentID)) + base.actions = []Action{action} + return &LeaderTask{ + baseTask: base, + segmentID: segmentID, + leaderID: leaderID, + } +} + +func (task *LeaderTask) SegmentID() typeutil.UniqueID { + return task.segmentID +} + +func (task *LeaderTask) Index() string { + return fmt.Sprintf("%s[segment=%d][growing=false]", task.baseTask.Index(), task.segmentID) +} + +func (task *LeaderTask) String() string { + return fmt.Sprintf("%s [segmentID=%d][leader=%d]", task.baseTask.String(), task.segmentID, task.leaderID) +} diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index ea3ad01bbe..a013dd2a04 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -171,6 +171,8 @@ func (suite *TaskSuite) BeforeTest(suiteName, testName string) { "TestMoveSegmentTaskStale", "TestSubmitDuplicateLoadSegmentTask", "TestSubmitDuplicateSubscribeChannelTask", + "TestLeaderTaskSet", + "TestLeaderTaskRemove", "TestNoExecutor": suite.meta.PutCollection(&meta.Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ @@ -1213,6 +1215,113 @@ func (suite *TaskSuite) TestChannelTaskReplace() { suite.AssertTaskNum(0, channelNum, channelNum, 0) } +func (suite *TaskSuite) TestLeaderTaskSet() { + ctx := context.Background() + timeout := 10 * time.Second + targetNode := int64(3) + partition := int64(100) + channel := &datapb.VchannelInfo{ + CollectionID: suite.collection, + ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test", + } + + // Expect + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestLoadSegmentTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, + }, + }, nil) + suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ + { + CollectionID: suite.collection, + }, + }, nil) + for _, segment := range suite.loadSegments { + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, + }, nil) + suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) + } + suite.cluster.EXPECT().SyncDistribution(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) + + // Test load segment task + suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ + CollectionID: suite.collection, + ChannelName: channel.ChannelName, + })) + tasks := []Task{} + segments := make([]*datapb.SegmentInfo, 0) + for _, segment := range suite.loadSegments { + segments = append(segments, &datapb.SegmentInfo{ + ID: segment, + InsertChannel: channel.ChannelName, + PartitionID: 1, + }) + task := NewLeaderTask( + ctx, + timeout, + WrapIDSource(0), + suite.collection, + suite.replica, + targetNode, + NewLeaderAction(targetNode, targetNode, ActionTypeGrow, channel.GetChannelName(), segment), + ) + tasks = append(tasks, task) + err := suite.scheduler.Add(task) + suite.NoError(err) + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) + suite.target.UpdateCollectionNextTarget(suite.collection) + segmentsNum := len(suite.loadSegments) + suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) + + view := &meta.LeaderView{ + ID: targetNode, + CollectionID: suite.collection, + Channel: channel.GetChannelName(), + Segments: map[int64]*querypb.SegmentDist{}, + } + suite.dist.LeaderViewManager.Update(targetNode, view) + + // Process tasks + suite.dispatchAndWait(targetNode) + suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) + + // Process tasks done + // Dist contains channels + view = &meta.LeaderView{ + ID: targetNode, + CollectionID: suite.collection, + Channel: channel.GetChannelName(), + Segments: map[int64]*querypb.SegmentDist{}, + } + for _, segment := range suite.loadSegments { + view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0} + } + distSegments := lo.Map(segments, func(info *datapb.SegmentInfo, _ int) *meta.Segment { + return meta.SegmentFromInfo(info) + }) + suite.dist.LeaderViewManager.Update(targetNode, view) + suite.dist.SegmentDistManager.Update(targetNode, distSegments...) + suite.dispatchAndWait(targetNode) + suite.AssertTaskNum(0, 0, 0, 0) + + for _, task := range tasks { + suite.Equal(TaskStatusSucceeded, task.Status()) + suite.NoError(task.Err()) + } +} + func (suite *TaskSuite) TestCreateTaskBehavior() { chanelTask, err := NewChannelTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0) suite.ErrorIs(err, merr.ErrParameterInvalid) @@ -1244,6 +1353,10 @@ func (suite *TaskSuite) TestCreateTaskBehavior() { segmentTask, err = NewSegmentTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, segmentAction1, segmentAction2) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(segmentTask) + + leaderAction := NewLeaderAction(1, 2, ActionTypeGrow, "fake-channel1", 100) + leaderTask := NewLeaderTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, 1, leaderAction) + suite.NotNil(leaderTask) } func (suite *TaskSuite) TestSegmentTaskReplace() { @@ -1387,6 +1500,74 @@ func (suite *TaskSuite) dispatchAndWait(node int64) { suite.FailNow("executor hangs in executing tasks", "count=%d keys=%+v", count, keys) } +func (suite *TaskSuite) TestLeaderTaskRemove() { + ctx := context.Background() + timeout := 10 * time.Second + targetNode := int64(3) + partition := int64(100) + channel := &datapb.VchannelInfo{ + CollectionID: suite.collection, + ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test", + } + + // Expect + suite.cluster.EXPECT().SyncDistribution(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) + + // Test remove segment task + view := &meta.LeaderView{ + ID: targetNode, + CollectionID: suite.collection, + Channel: channel.ChannelName, + Segments: make(map[int64]*querypb.SegmentDist), + } + segments := make([]*meta.Segment, 0) + tasks := []Task{} + for _, segment := range suite.releaseSegments { + segments = append(segments, &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }) + view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0} + task := NewLeaderTask( + ctx, + timeout, + WrapIDSource(0), + suite.collection, + suite.replica, + targetNode, + NewLeaderAction(targetNode, targetNode, ActionTypeReduce, channel.GetChannelName(), segment), + ) + tasks = append(tasks, task) + err := suite.scheduler.Add(task) + suite.NoError(err) + } + suite.dist.SegmentDistManager.Update(targetNode, segments...) + suite.dist.LeaderViewManager.Update(targetNode, view) + + segmentsNum := len(suite.releaseSegments) + suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) + + // Process tasks + suite.dispatchAndWait(targetNode) + suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum) + + view.Segments = make(map[int64]*querypb.SegmentDist) + suite.dist.LeaderViewManager.Update(targetNode, view) + // Process tasks done + // suite.dist.LeaderViewManager.Update(targetNode) + suite.dispatchAndWait(targetNode) + suite.AssertTaskNum(0, 0, 0, 0) + + for _, task := range tasks { + suite.Equal(TaskStatusSucceeded, task.Status()) + suite.NoError(task.Err()) + } +} + func (suite *TaskSuite) newScheduler() *taskScheduler { return NewScheduler( context.Background(),