diff --git a/internal/querycoord/segment_allocator.go b/internal/querycoord/segment_allocator.go index ec93b4f814..3c3f9e776d 100644 --- a/internal/querycoord/segment_allocator.go +++ b/internal/querycoord/segment_allocator.go @@ -123,7 +123,12 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme if err != nil { return err } - onlineNodeIDs = replica.GetNodeIds() + replicaNodes := replica.GetNodeIds() + for _, nodeID := range replicaNodes { + if ok, err := cluster.isOnline(nodeID); err == nil && ok { + onlineNodeIDs = append(onlineNodeIDs, nodeID) + } + } } if len(onlineNodeIDs) == 0 && !wait { err := errors.New("no online queryNode to allocate") diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index e1be829110..1f473fd868 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -2206,13 +2206,58 @@ func (lbt *loadBalanceTask) postExecute(context.Context) error { // then the queryCoord will panic, and the nodeInfo should not be removed immediately // after queryCoord recovery, the balanceTask will redo if lbt.triggerCondition == querypb.TriggerCondition_NodeDown && lbt.getResultInfo().ErrorCode == commonpb.ErrorCode_Success { + offlineNodes := make(map[UniqueID]struct{}, len(lbt.SourceNodeIDs)) + for _, nodeID := range lbt.SourceNodeIDs { + offlineNodes[nodeID] = struct{}{} + } + replicas := make(map[UniqueID]*milvuspb.ReplicaInfo) + for _, id := range lbt.SourceNodeIDs { err := lbt.cluster.removeNodeInfo(id) if err != nil { //TODO:: clear node info after removeNodeInfo failed - log.Error("loadBalanceTask: occur error when removing node info from cluster", zap.Int64("nodeID", id)) + log.Warn("loadBalanceTask: occur error when removing node info from cluster", + zap.Int64("nodeID", id), + zap.Error(err)) + continue } + + replica, err := lbt.getReplica(id, lbt.CollectionID) + if err != nil { + log.Warn("failed to get replica for removing offline querynode from it", + zap.Int64("querynodeID", id), + zap.Int64("collectionID", lbt.CollectionID), + zap.Error(err)) + continue + } + replicas[replica.ReplicaID] = replica } + + log.Debug("removing offline nodes from replicas...", + zap.Int("len(replicas)", len(replicas))) + wg := sync.WaitGroup{} + for _, replica := range replicas { + wg.Add(1) + go func(replica *milvuspb.ReplicaInfo) { + defer wg.Done() + + onlineNodes := make([]UniqueID, 0, len(replica.NodeIds)) + for _, nodeID := range replica.NodeIds { + if _, ok := offlineNodes[nodeID]; !ok { + onlineNodes = append(onlineNodes, nodeID) + } + } + replica.NodeIds = onlineNodes + + err := lbt.meta.setReplicaInfo(replica) + if err != nil { + log.Warn("failed to remove offline nodes from replica info", + zap.Int64("replicaID", replica.ReplicaID), + zap.Error(err)) + } + }(replica) + } + wg.Wait() } log.Info("loadBalanceTask postExecute done", diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 7503df82be..75a337f08a 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -925,7 +925,7 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { }) t.Run("Test LoadBalanceByNode", func(t *testing.T) { - baseTask := newBaseTask(ctx, querypb.TriggerCondition_LoadBalance) + baseTask := newBaseTask(ctx, querypb.TriggerCondition_NodeDown) loadBalanceTask := &loadBalanceTask{ baseTask: baseTask, LoadBalanceRequest: &querypb.LoadBalanceRequest{ @@ -934,6 +934,7 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { }, SourceNodeIDs: []int64{node1.queryNodeID}, CollectionID: defaultCollectionID, + BalanceReason: querypb.TriggerCondition_NodeDown, }, broker: queryCoord.broker, cluster: queryCoord.cluster, @@ -942,6 +943,7 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) waitTaskFinalState(loadBalanceTask, taskExpired) + assert.Equal(t, commonpb.ErrorCode_Success, loadBalanceTask.result.ErrorCode) }) t.Run("Test LoadBalanceWithEmptySourceNode", func(t *testing.T) {