From 41d9ab3d78bc46d031fedaf85119090d79f4ae23 Mon Sep 17 00:00:00 2001 From: Xiaofan <83447078+xiaofan-luan@users.noreply.github.com> Date: Mon, 17 Apr 2023 01:32:33 -0700 Subject: [PATCH] Fix Dead lock in shard manager (#23446) Signed-off-by: xiaofan-luan --- internal/querynode/distribution.go | 81 +++++++++--------------- internal/querynode/distribution_test.go | 70 ++++++++++++-------- internal/querynode/impl_utils_test.go | 16 ++--- internal/querynode/shard_cluster.go | 76 +++++++++++----------- internal/querynode/shard_cluster_test.go | 2 +- internal/querynode/snapshot.go | 3 +- 6 files changed, 122 insertions(+), 126 deletions(-) diff --git a/internal/querynode/distribution.go b/internal/querynode/distribution.go index 431d214094..6bfd8e18b6 100644 --- a/internal/querynode/distribution.go +++ b/internal/querynode/distribution.go @@ -19,6 +19,7 @@ package querynode import ( "context" "sync" + "time" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -42,9 +43,6 @@ type distribution struct { // version indicator version int64 - // offline is the quick healthy check indicator for offline segments - offlines *atomic.Int32 - snapshots *typeutil.ConcurrentMap[int64, *snapshot] // current is the snapshot for quick usage for search/query // generated for each change of distribution @@ -77,7 +75,6 @@ func NewDistribution(replicaID int64) *distribution { replicaID: replicaID, sealedSegments: make(map[UniqueID]SegmentEntry), snapshots: typeutil.NewConcurrentMap[int64, *snapshot](), - offlines: atomic.NewInt32(0), current: atomic.NewPointer[snapshot](nil), } @@ -89,16 +86,29 @@ func (d *distribution) getLogger() *log.MLogger { return log.Ctx(context.Background()).With(zap.Int64("replicaID", d.replicaID)) } -// Serviceable returns whether all segment recorded is in loaded state. func (d *distribution) Serviceable() bool { - return d.offlines.Load() == 0 + d.mut.RLock() + defer d.mut.RUnlock() + return d.serviceableImpl() +} + +// Serviceable returns whether all segment recorded is in loaded state, hold d.mut before call it +func (d *distribution) serviceableImpl() bool { + for _, entry := range d.sealedSegments { + if entry.State != segmentStateLoaded { + return false + } + } + return true } // GetCurrent returns current snapshot. func (d *distribution) GetCurrent(partitions ...int64) (sealed []SnapshotItem, version int64) { d.mut.RLock() defer d.mut.RUnlock() - + if !d.serviceableImpl() { + return nil, -1 + } current := d.current.Load() sealed = current.Get(partitions...) version = current.version @@ -142,14 +152,15 @@ func (d *distribution) UpdateDistribution(entries ...SegmentEntry) { for _, entry := range entries { old, ok := d.sealedSegments[entry.SegmentID] + d.getLogger().Info("Update distribution", zap.Int64("segmentID", entry.SegmentID), + zap.Int64("node", entry.NodeID), + zap.Bool("segment exist", ok)) if !ok { d.sealedSegments[entry.SegmentID] = entry - if entry.State == segmentStateOffline { - d.offlines.Add(1) - } continue } - d.updateSegment(old, entry) + old.Update(entry) + d.sealedSegments[old.SegmentID] = old } d.genSnapshot() @@ -160,43 +171,14 @@ func (d *distribution) NodeDown(nodeID int64) { d.mut.Lock() defer d.mut.Unlock() - var delta int32 - + d.getLogger().Info("handle node down", zap.Int64("node", nodeID)) for _, entry := range d.sealedSegments { if entry.NodeID == nodeID && entry.State != segmentStateOffline { entry.State = segmentStateOffline d.sealedSegments[entry.SegmentID] = entry - delta++ + d.getLogger().Info("update the segment to offline since nodeDown", zap.Int64("nodeID", nodeID), zap.Int64("segmentID", entry.SegmentID)) } } - - if delta != 0 { - d.offlines.Add(delta) - d.getLogger().Info("distribution updated since nodeDown", zap.Int32("delta", delta), zap.Int32("offlines", d.offlines.Load()), zap.Int64("nodeID", nodeID)) - } -} - -// updateSegment update segment entry value and offline segment number based on old/new state. -func (d *distribution) updateSegment(old, new SegmentEntry) { - delta := int32(0) - switch { - case old.State != segmentStateLoaded && new.State == segmentStateLoaded: - delta = -1 - case old.State == segmentStateLoaded && new.State != segmentStateLoaded: - delta = 1 - } - - old.Update(new) - d.sealedSegments[old.SegmentID] = old - if delta != 0 { - d.offlines.Add(delta) - d.getLogger().Info("distribution updated since segment update", - zap.Int32("delta", delta), - zap.Int32("offlines", d.offlines.Load()), - zap.Int64("segmentID", new.SegmentID), - zap.Int32("state", int32(new.State)), - ) - } } // RemoveDistributions remove segments distributions and returns the clear signal channel, @@ -204,32 +186,28 @@ func (d *distribution) updateSegment(old, new SegmentEntry) { func (d *distribution) RemoveDistributions(releaseFn func(), sealedSegments ...SegmentEntry) { d.mut.Lock() defer d.mut.Unlock() - - var delta int32 for _, sealed := range sealedSegments { entry, ok := d.sealedSegments[sealed.SegmentID] + d.getLogger().Info("Remove distribution", zap.Int64("segmentID", sealed.SegmentID), + zap.Int64("node", sealed.NodeID), + zap.Bool("segment exist", ok)) if !ok { continue } if entry.NodeID == sealed.NodeID || sealed.NodeID == wildcardNodeID { - if entry.State == segmentStateOffline { - delta-- - } delete(d.sealedSegments, sealed.SegmentID) } } - - d.offlines.Add(delta) - + ts := time.Now() <-d.genSnapshot() releaseFn() + d.getLogger().Info("successfully remove distribution", zap.Any("segments", sealedSegments), zap.Duration("time", time.Since(ts))) } // getSnapshot converts current distribution to snapshot format. // in which, user could juse found nodeID=>segmentID list. // mutex RLock is required before calling this method. func (d *distribution) genSnapshot() chan struct{} { - nodeSegments := make(map[int64][]SegmentEntry) for _, entry := range d.sealedSegments { nodeSegments[entry.NodeID] = append(nodeSegments[entry.NodeID], entry) @@ -260,6 +238,7 @@ func (d *distribution) genSnapshot() chan struct{} { return ch } + d.getLogger().Info("gen snapshot for version", zap.Any("version", d.version), zap.Any("is serviceable", d.serviceableImpl())) last.Expire(d.getCleanup(last.version)) return last.cleared diff --git a/internal/querynode/distribution_test.go b/internal/querynode/distribution_test.go index 3a274ad248..b9a5d678fb 100644 --- a/internal/querynode/distribution_test.go +++ b/internal/querynode/distribution_test.go @@ -50,10 +50,12 @@ func (s *DistributionSuite) TestAddDistribution() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 2, + State: segmentStateLoaded, }, }, expected: []SnapshotItem{ @@ -63,10 +65,12 @@ func (s *DistributionSuite) TestAddDistribution() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 2, + State: segmentStateLoaded, }, }, }, @@ -78,14 +82,17 @@ func (s *DistributionSuite) TestAddDistribution() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 2, SegmentID: 2, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 3, + State: segmentStateLoaded, }, }, expected: []SnapshotItem{ @@ -95,11 +102,13 @@ func (s *DistributionSuite) TestAddDistribution() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 3, + State: segmentStateLoaded, }, }, }, @@ -109,6 +118,7 @@ func (s *DistributionSuite) TestAddDistribution() { { NodeID: 2, SegmentID: 2, + State: segmentStateLoaded, }, }, }, @@ -161,13 +171,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { tag: "remove with no read", presetSealed: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, - {NodeID: 2, SegmentID: 2}, - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, removalSealed: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, }, withMockRead: false, @@ -176,13 +186,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { NodeID: 1, Segments: []SegmentEntry{ - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, }, { NodeID: 2, Segments: []SegmentEntry{ - {NodeID: 2, SegmentID: 2}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, }, }, }, @@ -190,13 +200,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { tag: "remove with wrong nodeID", presetSealed: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, - {NodeID: 2, SegmentID: 2}, - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, removalSealed: []SegmentEntry{ - {NodeID: 2, SegmentID: 1}, + {NodeID: 2, SegmentID: 1, State: segmentStateLoaded}, }, withMockRead: false, @@ -205,14 +215,14 @@ func (s *DistributionSuite) TestRemoveDistribution() { { NodeID: 1, Segments: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, }, { NodeID: 2, Segments: []SegmentEntry{ - {NodeID: 2, SegmentID: 2}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, }, }, }, @@ -220,13 +230,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { tag: "remove with wildcardNodeID", presetSealed: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, - {NodeID: 2, SegmentID: 2}, - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, removalSealed: []SegmentEntry{ - {NodeID: wildcardNodeID, SegmentID: 1}, + {NodeID: wildcardNodeID, SegmentID: 1, State: segmentStateLoaded}, }, withMockRead: false, @@ -235,13 +245,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { NodeID: 1, Segments: []SegmentEntry{ - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, }, { NodeID: 2, Segments: []SegmentEntry{ - {NodeID: 2, SegmentID: 2}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, }, }, }, @@ -249,13 +259,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { tag: "remove with read", presetSealed: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, - {NodeID: 2, SegmentID: 2}, - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, removalSealed: []SegmentEntry{ - {NodeID: 1, SegmentID: 1}, + {NodeID: 1, SegmentID: 1, State: segmentStateLoaded}, }, withMockRead: true, @@ -264,13 +274,13 @@ func (s *DistributionSuite) TestRemoveDistribution() { { NodeID: 1, Segments: []SegmentEntry{ - {NodeID: 1, SegmentID: 3}, + {NodeID: 1, SegmentID: 3, State: segmentStateLoaded}, }, }, { NodeID: 2, Segments: []SegmentEntry{ - {NodeID: 2, SegmentID: 2}, + {NodeID: 2, SegmentID: 2, State: segmentStateLoaded}, }, }, }, @@ -407,10 +417,12 @@ func (s *DistributionSuite) TestPeek() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 2, + State: segmentStateLoaded, }, }, expected: []SnapshotItem{ @@ -420,10 +432,12 @@ func (s *DistributionSuite) TestPeek() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 2, + State: segmentStateLoaded, }, }, }, @@ -435,14 +449,17 @@ func (s *DistributionSuite) TestPeek() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 2, SegmentID: 2, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 3, + State: segmentStateLoaded, }, }, expected: []SnapshotItem{ @@ -452,11 +469,13 @@ func (s *DistributionSuite) TestPeek() { { NodeID: 1, SegmentID: 1, + State: segmentStateLoaded, }, { NodeID: 1, SegmentID: 3, + State: segmentStateLoaded, }, }, }, @@ -466,6 +485,7 @@ func (s *DistributionSuite) TestPeek() { { NodeID: 2, SegmentID: 2, + State: segmentStateLoaded, }, }, }, diff --git a/internal/querynode/impl_utils_test.go b/internal/querynode/impl_utils_test.go index 57c0593dc5..c493bb623f 100644 --- a/internal/querynode/impl_utils_test.go +++ b/internal/querynode/impl_utils_test.go @@ -128,7 +128,7 @@ func (s *ImplUtilsSuite) TestTransferLoad() { s.Run("transfer load fail", func() { cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName) s.Require().True(ok) - cs.nodes[100] = &shardNode{ + cs.nodes.InsertIfNotPresent(100, &shardNode{ nodeID: 100, nodeAddr: "test", client: &mockShardQueryNode{ @@ -137,7 +137,7 @@ func (s *ImplUtilsSuite) TestTransferLoad() { Reason: "error", }, }, - } + }) status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ @@ -161,8 +161,8 @@ func (s *ImplUtilsSuite) TestTransferLoad() { s.Run("insufficient memory", func() { cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName) s.Require().True(ok) - cs.nodes[100] = &shardNode{ - nodeID: 100, + cs.nodes.InsertIfNotPresent(101, &shardNode{ + nodeID: 101, nodeAddr: "test", client: &mockShardQueryNode{ loadSegmentsResults: &commonpb.Status{ @@ -170,13 +170,13 @@ func (s *ImplUtilsSuite) TestTransferLoad() { Reason: "mock InsufficientMemoryToLoad", }, }, - } + }) status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ TargetID: s.querynode.session.ServerID, }, - DstNodeID: 100, + DstNodeID: 101, Infos: []*querypb.SegmentLoadInfo{ { SegmentID: defaultSegmentID, @@ -227,7 +227,7 @@ func (s *ImplUtilsSuite) TestTransferRelease() { s.Run("transfer release fail", func() { cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName) s.Require().True(ok) - cs.nodes[100] = &shardNode{ + cs.nodes.InsertIfNotPresent(100, &shardNode{ nodeID: 100, nodeAddr: "test", client: &mockShardQueryNode{ @@ -235,7 +235,7 @@ func (s *ImplUtilsSuite) TestTransferRelease() { ErrorCode: commonpb.ErrorCode_UnexpectedError, }, }, - } + }) status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{ Base: &commonpb.MsgBase{ diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index e5be2529ea..87804fa35b 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -62,7 +62,6 @@ const ( type segmentState int32 const ( - segmentStateNone segmentState = 0 segmentStateOffline segmentState = 1 segmentStateLoading segmentState = 2 segmentStateLoaded segmentState = 3 @@ -145,10 +144,9 @@ type ShardCluster struct { segmentDetector ShardSegmentDetector nodeBuilder ShardNodeBuilder - mut sync.RWMutex - leader *shardNode // shard leader node instance - nodes map[int64]*shardNode // online nodes - + mut sync.RWMutex + leader *shardNode // shard leader node instance + nodes *typeutil.ConcurrentMap[int64, *shardNode] // online nodes mutVersion sync.RWMutex distribution *distribution @@ -170,8 +168,7 @@ func NewShardCluster(collectionID int64, replicaID int64, vchannelName string, v segmentDetector: segmentDetector, nodeDetector: nodeDetector, nodeBuilder: nodeBuilder, - - nodes: make(map[int64]*shardNode), + nodes: typeutil.NewConcurrentMap[int64, *shardNode](), closeCh: make(chan struct{}), } @@ -203,11 +200,10 @@ func (sc *ShardCluster) getLogger() *log.MLogger { ) } -// serviceable returns whether shard cluster could provide query service. +// serviceable returns whether shard cluster could provide query service, used only for test func (sc *ShardCluster) serviceable() bool { sc.mutVersion.RLock() defer sc.mutVersion.RUnlock() - return sc.distribution != nil && sc.distribution.Serviceable() } @@ -218,12 +214,13 @@ func (sc *ShardCluster) addNode(evt nodeEvent) { sc.mut.Lock() defer sc.mut.Unlock() - oldNode, ok := sc.nodes[evt.nodeID] + oldNode, ok := sc.nodes.Get(evt.nodeID) if ok { if oldNode.nodeAddr == evt.nodeAddr { log.Warn("ShardCluster add same node, skip", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr)) return } + sc.nodes.GetAndRemove(evt.nodeID) defer oldNode.client.Stop() } @@ -232,7 +229,7 @@ func (sc *ShardCluster) addNode(evt nodeEvent) { nodeAddr: evt.nodeAddr, client: sc.nodeBuilder(evt.nodeID, evt.nodeAddr), } - sc.nodes[evt.nodeID] = node + sc.nodes.InsertIfNotPresent(evt.nodeID, node) if evt.isLeader { sc.leader = node } @@ -245,15 +242,13 @@ func (sc *ShardCluster) removeNode(evt nodeEvent) { sc.mut.Lock() defer sc.mut.Unlock() - old, ok := sc.nodes[evt.nodeID] + old, ok := sc.nodes.GetAndRemove(evt.nodeID) if !ok { log.Warn("ShardCluster removeNode does not belong to it", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr)) return } defer old.client.Stop() - delete(sc.nodes, evt.nodeID) - sc.distribution.NodeDown(evt.nodeID) } @@ -451,13 +446,7 @@ func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) { // getNode returns shallow copy of shardNode func (sc *ShardCluster) getNode(nodeID int64) (*shardNode, bool) { - sc.mut.RLock() - defer sc.mut.RUnlock() - return sc.getNodeImpl(nodeID) -} - -func (sc *ShardCluster) getNodeImpl(nodeID int64) (*shardNode, bool) { - node, ok := sc.nodes[nodeID] + node, ok := sc.nodes.Get(nodeID) if !ok { return nil, false } @@ -485,8 +474,10 @@ func (sc *ShardCluster) getSegment(segmentID int64) (shardSegmentInfo, bool) { // segmentAllocations returns node to segments mappings. // calling this function also increases the reference count of related segments. func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) (map[int64][]int64, int64) { - if !sc.serviceable() { - return nil, 0 + sc.mutVersion.RLock() + defer sc.mutVersion.RUnlock() + if sc.distribution == nil { + return nil, -1 } items, version := sc.distribution.GetCurrent(partitionIDs...) return lo.SliceToMap(items, func(item SnapshotItem) (int64, []int64) { @@ -496,6 +487,9 @@ func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) (map[int64][]in // finishUsage decreases the inUse count of provided segments func (sc *ShardCluster) finishUsage(versionID int64) { + if versionID == -1 { + return + } sc.distribution.FinishUsage(versionID) } @@ -578,7 +572,7 @@ func (sc *ShardCluster) ReleaseSegments(ctx context.Context, req *querypb.Releas // requires sc.mut read lock held releaseFn := func() { // try to release segments from nodes - node, ok := sc.getNodeImpl(req.GetNodeID()) + node, ok := sc.getNode(req.GetNodeID()) if !ok { log.Warn("node not in cluster", zap.Int64("nodeID", req.NodeID)) err = fmt.Errorf("node %d not in cluster ", req.NodeID) @@ -608,9 +602,6 @@ func (sc *ShardCluster) ReleaseSegments(ctx context.Context, req *querypb.Releas // GetStatistics returns the statistics on the shard cluster. func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, withStreaming getStatisticsWithStreaming) ([]*internalpb.GetStatisticsResponse, error) { - if !sc.serviceable() { - return nil, fmt.Errorf("ShardCluster for %s replicaID %d is not available", sc.vchannelName, sc.replicaID) - } if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) { return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels()) } @@ -619,6 +610,10 @@ func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStati segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs()) defer sc.finishUsage(versionID) + if versionID == -1 { + return nil, fmt.Errorf("ShardCluster for %s replicaID %d is not available", sc.vchannelName, sc.replicaID) + } + log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs))) for nodeID, segmentIDs := range segAllocs { log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs)) @@ -698,7 +693,16 @@ func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStati // Search preforms search operation on shard cluster. func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, withStreaming searchWithStreaming) ([]*internalpb.SearchResults, error) { - if !sc.serviceable() { + if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) { + return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels()) + } + + // get node allocation and maintains the inUse reference count + segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs()) + defer sc.finishUsage(versionID) + + // not serviceable + if versionID == -1 { err := WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName) log.Warn("failed to search on shard", zap.Int64("replicaID", sc.replicaID), @@ -707,13 +711,6 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, ) return nil, err } - if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) { - return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels()) - } - - // get node allocation and maintains the inUse reference count - segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs()) - defer sc.finishUsage(versionID) log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)), zap.Int64s("partitionIDs", req.GetReq().GetPartitionIDs())) for nodeID, segmentIDs := range segAllocs { @@ -807,10 +804,6 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, // Query performs query operation on shard cluster. func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, withStreaming queryWithStreaming) ([]*internalpb.RetrieveResults, error) { - if !sc.serviceable() { - return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName) - } - // handles only the dml channel part, segment ids is dispatch by cluster itself if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) { return nil, fmt.Errorf("ShardCluster for %s does not match to request channels :%v", sc.vchannelName, req.GetDmlChannels()) @@ -820,6 +813,11 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs()) defer sc.finishUsage(versionID) + // not serviceable + if versionID == -1 { + return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName) + } + // concurrent visiting nodes var wg sync.WaitGroup reqCtx, cancel := context.WithCancel(ctx) diff --git a/internal/querynode/shard_cluster_test.go b/internal/querynode/shard_cluster_test.go index 514445ba0c..91522d1857 100644 --- a/internal/querynode/shard_cluster_test.go +++ b/internal/querynode/shard_cluster_test.go @@ -2040,7 +2040,7 @@ func TestShardCluster_Version(t *testing.T) { defer sc.Close() _, v := sc.segmentAllocations(nil) - assert.Equal(t, int64(0), v) + assert.Equal(t, int64(-1), v) }) t.Run("normal alloc & finish", func(t *testing.T) { diff --git a/internal/querynode/snapshot.go b/internal/querynode/snapshot.go index d1879bf5e6..d16d77a5b4 100644 --- a/internal/querynode/snapshot.go +++ b/internal/querynode/snapshot.go @@ -73,7 +73,6 @@ func (s *snapshot) Expire(cleanup snapshotCleanup) { // Get returns segment distributions with provided partition ids. func (s *snapshot) Get(partitions ...int64) []SnapshotItem { s.inUse.Inc() - return s.filter(partitions...) } @@ -120,8 +119,8 @@ func (s *snapshot) checkCleared(cleanup snapshotCleanup) { go func() { <-s.last.cleared s.last = nil - cleanup() close(s.cleared) + cleanup() }() }) }