diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 7d230be4e3..6ed61cd5b9 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -75,7 +76,6 @@ type Cache interface { expireShardLeaderCache(ctx context.Context) RemoveCollection(ctx context.Context, database, collectionName string) RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string - RemovePartition(ctx context.Context, database, collectionName string, partitionName string) // GetCredentialInfo operate credential cache GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) @@ -256,6 +256,7 @@ type MetaCache struct { credMut sync.RWMutex privilegeMut sync.RWMutex shardMgr shardClientMgr + sfGlobal conc.Singleflight[*collectionInfo] } // globalMetaCache is singleton instance of Cache @@ -294,33 +295,121 @@ func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordCl }, nil } -// GetCollectionID returns the corresponding collection id for provided collection name -func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error) { +func (m *MetaCache) getCollection(database, collectionName string, collectionID UniqueID) (*collectionInfo, bool) { m.mu.RLock() + defer m.mu.RUnlock() - var ok bool - var collInfo *collectionInfo - - db, dbOk := m.collInfo[database] - if dbOk && db != nil { - collInfo, ok = db[collectionName] + db, ok := m.collInfo[database] + if !ok { + return nil, false + } + if collectionName == "" { + for _, collection := range db { + if collection.collID == collectionID { + return collection, collection.isCollectionCached() + } + } + } else { + if collection, ok := db[collectionName]; ok { + return collection, collection.isCollectionCached() + } } + return nil, false +} + +func (m *MetaCache) update(ctx context.Context, database, collectionName string, collectionID UniqueID) (*collectionInfo, error) { + if collInfo, ok := m.getCollection(database, collectionName, collectionID); ok { + return collInfo, nil + } + + collection, err := m.describeCollection(ctx, database, collectionName, collectionID) + if err != nil { + return nil, err + } + + partitions, err := m.showPartitions(ctx, database, collectionName, collectionID) + if err != nil { + return nil, err + } + + // check partitionID, createdTimestamp and utcstamp has sam element numbers + if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) { + return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String()) + } + + infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo { + return &partitionInfo{ + name: partitions.PartitionNames[idx], + partitionID: partitions.PartitionIDs[idx], + createdTimestamp: partitions.CreatedTimestamps[idx], + createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx], + } + }) + + collectionName = collection.Schema.GetName() + m.mu.Lock() + defer m.mu.Unlock() + _, dbOk := m.collInfo[database] + if !dbOk { + m.collInfo[database] = make(map[string]*collectionInfo) + } + + _, ok := m.collInfo[database][collectionName] + if !ok { + m.collInfo[database][collectionName] = &collectionInfo{} + } + + collInfo := m.collInfo[database][collectionName] + collInfo.schema = newSchemaInfo(collection.Schema) + collInfo.collID = collection.CollectionID + collInfo.createdTimestamp = collection.CreatedTimestamp + collInfo.createdUtcTimestamp = collection.CreatedUtcTimestamp + collInfo.consistencyLevel = collection.ConsistencyLevel + collInfo.partInfo = parsePartitionsInfo(infos) + log.Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), zap.Int64("collectionID", collInfo.collID)) + + return m.collInfo[database][collectionName], nil +} + +func buildSfKeyByName(database, collectionName string) string { + return database + "-" + collectionName +} + +func buildSfKeyById(database string, collectionID UniqueID) string { + return database + "--" + fmt.Sprint(collectionID) +} + +func (m *MetaCache) UpdateByName(ctx context.Context, database, collectionName string) (*collectionInfo, error) { + collection, err, _ := m.sfGlobal.Do(buildSfKeyByName(database, collectionName), func() (*collectionInfo, error) { + return m.update(ctx, database, collectionName, 0) + }) + return collection, err +} + +func (m *MetaCache) UpdateByID(ctx context.Context, database string, collectionID UniqueID) (*collectionInfo, error) { + collection, err, _ := m.sfGlobal.Do(buildSfKeyById(database, collectionID), func() (*collectionInfo, error) { + return m.update(ctx, database, "", collectionID) + }) + return collection, err +} + +// GetCollectionID returns the corresponding collection id for provided collection name +func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionName string) (UniqueID, error) { method := "GetCollectionID" - if !ok || !collInfo.isCollectionCached() { + m.mu.RLock() + collInfo, ok := m.getCollection(database, collectionName, 0) + if !ok { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() tr := timerecord.NewTimeRecorder("UpdateCache") m.mu.RUnlock() - coll, err := m.describeCollection(ctx, database, collectionName, 0) - if err != nil { - return 0, err - } - m.mu.Lock() - defer m.mu.Unlock() - m.updateCollection(coll, database, collectionName) + collInfo, err := m.UpdateByName(ctx, database, collectionName) + if err != nil { + return UniqueID(0), err + } + metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - collInfo = m.collInfo[database][collectionName] return collInfo.collID, nil } defer m.mu.RUnlock() @@ -331,32 +420,22 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam // GetCollectionName returns the corresponding collection name for provided collection id func (m *MetaCache) GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error) { - m.mu.RLock() - var collInfo *collectionInfo - for _, db := range m.collInfo { - for _, coll := range db { - if coll.collID == collectionID { - collInfo = coll - break - } - } - } - method := "GetCollectionName" - if collInfo == nil || !collInfo.isCollectionCached() { + m.mu.RLock() + collInfo, ok := m.getCollection(database, "", collectionID) + + if !ok { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() tr := timerecord.NewTimeRecorder("UpdateCache") m.mu.RUnlock() - coll, err := m.describeCollection(ctx, database, "", collectionID) + + collInfo, err := m.UpdateByID(ctx, database, collectionID) if err != nil { return "", err } - m.mu.Lock() - defer m.mu.Unlock() - m.updateCollection(coll, coll.GetDbName(), coll.Schema.Name) metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - return coll.Schema.Name, nil + return collInfo.schema.Name, nil } defer m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() @@ -366,29 +445,20 @@ func (m *MetaCache) GetCollectionName(ctx context.Context, database string, coll func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) { m.mu.RLock() - var collInfo *collectionInfo - var ok bool - - db, dbOk := m.collInfo[database] - if dbOk { - collInfo, ok = db[collectionName] - } + collInfo, ok := m.getCollection(database, collectionName, 0) method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID - if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID { + if !ok || collInfo.collID != collectionID { m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() - coll, err := m.describeCollection(ctx, database, "", collectionID) + + collInfo, err := m.UpdateByID(ctx, database, collectionID) if err != nil { return nil, err } - m.mu.Lock() - defer m.mu.Unlock() - m.updateCollection(coll, database, collectionName) - collInfo = m.collInfo[database][collectionName] metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return collInfo.getBasicInfo(), nil } @@ -403,34 +473,20 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll // TODO: may cause data race of this implementation, should be refactored in future. func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) { m.mu.RLock() - var collInfo *collectionInfo - var ok bool - - db, dbOk := m.collInfo[database] - if dbOk { - collInfo, ok = db[collectionName] - } + collInfo, ok := m.getCollection(database, collectionName, collectionID) method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID - if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID { + if !ok || collInfo.collID != collectionID { m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() - var coll *milvuspb.DescribeCollectionResponse - var err error - - // collectionName maybe not trustable, get collection according to id - coll, err = m.describeCollection(ctx, database, "", collectionID) + collInfo, err := m.UpdateByID(ctx, database, collectionID) if err != nil { return nil, err } - m.mu.Lock() - m.updateCollection(coll, database, collectionName) - collInfo = m.collInfo[database][collectionName] - m.mu.Unlock() metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return collInfo, nil } @@ -442,31 +498,18 @@ func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collect func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) { m.mu.RLock() - var collInfo *collectionInfo - var ok bool - - db, dbOk := m.collInfo[database] - if dbOk { - collInfo, ok = db[collectionName] - } + collInfo, ok := m.getCollection(database, collectionName, 0) method := "GetCollectionSchema" - if !ok || !collInfo.isCollectionCached() { - metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() - tr := timerecord.NewTimeRecorder("UpdateCache") + if !ok { m.mu.RUnlock() - coll, err := m.describeCollection(ctx, database, collectionName, 0) + tr := timerecord.NewTimeRecorder("UpdateCache") + metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() + + collInfo, err := m.UpdateByName(ctx, database, collectionName) if err != nil { - log.Warn("Failed to load collection from rootcoord ", - zap.String("collection name ", collectionName), - zap.Error(err)) return nil, err } - m.mu.Lock() - defer m.mu.Unlock() - - m.updateCollection(coll, database, collectionName) - collInfo = m.collInfo[database][collectionName] metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) log.Debug("Reload collection from root coordinator ", zap.String("collectionName", collectionName), @@ -479,23 +522,6 @@ func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectio return collInfo.schema, nil } -func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse, database, collectionName string) { - _, dbOk := m.collInfo[database] - if !dbOk { - m.collInfo[database] = make(map[string]*collectionInfo) - } - - _, ok := m.collInfo[database][collectionName] - if !ok { - m.collInfo[database][collectionName] = &collectionInfo{} - } - m.collInfo[database][collectionName].schema = newSchemaInfo(coll.Schema) - m.collInfo[database][collectionName].collID = coll.CollectionID - m.collInfo[database][collectionName].createdTimestamp = coll.CreatedTimestamp - m.collInfo[database][collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp - m.collInfo[database][collectionName].consistencyLevel = coll.ConsistencyLevel -} - func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error) { partInfo, err := m.GetPartitionInfo(ctx, database, collectionName, partitionName) if err != nil { @@ -505,7 +531,7 @@ func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName } func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error) { - partitions, err := m.getPartitionInfos(ctx, database, collectionName) + partitions, err := m.GetPartitionInfos(ctx, database, collectionName) if err != nil { return nil, err } @@ -514,7 +540,7 @@ func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName } func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) { - partitions, err := m.getPartitionInfos(ctx, database, collectionName) + partitions, err := m.GetPartitionInfos(ctx, database, collectionName) if err != nil { return nil, err } @@ -527,7 +553,7 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa } func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) { - partitions, err := m.getPartitionInfos(ctx, database, collectionName) + partitions, err := m.GetPartitionInfos(ctx, database, collectionName) if err != nil { return nil, err } @@ -539,49 +565,26 @@ func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collection return partitions.indexedPartitionNames, nil } -func (m *MetaCache) getPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) { - _, err := m.GetCollectionID(ctx, database, collectionName) - if err != nil { - return nil, err - } +func (m *MetaCache) GetPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) { m.mu.RLock() - - var collInfo *collectionInfo - var ok bool - db, dbOk := m.collInfo[database] - if dbOk { - collInfo, ok = db[collectionName] - } + method := "GetPartitionInfo" + collInfo, ok := m.getCollection(database, collectionName, 0) if !ok { m.mu.RUnlock() - return nil, fmt.Errorf("can't find collection name %s:%s", database, collectionName) - } - - partitionInfos := collInfo.partInfo - m.mu.RUnlock() - - method := "GetPartitionInfo" - if partitionInfos == nil { tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() - partitions, err := m.showPartitions(ctx, database, collectionName) + + collInfo, err := m.UpdateByName(ctx, database, collectionName) if err != nil { return nil, err } - m.mu.Lock() - defer m.mu.Unlock() - err = m.updatePartitions(partitions, database, collectionName) - if err != nil { - return nil, err - } metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - partitionInfos = m.collInfo[database][collectionName].partInfo - return partitionInfos, nil + return collInfo.partInfo, nil } - metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() - return partitionInfos, nil + defer m.mu.RUnlock() + return collInfo.partInfo, nil } // Get the collection information from rootcoord. @@ -627,21 +630,23 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection return resp, nil } -func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectionName string) (*milvuspb.ShowPartitionsResponse, error) { +func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectionName string, collectionID UniqueID) (*milvuspb.ShowPartitionsResponse, error) { req := &milvuspb.ShowPartitionsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), ), DbName: dbName, CollectionName: collectionName, + CollectionID: collectionID, } partitions, err := m.rootCoord.ShowPartitions(ctx, req) if err != nil { return nil, err } - if partitions.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return nil, fmt.Errorf("%s", partitions.GetStatus().GetReason()) + + if err := merr.Error(partitions.GetStatus()); err != nil { + return nil, err } if len(partitions.PartitionIDs) != len(partitions.PartitionNames) { @@ -690,35 +695,6 @@ func parsePartitionsInfo(infos []*partitionInfo) *partitionInfos { return result } -func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse, database, collectionName string) error { - // check partitionID, createdTimestamp and utcstamp has sam element numbers - if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) { - return merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String()) - } - - _, dbOk := m.collInfo[database] - if !dbOk { - m.collInfo[database] = make(map[string]*collectionInfo) - } - - _, ok := m.collInfo[database][collectionName] - if !ok { - m.collInfo[database][collectionName] = &collectionInfo{} - } - - infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo { - return &partitionInfo{ - name: partitions.PartitionNames[idx], - partitionID: partitions.PartitionIDs[idx], - createdTimestamp: partitions.CreatedTimestamps[idx], - createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx], - } - }) - - m.collInfo[database][collectionName].partInfo = parsePartitionsInfo(infos) - return nil -} - func (m *MetaCache) RemoveCollection(ctx context.Context, database, collectionName string) { m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index f087431841..ce21d4f155 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -69,7 +69,7 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m if m.Error { return nil, errors.New("mocked error") } - if in.CollectionName == "collection1" { + if in.CollectionName == "collection1" || in.CollectionID == 1 { return &milvuspb.ShowPartitionsResponse{ Status: merr.Success(), PartitionIDs: []typeutil.UniqueID{1, 2}, @@ -78,7 +78,7 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m PartitionNames: []string{"par1", "par2"}, }, nil } - if in.CollectionName == "collection2" { + if in.CollectionName == "collection2" || in.CollectionID == 2 { return &milvuspb.ShowPartitionsResponse{ Status: merr.Success(), PartitionIDs: []typeutil.UniqueID{3, 4}, @@ -900,12 +900,6 @@ func TestMetaCache_Database(t *testing.T) { assert.NoError(t, err) assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false) - queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: merr.Success(), - CollectionIDs: []UniqueID{1, 2}, - InMemoryPercentages: []int64{100, 50}, - }, nil) - _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1) assert.NoError(t, err) _, err = GetCachedCollectionSchema(ctx, dbName, "collection1") diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index d0626786c0..d7989fe274 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -57,10 +57,6 @@ func TestGetIndexStateTask_Execute(t *testing.T) { rootCoord := newMockRootCoord() queryCoord := getMockQueryCoord() - queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: merr.Success(), - CollectionIDs: []int64{}, - }, nil) datacoord := NewDataCoordMock() gist := &getIndexStateTask{ @@ -75,7 +71,7 @@ func TestGetIndexStateTask_Execute(t *testing.T) { rootCoord: rootCoord, dataCoord: datacoord, result: &milvuspb.GetIndexStateResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mock"}, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mock-1"}, State: commonpb.IndexState_Unissued, }, collectionID: collectionID, @@ -83,7 +79,8 @@ func TestGetIndexStateTask_Execute(t *testing.T) { shardMgr := newShardClientMgr() // failed to get collection id. - _ = InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) + err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) + assert.NoError(t, err) assert.Error(t, gist.Execute(ctx)) rootCoord.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { @@ -95,6 +92,12 @@ func TestGetIndexStateTask_Execute(t *testing.T) { }, nil } + rootCoord.ShowPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { + return &milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + }, nil + } + datacoord.GetIndexStateFunc = func(ctx context.Context, request *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { return &indexpb.GetIndexStateResponse{ Status: merr.Success(),