diff --git a/internal/rootcoord/show_collection_task.go b/internal/rootcoord/show_collection_task.go index 959703aa20..36dfd16036 100644 --- a/internal/rootcoord/show_collection_task.go +++ b/internal/rootcoord/show_collection_task.go @@ -97,10 +97,15 @@ func (t *showCollectionTask) Execute(ctx context.Context) error { privilegeColls.Insert(util.AnyWord) return privilegeColls, nil } - // should list collection level built-in privilege group objects - if objectType != commonpb.ObjectType_Collection.String() && - !Params.RbacConfig.IsCollectionPrivilegeGroup(priv) { - continue + // should list collection level built-in privilege group or custom privilege group objects + if objectType != commonpb.ObjectType_Collection.String() { + customGroup, err := t.core.meta.IsCustomPrivilegeGroup(ctx, priv) + if err != nil { + return nil, err + } + if !customGroup && !Params.RbacConfig.IsCollectionPrivilegeGroup(priv) { + continue + } } collectionName := entity.GetObjectName() privilegeColls.Insert(collectionName) diff --git a/internal/rootcoord/show_collection_task_test.go b/internal/rootcoord/show_collection_task_test.go index 0695860915..e133c94ecf 100644 --- a/internal/rootcoord/show_collection_task_test.go +++ b/internal/rootcoord/show_collection_task_test.go @@ -425,6 +425,7 @@ func TestShowCollectionsAuth(t *testing.T) { CreateTime: tsoutil.GetCurrentTime(), }, }, nil).Once() + meta.EXPECT().IsCustomPrivilegeGroup(mock.Anything, util.PrivilegeNameForAPI(commonpb.ObjectPrivilege_PrivilegeGroupCollectionReadOnly.String())).Return(false, nil).Once() task := &showCollectionTask{ baseTask: newBaseTask(context.Background(), core), @@ -529,6 +530,7 @@ func TestShowCollectionsAuth(t *testing.T) { CreateTime: tsoutil.GetCurrentTime(), }, }, nil).Once() + meta.EXPECT().IsCustomPrivilegeGroup(mock.Anything, mock.Anything).Return(false, nil).Once() task := &showCollectionTask{ baseTask: newBaseTask(context.Background(), core), @@ -541,4 +543,56 @@ func TestShowCollectionsAuth(t *testing.T) { assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) assert.Equal(t, "a", task.Rsp.GetCollectionNames()[0]) }) + + t.Run("custom privilege group", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: "custom_type"}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{ + Name: "privilege_group", + }, + }, + ObjectName: "test_collection", + }, + }, nil).Once() + meta.EXPECT().IsCustomPrivilegeGroup(mock.Anything, "privilege_group").Return(true, nil).Once() + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "test_collection", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "test_collection", task.Rsp.GetCollectionNames()[0]) + }) }