enhance: refine proxy meta cache partition logic (#29315)

See also #29113

- Unify partition info refresh logic
- Prevent parse partition names for each partition key search request

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2023-12-20 10:02:43 +08:00 committed by GitHub
parent 89b208d27a
commit bcf8f27aa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 232 additions and 117 deletions

View File

@ -20,6 +20,8 @@ import (
"context"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
@ -64,6 +66,8 @@ type Cache interface {
GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error)
// GetPartitionInfo get partition's info.
GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error)
// GetPartitionsIndex returns a partition names in partition key indexed order.
GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
@ -92,13 +96,13 @@ type collectionBasicInfo struct {
createdTimestamp uint64
createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel
partInfo map[string]*partitionInfo
}
type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
// partInfo map[string]*partitionInfo
partInfo *partitionInfos
leaderMutex sync.RWMutex
shardLeaders *shardLeaders
createdTimestamp uint64
@ -106,6 +110,22 @@ type collectionInfo struct {
consistencyLevel commonpb.ConsistencyLevel
}
// partitionInfos contains the cached collection partition informations.
type partitionInfos struct {
partitionInfos []*partitionInfo
name2Info map[string]*partitionInfo // map[int64]*partitionInfo
name2ID map[string]int64 // map[int64]*partitionInfo
indexedPartitionNames []string
}
// partitionInfo single model for partition information.
type partitionInfo struct {
name string
partitionID typeutil.UniqueID
createdTimestamp uint64
createdUtcTimestamp uint64
}
// getBasicInfo get a basic info by deep copy.
func (info *collectionInfo) getBasicInfo() *collectionBasicInfo {
// Do a deep copy for all fields.
@ -114,12 +134,8 @@ func (info *collectionInfo) getBasicInfo() *collectionBasicInfo {
createdTimestamp: info.createdTimestamp,
createdUtcTimestamp: info.createdUtcTimestamp,
consistencyLevel: info.consistencyLevel,
partInfo: make(map[string]*partitionInfo, len(info.partInfo)),
}
for s, info := range info.partInfo {
info2 := *info
basicInfo.partInfo[s] = &info2
}
return basicInfo
}
@ -180,12 +196,6 @@ func (sl *shardLeaders) GetReader() shardLeadersReader {
}
}
type partitionInfo struct {
partitionID typeutil.UniqueID
createdTimestamp uint64
createdUtcTimestamp uint64
}
// make sure MetaCache implements Cache.
var _ Cache = (*MetaCache)(nil)
@ -451,66 +461,41 @@ func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName
}
func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName string) (map[string]typeutil.UniqueID, error) {
_, err := m.GetCollectionID(ctx, database, collectionName)
partitions, err := m.getPartitionInfos(ctx, database, collectionName)
if err != nil {
return nil, err
}
method := "GetPartitions"
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
db, dbOk := m.collInfo[database]
if dbOk {
collInfo, ok = db[collectionName]
}
if !ok {
m.mu.RUnlock()
return nil, fmt.Errorf("can't find collection name %s:%s", database, collectionName)
}
if collInfo.partInfo == nil || len(collInfo.partInfo) == 0 {
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
m.mu.RUnlock()
partitions, err := m.showPartitions(ctx, database, collectionName)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
err = m.updatePartitions(partitions, database, collectionName)
if err != nil {
return nil, err
}
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("proxy", zap.Any("GetPartitions:partitions after update", partitions), zap.String("collectionName", collectionName))
ret := make(map[string]typeutil.UniqueID)
partInfo := m.collInfo[database][collectionName].partInfo
for k, v := range partInfo {
ret[k] = v.partitionID
}
return ret, nil
}
defer m.mu.RUnlock()
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
ret := make(map[string]typeutil.UniqueID)
partInfo := collInfo.partInfo
for k, v := range partInfo {
ret[k] = v.partitionID
}
return ret, nil
return partitions.name2ID, nil
}
func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error) {
partitions, err := m.getPartitionInfos(ctx, database, collectionName)
if err != nil {
return nil, err
}
info, ok := partitions.name2Info[partitionName]
if !ok {
return nil, merr.WrapErrPartitionNotFound(partitionName)
}
return info, nil
}
func (m *MetaCache) GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) {
partitions, err := m.getPartitionInfos(ctx, database, collectionName)
if err != nil {
return nil, err
}
if partitions.indexedPartitionNames == nil {
return nil, merr.WrapErrServiceInternal("partitions not in partition key naming pattern")
}
return partitions.indexedPartitionNames, nil
}
func (m *MetaCache) getPartitionInfos(ctx context.Context, database, collectionName string) (*partitionInfos, error) {
_, err := m.GetCollectionID(ctx, database, collectionName)
if err != nil {
return nil, err
@ -529,12 +514,11 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa
return nil, fmt.Errorf("can't find collection name %s:%s", database, collectionName)
}
var partInfo *partitionInfo
partInfo, ok = collInfo.partInfo[partitionName]
partitionInfos := collInfo.partInfo
m.mu.RUnlock()
method := "GetPartitionInfo"
if !ok {
if partitionInfos == nil {
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
partitions, err := m.showPartitions(ctx, database, collectionName)
@ -549,18 +533,11 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa
return nil, err
}
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("proxy", zap.Any("GetPartitionID:partitions after update", partitions), zap.String("collectionName", collectionName))
partInfo, ok = m.collInfo[database][collectionName].partInfo[partitionName]
if !ok {
return nil, merr.WrapErrPartitionNotFound(partitionName)
}
partitionInfos = m.collInfo[database][collectionName].partInfo
return partitionInfos, nil
}
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
return &partitionInfo{
partitionID: partInfo.partitionID,
createdTimestamp: partInfo.createdTimestamp,
createdUtcTimestamp: partInfo.createdUtcTimestamp,
}, nil
return partitionInfos, nil
}
// Get the collection information from rootcoord.
@ -631,7 +608,50 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio
return partitions, nil
}
// parsePartitionsInfo parse partitionInfo list to partitionInfos struct.
// prepare all name to id & info map
// try parse partition names to partitionKey index.
func parsePartitionsInfo(infos []*partitionInfo) *partitionInfos {
name2ID := lo.SliceToMap(infos, func(info *partitionInfo) (string, int64) {
return info.name, info.partitionID
})
name2Info := lo.SliceToMap(infos, func(info *partitionInfo) (string, *partitionInfo) {
return info.name, info
})
result := &partitionInfos{
partitionInfos: infos,
name2ID: name2ID,
name2Info: name2Info,
}
// Make sure the order of the partition names got every time is the same
partitionNames := make([]string, len(infos))
for _, info := range infos {
partitionName := info.name
splits := strings.Split(partitionName, "_")
if len(splits) < 2 {
log.Info("partition group not in partitionKey pattern", zap.String("partitionName", partitionName))
return result
}
index, err := strconv.ParseInt(splits[len(splits)-1], 10, 64)
if err != nil {
log.Info("partition group not in partitionKey pattern", zap.String("parititonName", partitionName), zap.Error(err))
return result
}
partitionNames[index] = partitionName
}
result.indexedPartitionNames = partitionNames
return result
}
func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse, database, collectionName string) error {
// check partitionID, createdTimestamp and utcstamp has sam element numbers
if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) {
return merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
}
_, dbOk := m.collInfo[database]
if !dbOk {
m.collInfo[database] = make(map[string]*collectionInfo)
@ -639,30 +659,19 @@ func (m *MetaCache) updatePartitions(partitions *milvuspb.ShowPartitionsResponse
_, ok := m.collInfo[database][collectionName]
if !ok {
m.collInfo[database][collectionName] = &collectionInfo{
partInfo: map[string]*partitionInfo{},
}
}
partInfo := m.collInfo[database][collectionName].partInfo
if partInfo == nil {
partInfo = map[string]*partitionInfo{}
m.collInfo[database][collectionName] = &collectionInfo{}
}
// check partitionID, createdTimestamp and utcstamp has sam element numbers
if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) {
return errors.New("partition names and timestamps number is not aligned, response " + partitions.String())
}
for i := 0; i < len(partitions.PartitionIDs); i++ {
if _, ok := partInfo[partitions.PartitionNames[i]]; !ok {
partInfo[partitions.PartitionNames[i]] = &partitionInfo{
partitionID: partitions.PartitionIDs[i],
createdTimestamp: partitions.CreatedTimestamps[i],
createdUtcTimestamp: partitions.CreatedUtcTimestamps[i],
}
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
return &partitionInfo{
name: partitions.PartitionNames[idx],
partitionID: partitions.PartitionIDs[idx],
createdTimestamp: partitions.CreatedTimestamps[idx],
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
}
}
m.collInfo[database][collectionName].partInfo = partInfo
})
m.collInfo[database][collectionName].partInfo = parsePartitionsInfo(infos)
return nil
}
@ -709,7 +718,11 @@ func (m *MetaCache) RemovePartition(ctx context.Context, database, collectionNam
if partInfo == nil {
return
}
delete(partInfo, partitionName)
filteredInfos := lo.Filter(partInfo.partitionInfos, func(info *partitionInfo, idx int) bool {
return info.name != partitionName
})
m.collInfo[database][collectionName].partInfo = parsePartitionsInfo(filteredInfos)
}
// GetCredentialInfo returns the credential related to provided username

View File

@ -260,7 +260,6 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
_ = info.consistencyLevel
_ = info.createdTimestamp
_ = info.createdUtcTimestamp
_ = info.partInfo
}()
go func() {
defer wg.Done()
@ -270,7 +269,6 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
_ = info.consistencyLevel
_ = info.createdTimestamp
_ = info.createdUtcTimestamp
_ = info.partInfo
}()
wg.Wait()
}

View File

@ -504,6 +504,62 @@ func (_c *MockCache_GetPartitions_Call) RunAndReturn(run func(context.Context, s
return _c
}
// GetPartitionsIndex provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetPartitionsIndex(ctx context.Context, database string, collectionName string) ([]string, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]string, error)); ok {
return rf(ctx, database, collectionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) []string); ok {
r0 = rf(ctx, database, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, database, collectionName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetPartitionsIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionsIndex'
type MockCache_GetPartitionsIndex_Call struct {
*mock.Call
}
// GetPartitionsIndex is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetPartitionsIndex(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetPartitionsIndex_Call {
return &MockCache_GetPartitionsIndex_Call{Call: _e.mock.On("GetPartitionsIndex", ctx, database, collectionName)}
}
func (_c *MockCache_GetPartitionsIndex_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetPartitionsIndex_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
})
return _c
}
func (_c *MockCache_GetPartitionsIndex_Call) Return(_a0 []string, _a1 error) *MockCache_GetPartitionsIndex_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetPartitionsIndex_Call) RunAndReturn(run func(context.Context, string, string) ([]string, error)) *MockCache_GetPartitionsIndex_Call {
_c.Call.Return(run)
return _c
}
// GetPrivilegeInfo provides a mock function with given fields: ctx
func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string {
ret := _m.Called(ctx)
@ -650,6 +706,49 @@ func (_c *MockCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *M
return _c
}
// HasDatabase provides a mock function with given fields: ctx, database
func (_m *MockCache) HasDatabase(ctx context.Context, database string) bool {
ret := _m.Called(ctx, database)
var r0 bool
if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok {
r0 = rf(ctx, database)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockCache_HasDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasDatabase'
type MockCache_HasDatabase_Call struct {
*mock.Call
}
// HasDatabase is a helper method to define mock.On call
// - ctx context.Context
// - database string
func (_e *MockCache_Expecter) HasDatabase(ctx interface{}, database interface{}) *MockCache_HasDatabase_Call {
return &MockCache_HasDatabase_Call{Call: _e.mock.On("HasDatabase", ctx, database)}
}
func (_c *MockCache_HasDatabase_Call) Run(run func(ctx context.Context, database string)) *MockCache_HasDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockCache_HasDatabase_Call) Return(_a0 bool) *MockCache_HasDatabase_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCache_HasDatabase_Call) RunAndReturn(run func(context.Context, string) bool) *MockCache_HasDatabase_Call {
_c.Call.Return(run)
return _c
}
// InitPolicyInfo provides a mock function with given fields: info, userRoles
func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) {
_m.Called(info, userRoles)
@ -844,10 +943,6 @@ func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) {
_m.Called(ctx, database)
}
func (_m *MockCache) HasDatabase(ctx context.Context, database string) bool {
return true
}
// MockCache_RemoveDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveDatabase'
type MockCache_RemoveDatabase_Call struct {
*mock.Call

View File

@ -616,6 +616,7 @@ func TestDeleteTask_Execute(t *testing.T) {
partitionMaps["test_0"] = 1
partitionMaps["test_1"] = 2
partitionMaps["test_2"] = 3
indexedPartitions := []string{"test_0", "test_1", "test_2"}
t.Run("complex delete with partitionKey mode success", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
@ -631,6 +632,8 @@ func TestDeleteTask_Execute(t *testing.T) {
partitionMaps, nil)
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(
schema, nil)
mockCache.EXPECT().GetPartitionsIndex(mock.Anything, mock.Anything, mock.Anything).
Return(indexedPartitions, nil)
globalMetaCache = mockCache
defer func() { globalMetaCache = nil }()
@ -729,6 +732,7 @@ func TestDeleteTask_StreamingQueryAndDelteFunc(t *testing.T) {
partitionMaps["test_0"] = 1
partitionMaps["test_1"] = 2
partitionMaps["test_2"] = 3
indexedPartitions := []string{"test_0", "test_1", "test_2"}
t.Run("partitionKey mode parse plan failed", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -785,8 +789,8 @@ func TestDeleteTask_StreamingQueryAndDelteFunc(t *testing.T) {
qn := mocks.NewMockQueryNodeClient(t)
mockCache := NewMockCache(t)
mockCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(
nil, fmt.Errorf("mock error"))
mockCache.EXPECT().GetPartitionsIndex(mock.Anything, mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock error"))
globalMetaCache = mockCache
defer func() { globalMetaCache = nil }()
@ -823,8 +827,8 @@ func TestDeleteTask_StreamingQueryAndDelteFunc(t *testing.T) {
qn := mocks.NewMockQueryNodeClient(t)
mockCache := NewMockCache(t)
mockCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(
partitionMaps, nil).Once()
mockCache.EXPECT().GetPartitionsIndex(mock.Anything, mock.Anything, mock.Anything).
Return(indexedPartitions, nil)
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(
schema, nil)
mockCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(

View File

@ -320,6 +320,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoordClient, dbName stri
if err != nil {
return nil, nil, fmt.Errorf("GetCollectionInfo failed, dbName = %s, collectionName = %s,collectionID = %d, err = %s", dbName, collectionName, collectionID, err)
}
partitionInfos, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, nil, fmt.Errorf("GetPartitions failed, dbName = %s, collectionName = %s,collectionID = %d, err = %s", dbName, collectionName, collectionID, err)
}
// If request to search partitions
if len(searchPartitionIDs) > 0 {
@ -372,11 +376,12 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoordClient, dbName stri
}
}
for _, partInfo := range info.partInfo {
if _, ok := loadedMap[partInfo.partitionID]; !ok {
unloadPartitionIDs = append(unloadPartitionIDs, partInfo.partitionID)
for _, partitionID := range partitionInfos {
if _, ok := loadedMap[partitionID]; !ok {
unloadPartitionIDs = append(unloadPartitionIDs, partitionID)
}
}
return loadedPartitionIDs, unloadPartitionIDs, nil
}

View File

@ -1427,7 +1427,7 @@ func assignChannelsByPK(pks *schemapb.IDs, channelNames []string, insertMsg *msg
}
func assignPartitionKeys(ctx context.Context, dbName string, collName string, keys []*planpb.GenericValue) ([]string, error) {
partitionNames, err := getDefaultPartitionNames(ctx, dbName, collName)
partitionNames, err := globalMetaCache.GetPartitionsIndex(ctx, dbName, collName)
if err != nil {
return nil, err
}