From 38746dfc2b87b022413bccba90865ab67f00cc6a Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Tue, 23 Jan 2024 23:56:55 +0800 Subject: [PATCH] fix: Remove useless lock which cause porxy meta cache recursive lock (#30203) relate: https://github.com/milvus-io/milvus/issues/30193 --------- Signed-off-by: aoiasd --- internal/proxy/meta_cache.go | 141 ++++++++++++------------------ internal/proxy/meta_cache_test.go | 6 +- 2 files changed, 57 insertions(+), 90 deletions(-) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 6ed61cd5b9..b1cf19e319 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -76,6 +76,7 @@ type Cache interface { expireShardLeaderCache(ctx context.Context) RemoveCollection(ctx context.Context, database, collectionName string) RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string + RemovePartition(ctx context.Context, database, collectionName string, partitionName string) // GetCredentialInfo operate credential cache GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) @@ -102,8 +103,6 @@ type collectionInfo struct { collID typeutil.UniqueID schema *schemaInfo partInfo *partitionInfos - leaderMutex sync.RWMutex - shardLeaders *shardLeaders createdTimestamp uint64 createdUtcTimestamp uint64 consistencyLevel commonpb.ConsistencyLevel @@ -187,14 +186,6 @@ func (info *collectionInfo) isCollectionCached() bool { return info != nil && info.collID != UniqueID(0) && info.schema != nil } -func (info *collectionInfo) deprecateLeaderCache() { - info.leaderMutex.RLock() - defer info.leaderMutex.RUnlock() - if info.shardLeaders != nil { - info.shardLeaders.deprecated.Store(true) - } -} - // shardLeaders wraps shard leader mapping for iteration. type shardLeaders struct { idx *atomic.Int64 @@ -248,13 +239,14 @@ type MetaCache struct { rootCoord types.RootCoordClient queryCoord types.QueryCoordClient - collInfo map[string]map[string]*collectionInfo // database -> collection -> collection_info + collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info + collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load privilegeInfos map[string]struct{} // privileges cache userToRoles map[string]map[string]struct{} // user to role cache mu sync.RWMutex credMut sync.RWMutex - privilegeMut sync.RWMutex + leaderMut sync.RWMutex shardMgr shardClientMgr sfGlobal conc.Singleflight[*collectionInfo] } @@ -288,6 +280,7 @@ func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordCl rootCoord: rootCoord, queryCoord: queryCoord, collInfo: map[string]map[string]*collectionInfo{}, + collLeader: map[string]map[string]*shardLeaders{}, credMap: map[string]*internalpb.CredentialInfo{}, shardMgr: shardMgr, privilegeInfos: map[string]struct{}{}, @@ -318,6 +311,21 @@ func (m *MetaCache) getCollection(database, collectionName string, collectionID return nil, false } +func (m *MetaCache) getCollectionShardLeader(database, collectionName string) (*shardLeaders, bool) { + m.leaderMut.RLock() + defer m.leaderMut.RUnlock() + + db, ok := m.collLeader[database] + if !ok { + return nil, false + } + + if leaders, ok := db[collectionName]; ok { + return leaders, !leaders.deprecated.Load() + } + return nil, false +} + func (m *MetaCache) update(ctx context.Context, database, collectionName string, collectionID UniqueID) (*collectionInfo, error) { if collInfo, ok := m.getCollection(database, collectionName, collectionID); ok { return collInfo, nil @@ -355,20 +363,16 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, m.collInfo[database] = make(map[string]*collectionInfo) } - _, ok := m.collInfo[database][collectionName] - if !ok { - m.collInfo[database][collectionName] = &collectionInfo{} + m.collInfo[database][collectionName] = &collectionInfo{ + collID: collection.CollectionID, + schema: newSchemaInfo(collection.Schema), + partInfo: parsePartitionsInfo(infos), + createdTimestamp: collection.CreatedTimestamp, + createdUtcTimestamp: collection.CreatedUtcTimestamp, + consistencyLevel: collection.ConsistencyLevel, } - collInfo := m.collInfo[database][collectionName] - collInfo.schema = newSchemaInfo(collection.Schema) - collInfo.collID = collection.CollectionID - collInfo.createdTimestamp = collection.CreatedTimestamp - collInfo.createdUtcTimestamp = collection.CreatedUtcTimestamp - collInfo.consistencyLevel = collection.ConsistencyLevel - collInfo.partInfo = parsePartitionsInfo(infos) - log.Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), zap.Int64("collectionID", collInfo.collID)) - + log.Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), zap.Int64("collectionID", collection.CollectionID)) return m.collInfo[database][collectionName], nil } @@ -397,12 +401,10 @@ func (m *MetaCache) UpdateByID(ctx context.Context, database string, collectionI // GetCollectionID returns the corresponding collection id for provided collection name func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionName string) (UniqueID, error) { method := "GetCollectionID" - m.mu.RLock() collInfo, ok := m.getCollection(database, collectionName, 0) if !ok { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() tr := timerecord.NewTimeRecorder("UpdateCache") - m.mu.RUnlock() collInfo, err := m.UpdateByName(ctx, database, collectionName) if err != nil { @@ -412,7 +414,6 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return collInfo.collID, nil } - defer m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() return collInfo.collID, nil @@ -421,13 +422,11 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam // GetCollectionName returns the corresponding collection name for provided collection id func (m *MetaCache) GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error) { method := "GetCollectionName" - m.mu.RLock() collInfo, ok := m.getCollection(database, "", collectionID) if !ok { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() tr := timerecord.NewTimeRecorder("UpdateCache") - m.mu.RUnlock() collInfo, err := m.UpdateByID(ctx, database, collectionID) if err != nil { @@ -437,21 +436,18 @@ func (m *MetaCache) GetCollectionName(ctx context.Context, database string, coll metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return collInfo.schema.Name, nil } - defer m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() return collInfo.schema.Name, nil } func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) { - m.mu.RLock() collInfo, ok := m.getCollection(database, collectionName, 0) method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID if !ok || collInfo.collID != collectionID { - m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() @@ -462,7 +458,6 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return collInfo.getBasicInfo(), nil } - defer m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() return collInfo.getBasicInfo(), nil @@ -472,14 +467,12 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll // If the information is not found, proxy will try to fetch information for other source (RootCoord for now) // TODO: may cause data race of this implementation, should be refactored in future. func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) { - m.mu.RLock() collInfo, ok := m.getCollection(database, collectionName, collectionID) method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID if !ok || collInfo.collID != collectionID { - m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() @@ -491,18 +484,15 @@ func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collect return collInfo, nil } - m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() return collInfo, nil } func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) { - m.mu.RLock() collInfo, ok := m.getCollection(database, collectionName, 0) method := "GetCollectionSchema" if !ok { - m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() @@ -516,7 +506,6 @@ func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectio zap.Int64("time (milliseconds) take ", tr.ElapseSpan().Milliseconds())) return collInfo.schema, nil } - defer m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() return collInfo.schema, nil @@ -566,12 +555,10 @@ func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collection } func (m *MetaCache) GetPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) { - m.mu.RLock() method := "GetPartitionInfo" collInfo, ok := m.getCollection(database, collectionName, 0) if !ok { - m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() @@ -583,7 +570,6 @@ func (m *MetaCache) GetPartitionInfos(ctx context.Context, database, collectionN metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return collInfo.partInfo, nil } - defer m.mu.RUnlock() return collInfo.partInfo, nil } @@ -796,6 +782,7 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) { // GetShards update cache if withCache == false func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) { + method := "GetShards" log := log.Ctx(ctx).With( zap.String("collectionName", collectionName), zap.Int64("collectionID", collectionID)) @@ -805,16 +792,11 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col return nil, err } - method := "GetShards" + cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName) if withCache { - var shardLeaders *shardLeaders - info.leaderMutex.RLock() - shardLeaders = info.shardLeaders - info.leaderMutex.RUnlock() - - if shardLeaders != nil && !shardLeaders.deprecated.Load() { + if ok { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() - iterator := shardLeaders.GetReader() + iterator := cacheShardLeaders.GetReader() return iterator.Shuffle(), nil } @@ -830,36 +812,36 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col } tr := timerecord.NewTimeRecorder("UpdateShardCache") - var resp *querypb.GetShardLeadersResponse - resp, err = m.queryCoord.GetShardLeaders(ctx, req) + resp, err := m.queryCoord.GetShardLeaders(ctx, req) if err != nil { return nil, err } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return nil, merr.Error(resp.GetStatus()) + if err = merr.Error(resp.GetStatus()); err != nil { + return nil, err } shards := parseShardLeaderList2QueryNode(resp.GetShards()) - - info, err = m.getFullCollectionInfo(ctx, database, collectionName, collectionID) - if err != nil { - return nil, err - } - // lock leader - info.leaderMutex.Lock() - oldShards := info.shardLeaders - info.shardLeaders = &shardLeaders{ + newShardLeaders := &shardLeaders{ shardLeaders: shards, deprecated: atomic.NewBool(false), idx: atomic.NewInt64(0), } - iterator := info.shardLeaders.GetReader() - info.leaderMutex.Unlock() + // lock leader + m.leaderMut.Lock() + if _, ok := m.collLeader[database]; !ok { + m.collLeader[database] = make(map[string]*shardLeaders) + } + + m.collLeader[database][collectionName] = newShardLeaders + m.leaderMut.Unlock() + + iterator := newShardLeaders.GetReader() ret := iterator.Shuffle() + oldLeaders := make(map[string][]nodeInfo) - if oldShards != nil { - oldLeaders = oldShards.shardLeaders + if cacheShardLeaders != nil { + oldLeaders = cacheShardLeaders.shardLeaders } // update refcnt in shardClientMgr // and create new client for new leaders @@ -888,19 +870,8 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m // DeprecateShardCache clear the shard leader cache of a collection func (m *MetaCache) DeprecateShardCache(database, collectionName string) { log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) - m.mu.RLock() - var info *collectionInfo - var ok bool - db, dbOk := m.collInfo[database] - if !dbOk { - m.mu.RUnlock() - log.Warn("not found database", zap.String("dbName", database)) - return - } - info, ok = db[collectionName] - m.mu.RUnlock() - if ok { - info.deprecateLeaderCache() + if shards, ok := m.getCollectionShardLeader(database, collectionName); ok { + shards.deprecated.Store(true) } } @@ -916,16 +887,16 @@ func (m *MetaCache) expireShardLeaderCache(ctx context.Context) { log.Info("stop periodically update meta cache") return case <-ticker.C: - m.mu.RLock() - for database, db := range m.collInfo { + m.leaderMut.RLock() + for database, db := range m.collLeader { log.RatedInfo(10, "expire all shard leader cache", zap.String("database", database), zap.Strings("collections", lo.Keys(db))) - for _, info := range db { - info.deprecateLeaderCache() + for _, shards := range db { + shards.deprecated.Store(true) } } - m.mu.RUnlock() + m.leaderMut.RUnlock() } } }() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index ce21d4f155..f427b01546 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -773,11 +773,6 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) assert.NoError(t, err) - queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: merr.Success(), - CollectionIDs: []UniqueID{1}, - InMemoryPercentages: []int64{100}, - }, nil) queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ @@ -788,6 +783,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { }, }, }, nil) + nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1) assert.NoError(t, err) assert.Len(t, nodeInfos["channel-1"], 3)