diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 0b43357b0c..c8230f5de3 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -96,9 +96,18 @@ func (info *collectionInfo) isCollectionCached() bool { return info != nil && info.collID != UniqueID(0) && info.schema != nil } +func (info *collectionInfo) deprecateLeaderCache() { + info.leaderMutex.RLock() + defer info.leaderMutex.RUnlock() + if info.shardLeaders != nil { + info.shardLeaders.deprecated.Store(true) + } +} + // shardLeaders wraps shard leader mapping for iteration. type shardLeaders struct { - idx *atomic.Int64 + idx *atomic.Int64 + deprecated *atomic.Bool shardLeaders map[string][]nodeInfo } @@ -652,7 +661,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam shardLeaders = info.shardLeaders info.leaderMutex.RUnlock() - if shardLeaders != nil { + if shardLeaders != nil && !shardLeaders.deprecated.Load() { metrics.ProxyCacheStatsCounter.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.CacheHitLabel).Inc() iterator := shardLeaders.GetReader() return iterator.Shuffle(), nil @@ -707,6 +716,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam oldShards := info.shardLeaders info.shardLeaders = &shardLeaders{ shardLeaders: shards, + deprecated: atomic.NewBool(false), idx: atomic.NewInt64(0), } iterator := info.shardLeaders.GetReader() @@ -744,20 +754,13 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m // ClearShards clear the shard leader cache of a collection func (m *MetaCache) ClearShards(collectionName string) { log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) - m.mu.Lock() + m.mu.RLock() info, ok := m.collInfo[collectionName] - m.mu.Unlock() - var shardLeaders *shardLeaders + m.mu.RUnlock() if ok { - info.leaderMutex.Lock() - m.collInfo[collectionName].shardLeaders = nil - shardLeaders = info.shardLeaders - info.leaderMutex.Unlock() - } - // delete refcnt in shardClientMgr - if ok && shardLeaders != nil { - _ = m.shardMgr.UpdateShardLeaders(shardLeaders.shardLeaders, nil) + info.deprecateLeaderCache() } + } func (m *MetaCache) expireShardLeaderCache(ctx context.Context) { @@ -776,23 +779,13 @@ func (m *MetaCache) expireShardLeaderCache(ctx context.Context) { log.Info("stop periodically update meta cache") return case <-ticker.C: - m.mu.Lock() + m.mu.RLock() log.Info("expire all shard leader cache", zap.Strings("collections", lo.Keys(m.collInfo))) for _, info := range m.collInfo { - info.leaderMutex.Lock() - shardLeaders := info.shardLeaders - info.shardLeaders = nil - info.leaderMutex.Unlock() - if shardLeaders != nil { - err := m.shardMgr.UpdateShardLeaders(shardLeaders.shardLeaders, nil) - if err != nil { - // unreachable logic path - log.Warn("failed to update shard leaders reference", zap.Error(err)) - } - } + info.deprecateLeaderCache() } - m.mu.Unlock() + m.mu.RUnlock() } } }() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 29ca7f25dc..bc1886e17d 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -796,7 +796,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { assert.Eventually(t, func() bool { nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") assert.NoError(t, err) - return assert.Len(t, nodeInfos["channel-1"], 2) + return len(nodeInfos["channel-1"]) == 2 }, 3*time.Second, 1*time.Second) queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -815,7 +815,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { assert.Eventually(t, func() bool { nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") assert.NoError(t, err) - return assert.Len(t, nodeInfos["channel-1"], 3) + return len(nodeInfos["channel-1"]) == 3 }, 3*time.Second, 1*time.Second) queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -839,6 +839,6 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { assert.Eventually(t, func() bool { nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1") assert.NoError(t, err) - return assert.Len(t, nodeInfos["channel-1"], 3) && assert.Len(t, nodeInfos["channel-2"], 3) + return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3 }, 3*time.Second, 1*time.Second) }