From 152d55dbfa9b546e598723b26b4efd8eaacab3ab Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Wed, 17 Nov 2021 23:37:11 +0800 Subject: [PATCH] Fix the deadlock of the querycoord cluster (#11796) Signed-off-by: xige-16 Co-authored-by: xige-16 --- internal/querycoord/cluster.go | 118 +++++++++++------- .../querycoord/mock_querynode_server_test.go | 6 + internal/querynode/collection.go | 9 ++ internal/querynode/task_queue.go | 12 +- 4 files changed, 99 insertions(+), 46 deletions(-) diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index 9e46cdc107..9e2bfe76ae 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -249,11 +249,15 @@ func (c *queryNodeCluster) getComponentInfos(ctx context.Context) ([]*internalpb } func (c *queryNodeCluster) loadSegments(ctx context.Context, nodeID int64, in *querypb.LoadSegmentsRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.loadSegments(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.loadSegments(ctx, in) if err != nil { log.Debug("loadSegments: queryNode load segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) return err @@ -265,15 +269,19 @@ func (c *queryNodeCluster) loadSegments(ctx context.Context, nodeID int64, in *q } func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in *querypb.ReleaseSegmentsRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - if !node.isOnline() { + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + if !targetNode.isOnline() { return errors.New("node offline") } - err := node.releaseSegments(ctx, in) + err := targetNode.releaseSegments(ctx, in) if err != nil { log.Debug("releaseSegments: queryNode release segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) return err @@ -286,11 +294,15 @@ func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in } func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.watchDmChannels(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.watchDmChannels(ctx, in) if err != nil { log.Debug("watchDmChannels: queryNode watch dm channel error", zap.String("error", err.Error())) return err @@ -313,13 +325,17 @@ func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in } func (c *queryNodeCluster) watchDeltaChannels(ctx context.Context, nodeID int64, in *querypb.WatchDeltaChannelsRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.watchDeltaChannels(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.watchDeltaChannels(ctx, in) if err != nil { - log.Debug("watchDeltaChannels: queryNode watch dm channel error", zap.String("error", err.Error())) + log.Debug("watchDeltaChannels: queryNode watch delta channel error", zap.String("error", err.Error())) return err } err = c.clusterMeta.setDeltaChannel(in.CollectionID, in.Infos) @@ -330,28 +346,34 @@ func (c *queryNodeCluster) watchDeltaChannels(ctx context.Context, nodeID int64, return nil } - return errors.New("watchDeltaChannels: Can't find query node by nodeID ") + + return fmt.Errorf("watchDeltaChannels: Can't find query node by nodeID, nodeID = %d", nodeID) } func (c *queryNodeCluster) hasWatchedDeltaChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool { - c.Lock() - defer c.Unlock() + c.RLock() + defer c.RUnlock() return c.nodes[nodeID].hasWatchedDeltaChannel(collectionID) } func (c *queryNodeCluster) hasWatchedQueryChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool { - c.Lock() - defer c.Unlock() + c.RLock() + defer c.RUnlock() return c.nodes[nodeID].hasWatchedQueryChannel(collectionID) } func (c *queryNodeCluster) addQueryChannel(ctx context.Context, nodeID int64, in *querypb.AddQueryChannelRequest) error { - c.Lock() - defer c.Unlock() + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.addQueryChannel(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.addQueryChannel(ctx, in) if err != nil { log.Debug("addQueryChannel: queryNode add query channel error", zap.String("error", err.Error())) return err @@ -362,11 +384,15 @@ func (c *queryNodeCluster) addQueryChannel(ctx context.Context, nodeID int64, in return fmt.Errorf("addQueryChannel: can't find query node by nodeID, nodeID = %d", nodeID) } func (c *queryNodeCluster) removeQueryChannel(ctx context.Context, nodeID int64, in *querypb.RemoveQueryChannelRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.removeQueryChannel(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.removeQueryChannel(ctx, in) if err != nil { log.Debug("removeQueryChannel: queryNode remove query channel error", zap.String("error", err.Error())) return err @@ -379,11 +405,15 @@ func (c *queryNodeCluster) removeQueryChannel(ctx context.Context, nodeID int64, } func (c *queryNodeCluster) releaseCollection(ctx context.Context, nodeID int64, in *querypb.ReleaseCollectionRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.releaseCollection(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.releaseCollection(ctx, in) if err != nil { log.Debug("ReleaseCollection: queryNode release collection error", zap.String("error", err.Error())) return err @@ -400,11 +430,15 @@ func (c *queryNodeCluster) releaseCollection(ctx context.Context, nodeID int64, } func (c *queryNodeCluster) releasePartitions(ctx context.Context, nodeID int64, in *querypb.ReleasePartitionsRequest) error { - c.Lock() - defer c.Unlock() - + c.RLock() + var targetNode Node if node, ok := c.nodes[nodeID]; ok { - err := node.releasePartitions(ctx, in) + targetNode = node + } + c.RUnlock() + + if targetNode != nil { + err := targetNode.releasePartitions(ctx, in) if err != nil { log.Debug("ReleasePartitions: queryNode release partitions error", zap.String("error", err.Error())) return err @@ -621,8 +655,8 @@ func (c *queryNodeCluster) removeNodeInfo(nodeID int64) error { } func (c *queryNodeCluster) stopNode(nodeID int64) { - c.Lock() - defer c.Unlock() + c.RLock() + defer c.RUnlock() if node, ok := c.nodes[nodeID]; ok { node.stop() @@ -684,8 +718,8 @@ func (c *queryNodeCluster) getOfflineNodes() (map[int64]Node, error) { } func (c *queryNodeCluster) isOnline(nodeID int64) (bool, error) { - c.Lock() - defer c.Unlock() + c.RLock() + defer c.RUnlock() if node, ok := c.nodes[nodeID]; ok { return node.isOnline(), nil diff --git a/internal/querycoord/mock_querynode_server_test.go b/internal/querycoord/mock_querynode_server_test.go index e595e7197a..cc30ea4add 100644 --- a/internal/querycoord/mock_querynode_server_test.go +++ b/internal/querycoord/mock_querynode_server_test.go @@ -21,6 +21,7 @@ import ( "errors" "net" "strconv" + "sync" "go.uber.org/zap" "google.golang.org/grpc" @@ -67,6 +68,7 @@ type queryNodeServerMock struct { getMetrics func() (*milvuspb.GetMetricsResponse, error) segmentInfos map[UniqueID]*querypb.SegmentInfo + segmentMu sync.RWMutex totalMem uint64 } @@ -212,7 +214,9 @@ func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.Lo MemSize: info.NumOfRows * int64(sizePerRecord), NumRows: info.NumOfRows, } + qs.segmentMu.Lock() qs.segmentInfos[info.SegmentID] = segmentInfo + qs.segmentMu.Unlock() } return qs.loadSegment() @@ -232,11 +236,13 @@ func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb func (qs *queryNodeServerMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { segmentInfos := make([]*querypb.SegmentInfo, 0) + qs.segmentMu.RLock() for _, info := range qs.segmentInfos { if info.CollectionID == req.CollectionID && info.NodeID == qs.queryNodeID { segmentInfos = append(segmentInfos, info) } } + qs.segmentMu.RUnlock() res, err := qs.getSegmentInfos() if err == nil { diff --git a/internal/querynode/collection.go b/internal/querynode/collection.go index e29b74c8b8..c0b3c81120 100644 --- a/internal/querynode/collection.go +++ b/internal/querynode/collection.go @@ -42,6 +42,7 @@ type Collection struct { id UniqueID partitionIDs []UniqueID schema *schemapb.CollectionSchema + channelMu sync.RWMutex vChannels []Channel pChannels []Channel @@ -88,6 +89,8 @@ func (c *Collection) removePartitionID(partitionID UniqueID) { // addVChannels add virtual channels to collection func (c *Collection) addVChannels(channels []Channel) { + c.channelMu.Lock() + defer c.channelMu.Unlock() OUTER: for _, dstChan := range channels { for _, srcChan := range c.vChannels { @@ -109,11 +112,15 @@ OUTER: // getVChannels get virtual channels of collection func (c *Collection) getVChannels() []Channel { + c.channelMu.RLock() + defer c.channelMu.RUnlock() return c.vChannels } // addPChannels add physical channels to physical channels of collection func (c *Collection) addPChannels(channels []Channel) { + c.channelMu.Lock() + defer c.channelMu.Unlock() OUTER: for _, dstChan := range channels { for _, srcChan := range c.pChannels { @@ -135,6 +142,8 @@ OUTER: // getPChannels get physical channels of collection func (c *Collection) getPChannels() []Channel { + c.channelMu.RLock() + defer c.channelMu.RUnlock() return c.pChannels } diff --git a/internal/querynode/task_queue.go b/internal/querynode/task_queue.go index 1b0827fa54..a7d046babf 100644 --- a/internal/querynode/task_queue.go +++ b/internal/querynode/task_queue.go @@ -35,7 +35,7 @@ type taskQueue interface { } type baseTaskQueue struct { - utMu sync.Mutex // guards unissuedTasks + utMu sync.RWMutex // guards unissuedTasks unissuedTasks *list.List atMu sync.Mutex // guards activeTasks @@ -58,21 +58,25 @@ func (queue *baseTaskQueue) utChan() <-chan int { } func (queue *baseTaskQueue) utEmpty() bool { + queue.utMu.RLock() + defer queue.utMu.RUnlock() return queue.unissuedTasks.Len() == 0 } func (queue *baseTaskQueue) utFull() bool { + queue.utMu.RLock() + defer queue.utMu.RUnlock() return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum } func (queue *baseTaskQueue) addUnissuedTask(t task) error { - queue.utMu.Lock() - defer queue.utMu.Unlock() - if queue.utFull() { return errors.New("task queue is full") } + queue.utMu.Lock() + defer queue.utMu.Unlock() + if queue.unissuedTasks.Len() <= 0 { queue.unissuedTasks.PushBack(t) queue.utBufChan <- 1