diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 1aa1098f2c..fac890f890 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -20,6 +20,8 @@ import ( "context" "fmt" "math/rand" + "strconv" + "strings" "sync" "time" @@ -64,6 +66,8 @@ type Cache interface { GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error) // GetPartitionInfo get partition's info. GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) + // GetPartitionsIndex returns a partition names in partition key indexed order. + GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) // GetCollectionSchema get collection's schema. GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) @@ -92,13 +96,13 @@ type collectionBasicInfo struct { createdTimestamp uint64 createdUtcTimestamp uint64 consistencyLevel commonpb.ConsistencyLevel - partInfo map[string]*partitionInfo } type collectionInfo struct { - collID typeutil.UniqueID - schema *schemapb.CollectionSchema - partInfo map[string]*partitionInfo + collID typeutil.UniqueID + schema *schemapb.CollectionSchema + // partInfo map[string]*partitionInfo + partInfo *partitionInfos leaderMutex sync.RWMutex shardLeaders *shardLeaders createdTimestamp uint64 @@ -106,6 +110,22 @@ type collectionInfo struct { consistencyLevel commonpb.ConsistencyLevel } +// partitionInfos contains the cached collection partition informations. +type partitionInfos struct { + partitionInfos []*partitionInfo + name2Info map[string]*partitionInfo // map[int64]*partitionInfo + name2ID map[string]int64 // map[int64]*partitionInfo + indexedPartitionNames []string +} + +// partitionInfo single model for partition information. +type partitionInfo struct { + name string + partitionID typeutil.UniqueID + createdTimestamp uint64 + createdUtcTimestamp uint64 +} + // getBasicInfo get a basic info by deep copy. func (info *collectionInfo) getBasicInfo() *collectionBasicInfo { // Do a deep copy for all fields. @@ -114,12 +134,8 @@ func (info *collectionInfo) getBasicInfo() *collectionBasicInfo { createdTimestamp: info.createdTimestamp, createdUtcTimestamp: info.createdUtcTimestamp, consistencyLevel: info.consistencyLevel, - partInfo: make(map[string]*partitionInfo, len(info.partInfo)), - } - for s, info := range info.partInfo { - info2 := *info - basicInfo.partInfo[s] = &info2 } + return basicInfo } @@ -180,12 +196,6 @@ func (sl *shardLeaders) GetReader() shardLeadersReader { } } -type partitionInfo struct { - partitionID typeutil.UniqueID - createdTimestamp uint64 - createdUtcTimestamp uint64 -} - // make sure MetaCache implements Cache. var _ Cache = (*MetaCache)(nil) @@ -451,66 +461,41 @@ func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName } func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error) { - _, err := m.GetCollectionID(ctx, database, collectionName) + partitions, err := m.getPartitionInfos(ctx, database, collectionName) if err != nil { return nil, err } - method := "GetPartitions" - m.mu.RLock() - - var collInfo *collectionInfo - var ok bool - db, dbOk := m.collInfo[database] - if dbOk { - collInfo, ok = db[collectionName] - } - - if !ok { - m.mu.RUnlock() - return nil, fmt.Errorf("can't find collection name %s:%s", database, collectionName) - } - - if collInfo.partInfo == nil || len(collInfo.partInfo) == 0 { - tr := timerecord.NewTimeRecorder("UpdateCache") - metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() - m.mu.RUnlock() - - partitions, err := m.showPartitions(ctx, database, collectionName) - if err != nil { - return nil, err - } - - m.mu.Lock() - defer m.mu.Unlock() - - err = m.updatePartitions(partitions, database, collectionName) - if err != nil { - return nil, err - } - metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - log.Debug("proxy", zap.Any("GetPartitions:partitions after update", partitions), zap.String("collectionName", collectionName)) - ret := make(map[string]typeutil.UniqueID) - partInfo := m.collInfo[database][collectionName].partInfo - for k, v := range partInfo { - ret[k] = v.partitionID - } - return ret, nil - } - - defer m.mu.RUnlock() - metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() - - ret := make(map[string]typeutil.UniqueID) - partInfo := collInfo.partInfo - for k, v := range partInfo { - ret[k] = v.partitionID - } - - return ret, nil + return partitions.name2ID, nil } func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) { + partitions, err := m.getPartitionInfos(ctx, database, collectionName) + if err != nil { + return nil, err + } + + info, ok := partitions.name2Info[partitionName] + if !ok { + return nil, merr.WrapErrPartitionNotFound(partitionName) + } + return info, nil +} + +func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) { + partitions, err := m.getPartitionInfos(ctx, database, collectionName) + if err != nil { + return nil, err + } + + if partitions.indexedPartitionNames == nil { + return nil, merr.WrapErrServiceInternal("partitions not in partition key naming pattern") + } + + return partitions.indexedPartitionNames, nil +} + +func (m *MetaCache) getPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) { _, err := m.GetCollectionID(ctx, database, collectionName) if err != nil { return nil, err @@ -529,12 +514,11 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa return nil, fmt.Errorf("can't find collection name %s:%s", database, collectionName) } - var partInfo *partitionInfo - partInfo, ok = collInfo.partInfo[partitionName] + partitionInfos := collInfo.partInfo m.mu.RUnlock() method := "GetPartitionInfo" - if !ok { + if partitionInfos == nil { tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() partitions, err := m.showPartitions(ctx, database, collectionName) @@ -549,18 +533,11 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa return nil, err } metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - log.Debug("proxy", zap.Any("GetPartitionID:partitions after update", partitions), zap.String("collectionName", collectionName)) - partInfo, ok = m.collInfo[database][collectionName].partInfo[partitionName] - if !ok { - return nil, merr.WrapErrPartitionNotFound(partitionName) - } + partitionInfos = m.collInfo[database][collectionName].partInfo + return partitionInfos, nil } metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() - return &partitionInfo{ - partitionID: partInfo.partitionID, - createdTimestamp: partInfo.createdTimestamp, - createdUtcTimestamp: partInfo.createdUtcTimestamp, - }, nil + return partitionInfos, nil } // Get the collection information from rootcoord. @@ -631,7 +608,50 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio return partitions, nil } +// parsePartitionsInfo parse partitionInfo list to partitionInfos struct. +// prepare all name to id & info map +// try parse partition names to partitionKey index. +func parsePartitionsInfo(infos []*partitionInfo) *partitionInfos { + name2ID := lo.SliceToMap(infos, func(info *partitionInfo) (string, int64) { + return info.name, info.partitionID + }) + name2Info := lo.SliceToMap(infos, func(info *partitionInfo) (string, *partitionInfo) { + return info.name, info + }) + + result := &partitionInfos{ + partitionInfos: infos, + name2ID: name2ID, + name2Info: name2Info, + } + + // Make sure the order of the partition names got every time is the same + partitionNames := make([]string, len(infos)) + for _, info := range infos { + partitionName := info.name + splits := strings.Split(partitionName, "_") + if len(splits) < 2 { + log.Info("partition group not in partitionKey pattern", zap.String("partitionName", partitionName)) + return result + } + index, err := strconv.ParseInt(splits[len(splits)-1], 10, 64) + if err != nil { + log.Info("partition group not in partitionKey pattern", zap.String("parititonName", partitionName), zap.Error(err)) + return result + } + partitionNames[index] = partitionName + } + + result.indexedPartitionNames = partitionNames + return result +} + func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse, database, collectionName string) error { + // check partitionID, createdTimestamp and utcstamp has sam element numbers + if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) { + return merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String()) + } + _, dbOk := m.collInfo[database] if !dbOk { m.collInfo[database] = make(map[string]*collectionInfo) @@ -639,30 +659,19 @@ func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse _, ok := m.collInfo[database][collectionName] if !ok { - m.collInfo[database][collectionName] = &collectionInfo{ - partInfo: map[string]*partitionInfo{}, - } - } - partInfo := m.collInfo[database][collectionName].partInfo - if partInfo == nil { - partInfo = map[string]*partitionInfo{} + m.collInfo[database][collectionName] = &collectionInfo{} } - // check partitionID, createdTimestamp and utcstamp has sam element numbers - if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) { - return errors.New("partition names and timestamps number is not aligned, response " + partitions.String()) - } - - for i := 0; i < len(partitions.PartitionIDs); i++ { - if _, ok := partInfo[partitions.PartitionNames[i]]; !ok { - partInfo[partitions.PartitionNames[i]] = &partitionInfo{ - partitionID: partitions.PartitionIDs[i], - createdTimestamp: partitions.CreatedTimestamps[i], - createdUtcTimestamp: partitions.CreatedUtcTimestamps[i], - } + infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo { + return &partitionInfo{ + name: partitions.PartitionNames[idx], + partitionID: partitions.PartitionIDs[idx], + createdTimestamp: partitions.CreatedTimestamps[idx], + createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx], } - } - m.collInfo[database][collectionName].partInfo = partInfo + }) + + m.collInfo[database][collectionName].partInfo = parsePartitionsInfo(infos) return nil } @@ -709,7 +718,11 @@ func (m *MetaCache) RemovePartition(ctx context.Context, database, collectionNam if partInfo == nil { return } - delete(partInfo, partitionName) + filteredInfos := lo.Filter(partInfo.partitionInfos, func(info *partitionInfo, idx int) bool { + return info.name != partitionName + }) + + m.collInfo[database][collectionName].partInfo = parsePartitionsInfo(filteredInfos) } // GetCredentialInfo returns the credential related to provided username diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 46efae8037..c662935617 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -260,7 +260,6 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { _ = info.consistencyLevel _ = info.createdTimestamp _ = info.createdUtcTimestamp - _ = info.partInfo }() go func() { defer wg.Done() @@ -270,7 +269,6 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { _ = info.consistencyLevel _ = info.createdTimestamp _ = info.createdUtcTimestamp - _ = info.partInfo }() wg.Wait() } diff --git a/internal/proxy/mock_cache.go b/internal/proxy/mock_cache.go index 89cc24b7e3..b2bbaddd01 100644 --- a/internal/proxy/mock_cache.go +++ b/internal/proxy/mock_cache.go @@ -504,6 +504,62 @@ func (_c *MockCache_GetPartitions_Call) RunAndReturn(run func(context.Context, s return _c } +// GetPartitionsIndex provides a mock function with given fields: ctx, database, collectionName +func (_m *MockCache) GetPartitionsIndex(ctx context.Context, database string, collectionName string) ([]string, error) { + ret := _m.Called(ctx, database, collectionName) + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]string, error)); ok { + return rf(ctx, database, collectionName) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) []string); ok { + r0 = rf(ctx, database, collectionName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, database, collectionName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCache_GetPartitionsIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionsIndex' +type MockCache_GetPartitionsIndex_Call struct { + *mock.Call +} + +// GetPartitionsIndex is a helper method to define mock.On call +// - ctx context.Context +// - database string +// - collectionName string +func (_e *MockCache_Expecter) GetPartitionsIndex(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetPartitionsIndex_Call { + return &MockCache_GetPartitionsIndex_Call{Call: _e.mock.On("GetPartitionsIndex", ctx, database, collectionName)} +} + +func (_c *MockCache_GetPartitionsIndex_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetPartitionsIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockCache_GetPartitionsIndex_Call) Return(_a0 []string, _a1 error) *MockCache_GetPartitionsIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCache_GetPartitionsIndex_Call) RunAndReturn(run func(context.Context, string, string) ([]string, error)) *MockCache_GetPartitionsIndex_Call { + _c.Call.Return(run) + return _c +} + // GetPrivilegeInfo provides a mock function with given fields: ctx func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string { ret := _m.Called(ctx) @@ -650,6 +706,49 @@ func (_c *MockCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *M return _c } +// HasDatabase provides a mock function with given fields: ctx, database +func (_m *MockCache) HasDatabase(ctx context.Context, database string) bool { + ret := _m.Called(ctx, database) + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = rf(ctx, database) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockCache_HasDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasDatabase' +type MockCache_HasDatabase_Call struct { + *mock.Call +} + +// HasDatabase is a helper method to define mock.On call +// - ctx context.Context +// - database string +func (_e *MockCache_Expecter) HasDatabase(ctx interface{}, database interface{}) *MockCache_HasDatabase_Call { + return &MockCache_HasDatabase_Call{Call: _e.mock.On("HasDatabase", ctx, database)} +} + +func (_c *MockCache_HasDatabase_Call) Run(run func(ctx context.Context, database string)) *MockCache_HasDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockCache_HasDatabase_Call) Return(_a0 bool) *MockCache_HasDatabase_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCache_HasDatabase_Call) RunAndReturn(run func(context.Context, string) bool) *MockCache_HasDatabase_Call { + _c.Call.Return(run) + return _c +} + // InitPolicyInfo provides a mock function with given fields: info, userRoles func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) { _m.Called(info, userRoles) @@ -844,10 +943,6 @@ func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) { _m.Called(ctx, database) } -func (_m *MockCache) HasDatabase(ctx context.Context, database string) bool { - return true -} - // MockCache_RemoveDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveDatabase' type MockCache_RemoveDatabase_Call struct { *mock.Call diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index bf8438a5d5..eddead02fc 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -616,6 +616,7 @@ func TestDeleteTask_Execute(t *testing.T) { partitionMaps["test_0"] = 1 partitionMaps["test_1"] = 2 partitionMaps["test_2"] = 3 + indexedPartitions := []string{"test_0", "test_1", "test_2"} t.Run("complex delete with partitionKey mode success", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -631,6 +632,8 @@ func TestDeleteTask_Execute(t *testing.T) { partitionMaps, nil) mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return( schema, nil) + mockCache.EXPECT().GetPartitionsIndex(mock.Anything, mock.Anything, mock.Anything). + Return(indexedPartitions, nil) globalMetaCache = mockCache defer func() { globalMetaCache = nil }() @@ -729,6 +732,7 @@ func TestDeleteTask_StreamingQueryAndDelteFunc(t *testing.T) { partitionMaps["test_0"] = 1 partitionMaps["test_1"] = 2 partitionMaps["test_2"] = 3 + indexedPartitions := []string{"test_0", "test_1", "test_2"} t.Run("partitionKey mode parse plan failed", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -785,8 +789,8 @@ func TestDeleteTask_StreamingQueryAndDelteFunc(t *testing.T) { qn := mocks.NewMockQueryNodeClient(t) mockCache := NewMockCache(t) - mockCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return( - nil, fmt.Errorf("mock error")) + mockCache.EXPECT().GetPartitionsIndex(mock.Anything, mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("mock error")) globalMetaCache = mockCache defer func() { globalMetaCache = nil }() @@ -823,8 +827,8 @@ func TestDeleteTask_StreamingQueryAndDelteFunc(t *testing.T) { qn := mocks.NewMockQueryNodeClient(t) mockCache := NewMockCache(t) - mockCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return( - partitionMaps, nil).Once() + mockCache.EXPECT().GetPartitionsIndex(mock.Anything, mock.Anything, mock.Anything). + Return(indexedPartitions, nil) mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return( schema, nil) mockCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return( diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index 2b4104c082..e423068829 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -320,6 +320,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoordClient, dbName stri if err != nil { return nil, nil, fmt.Errorf("GetCollectionInfo failed, dbName = %s, collectionName = %s,collectionID = %d, err = %s", dbName, collectionName, collectionID, err) } + partitionInfos, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName) + if err != nil { + return nil, nil, fmt.Errorf("GetPartitions failed, dbName = %s, collectionName = %s,collectionID = %d, err = %s", dbName, collectionName, collectionID, err) + } // If request to search partitions if len(searchPartitionIDs) > 0 { @@ -372,11 +376,12 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoordClient, dbName stri } } - for _, partInfo := range info.partInfo { - if _, ok := loadedMap[partInfo.partitionID]; !ok { - unloadPartitionIDs = append(unloadPartitionIDs, partInfo.partitionID) + for _, partitionID := range partitionInfos { + if _, ok := loadedMap[partitionID]; !ok { + unloadPartitionIDs = append(unloadPartitionIDs, partitionID) } } + return loadedPartitionIDs, unloadPartitionIDs, nil } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index aea982bda2..8ffef9c16f 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1427,7 +1427,7 @@ func assignChannelsByPK(pks *schemapb.IDs, channelNames []string, insertMsg *msg } func assignPartitionKeys(ctx context.Context, dbName string, collName string, keys []*planpb.GenericValue) ([]string, error) { - partitionNames, err := getDefaultPartitionNames(ctx, dbName, collName) + partitionNames, err := globalMetaCache.GetPartitionsIndex(ctx, dbName, collName) if err != nil { return nil, err }