From 2d0f908dbafeb7439a5bbc2e35d9b5726cf33bce Mon Sep 17 00:00:00 2001 From: yah01 Date: Tue, 10 May 2022 15:47:53 +0800 Subject: [PATCH] Fix updating segments' NodeIds correctly after LoadBalance (#16854) Signed-off-by: yah01 --- internal/querycoord/meta.go | 5 +++ internal/querycoord/segments_info.go | 4 +-- internal/querycoord/task.go | 52 ++++++++++++++++++++------- internal/querycoord/task_scheduler.go | 12 +++++-- internal/querycoord/task_test.go | 42 ++++++++++++++++++---- internal/querycoord/util.go | 8 +++++ 6 files changed, 98 insertions(+), 25 deletions(-) diff --git a/internal/querycoord/meta.go b/internal/querycoord/meta.go index d83ade90aa..9c7cafd9f9 100644 --- a/internal/querycoord/meta.go +++ b/internal/querycoord/meta.go @@ -76,6 +76,7 @@ type Meta interface { getSegmentInfoByID(segmentID UniqueID) (*querypb.SegmentInfo, error) getSegmentInfosByNode(nodeID int64) []*querypb.SegmentInfo getSegmentInfosByNodeAndCollection(nodeID, collectionID int64) []*querypb.SegmentInfo + saveSegmentInfo(segment *querypb.SegmentInfo) error getPartitionStatesByID(collectionID UniqueID, partitionID UniqueID) (*querypb.PartitionStates, error) @@ -880,6 +881,10 @@ func (m *MetaReplica) getSegmentInfosByNodeAndCollection(nodeID, collectionID in return res } +func (m *MetaReplica) saveSegmentInfo(segment *querypb.SegmentInfo) error { + return m.segmentsInfo.saveSegment(segment) +} + func (m *MetaReplica) getCollectionInfoByID(collectionID UniqueID) (*querypb.CollectionInfo, error) { m.collectionMu.RLock() defer m.collectionMu.RUnlock() diff --git a/internal/querycoord/segments_info.go b/internal/querycoord/segments_info.go index fc16196dcb..ae41a275f0 100644 --- a/internal/querycoord/segments_info.go +++ b/internal/querycoord/segments_info.go @@ -107,7 +107,7 @@ func (s *segmentsInfo) removeSegment(segment *querypb.SegmentInfo) error { func (s *segmentsInfo) getSegment(ID int64) *querypb.SegmentInfo { s.mu.RLock() defer s.mu.RUnlock() - return s.segmentIDMap[ID] + return proto.Clone(s.segmentIDMap[ID]).(*querypb.SegmentInfo) } func (s *segmentsInfo) getSegments() []*querypb.SegmentInfo { @@ -115,7 +115,7 @@ func (s *segmentsInfo) getSegments() []*querypb.SegmentInfo { defer s.mu.RUnlock() res := make([]*querypb.SegmentInfo, 0, len(s.segmentIDMap)) for _, segment := range s.segmentIDMap { - res = append(res, segment) + res = append(res, proto.Clone(segment).(*querypb.SegmentInfo)) } return res } diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 8676dc0f7e..dadd01c9c3 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -2235,12 +2235,12 @@ func (lbt *loadBalanceTask) postExecute(context.Context) error { // then the queryCoord will panic, and the nodeInfo should not be removed immediately // after queryCoord recovery, the balanceTask will redo if lbt.triggerCondition == querypb.TriggerCondition_NodeDown && lbt.getResultInfo().ErrorCode == commonpb.ErrorCode_Success { - for _, id := range lbt.SourceNodeIDs { - err := lbt.cluster.removeNodeInfo(id) + for _, offlineNodeID := range lbt.SourceNodeIDs { + err := lbt.cluster.removeNodeInfo(offlineNodeID) if err != nil { //TODO:: clear node info after removeNodeInfo failed log.Warn("loadBalanceTask: occur error when removing node info from cluster", - zap.Int64("nodeID", id), + zap.Int64("nodeID", offlineNodeID), zap.Error(err)) continue } @@ -2263,21 +2263,27 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { offlineNodes.Insert(nodeID) } replicas := make(map[UniqueID]*milvuspb.ReplicaInfo) + segments := make(map[UniqueID]*querypb.SegmentInfo) for _, id := range lbt.SourceNodeIDs { - replica, err := lbt.getReplica(id, lbt.CollectionID) - if err != nil { - log.Warn("failed to get replica for removing offline querynode from it", - zap.Int64("querynodeID", id), - zap.Int64("collectionID", lbt.CollectionID), - zap.Error(err)) - continue + for _, segment := range lbt.meta.getSegmentInfosByNode(id) { + segments[segment.SegmentID] = segment + } + + nodeReplicas, err := lbt.meta.getReplicasByNodeID(id) + if err != nil { + log.Warn("failed to get replicas for removing offline querynode from it", + zap.Int64("querynodeID", id), + zap.Error(err)) + } + for _, replica := range nodeReplicas { + replicas[replica.ReplicaID] = replica } - replicas[replica.ReplicaID] = replica } - log.Debug("removing offline nodes from replicas...", - zap.Int("len(replicas)", len(replicas))) + log.Debug("removing offline nodes from replicas and segments...", + zap.Int("len(replicas)", len(replicas)), + zap.Int("len(segments)", len(segments))) wg := sync.WaitGroup{} for _, replica := range replicas { wg.Add(1) @@ -2300,6 +2306,26 @@ func (lbt *loadBalanceTask) globalPostExecute(ctx context.Context) error { } }(replica) } + + for _, segment := range segments { + wg.Add(1) + go func(segment *querypb.SegmentInfo) { + defer wg.Done() + + segment.NodeID = -1 + segment.NodeIds = removeFromSlice(segment.NodeIds, lbt.SourceNodeIDs...) + if len(segment.NodeIds) > 0 { + segment.NodeID = segment.NodeIds[0] + } + + err := lbt.meta.saveSegmentInfo(segment) + if err != nil { + log.Warn("failed to remove offline nodes from segment info", + zap.Int64("segmentID", segment.SegmentID), + zap.Error(err)) + } + }(segment) + } wg.Wait() } diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index f73b746718..4f3338e1cf 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -903,6 +903,7 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) req := triggerTask.(*releaseCollectionTask).ReleaseCollectionRequest collectionID := req.CollectionID sealedSegmentChangeInfos, err = meta.removeGlobalSealedSegInfos(collectionID, nil) + case commonpb.MsgType_ReleasePartitions: // release all segmentInfo of the partitions when release partitions req := triggerTask.(*releasePartitionTask).ReleasePartitionsRequest @@ -917,6 +918,7 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) } } sealedSegmentChangeInfos, err = meta.removeGlobalSealedSegInfos(collectionID, req.PartitionIDs) + default: // save new segmentInfo when load segment segments := make(map[UniqueID]*querypb.SegmentInfo) @@ -929,8 +931,8 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) collectionID := loadInfo.CollectionID segmentID := loadInfo.SegmentID - segment, saved := segments[segmentID] - if !saved { + segment, err := meta.getSegmentInfoByID(segmentID) + if err != nil { segment = &querypb.SegmentInfo{ SegmentID: segmentID, CollectionID: loadInfo.CollectionID, @@ -942,11 +944,15 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) ReplicaIds: []UniqueID{req.ReplicaID}, NodeIds: []UniqueID{dstNodeID}, } - segments[segmentID] = segment } else { segment.ReplicaIds = append(segment.ReplicaIds, req.ReplicaID) + segment.ReplicaIds = removeFromSlice(segment.GetReplicaIds()) + segment.NodeIds = append(segment.NodeIds, dstNodeID) + segment.NodeID = dstNodeID } + _, saved := segments[segmentID] + segments[segmentID] = segment if _, ok := segmentInfosToSave[collectionID]; !ok { segmentInfosToSave[collectionID] = make([]*querypb.SegmentInfo, 0) diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index b987e30d88..afff20030c 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -19,9 +19,12 @@ package querycoord import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -785,6 +788,7 @@ func Test_reverseSealedSegmentChangeInfo(t *testing.T) { node1, err := startQueryNodeServer(ctx) assert.Nil(t, err) waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + defer node1.stop() loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) queryCoord.scheduler.Enqueue(loadCollectionTask) @@ -793,6 +797,7 @@ func Test_reverseSealedSegmentChangeInfo(t *testing.T) { node2, err := startQueryNodeServer(ctx) assert.Nil(t, err) waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + defer node2.stop() loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node2.queryNodeID) parentTask := loadSegmentTask.parentTask @@ -857,6 +862,7 @@ func Test_handoffSegmentFail(t *testing.T) { waitTaskFinalState(handoffTask, taskFailed) + node1.stop() queryCoord.Stop() err = removeAllSession() assert.Nil(t, err) @@ -1048,32 +1054,54 @@ func TestLoadBalanceIndexedSegmentsAfterNodeDown(t *testing.T) { ctx := context.Background() queryCoord, err := startQueryCoord(ctx) assert.Nil(t, err) + defer queryCoord.Stop() node1, err := startQueryNodeServer(ctx) assert.Nil(t, err) waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) loadCollectionTask := genLoadCollectionTask(ctx, queryCoord) - err = queryCoord.scheduler.Enqueue(loadCollectionTask) assert.Nil(t, err) waitTaskFinalState(loadCollectionTask, taskExpired) + segments := queryCoord.meta.getSegmentInfosByNode(node1.queryNodeID) + log.Debug("get segments by node", + zap.Int64("nodeID", node1.queryNodeID), + zap.Any("segments", segments)) + + rootCoord := queryCoord.rootCoordClient.(*rootCoordMock) + rootCoord.enableIndex = true + + node1.stop() node2, err := startQueryNodeServer(ctx) assert.Nil(t, err) waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) + defer node2.stop() - rootCoord := queryCoord.rootCoordClient.(*rootCoordMock) - rootCoord.enableIndex = true - removeNodeSession(node1.queryNodeID) for { - if len(queryCoord.meta.getSegmentInfosByNode(node1.queryNodeID)) == 0 { + segments := queryCoord.meta.getSegmentInfosByNode(node1.queryNodeID) + if len(segments) == 0 { break } + log.Debug("node still has segments", + zap.Int64("nodeID", node1.queryNodeID)) + time.Sleep(200 * time.Millisecond) + } + + for { + segments := queryCoord.meta.getSegmentInfosByNode(node2.queryNodeID) + if len(segments) != 0 { + log.Debug("get segments by node", + zap.Int64("nodeID", node2.queryNodeID), + zap.Any("segments", segments)) + break + } + log.Debug("node hasn't segments", + zap.Int64("nodeID", node2.queryNodeID)) + time.Sleep(200 * time.Millisecond) } - node2.stop() - queryCoord.Stop() err = removeAllSession() assert.Nil(t, err) } diff --git a/internal/querycoord/util.go b/internal/querycoord/util.go index c90529c215..3c3b5f2679 100644 --- a/internal/querycoord/util.go +++ b/internal/querycoord/util.go @@ -199,3 +199,11 @@ func syncReplicaSegments(ctx context.Context, cluster Cluster, childTasks []task return nil } + +func removeFromSlice(origin []UniqueID, del ...UniqueID) []UniqueID { + set := make(typeutil.UniqueSet, len(origin)) + set.Insert(origin...) + set.Remove(del...) + + return set.Collect() +}