From f9823e039f18fd1baa85cf9ef0eb366c812c865d Mon Sep 17 00:00:00 2001 From: wei liu Date: Thu, 10 Aug 2023 18:43:16 +0800 Subject: [PATCH] fix describe rg with non exist collection (#26227) Signed-off-by: Wei Liu --- internal/proxy/meta_cache.go | 6 +-- internal/proxy/meta_cache_test.go | 6 +-- .../{mock_cache_test.go => mock_cache.go} | 31 +++++++------- internal/proxy/task.go | 40 +++++++++++++------ internal/proxy/task_test.go | 15 +++++++ 5 files changed, 64 insertions(+), 34 deletions(-) rename internal/proxy/{mock_cache_test.go => mock_cache.go} (97%) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 148f60c3cc..c0749d9db0 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -56,7 +56,7 @@ type Cache interface { // GetCollectionID get collection's id by name. GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error) // GetCollectionName get collection's name and database by id - GetCollectionName(ctx context.Context, collectionID int64) (string, error) + GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error) // GetCollectionInfo get collection's information by name or collection id, such as schema, and etc. GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionBasicInfo, error) // GetPartitionID get partition's identifier of specific collection. @@ -277,7 +277,7 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam } // GetCollectionName returns the corresponding collection name for provided collection id -func (m *MetaCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) { +func (m *MetaCache) GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error) { m.mu.RLock() var collInfo *collectionInfo for _, db := range m.collInfo { @@ -294,7 +294,7 @@ func (m *MetaCache) GetCollectionName(ctx context.Context, collectionID int64) ( metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() tr := timerecord.NewTimeRecorder("UpdateCache") m.mu.RUnlock() - coll, err := m.describeCollection(ctx, "", "", collectionID) + coll, err := m.describeCollection(ctx, database, "", collectionID) if err != nil { return "", err } diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index ac658c3d2e..8789242503 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -303,7 +303,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) - collection, err := globalMetaCache.GetCollectionName(ctx, 1) + collection, err := globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1) assert.NoError(t, err) assert.Equal(t, collection, "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 1) @@ -317,7 +317,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { Fields: []*schemapb.FieldSchema{}, Name: "collection1", }) - collection, err = globalMetaCache.GetCollectionName(ctx, 1) + collection, err = globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.NoError(t, err) assert.Equal(t, collection, "collection1") @@ -331,7 +331,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { }) // test to get from cache, this should trigger root request - collection, err = globalMetaCache.GetCollectionName(ctx, 1) + collection, err = globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1) assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) assert.Equal(t, collection, "collection1") diff --git a/internal/proxy/mock_cache_test.go b/internal/proxy/mock_cache.go similarity index 97% rename from internal/proxy/mock_cache_test.go rename to internal/proxy/mock_cache.go index 0e457ed524..04d71f92b8 100644 --- a/internal/proxy/mock_cache_test.go +++ b/internal/proxy/mock_cache.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.23.1. DO NOT EDIT. +// Code generated by mockery v2.21.1. DO NOT EDIT. package proxy @@ -171,23 +171,23 @@ func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Contex return _c } -// GetCollectionName provides a mock function with given fields: ctx, collectionID -func (_m *MockCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) { - ret := _m.Called(ctx, collectionID) +// GetCollectionName provides a mock function with given fields: ctx, database, collectionID +func (_m *MockCache) GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error) { + ret := _m.Called(ctx, database, collectionID) var r0 string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (string, error)); ok { - return rf(ctx, collectionID) + if rf, ok := ret.Get(0).(func(context.Context, string, int64) (string, error)); ok { + return rf(ctx, database, collectionID) } - if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok { - r0 = rf(ctx, collectionID) + if rf, ok := ret.Get(0).(func(context.Context, string, int64) string); ok { + r0 = rf(ctx, database, collectionID) } else { r0 = ret.Get(0).(string) } - if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { - r1 = rf(ctx, collectionID) + if rf, ok := ret.Get(1).(func(context.Context, string, int64) error); ok { + r1 = rf(ctx, database, collectionID) } else { r1 = ret.Error(1) } @@ -202,14 +202,15 @@ type MockCache_GetCollectionName_Call struct { // GetCollectionName is a helper method to define mock.On call // - ctx context.Context +// - database string // - collectionID int64 -func (_e *MockCache_Expecter) GetCollectionName(ctx interface{}, collectionID interface{}) *MockCache_GetCollectionName_Call { - return &MockCache_GetCollectionName_Call{Call: _e.mock.On("GetCollectionName", ctx, collectionID)} +func (_e *MockCache_Expecter) GetCollectionName(ctx interface{}, database interface{}, collectionID interface{}) *MockCache_GetCollectionName_Call { + return &MockCache_GetCollectionName_Call{Call: _e.mock.On("GetCollectionName", ctx, database, collectionID)} } -func (_c *MockCache_GetCollectionName_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_GetCollectionName_Call { +func (_c *MockCache_GetCollectionName_Call) Run(run func(ctx context.Context, database string, collectionID int64)) *MockCache_GetCollectionName_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64)) + run(args[0].(context.Context), args[1].(string), args[2].(int64)) }) return _c } @@ -219,7 +220,7 @@ func (_c *MockCache_GetCollectionName_Call) Return(_a0 string, _a1 error) *MockC return _c } -func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, error)) *MockCache_GetCollectionName_Call { +func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Context, string, int64) (string, error)) *MockCache_GetCollectionName_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 0c9cc5b235..13585a9a11 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -23,7 +23,6 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" - "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -2209,21 +2208,36 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error { return err } - getCollectionNameFunc := func(value int32, key int64) string { - name, err := globalMetaCache.GetCollectionName(ctx, key) - if err != nil { - // unreachable logic path - return "unavailable_collection" + getCollectionName := func(collections map[int64]int32) (map[string]int32, error) { + ret := make(map[string]int32) + for key, value := range collections { + name, err := globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), key) + if err != nil { + log.Warn("failed to get collection name", + zap.Int64("collectionID", key), + zap.Error(err)) + return nil, err + } + ret[name] = value } - return name + return ret, nil } if resp.Status.ErrorCode == commonpb.ErrorCode_Success { rgInfo := resp.GetResourceGroup() - loadReplicas := lo.MapKeys(rgInfo.NumLoadedReplica, getCollectionNameFunc) - outgoingNodes := lo.MapKeys(rgInfo.NumOutgoingNode, getCollectionNameFunc) - incomingNodes := lo.MapKeys(rgInfo.NumIncomingNode, getCollectionNameFunc) + numLoadedReplica, err := getCollectionName(rgInfo.NumLoadedReplica) + if err != nil { + return err + } + numOutgoingNode, err := getCollectionName(rgInfo.NumOutgoingNode) + if err != nil { + return err + } + numIncomingNode, err := getCollectionName(rgInfo.NumIncomingNode) + if err != nil { + return err + } t.result = &milvuspb.DescribeResourceGroupResponse{ Status: resp.Status, @@ -2231,9 +2245,9 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error { Name: rgInfo.GetName(), Capacity: rgInfo.GetCapacity(), NumAvailableNode: rgInfo.NumAvailableNode, - NumLoadedReplica: loadReplicas, - NumOutgoingNode: outgoingNodes, - NumIncomingNode: incomingNodes, + NumLoadedReplica: numLoadedReplica, + NumOutgoingNode: numOutgoingNode, + NumIncomingNode: numIncomingNode, }, } } else { diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index f11196869b..b689eae625 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -3101,6 +3101,21 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) { err := task.Execute(ctx) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.Status.ErrorCode) + + qc.ExpectedCalls = nil + qc.EXPECT().Stop().Return(nil) + qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + ResourceGroup: &querypb.ResourceGroupInfo{ + Name: "rg", + Capacity: 2, + NumAvailableNode: 1, + NumOutgoingNode: map[int64]int32{3: 1}, + NumIncomingNode: map[int64]int32{4: 2}, + }, + }, nil) + err = task.Execute(ctx) + assert.Error(t, err) } func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {