diff --git a/go.mod b/go.mod index 19cd99ee57..2184dcc24f 100644 --- a/go.mod +++ b/go.mod @@ -263,7 +263,7 @@ replace ( github.com/apache/arrow/go/v12 => github.com/milvus-io/arrow/go/v12 v12.0.1 github.com/apache/pulsar-client-go => github.com/milvus-io/pulsar-client-go v0.12.1 github.com/bketelsen/crypt => github.com/bketelsen/crypt v0.0.4 // Fix security alert for core-os/etcd - github.com/expr-lang/expr => github.com/SimFG/expr v0.0.0-20250415035630-0728e795e4e9 + github.com/expr-lang/expr => github.com/SimFG/expr v0.0.0-20250513112851-9b981e8400b9 github.com/go-kit/kit => github.com/go-kit/kit v0.1.0 github.com/golang-jwt/jwt => github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/greatroar/blobloom => github.com/milvus-io/blobloom v0.0.0-20240603110411-471ae49f3b93 diff --git a/go.sum b/go.sum index ebdca6a5c0..d378f8290b 100644 --- a/go.sum +++ b/go.sum @@ -86,8 +86,8 @@ github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1 github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= -github.com/SimFG/expr v0.0.0-20250415035630-0728e795e4e9 h1:p/1Prokv2YkGbcyLV/gOD28Gr3VgMXIa0c9ulg5KjOY= -github.com/SimFG/expr v0.0.0-20250415035630-0728e795e4e9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/SimFG/expr v0.0.0-20250513112851-9b981e8400b9 h1:eXnmJhsHt8m6NU3IJ19UthXJ8JK6e3tmfN07nym3BXs= +github.com/SimFG/expr v0.0.0-20250513112851-9b981e8400b9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/actgardner/gogen-avro/v10 v10.1.0/go.mod h1:o+ybmVjEa27AAr35FRqU98DJu1fXES56uXniYFv4yDA= github.com/actgardner/gogen-avro/v10 v10.2.1/go.mod h1:QUhjeHPchheYmMDni/Nx7VB0RsT/ee8YIgGY/xpEQgQ= github.com/actgardner/gogen-avro/v9 v9.1.0/go.mod h1:nyTj6wPqDJoxM3qdnjcLv+EnMDSDFqE0qDpva2QRmKc= diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 3872d004aa..a80b9e437d 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -638,7 +638,7 @@ func (t *describeCollectionTask) Execute(ctx context.Context) error { CollectionName: t.GetCollectionName(), DbName: t.GetDbName(), } - + ctx = AppendUserInfoForRPC(ctx) result, err := t.rootCoord.DescribeCollection(ctx, t.DescribeCollectionRequest) if err != nil { return err diff --git a/internal/proxy/task_database.go b/internal/proxy/task_database.go index 6d9a4a19df..8e81d63249 100644 --- a/internal/proxy/task_database.go +++ b/internal/proxy/task_database.go @@ -392,6 +392,7 @@ func (t *describeDatabaseTask) Execute(ctx context.Context) error { Base: t.DescribeDatabaseRequest.GetBase(), DbName: t.DescribeDatabaseRequest.GetDbName(), } + ctx = AppendUserInfoForRPC(ctx) ret, err := t.rootCoord.DescribeDatabase(ctx, req) if err != nil { log.Ctx(ctx).Warn("DescribeDatabase failed", zap.Error(err)) diff --git a/internal/rootcoord/describe_collection_task.go b/internal/rootcoord/describe_collection_task.go index b911d64ba0..d5c6e6d22b 100644 --- a/internal/rootcoord/describe_collection_task.go +++ b/internal/rootcoord/describe_collection_task.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/v2/util/merr" ) // describeCollectionTask describe collection request task @@ -40,11 +41,25 @@ func (t *describeCollectionTask) Prepare(ctx context.Context) error { // Execute task execution func (t *describeCollectionTask) Execute(ctx context.Context) (err error) { + // if collecction name is not empty, check if the collection is visible to the current user coll, err := t.core.describeCollection(ctx, t.Req, t.allowUnavailable) if err != nil { return err } + if t.Req.GetCollectionName() != "" { + visibleCollections, err := t.core.getCurrentUserVisibleCollections(ctx, t.Req.GetDbName()) + if err != nil { + t.Rsp.Status = merr.Status(err) + return err + } + if !isVisibleCollectionForCurUser(coll.Name, visibleCollections) { + err = merr.WrapErrPrivilegeNotPermitted("not allowed to access collection, collection name: %s", t.Req.GetCollectionName()) + t.Rsp.Status = merr.Status(err) + return err + } + } + aliases := t.core.meta.ListAliasesByID(ctx, coll.CollectionID) db, err := t.core.meta.GetDatabaseByID(ctx, coll.DBID, t.GetTs()) if err != nil { diff --git a/internal/rootcoord/describe_collection_task_test.go b/internal/rootcoord/describe_collection_task_test.go index 9787762f74..38b0687543 100644 --- a/internal/rootcoord/describe_collection_task_test.go +++ b/internal/rootcoord/describe_collection_task_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -27,7 +28,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" + "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) func Test_describeCollectionTask_Prepare(t *testing.T) { @@ -130,3 +133,513 @@ func Test_describeCollectionTask_Execute(t *testing.T) { assert.ElementsMatch(t, []string{alias1, alias2}, task.Rsp.GetAliases()) }) } + +func TestDescribeCollectionsAuth(t *testing.T) { + paramtable.Init() + + getTask := func(core *Core) *describeCollectionTask { + return &describeCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.DescribeCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DescribeCollection, + }, + DbName: "default", + CollectionName: "test coll", + }, + Rsp: &milvuspb.DescribeCollectionResponse{}, + } + } + + t.Run("no auth", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "false") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("empty ctx", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("root user", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + + ctx := GetContext(context.Background(), "root:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("root user, should bind role", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + Params.Save(Params.CommonCfg.RootShouldBindRole.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + defer Params.Reset(Params.CommonCfg.RootShouldBindRole.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "public", + }, + }, + }, + }, nil).Once() + + task := getTask(core) + + ctx := GetContext(context.Background(), "root:root") + err := task.Execute(ctx) + assert.Error(t, err) + }) + + t.Run("fail to select user", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mock error: select user")).Once() + + task := getTask(core) + + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.Error(t, err) + }) + + t.Run("no user", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{}, nil).Once() + + task := getTask(core) + + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.Error(t, err) + }) + + t.Run("admin role", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "admin", + }, + }, + }, + }, nil).Once() + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("select grant error", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error: select grant")).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.Error(t, err) + }) + + t.Run("global all privilege", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Global.String()}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{ + Name: util.PrivilegeNameForAPI(commonpb.ObjectPrivilege_PrivilegeAll.String()), + }, + }, + }, + }, nil).Once() + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("collection level privilege group", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Global.String()}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{ + Name: util.PrivilegeNameForAPI(commonpb.ObjectPrivilege_PrivilegeGroupCollectionReadOnly.String()), + }, + }, + ObjectName: util.AnyWord, + }, + }, nil).Once() + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + meta.EXPECT().IsCustomPrivilegeGroup(mock.Anything, util.PrivilegeNameForAPI(commonpb.ObjectPrivilege_PrivilegeGroupCollectionReadOnly.String())).Return(false, nil).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("all collection", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: util.AnyWord, + }, + }, nil).Once() + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + t.Run("normal", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: "test coll", + }, + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Global.String()}, + }, + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: "b", + }, + }, nil).Once() + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + meta.EXPECT().IsCustomPrivilegeGroup(mock.Anything, mock.Anything).Return(false, nil).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) + + t.Run("custom privilege group", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: "custom_type"}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{ + Name: "privilege_group", + }, + }, + ObjectName: "test coll", + }, + }, nil).Once() + meta.EXPECT().IsCustomPrivilegeGroup(mock.Anything, "privilege_group").Return(true, nil).Once() + meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{ + CollectionID: 1, + Name: "test coll", + DBID: 1, + }, nil).Once() + meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{}).Once() + meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + ID: 1, + Name: "test db", + }, nil).Once() + + task := getTask(core) + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(0), task.Rsp.GetStatus().GetCode()) + assert.Equal(t, "test db", task.Rsp.GetDbName()) + assert.Equal(t, int64(1), task.Rsp.GetDbId()) + assert.Equal(t, "test coll", task.Rsp.GetCollectionName()) + assert.Equal(t, int64(1), task.Rsp.GetCollectionID()) + }) +} diff --git a/internal/rootcoord/describe_db_task.go b/internal/rootcoord/describe_db_task.go index 8dff48104c..71cccd2db6 100644 --- a/internal/rootcoord/describe_db_task.go +++ b/internal/rootcoord/describe_db_task.go @@ -38,6 +38,21 @@ func (t *describeDBTask) Prepare(ctx context.Context) error { // Execute task execution func (t *describeDBTask) Execute(ctx context.Context) (err error) { + visibleDatabases, err := t.core.getCurrentUserVisibleDatabases(ctx) + if err != nil { + t.Rsp = &rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Status(err), + } + return err + } + if !isVisibleDatabaseForCurUser(t.Req.GetDbName(), visibleDatabases) { + err = merr.WrapErrPrivilegeNotPermitted("not allowed to access database, db name: %s", t.Req.GetDbName()) + t.Rsp = &rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Status(err), + } + return err + } + db, err := t.core.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp) if err != nil { t.Rsp = &rootcoordpb.DescribeDatabaseResponse{ diff --git a/internal/rootcoord/describe_db_task_test.go b/internal/rootcoord/describe_db_task_test.go index 69ecb6890a..a35482429e 100644 --- a/internal/rootcoord/describe_db_task_test.go +++ b/internal/rootcoord/describe_db_task_test.go @@ -20,14 +20,17 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/util" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) func Test_describeDatabaseTask_Execute(t *testing.T) { @@ -86,3 +89,249 @@ func Test_describeDatabaseTask_Execute(t *testing.T) { assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) }) } + +func Test_describeDBTask_WithAuth(t *testing.T) { + paramtable.Init() + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + + core := newTestCore(withMeta(meta)) + getTask := func() *describeDBTask { + return &describeDBTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.DescribeDatabaseRequest{DbName: "db1"}, + } + } + + { + // inner node + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(&model.Database{ + Name: "db1", + ID: 100, + CreatedTime: 1, + }, nil).Once() + + task := getTask() + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success)) + assert.Equal(t, "db1", task.Rsp.GetDbName()) + assert.Equal(t, int64(100), task.Rsp.GetDbID()) + assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) + } + + { + // proxy node with root user + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(&model.Database{ + Name: "db1", + ID: 100, + CreatedTime: 1, + }, nil).Once() + + ctx := GetContext(context.Background(), "root:root") + task := getTask() + err := task.Execute(ctx) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success)) + assert.Equal(t, "db1", task.Rsp.GetDbName()) + assert.Equal(t, int64(100), task.Rsp.GetDbID()) + assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) + } + + { + // proxy node with root user, root user should bind role + Params.Save(Params.CommonCfg.RootShouldBindRole.Key, "true") + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "root", + }, + Roles: []*milvuspb.RoleEntity{}, + }, + }, nil).Once() + + ctx := GetContext(context.Background(), "root:root") + task := getTask() + err := task.Execute(ctx) + assert.Error(t, err) + Params.Reset(Params.CommonCfg.RootShouldBindRole.Key) + } + + { + // select role fail + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mock select user error")).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.Error(t, err) + } + + { + // select role, empty result + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{}, nil).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.Error(t, err) + } + + { + // select role, the user is added to admin role + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "admin", + }, + }, + }, + }, nil).Once() + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(&model.Database{ + Name: "db1", + ID: 100, + CreatedTime: 1, + }, nil).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success)) + assert.Equal(t, "db1", task.Rsp.GetDbName()) + assert.Equal(t, int64(100), task.Rsp.GetDbID()) + assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) + } + + { + // select grant fail + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mock select grant error")).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.Error(t, err) + } + + { + // normal user + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.GrantEntity{ + { + DbName: "db1", + }, + }, nil).Once() + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(&model.Database{ + Name: "db1", + ID: 100, + CreatedTime: 1, + }, nil).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success)) + assert.Equal(t, "db1", task.Rsp.GetDbName()) + assert.Equal(t, int64(100), task.Rsp.GetDbID()) + assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) + } + + { + // normal user and public role + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "public", + }, + }, + }, + }, nil).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.Error(t, err) + } + + { + // normal user with any db privilege + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.GrantEntity{ + { + DbName: "*", + }, + }, nil).Once() + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). + Return(&model.Database{ + Name: "db1", + ID: 100, + CreatedTime: 1, + }, nil).Once() + ctx := GetContext(context.Background(), "foo:root") + task := getTask() + err := task.Execute(ctx) + assert.NoError(t, err) + assert.NotNil(t, task.Rsp) + assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success)) + assert.Equal(t, "db1", task.Rsp.GetDbName()) + assert.Equal(t, int64(100), task.Rsp.GetDbID()) + assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp()) + } +} diff --git a/internal/rootcoord/list_db_task.go b/internal/rootcoord/list_db_task.go index d74c80227a..a598048fbf 100644 --- a/internal/rootcoord/list_db_task.go +++ b/internal/rootcoord/list_db_task.go @@ -19,14 +19,8 @@ package rootcoord import ( "context" - "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/util" - "github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" - "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) type listDatabaseTask struct { @@ -42,64 +36,7 @@ func (t *listDatabaseTask) Prepare(ctx context.Context) error { func (t *listDatabaseTask) Execute(ctx context.Context) error { t.Resp.Status = merr.Success() - getVisibleDBs := func() (typeutil.Set[string], error) { - enableAuth := Params.CommonCfg.AuthorizationEnabled.GetAsBool() - privilegeDBs := typeutil.NewSet[string]() - if !enableAuth { - privilegeDBs.Insert(util.AnyWord) - return privilegeDBs, nil - } - curUser, err := contextutil.GetCurUserFromContext(ctx) - // it will fail if the inner node server use the list database API - if err != nil || (curUser == util.UserRoot && !Params.CommonCfg.RootShouldBindRole.GetAsBool()) { - if err != nil { - log.Ctx(ctx).Warn("get current user from context failed", zap.Error(err)) - } - privilegeDBs.Insert(util.AnyWord) - return privilegeDBs, nil - } - userRoles, err := t.core.meta.SelectUser(ctx, "", &milvuspb.UserEntity{ - Name: curUser, - }, true) - if err != nil { - return nil, err - } - if len(userRoles) == 0 { - return privilegeDBs, nil - } - for _, role := range userRoles[0].Roles { - if role.GetName() == util.RoleAdmin { - privilegeDBs.Insert(util.AnyWord) - return privilegeDBs, nil - } - if role.GetName() == util.RolePublic { - continue - } - entities, err := t.core.meta.SelectGrant(ctx, "", &milvuspb.GrantEntity{ - Role: role, - DbName: util.AnyWord, - }) - if err != nil { - return nil, err - } - for _, entity := range entities { - privilegeDBs.Insert(entity.GetDbName()) - if entity.GetDbName() == util.AnyWord { - return privilegeDBs, nil - } - } - } - return privilegeDBs, nil - } - - isVisibleDBForCurUser := func(dbName string, visibleDBs typeutil.Set[string]) bool { - if visibleDBs.Contain(util.AnyWord) { - return true - } - return visibleDBs.Contain(dbName) - } - - visibleDBs, err := getVisibleDBs() + visibleDBs, err := t.core.getCurrentUserVisibleDatabases(ctx) if err != nil { t.Resp.Status = merr.Status(err) return err @@ -118,7 +55,7 @@ func (t *listDatabaseTask) Execute(ctx context.Context) error { dbIDs := make([]int64, 0, len(ret)) createdTimes := make([]uint64, 0, len(ret)) for _, db := range ret { - if !isVisibleDBForCurUser(db.Name, visibleDBs) { + if !isVisibleDatabaseForCurUser(db.Name, visibleDBs) { continue } dbNames = append(dbNames, db.Name) diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 4ecb844882..e0b894614d 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -64,6 +64,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/expr" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" @@ -3261,6 +3262,137 @@ func (c *Core) getDefaultAndCustomPrivilegeGroups(ctx context.Context) ([]*milvu return allGroups, nil } +func (c *Core) getCurrentUserVisibleDatabases(ctx context.Context) (typeutil.Set[string], error) { + enableAuth := Params.CommonCfg.AuthorizationEnabled.GetAsBool() + privilegeDatabases := typeutil.NewSet[string]() + if !enableAuth { + privilegeDatabases.Insert(util.AnyWord) + return privilegeDatabases, nil + } + curUser, err := contextutil.GetCurUserFromContext(ctx) + // it will fail if the inner node server use the list database API + if err != nil || (curUser == util.UserRoot && !Params.CommonCfg.RootShouldBindRole.GetAsBool()) { + if err != nil { + log.Ctx(ctx).Warn("get current user from context failed", zap.Error(err)) + } + privilegeDatabases.Insert(util.AnyWord) + return privilegeDatabases, nil + } + userRoles, err := c.meta.SelectUser(ctx, "", &milvuspb.UserEntity{ + Name: curUser, + }, true) + if err != nil { + return nil, err + } + if len(userRoles) == 0 { + return privilegeDatabases, nil + } + for _, role := range userRoles[0].Roles { + if role.GetName() == util.RoleAdmin { + privilegeDatabases.Insert(util.AnyWord) + return privilegeDatabases, nil + } + if role.GetName() == util.RolePublic { + continue + } + entities, err := c.meta.SelectGrant(ctx, "", &milvuspb.GrantEntity{ + Role: role, + DbName: util.AnyWord, + }) + if err != nil { + return nil, err + } + for _, entity := range entities { + privilegeDatabases.Insert(entity.GetDbName()) + if entity.GetDbName() == util.AnyWord { + return privilegeDatabases, nil + } + } + } + return privilegeDatabases, nil +} + +func isVisibleDatabaseForCurUser(currentDatabase string, visibleDatabases typeutil.Set[string]) bool { + if visibleDatabases.Contain(util.AnyWord) { + return true + } + return visibleDatabases.Contain(currentDatabase) +} + +func (c *Core) getCurrentUserVisibleCollections(ctx context.Context, databaseName string) (typeutil.Set[string], error) { + enableAuth := Params.CommonCfg.AuthorizationEnabled.GetAsBool() + privilegeColls := typeutil.NewSet[string]() + if !enableAuth { + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + curUser, err := contextutil.GetCurUserFromContext(ctx) + if err != nil || (curUser == util.UserRoot && !Params.CommonCfg.RootShouldBindRole.GetAsBool()) { + if err != nil { + log.Ctx(ctx).Warn("get current user from context failed", zap.Error(err)) + } + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + userRoles, err := c.meta.SelectUser(ctx, "", &milvuspb.UserEntity{ + Name: curUser, + }, true) + if err != nil { + return nil, err + } + if len(userRoles) == 0 { + return privilegeColls, nil + } + for _, role := range userRoles[0].Roles { + if role.GetName() == util.RoleAdmin { + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + if role.GetName() == util.RolePublic { + continue + } + entities, err := c.meta.SelectGrant(ctx, "", &milvuspb.GrantEntity{ + Role: role, + DbName: databaseName, + }) + if err != nil { + return nil, err + } + for _, entity := range entities { + objectType := entity.GetObject().GetName() + priv := entity.GetGrantor().GetPrivilege().GetName() + if objectType == commonpb.ObjectType_Global.String() && + priv == util.PrivilegeNameForAPI(commonpb.ObjectPrivilege_PrivilegeAll.String()) { + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + // should list collection level built-in privilege group or custom privilege group objects + if objectType != commonpb.ObjectType_Collection.String() { + customGroup, err := c.meta.IsCustomPrivilegeGroup(ctx, priv) + if err != nil { + return nil, err + } + if !customGroup && !Params.RbacConfig.IsCollectionPrivilegeGroup(priv) { + continue + } + } + collectionName := entity.GetObjectName() + privilegeColls.Insert(collectionName) + if collectionName == util.AnyWord { + return privilegeColls, nil + } + } + } + return privilegeColls, nil +} + +func isVisibleCollectionForCurUser(collectionName string, visibleCollections typeutil.Set[string]) bool { + if visibleCollections.Contain(util.AnyWord) { + return true + } + return visibleCollections.Contain(collectionName) +} + // RegisterStreamingCoordGRPCService registers the grpc service of streaming coordinator. func (s *Core) RegisterStreamingCoordGRPCService(server *grpc.Server) { s.streamingCoord.RegisterGRPCService(server) diff --git a/internal/rootcoord/show_collection_task.go b/internal/rootcoord/show_collection_task.go index 36dfd16036..9a8a19b009 100644 --- a/internal/rootcoord/show_collection_task.go +++ b/internal/rootcoord/show_collection_task.go @@ -20,13 +20,9 @@ import ( "context" "github.com/samber/lo" - "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/util" - "github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" @@ -50,81 +46,7 @@ func (t *showCollectionTask) Prepare(ctx context.Context) error { func (t *showCollectionTask) Execute(ctx context.Context) error { t.Rsp.Status = merr.Success() - getVisibleCollections := func() (typeutil.Set[string], error) { - enableAuth := Params.CommonCfg.AuthorizationEnabled.GetAsBool() - privilegeColls := typeutil.NewSet[string]() - if !enableAuth { - privilegeColls.Insert(util.AnyWord) - return privilegeColls, nil - } - curUser, err := contextutil.GetCurUserFromContext(ctx) - if err != nil || (curUser == util.UserRoot && !Params.CommonCfg.RootShouldBindRole.GetAsBool()) { - if err != nil { - log.Ctx(ctx).Warn("get current user from context failed", zap.Error(err)) - } - privilegeColls.Insert(util.AnyWord) - return privilegeColls, nil - } - userRoles, err := t.core.meta.SelectUser(ctx, "", &milvuspb.UserEntity{ - Name: curUser, - }, true) - if err != nil { - return nil, err - } - if len(userRoles) == 0 { - return privilegeColls, nil - } - for _, role := range userRoles[0].Roles { - if role.GetName() == util.RoleAdmin { - privilegeColls.Insert(util.AnyWord) - return privilegeColls, nil - } - if role.GetName() == util.RolePublic { - continue - } - entities, err := t.core.meta.SelectGrant(ctx, "", &milvuspb.GrantEntity{ - Role: role, - DbName: t.Req.GetDbName(), - }) - if err != nil { - return nil, err - } - for _, entity := range entities { - objectType := entity.GetObject().GetName() - priv := entity.GetGrantor().GetPrivilege().GetName() - if objectType == commonpb.ObjectType_Global.String() && - priv == util.PrivilegeNameForAPI(commonpb.ObjectPrivilege_PrivilegeAll.String()) { - privilegeColls.Insert(util.AnyWord) - return privilegeColls, nil - } - // should list collection level built-in privilege group or custom privilege group objects - if objectType != commonpb.ObjectType_Collection.String() { - customGroup, err := t.core.meta.IsCustomPrivilegeGroup(ctx, priv) - if err != nil { - return nil, err - } - if !customGroup && !Params.RbacConfig.IsCollectionPrivilegeGroup(priv) { - continue - } - } - collectionName := entity.GetObjectName() - privilegeColls.Insert(collectionName) - if collectionName == util.AnyWord { - return privilegeColls, nil - } - } - } - return privilegeColls, nil - } - - isVisibleCollectionForCurUser := func(collectionName string, visibleCollections typeutil.Set[string]) bool { - if visibleCollections.Contain(util.AnyWord) { - return true - } - return visibleCollections.Contain(collectionName) - } - - visibleCollections, err := getVisibleCollections() + visibleCollections, err := t.core.getCurrentUserVisibleCollections(ctx, t.Req.GetDbName()) if err != nil { t.Rsp.Status = merr.Status(err) return err diff --git a/pkg/go.mod b/pkg/go.mod index 9f3b7b3b67..9731a51e83 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -175,7 +175,7 @@ require ( replace ( github.com/apache/pulsar-client-go => github.com/milvus-io/pulsar-client-go v0.12.1 github.com/bketelsen/crypt => github.com/bketelsen/crypt v0.0.4 // Fix security alert for core-os/etcd - github.com/expr-lang/expr => github.com/SimFG/expr v0.0.0-20250415035630-0728e795e4e9 + github.com/expr-lang/expr => github.com/SimFG/expr v0.0.0-20250513112851-9b981e8400b9 github.com/go-kit/kit => github.com/go-kit/kit v0.1.0 github.com/ianlancetaylor/cgosymbolizer => github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 github.com/streamnative/pulsarctl => github.com/xiaofan-luan/pulsarctl v0.5.1 diff --git a/pkg/go.sum b/pkg/go.sum index 3d7d3b57ab..9840dd9c57 100644 --- a/pkg/go.sum +++ b/pkg/go.sum @@ -55,8 +55,8 @@ github.com/DataDog/zstd v1.5.0/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwS github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= -github.com/SimFG/expr v0.0.0-20250415035630-0728e795e4e9 h1:p/1Prokv2YkGbcyLV/gOD28Gr3VgMXIa0c9ulg5KjOY= -github.com/SimFG/expr v0.0.0-20250415035630-0728e795e4e9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/SimFG/expr v0.0.0-20250513112851-9b981e8400b9 h1:eXnmJhsHt8m6NU3IJ19UthXJ8JK6e3tmfN07nym3BXs= +github.com/SimFG/expr v0.0.0-20250513112851-9b981e8400b9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/actgardner/gogen-avro/v10 v10.1.0/go.mod h1:o+ybmVjEa27AAr35FRqU98DJu1fXES56uXniYFv4yDA= github.com/actgardner/gogen-avro/v10 v10.2.1/go.mod h1:QUhjeHPchheYmMDni/Nx7VB0RsT/ee8YIgGY/xpEQgQ= github.com/actgardner/gogen-avro/v9 v9.1.0/go.mod h1:nyTj6wPqDJoxM3qdnjcLv+EnMDSDFqE0qDpva2QRmKc=