From dcfe472586b8e66057f8ed59a59a024c36a0a7fc Mon Sep 17 00:00:00 2001 From: yah01 Date: Thu, 19 May 2022 16:51:57 +0800 Subject: [PATCH] Fix LoadBalance doesn't save the modification to replicas' shards (#17064) Signed-off-by: yah01 --- internal/querycoord/group_balance.go | 12 +-- internal/querycoord/querynode.go | 4 +- internal/querycoord/task.go | 123 ++++++++++++++++++--------- 3 files changed, 86 insertions(+), 53 deletions(-) diff --git a/internal/querycoord/group_balance.go b/internal/querycoord/group_balance.go index 52a91aa1cf..32e6e0c3f5 100644 --- a/internal/querycoord/group_balance.go +++ b/internal/querycoord/group_balance.go @@ -36,25 +36,15 @@ func (b *replicaBalancer) addNode(nodeID int64) ([]*balancePlan, error) { continue } - offlineNodesCnt := make(map[UniqueID]int, len(replicas)) replicaAvailableMemory := make(map[UniqueID]uint64, len(replicas)) for _, replica := range replicas { - for _, nodeID := range replica.NodeIds { - if isOnline, err := b.cluster.isOnline(nodeID); err != nil || !isOnline { - offlineNodesCnt[replica.ReplicaID]++ - } - } - replicaAvailableMemory[replica.ReplicaID] = getReplicaAvailableMemory(b.cluster, replica) } sort.Slice(replicas, func(i, j int) bool { replicai := replicas[i].ReplicaID replicaj := replicas[j].ReplicaID - cnti := offlineNodesCnt[replicai] - cntj := offlineNodesCnt[replicaj] - return cnti > cntj || - cnti == cntj && replicaAvailableMemory[replicai] < replicaAvailableMemory[replicaj] + return replicaAvailableMemory[replicai] < replicaAvailableMemory[replicaj] }) ret = append(ret, &balancePlan{ diff --git a/internal/querycoord/querynode.go b/internal/querycoord/querynode.go index 0ad8fc080c..49cda5a0c6 100644 --- a/internal/querycoord/querynode.go +++ b/internal/querycoord/querynode.go @@ -408,8 +408,8 @@ func (qn *queryNode) releaseSegments(ctx context.Context, in *querypb.ReleaseSeg } func (qn *queryNode) getNodeInfo() (Node, error) { - qn.RLock() - defer qn.RUnlock() + qn.Lock() + defer qn.Unlock() if !qn.isOnline() { return nil, errors.New("getNodeInfo: queryNode is offline") diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 6b486ea528..c4ecb08ffa 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -26,6 +26,7 @@ import ( "github.com/golang/protobuf/proto" "go.uber.org/zap" + "golang.org/x/sync/errgroup" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" @@ -2038,7 +2039,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { mergedDmChannel := mergeDmChannelInfo(dmChannelInfos) for channelName, vChannelInfo := range mergedDmChannel { - if info, ok := dmChannel2WatchInfo[channelName]; ok { + if _, ok := dmChannel2WatchInfo[channelName]; ok { msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_WatchDmChannels watchRequest := &querypb.WatchDmChannelsRequest{ @@ -2051,7 +2052,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { CollectionID: collectionID, PartitionIDs: toRecoverPartitionIDs, }, - ReplicaID: info.ReplicaID, + ReplicaID: replica.ReplicaID, } if collectionInfo.LoadType == querypb.LoadType_LoadPartition { @@ -2286,19 +2287,24 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { zap.Int("len(segments)", len(segments)), zap.Int64("trigger task ID", lbt.getTaskID()), ) - wg := sync.WaitGroup{} + // Remove offline nodes from replica + wg := errgroup.Group{} if lbt.triggerCondition == querypb.TriggerCondition_NodeDown { offlineNodes := make(typeutil.UniqueSet, len(lbt.SourceNodeIDs)) for _, nodeID := range lbt.SourceNodeIDs { offlineNodes.Insert(nodeID) } - for _, replica := range replicas { - wg.Add(1) - go func(replica *milvuspb.ReplicaInfo) { - defer wg.Done() + log.Debug("removing offline nodes from replicas and segments...", + zap.Int("len(replicas)", len(replicas)), + zap.Int("len(segments)", len(segments)), + zap.Int64("trigger task ID", lbt.getTaskID()), + ) + for _, replica := range replicas { + replica := replica + wg.Go(func() error { onlineNodes := make([]UniqueID, 0, len(replica.NodeIds)) for _, nodeID := range replica.NodeIds { if !offlineNodes.Contain(nodeID) { @@ -2309,21 +2315,20 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { err := lbt.meta.setReplicaInfo(replica) if err != nil { - log.Warn("failed to remove offline nodes from replica info", + log.Error("failed to remove offline nodes from replica info", zap.Int64("replicaID", replica.ReplicaID), zap.Error(err)) + return err } - }(replica) + + return nil + }) } } - // Update the nodes list of segment, only remove the source nodes, - // adding destination nodes will be executed by updateSegmentInfoFromTask() for _, segment := range segments { - wg.Add(1) - go func(segment *querypb.SegmentInfo) { - defer wg.Done() - + segment := segment + wg.Go(func() error { segment.NodeID = -1 segment.NodeIds = removeFromSlice(segment.NodeIds, lbt.SourceNodeIDs...) if len(segment.NodeIds) > 0 { @@ -2332,38 +2337,76 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { err := lbt.meta.saveSegmentInfo(segment) if err != nil { - log.Warn("failed to remove offline nodes from segment info", + log.Error("failed to remove offline nodes from segment info", zap.Int64("segmentID", segment.SegmentID), zap.Error(err)) - } - }(segment) - } - wg.Wait() - err := syncReplicaSegments(ctx, lbt.cluster, lbt.getChildTask()) + return err + } + + return nil + }) + } + for _, childTask := range lbt.getChildTask() { + if task, ok := childTask.(*watchDmChannelTask); ok { + wg.Go(func() error { + nodeInfo, err := lbt.cluster.getNodeInfoByID(task.NodeID) + if err != nil { + log.Error("failed to get node info to update shard leader info", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.Int64("nodeID", task.NodeID), + zap.String("dmChannel", task.Infos[0].ChannelName), + zap.Error(err)) + return err + } + + replica, err := lbt.meta.getReplicaByID(task.ReplicaID) + if err != nil { + log.Error("failed to get replica to update shard leader info", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.Int64("replicaID", task.ReplicaID), + zap.String("dmChannel", task.Infos[0].ChannelName), + zap.Error(err)) + return err + } + + for _, shard := range replica.ShardReplicas { + if shard.DmChannelName == task.Infos[0].ChannelName { + log.Debug("LoadBalance: update shard leader", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.Int64("oldLeader", shard.LeaderID), + zap.Int64("newLeader", task.NodeID)) + shard.LeaderID = task.NodeID + shard.LeaderAddr = nodeInfo.(*queryNode).address + break + } + } + + err = lbt.meta.setReplicaInfo(replica) + if err != nil { + log.Error("failed to remove offline nodes from replica info", + zap.Int64("triggerTaskID", lbt.getTaskID()), + zap.Int64("taskID", task.getTaskID()), + zap.Int64("replicaID", replica.ReplicaID), + zap.Error(err)) + return err + } + + return nil + }) + } + } + err := wg.Wait() if err != nil { return err } - for _, childTask := range lbt.getChildTask() { - if task, ok := childTask.(*watchDmChannelTask); ok { - nodeInfo, err := lbt.cluster.getNodeInfoByID(task.NodeID) - if err != nil { - return err - } - - replica, err := lbt.meta.getReplicaByID(task.ReplicaID) - if err != nil { - return err - } - - for _, shard := range replica.ShardReplicas { - if shard.DmChannelName == task.Infos[0].ChannelName { - shard.LeaderID = task.NodeID - shard.LeaderAddr = nodeInfo.(*queryNode).address - } - } - } + err = syncReplicaSegments(ctx, lbt.cluster, lbt.getChildTask()) + if err != nil { + return err } }