diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 7524558565..71b9330b37 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -722,6 +722,7 @@ func (t *showCollectionsTask) PreExecute(ctx context.Context) error { } func (t *showCollectionsTask) Execute(ctx context.Context) error { + ctx = AppendUserInfoForRPC(ctx) respFromRootCoord, err := t.rootCoord.ShowCollections(ctx, t.ShowCollectionsRequest) if err != nil { return err diff --git a/internal/proxy/task_database.go b/internal/proxy/task_database.go index 0708373e7c..0254741f48 100644 --- a/internal/proxy/task_database.go +++ b/internal/proxy/task_database.go @@ -2,18 +2,12 @@ package proxy import ( "context" - "fmt" - "strings" - - "google.golang.org/grpc/metadata" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -211,14 +205,8 @@ func (ldt *listDatabaseTask) PreExecute(ctx context.Context) error { func (ldt *listDatabaseTask) Execute(ctx context.Context) error { var err error - curUser, _ := GetCurUserFromContext(ldt.ctx) - if curUser != "" { - originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeperator, curUser) - authKey := strings.ToLower(util.HeaderAuthorize) - authValue := crypto.Base64Encode(originValue) - ldt.ctx = metadata.AppendToOutgoingContext(ldt.ctx, authKey, authValue) - } - ldt.result, err = ldt.rootCoord.ListDatabases(ldt.ctx, ldt.ListDatabasesRequest) + ctx = AppendUserInfoForRPC(ctx) + ldt.result, err = ldt.rootCoord.ListDatabases(ctx, ldt.ListDatabasesRequest) return err } diff --git a/internal/proxy/task_database_test.go b/internal/proxy/task_database_test.go index 7e8952fba8..0e07128dc3 100644 --- a/internal/proxy/task_database_test.go +++ b/internal/proxy/task_database_test.go @@ -154,7 +154,8 @@ func TestListDatabaseTask(t *testing.T) { assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID()) assert.Equal(t, UniqueID(0), task.ID()) - md, ok := metadata.FromOutgoingContext(task.ctx) + taskCtx := AppendUserInfoForRPC(ctx) + md, ok := metadata.FromOutgoingContext(taskCtx) assert.True(t, ok) authorization, ok := md[strings.ToLower(util.HeaderAuthorize)] assert.True(t, ok) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 586f759c35..3731054f88 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -901,6 +901,17 @@ func NewContextWithMetadata(ctx context.Context, username string, dbName string) return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName) } +func AppendUserInfoForRPC(ctx context.Context) context.Context { + curUser, _ := GetCurUserFromContext(ctx) + if curUser != "" { + originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeperator, curUser) + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + ctx = metadata.AppendToOutgoingContext(ctx, authKey, authValue) + } + return ctx +} + func GetRole(username string) ([]string, error) { if globalMetaCache == nil { return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 66967c8922..7be3db075b 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -2107,3 +2107,15 @@ func TestSendReplicateMessagePack(t *testing.T) { SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleasePartitionsRequest{}) }) } + +func TestAppendUserInfoForRPC(t *testing.T) { + ctx := GetContext(context.Background(), "root:123456") + ctx = AppendUserInfoForRPC(ctx) + + md, ok := metadata.FromOutgoingContext(ctx) + assert.True(t, ok) + authorization, ok := md[strings.ToLower(util.HeaderAuthorize)] + assert.True(t, ok) + expectAuth := crypto.Base64Encode("root:root") + assert.Equal(t, expectAuth, authorization[0]) +} diff --git a/internal/rootcoord/list_db_task.go b/internal/rootcoord/list_db_task.go index 655793c928..1b4e81a795 100644 --- a/internal/rootcoord/list_db_task.go +++ b/internal/rootcoord/list_db_task.go @@ -19,7 +19,10 @@ package rootcoord import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -49,6 +52,9 @@ func (t *listDatabaseTask) Execute(ctx context.Context) error { curUser, err := contextutil.GetCurUserFromContext(ctx) // it will fail if the inner node server use the list database API if err != nil || curUser == util.UserRoot { + if err != nil { + log.Warn("get current user from context failed", zap.Error(err)) + } privilegeDBs.Insert(util.AnyWord) return privilegeDBs, nil } diff --git a/internal/rootcoord/show_collection_task.go b/internal/rootcoord/show_collection_task.go index 563057eada..247a171af3 100644 --- a/internal/rootcoord/show_collection_task.go +++ b/internal/rootcoord/show_collection_task.go @@ -20,9 +20,13 @@ import ( "context" "github.com/samber/lo" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -45,6 +49,79 @@ func (t *showCollectionTask) Prepare(ctx context.Context) error { // Execute task execution func (t *showCollectionTask) Execute(ctx context.Context) error { t.Rsp.Status = merr.Success() + + getVisibleCollections := func() (typeutil.Set[string], error) { + enableAuth := Params.CommonCfg.AuthorizationEnabled.GetAsBool() + privilegeColls := typeutil.NewSet[string]() + if !enableAuth { + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + curUser, err := contextutil.GetCurUserFromContext(ctx) + if err != nil || curUser == util.UserRoot { + if err != nil { + log.Warn("get current user from context failed", zap.Error(err)) + } + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + userRoles, err := t.core.meta.SelectUser("", &milvuspb.UserEntity{ + Name: curUser, + }, true) + if err != nil { + return nil, err + } + if len(userRoles) == 0 { + return privilegeColls, nil + } + for _, role := range userRoles[0].Roles { + if role.GetName() == util.RoleAdmin { + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + entities, err := t.core.meta.SelectGrant("", &milvuspb.GrantEntity{ + Role: role, + DbName: t.Req.GetDbName(), + }) + if err != nil { + return nil, err + } + for _, entity := range entities { + objectType := entity.GetObject().GetName() + if objectType == commonpb.ObjectType_Global.String() && + entity.GetGrantor().GetPrivilege().GetName() == commonpb.ObjectPrivilege_PrivilegeAll.String() { + privilegeColls.Insert(util.AnyWord) + return privilegeColls, nil + } + if objectType != commonpb.ObjectType_Collection.String() { + continue + } + collectionName := entity.GetObjectName() + privilegeColls.Insert(collectionName) + if collectionName == util.AnyWord { + return privilegeColls, nil + } + } + } + return privilegeColls, nil + } + + isVisibleCollectionForCurUser := func(collectionName string, visibleCollections typeutil.Set[string]) bool { + if visibleCollections.Contain(util.AnyWord) { + return true + } + return visibleCollections.Contain(collectionName) + } + + visibleCollections, err := getVisibleCollections() + if err != nil { + t.Rsp.Status = merr.Status(err) + return err + } + if len(visibleCollections) == 0 { + return nil + } + ts := t.Req.GetTimeStamp() if ts == 0 { ts = typeutil.MaxTimestamp @@ -58,6 +135,9 @@ func (t *showCollectionTask) Execute(ctx context.Context) error { if len(t.Req.GetCollectionNames()) > 0 && !lo.Contains(t.Req.GetCollectionNames(), coll.Name) { continue } + if !isVisibleCollectionForCurUser(coll.Name, visibleCollections) { + continue + } t.Rsp.CollectionNames = append(t.Rsp.CollectionNames, coll.Name) t.Rsp.CollectionIds = append(t.Rsp.CollectionIds, coll.CollectionID) diff --git a/internal/rootcoord/show_collection_task_test.go b/internal/rootcoord/show_collection_task_test.go index 3929b86d2b..8d82e4aa20 100644 --- a/internal/rootcoord/show_collection_task_test.go +++ b/internal/rootcoord/show_collection_task_test.go @@ -20,14 +20,21 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" + mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func Test_showCollectionTask_Prepare(t *testing.T) { + paramtable.Init() t.Run("invalid msg type", func(t *testing.T) { task := &showCollectionTask{ Req: &milvuspb.ShowCollectionsRequest{ @@ -54,6 +61,7 @@ func Test_showCollectionTask_Prepare(t *testing.T) { } func Test_showCollectionTask_Execute(t *testing.T) { + paramtable.Init() t.Run("failed to list collections", func(t *testing.T) { core := newTestCore(withInvalidMeta()) task := &showCollectionTask{ @@ -97,3 +105,325 @@ func Test_showCollectionTask_Execute(t *testing.T) { assert.Equal(t, 2, len(task.Rsp.GetCollectionNames())) }) } + +func TestShowCollectionsAuth(t *testing.T) { + paramtable.Init() + + t.Run("no auth", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "false") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "foo", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "foo", task.Rsp.GetCollectionNames()[0]) + }) + + t.Run("empty ctx", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "foo", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + + err := task.Execute(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "foo", task.Rsp.GetCollectionNames()[0]) + }) + + t.Run("fail to select user", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mock error: select user")).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.Error(t, err) + }) + + t.Run("no user", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{}, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, 0, len(task.Rsp.GetCollectionNames())) + }) + + t.Run("admin role", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "admin", + }, + }, + }, + }, nil).Once() + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "foo", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "foo", task.Rsp.GetCollectionNames()[0]) + }) + + t.Run("select grant error", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: select grant")).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.Error(t, err) + }) + + t.Run("global all privilege", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Global.String()}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{Name: commonpb.ObjectPrivilege_PrivilegeAll.String()}, + }, + }, + }, nil).Once() + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "foo", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "foo", task.Rsp.GetCollectionNames()[0]) + }) + + t.Run("all collection", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: util.AnyWord, + }, + }, nil).Once() + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "foo", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "foo", task.Rsp.GetCollectionNames()[0]) + }) + t.Run("normal", func(t *testing.T) { + Params.Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer Params.Reset(Params.CommonCfg.AuthorizationEnabled.Key) + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + + meta.EXPECT().SelectUser(mock.Anything, mock.Anything, mock.Anything). + Return([]*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{ + Name: "foo", + }, + Roles: []*milvuspb.RoleEntity{ + { + Name: "hoooo", + }, + }, + }, + }, nil).Once() + meta.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return([]*milvuspb.GrantEntity{ + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: "a", + }, + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Global.String()}, + }, + { + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: "b", + }, + }, nil).Once() + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + DBID: 1, + CollectionID: 100, + Name: "foo", + CreateTime: tsoutil.GetCurrentTime(), + }, + { + DBID: 1, + CollectionID: 200, + Name: "a", + CreateTime: tsoutil.GetCurrentTime(), + }, + }, nil).Once() + + task := &showCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.ShowCollectionsRequest{DbName: "default"}, + Rsp: &milvuspb.ShowCollectionsResponse{}, + } + ctx := GetContext(context.Background(), "foo:root") + err := task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(task.Rsp.GetCollectionNames())) + assert.Equal(t, "a", task.Rsp.GetCollectionNames()[0]) + }) +}