From f5f053f1d2f26ef3ac0411ea78e3aa4a9fa956e9 Mon Sep 17 00:00:00 2001 From: congqixia Date: Mon, 13 Oct 2025 11:15:58 +0800 Subject: [PATCH] enhance: Refactor privilege management by extracting privilege cache into separate package (#44762) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Related to #44761 This commit refactors the privilege management system in the proxy component by: 1. **Separation of Concerns**: Extracts privilege-related functionality from MetaCache into a dedicated `internal/proxy/privilege` package, improving code organization and maintainability. 2. **New Package Structure**: Creates `internal/proxy/privilege/` with: - `cache.go`: Core privilege cache implementation (PrivilegeCache) - `result_cache.go`: Privilege enforcement result caching - `model.go`: Casbin model and policy enforcement functions - `meta_cache_adapter.go`: Casbin adapter for MetaCache integration - Corresponding test files and mock implementations 3. **MetaCache Simplification**: Removes privilege and credential management methods from MetaCache interface and implementation: - Removed: GetCredentialInfo, RemoveCredential, UpdateCredential - Removed: GetPrivilegeInfo, GetUserRole, RefreshPolicyInfo, InitPolicyInfo - Deleted: meta_cache_adapter.go, privilege_cache.go and their tests 4. **Updated References**: Updates all callsites to use the new privilegeCache global: - Authentication interceptor now uses privilegeCache for password verification - Credential cache operations (InvalidateCredentialCache, UpdateCredentialCache, UpdateCredential) now use privilegeCache - Policy refresh operations (RefreshPolicyInfoCache) now use privilegeCache - Privilege interceptor uses new privilege.GetEnforcer() and privilege result cache 5. **Improved API**: Renames cache functions for clarity: - GetPrivilegeCache → GetResultCache - SetPrivilegeCache → SetResultCache - CleanPrivilegeCache → CleanResultCache This refactoring makes the codebase more modular, separates privilege management concerns from general metadata caching, and provides a clearer API for privilege enforcement operations. --------- Signed-off-by: Congqi Xia --- internal/proxy/authentication_interceptor.go | 3 +- .../proxy/authentication_interceptor_test.go | 3 +- internal/proxy/impl.go | 18 +- internal/proxy/meta_cache.go | 202 +---------- internal/proxy/meta_cache_test.go | 39 +- internal/proxy/meta_cache_testonly.go | 11 +- internal/proxy/privilege/OWNERS | 7 + internal/proxy/privilege/README.md | 3 + internal/proxy/privilege/cache.go | 267 ++++++++++++++ internal/proxy/privilege/cache_testonly.go | 27 ++ .../{ => privilege}/meta_cache_adapter.go | 6 +- .../meta_cache_adapter_test.go | 16 +- internal/proxy/privilege/mock_cache.go | 340 ++++++++++++++++++ internal/proxy/privilege/model.go | 139 +++++++ .../result_cache.go} | 16 +- .../result_cache_test.go} | 16 +- internal/proxy/privilege_interceptor.go | 129 +------ internal/proxy/privilege_interceptor_test.go | 27 +- internal/proxy/proxy_test.go | 3 +- internal/proxy/util.go | 14 +- internal/proxy/util_test.go | 66 ++-- 21 files changed, 931 insertions(+), 421 deletions(-) create mode 100644 internal/proxy/privilege/OWNERS create mode 100644 internal/proxy/privilege/README.md create mode 100644 internal/proxy/privilege/cache.go create mode 100644 internal/proxy/privilege/cache_testonly.go rename internal/proxy/{ => privilege}/meta_cache_adapter.go (95%) rename internal/proxy/{ => privilege}/meta_cache_adapter_test.go (84%) create mode 100644 internal/proxy/privilege/mock_cache.go create mode 100644 internal/proxy/privilege/model.go rename internal/proxy/{privilege_cache.go => privilege/result_cache.go} (83%) rename internal/proxy/{privilege_cache_test.go => privilege/result_cache_test.go} (80%) diff --git a/internal/proxy/authentication_interceptor.go b/internal/proxy/authentication_interceptor.go index 210a70e3a9..97b767b8fe 100644 --- a/internal/proxy/authentication_interceptor.go +++ b/internal/proxy/authentication_interceptor.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" @@ -111,7 +112,7 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) { } else { // username+password authentication username, password := parseMD(rawToken) - if !passwordVerify(ctx, username, password, globalMetaCache) { + if !passwordVerify(ctx, username, password, privilege.GetPrivilegeCache()) { log.Warn("fail to verify password", zap.String("username", username)) // NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct") diff --git a/internal/proxy/authentication_interceptor_test.go b/internal/proxy/authentication_interceptor_test.go index bd0ba45dc6..c204dde911 100644 --- a/internal/proxy/authentication_interceptor_test.go +++ b/internal/proxy/authentication_interceptor_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/grpc/metadata" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/crypto" @@ -27,7 +28,7 @@ func TestValidAuth(t *testing.T) { if username == "" || password == "" { return false } - return passwordVerify(ctx, username, password, globalMetaCache) + return passwordVerify(ctx, username, password, privilege.GetPrivilegeCache()) } ctx := context.Background() diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 6f60eaee8e..feb62da5fa 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/proxy/connection" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/proxy/replicate" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/analyzer" @@ -4689,8 +4690,9 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy } username := request.Username - if globalMetaCache != nil { - globalMetaCache.RemoveCredential(username) // no need to return error, though credential may be not cached + priCache := privilege.GetPrivilegeCache() + if priCache != nil { + priCache.RemoveCredential(username) // no need to return error, though credential may be not cached } log.Debug("complete to invalidate credential cache") @@ -4715,8 +4717,9 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U Username: request.Username, Sha256Password: request.Password, } - if globalMetaCache != nil { - globalMetaCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached + priCache := privilege.GetPrivilegeCache() + if priCache != nil { + priCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached } log.Debug("complete to update credential cache") @@ -4820,7 +4823,7 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre } } - if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, globalMetaCache) { + if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, privilege.GetPrivilegeCache()) { err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername()) return merr.Status(err), nil } @@ -5347,8 +5350,9 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr return merr.Status(err), nil } - if globalMetaCache != nil { - err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{ + priCache := privilege.GetPrivilegeCache() + if priCache != nil { + err := priCache.RefreshPolicyInfo(typeutil.CacheOp{ OpType: typeutil.CacheOpType(req.OpType), OpKey: req.OpKey, }) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 195bba4ee8..7045b4f06a 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -24,7 +24,6 @@ import ( "strings" "sync" - "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" @@ -32,6 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" @@ -39,11 +39,9 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" - "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/expr" - "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/timerecord" @@ -79,14 +77,14 @@ type Cache interface { RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string // GetCredentialInfo operate credential cache - GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) - RemoveCredential(username string) - UpdateCredential(credInfo *internalpb.CredentialInfo) + // GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) + // RemoveCredential(username string) + // UpdateCredential(credInfo *internalpb.CredentialInfo) - GetPrivilegeInfo(ctx context.Context) []string - GetUserRole(username string) []string - RefreshPolicyInfo(op typeutil.CacheOp) error - InitPolicyInfo(info []string, userRoles []string) + // GetPrivilegeInfo(ctx context.Context) []string + // GetUserRole(username string) []string + // RefreshPolicyInfo(op typeutil.CacheOp) error + // InitPolicyInfo(info []string, userRoles []string) RemoveDatabase(ctx context.Context, database string) HasDatabase(ctx context.Context, database string) bool @@ -382,14 +380,12 @@ func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr } expr.Register("cache", globalMetaCache) - // The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required. - resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{}) - if err = merr.CheckRPCCall(resp, err); err != nil { - log.Error("fail to init meta cache", zap.Error(err)) + err = privilege.InitPrivilegeCache(ctx, mixCoord) + if err != nil { + log.Error("failed to init privilege cache", zap.Error(err)) return err } - globalMetaCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles) - log.Info("success to init meta cache", zap.Strings("policy_infos", resp.PolicyInfos)) + return nil } @@ -902,55 +898,6 @@ func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID Uniq return collNames } -// GetCredentialInfo returns the credential related to provided username -// If the cache missed, proxy will try to fetch from storage -func (m *MetaCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) { - m.credMut.RLock() - var credInfo *internalpb.CredentialInfo - credInfo, ok := m.credMap[username] - m.credMut.RUnlock() - - if !ok { - req := &rootcoordpb.GetCredentialRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_GetCredential), - ), - Username: username, - } - resp, err := m.mixCoord.GetCredential(ctx, req) - if err != nil { - return &internalpb.CredentialInfo{}, err - } - credInfo = &internalpb.CredentialInfo{ - Username: resp.Username, - EncryptedPassword: resp.Password, - } - } - - return credInfo, nil -} - -func (m *MetaCache) RemoveCredential(username string) { - m.credMut.Lock() - defer m.credMut.Unlock() - // delete pair in credMap - delete(m.credMap, username) -} - -func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) { - m.credMut.Lock() - defer m.credMut.Unlock() - username := credInfo.Username - _, ok := m.credMap[username] - if !ok { - m.credMap[username] = &internalpb.CredentialInfo{} - } - - // Do not cache encrypted password content - m.credMap[username].Username = username - m.credMap[username].Sha256Password = credInfo.Sha256Password -} - func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) { method := "GetShard" // check cache first @@ -1127,131 +1074,6 @@ func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) { } } -func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) { - defer func() { - err := getEnforcer().LoadPolicy() - if err != nil { - log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err)) - } - CleanPrivilegeCache() - }() - m.mu.Lock() - defer m.mu.Unlock() - m.unsafeInitPolicyInfo(info, userRoles) -} - -func (m *MetaCache) unsafeInitPolicyInfo(info []string, userRoles []string) { - m.privilegeInfos = util.StringSet(info) - for _, userRole := range userRoles { - user, role, err := funcutil.DecodeUserRoleCache(userRole) - if err != nil { - log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err)) - continue - } - if m.userToRoles[user] == nil { - m.userToRoles[user] = make(map[string]struct{}) - } - m.userToRoles[user][role] = struct{}{} - } -} - -func (m *MetaCache) GetPrivilegeInfo(ctx context.Context) []string { - m.mu.RLock() - defer m.mu.RUnlock() - - return util.StringList(m.privilegeInfos) -} - -func (m *MetaCache) GetUserRole(user string) []string { - m.mu.RLock() - defer m.mu.RUnlock() - - return util.StringList(m.userToRoles[user]) -} - -func (m *MetaCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) { - defer func() { - if err == nil { - le := getEnforcer().LoadPolicy() - if le != nil { - log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le)) - } - CleanPrivilegeCache() - } - }() - if op.OpType != typeutil.CacheRefresh { - m.mu.Lock() - defer m.mu.Unlock() - if op.OpKey == "" { - return errors.New("empty op key") - } - } - - switch op.OpType { - case typeutil.CacheGrantPrivilege: - keys := funcutil.PrivilegesForPolicy(op.OpKey) - for _, key := range keys { - m.privilegeInfos[key] = struct{}{} - } - case typeutil.CacheRevokePrivilege: - keys := funcutil.PrivilegesForPolicy(op.OpKey) - for _, key := range keys { - delete(m.privilegeInfos, key) - } - case typeutil.CacheAddUserToRole: - user, role, err := funcutil.DecodeUserRoleCache(op.OpKey) - if err != nil { - return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey) - } - if m.userToRoles[user] == nil { - m.userToRoles[user] = make(map[string]struct{}) - } - m.userToRoles[user][role] = struct{}{} - case typeutil.CacheRemoveUserFromRole: - user, role, err := funcutil.DecodeUserRoleCache(op.OpKey) - if err != nil { - return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey) - } - if m.userToRoles[user] != nil { - delete(m.userToRoles[user], role) - } - case typeutil.CacheDeleteUser: - delete(m.userToRoles, op.OpKey) - case typeutil.CacheDropRole: - for user := range m.userToRoles { - delete(m.userToRoles[user], op.OpKey) - } - - for policy := range m.privilegeInfos { - if funcutil.PolicyCheckerWithRole(policy, op.OpKey) { - delete(m.privilegeInfos, policy) - } - } - case typeutil.CacheRefresh: - resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{}) - if err != nil { - log.Error("fail to init meta cache", zap.Error(err)) - return err - } - - if !merr.Ok(resp.GetStatus()) { - log.Error("fail to init meta cache", - zap.String("error_code", resp.GetStatus().GetErrorCode().String()), - zap.String("reason", resp.GetStatus().GetReason())) - return merr.Error(resp.Status) - } - - m.mu.Lock() - defer m.mu.Unlock() - m.userToRoles = make(map[string]map[string]struct{}) - m.privilegeInfos = make(map[string]struct{}) - m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles) - default: - return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey) - } - return nil -} - func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) { log.Ctx(ctx).Debug("remove database", zap.String("name", database)) m.mu.Lock() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index cd06fecb9f..66e254c5f8 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" @@ -1510,9 +1511,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) { } err := InitMetaCache(context.Background(), client, mgr) assert.NoError(t, err) - policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background()) + policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background()) assert.Equal(t, 3, len(policyInfos)) - roles := globalMetaCache.GetUserRole("foo") + roles := privilege.GetPrivilegeCache().GetUserRole("foo") assert.Equal(t, 2, len(roles)) }) @@ -1527,29 +1528,29 @@ func TestMetaCache_PolicyInfo(t *testing.T) { err := InitMetaCache(context.Background(), client, mgr) assert.NoError(t, err) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"}) assert.NoError(t, err) - policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background()) + policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background()) assert.Equal(t, 4, len(policyInfos)) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"}) assert.NoError(t, err) - policyInfos = globalMetaCache.GetPrivilegeInfo(context.Background()) + policyInfos = privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background()) assert.Equal(t, 3, len(policyInfos)) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")}) assert.NoError(t, err) - roles := globalMetaCache.GetUserRole("foo") + roles := privilege.GetPrivilegeCache().GetUserRole("foo") assert.Equal(t, 3, len(roles)) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")}) assert.NoError(t, err) - roles = globalMetaCache.GetUserRole("foo") + roles = privilege.GetPrivilegeCache().GetUserRole("foo") assert.Equal(t, 2, len(roles)) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""}) assert.Error(t, err) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"}) assert.Error(t, err) }) @@ -1568,18 +1569,18 @@ func TestMetaCache_PolicyInfo(t *testing.T) { err := InitMetaCache(context.Background(), client, mgr) assert.NoError(t, err) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"}) assert.NoError(t, err) - roles := globalMetaCache.GetUserRole("foo") + roles := privilege.GetPrivilegeCache().GetUserRole("foo") assert.Len(t, roles, 0) - roles = globalMetaCache.GetUserRole("foo2") + roles = privilege.GetPrivilegeCache().GetUserRole("foo2") assert.Len(t, roles, 2) - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"}) assert.NoError(t, err) - roles = globalMetaCache.GetUserRole("foo2") + roles = privilege.GetPrivilegeCache().GetUserRole("foo2") assert.Len(t, roles, 1) assert.Equal(t, "role3", roles[0]) @@ -1590,9 +1591,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) { UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")}, }, nil } - err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh}) + err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh}) assert.NoError(t, err) - roles = globalMetaCache.GetUserRole("foo") + roles = privilege.GetPrivilegeCache().GetUserRole("foo") assert.Len(t, roles, 2) }) } diff --git a/internal/proxy/meta_cache_testonly.go b/internal/proxy/meta_cache_testonly.go index c12a5ae09e..faf1ebb096 100644 --- a/internal/proxy/meta_cache_testonly.go +++ b/internal/proxy/meta_cache_testonly.go @@ -22,24 +22,29 @@ package proxy import ( + "context" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/pkg/v2/common" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func AddRootUserToAdminRole() { - err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")}) + err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")}) if err != nil { panic(err) } } func RemoveRootUserFromAdminRole() { - err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")}) + err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")}) if err != nil { panic(err) } @@ -55,6 +60,8 @@ func InitEmptyGlobalCache() { if err != nil { panic(err) } + mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil) + privilege.InitPrivilegeCache(context.Background(), mixcoord) } func SetGlobalMetaCache(metaCache *MetaCache) { diff --git a/internal/proxy/privilege/OWNERS b/internal/proxy/privilege/OWNERS new file mode 100644 index 0000000000..06572603fd --- /dev/null +++ b/internal/proxy/privilege/OWNERS @@ -0,0 +1,7 @@ +reviewers: + - congqixia + - czs007 + - shaoting-huang + +approvers: + - maintainers diff --git a/internal/proxy/privilege/README.md b/internal/proxy/privilege/README.md new file mode 100644 index 0000000000..7b9266e684 --- /dev/null +++ b/internal/proxy/privilege/README.md @@ -0,0 +1,3 @@ +# Summary + +this package contains privilege related components for proxy. \ No newline at end of file diff --git a/internal/proxy/privilege/cache.go b/internal/proxy/privilege/cache.go new file mode 100644 index 0000000000..66a326425c --- /dev/null +++ b/internal/proxy/privilege/cache.go @@ -0,0 +1,267 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package privilege + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/v2/util" + "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +var cacheInst atomic.Pointer[privilegeCache] + +func GetPrivilegeCache() *privilegeCache { + return cacheInst.Load() +} + +type PrivilegeCache interface { + // GetCredentialInfo operate credential cache + GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) + RemoveCredential(username string) + UpdateCredential(credInfo *internalpb.CredentialInfo) + + GetPrivilegeInfo(ctx context.Context) []string + GetUserRole(username string) []string + RefreshPolicyInfo(op typeutil.CacheOp) error + InitPolicyInfo(info []string, userRoles []string) +} + +var _ PrivilegeCache = (*privilegeCache)(nil) + +type privilegeCache struct { + mixCoord types.MixCoordClient + + mu sync.RWMutex + privilegeInfos map[string]struct{} // privileges cache + userToRoles map[string]map[string]struct{} // user to role cache + + credMut sync.RWMutex + credMap map[string]*internalpb.CredentialInfo +} + +func InitPrivilegeCache(ctx context.Context, mixCoord types.MixCoordClient) error { + privilegeCache := NewPrivilegeCache(mixCoord) + // The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required. + cacheInst.Store(privilegeCache) + resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{}) + if err = merr.CheckRPCCall(resp, err); err != nil { + log.Error("fail to init meta cache", zap.Error(err)) + return err + } + privilegeCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles) + log.Info("success to init privilege cache", zap.Strings("policy_infos", resp.PolicyInfos)) + return nil +} + +func NewPrivilegeCache(mixCoord types.MixCoordClient) *privilegeCache { + return &privilegeCache{ + mixCoord: mixCoord, + privilegeInfos: make(map[string]struct{}), + userToRoles: make(map[string]map[string]struct{}), + + credMap: make(map[string]*internalpb.CredentialInfo), + } +} + +// GetCredentialInfo returns the credential related to provided username +// If the cache missed, proxy will try to fetch from storage +func (m *privilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) { + m.credMut.RLock() + var credInfo *internalpb.CredentialInfo + credInfo, ok := m.credMap[username] + m.credMut.RUnlock() + + if !ok { + req := &rootcoordpb.GetCredentialRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_GetCredential), + ), + Username: username, + } + resp, err := m.mixCoord.GetCredential(ctx, req) + if err != nil { + return &internalpb.CredentialInfo{}, err + } + credInfo = &internalpb.CredentialInfo{ + Username: resp.Username, + EncryptedPassword: resp.Password, + } + } + + return credInfo, nil +} + +func (m *privilegeCache) RemoveCredential(username string) { + m.credMut.Lock() + defer m.credMut.Unlock() + // delete pair in credMap + delete(m.credMap, username) +} + +func (m *privilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) { + m.credMut.Lock() + defer m.credMut.Unlock() + username := credInfo.Username + _, ok := m.credMap[username] + if !ok { + m.credMap[username] = &internalpb.CredentialInfo{} + } + + // Do not cache encrypted password content + m.credMap[username].Username = username + m.credMap[username].Sha256Password = credInfo.Sha256Password +} + +func (m *privilegeCache) InitPolicyInfo(info []string, userRoles []string) { + defer func() { + err := GetEnforcer().LoadPolicy() + if err != nil { + log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err)) + } + CleanPrivilegeCache() + }() + m.mu.Lock() + defer m.mu.Unlock() + m.unsafeInitPolicyInfo(info, userRoles) +} + +func (m *privilegeCache) unsafeInitPolicyInfo(info []string, userRoles []string) { + m.privilegeInfos = util.StringSet(info) + for _, userRole := range userRoles { + user, role, err := funcutil.DecodeUserRoleCache(userRole) + if err != nil { + log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err)) + continue + } + if m.userToRoles[user] == nil { + m.userToRoles[user] = make(map[string]struct{}) + } + m.userToRoles[user][role] = struct{}{} + } +} + +func (m *privilegeCache) GetPrivilegeInfo(ctx context.Context) []string { + m.mu.RLock() + defer m.mu.RUnlock() + + return util.StringList(m.privilegeInfos) +} + +func (m *privilegeCache) GetUserRole(user string) []string { + m.mu.RLock() + defer m.mu.RUnlock() + + return util.StringList(m.userToRoles[user]) +} + +func (m *privilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) { + defer func() { + if err == nil { + le := GetEnforcer().LoadPolicy() + if le != nil { + log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le)) + } + CleanPrivilegeCache() + } + }() + if op.OpType != typeutil.CacheRefresh { + m.mu.Lock() + defer m.mu.Unlock() + if op.OpKey == "" { + return errors.New("empty op key") + } + } + + switch op.OpType { + case typeutil.CacheGrantPrivilege: + keys := funcutil.PrivilegesForPolicy(op.OpKey) + for _, key := range keys { + m.privilegeInfos[key] = struct{}{} + } + case typeutil.CacheRevokePrivilege: + keys := funcutil.PrivilegesForPolicy(op.OpKey) + for _, key := range keys { + delete(m.privilegeInfos, key) + } + case typeutil.CacheAddUserToRole: + user, role, err := funcutil.DecodeUserRoleCache(op.OpKey) + if err != nil { + return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey) + } + if m.userToRoles[user] == nil { + m.userToRoles[user] = make(map[string]struct{}) + } + m.userToRoles[user][role] = struct{}{} + case typeutil.CacheRemoveUserFromRole: + user, role, err := funcutil.DecodeUserRoleCache(op.OpKey) + if err != nil { + return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey) + } + if m.userToRoles[user] != nil { + delete(m.userToRoles[user], role) + } + case typeutil.CacheDeleteUser: + delete(m.userToRoles, op.OpKey) + case typeutil.CacheDropRole: + for user := range m.userToRoles { + delete(m.userToRoles[user], op.OpKey) + } + + for policy := range m.privilegeInfos { + if funcutil.PolicyCheckerWithRole(policy, op.OpKey) { + delete(m.privilegeInfos, policy) + } + } + case typeutil.CacheRefresh: + resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{}) + if err != nil { + log.Error("fail to init meta cache", zap.Error(err)) + return err + } + + if !merr.Ok(resp.GetStatus()) { + log.Error("fail to init meta cache", + zap.String("error_code", resp.GetStatus().GetErrorCode().String()), + zap.String("reason", resp.GetStatus().GetReason())) + return merr.Error(resp.Status) + } + + m.mu.Lock() + defer m.mu.Unlock() + m.userToRoles = make(map[string]map[string]struct{}) + m.privilegeInfos = make(map[string]struct{}) + m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles) + default: + return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey) + } + return nil +} diff --git a/internal/proxy/privilege/cache_testonly.go b/internal/proxy/privilege/cache_testonly.go new file mode 100644 index 0000000000..f1c7b836c6 --- /dev/null +++ b/internal/proxy/privilege/cache_testonly.go @@ -0,0 +1,27 @@ +//go:build test +// +build test + +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package privilege + +// This file contains only functions used in tests to manipulate the privilege cache. + +// ResetPrivilegeCacheForTest resets the privilege cache for testing purposes. +func ResetPrivilegeCacheForTest() { + cacheInst.Store(nil) +} diff --git a/internal/proxy/meta_cache_adapter.go b/internal/proxy/privilege/meta_cache_adapter.go similarity index 95% rename from internal/proxy/meta_cache_adapter.go rename to internal/proxy/privilege/meta_cache_adapter.go index dd1799e42c..3745802d9a 100644 --- a/internal/proxy/meta_cache_adapter.go +++ b/internal/proxy/privilege/meta_cache_adapter.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package privilege import ( "context" @@ -30,10 +30,10 @@ import ( // MetaCacheCasbinAdapter is the implementation of `persist.Adapter` with Cache // Since the usage shall be read-only, it implements only `LoadPolicy` for now. type MetaCacheCasbinAdapter struct { - cacheSource func() Cache + cacheSource func() PrivilegeCache } -func NewMetaCacheCasbinAdapter(cacheSource func() Cache) *MetaCacheCasbinAdapter { +func NewMetaCacheCasbinAdapter(cacheSource func() PrivilegeCache) *MetaCacheCasbinAdapter { return &MetaCacheCasbinAdapter{ cacheSource: cacheSource, } diff --git a/internal/proxy/meta_cache_adapter_test.go b/internal/proxy/privilege/meta_cache_adapter_test.go similarity index 84% rename from internal/proxy/meta_cache_adapter_test.go rename to internal/proxy/privilege/meta_cache_adapter_test.go index 63c48351b3..a11757cce2 100644 --- a/internal/proxy/meta_cache_adapter_test.go +++ b/internal/proxy/privilege/meta_cache_adapter_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package privilege import ( "testing" @@ -26,36 +26,36 @@ import ( type MetaCacheCasbinAdapterSuite struct { suite.Suite - cache *MockCache + cache *MockPrivilegeCache adapter *MetaCacheCasbinAdapter } func (s *MetaCacheCasbinAdapterSuite) SetupTest() { - s.cache = NewMockCache(s.T()) + s.cache = NewMockPrivilegeCache(s.T()) - s.adapter = NewMetaCacheCasbinAdapter(func() Cache { return s.cache }) + s.adapter = NewMetaCacheCasbinAdapter(func() PrivilegeCache { return s.cache }) } func (s *MetaCacheCasbinAdapterSuite) TestLoadPolicy() { s.Run("normal_load", func() { s.cache.EXPECT().GetPrivilegeInfo(mock.Anything).Return([]string{}) - m := getPolicyModel(ModelStr) + m := GetPolicyModel(ModelStr) err := s.adapter.LoadPolicy(m) s.NoError(err) }) s.Run("source_return_nil", func() { - adapter := NewMetaCacheCasbinAdapter(func() Cache { return nil }) + adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return nil }) - m := getPolicyModel(ModelStr) + m := GetPolicyModel(ModelStr) err := adapter.LoadPolicy(m) s.Error(err) }) } func (s *MetaCacheCasbinAdapterSuite) TestSavePolicy() { - m := getPolicyModel(ModelStr) + m := GetPolicyModel(ModelStr) s.Error(s.adapter.SavePolicy(m)) } diff --git a/internal/proxy/privilege/mock_cache.go b/internal/proxy/privilege/mock_cache.go new file mode 100644 index 0000000000..4596a3cda4 --- /dev/null +++ b/internal/proxy/privilege/mock_cache.go @@ -0,0 +1,340 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package privilege + +import ( + context "context" + + internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + mock "github.com/stretchr/testify/mock" + + typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +// MockPrivilegeCache is an autogenerated mock type for the PrivilegeCache type +type MockPrivilegeCache struct { + mock.Mock +} + +type MockPrivilegeCache_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPrivilegeCache) EXPECT() *MockPrivilegeCache_Expecter { + return &MockPrivilegeCache_Expecter{mock: &_m.Mock} +} + +// GetCredentialInfo provides a mock function with given fields: ctx, username +func (_m *MockPrivilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) { + ret := _m.Called(ctx, username) + + if len(ret) == 0 { + panic("no return value specified for GetCredentialInfo") + } + + var r0 *internalpb.CredentialInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok { + return rf(ctx, username) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok { + r0 = rf(ctx, username) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.CredentialInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, username) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPrivilegeCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo' +type MockPrivilegeCache_GetCredentialInfo_Call struct { + *mock.Call +} + +// GetCredentialInfo is a helper method to define mock.On call +// - ctx context.Context +// - username string +func (_e *MockPrivilegeCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockPrivilegeCache_GetCredentialInfo_Call { + return &MockPrivilegeCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)} +} + +func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockPrivilegeCache_GetCredentialInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockPrivilegeCache_GetCredentialInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPrivilegeCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockPrivilegeCache_GetCredentialInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetPrivilegeInfo provides a mock function with given fields: ctx +func (_m *MockPrivilegeCache) GetPrivilegeInfo(ctx context.Context) []string { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetPrivilegeInfo") + } + + var r0 []string + if rf, ok := ret.Get(0).(func(context.Context) []string); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// MockPrivilegeCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo' +type MockPrivilegeCache_GetPrivilegeInfo_Call struct { + *mock.Call +} + +// GetPrivilegeInfo is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockPrivilegeCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockPrivilegeCache_GetPrivilegeInfo_Call { + return &MockPrivilegeCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)} +} + +func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockPrivilegeCache_GetPrivilegeInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockPrivilegeCache_GetPrivilegeInfo_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockPrivilegeCache_GetPrivilegeInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetUserRole provides a mock function with given fields: username +func (_m *MockPrivilegeCache) GetUserRole(username string) []string { + ret := _m.Called(username) + + if len(ret) == 0 { + panic("no return value specified for GetUserRole") + } + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(username) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// MockPrivilegeCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole' +type MockPrivilegeCache_GetUserRole_Call struct { + *mock.Call +} + +// GetUserRole is a helper method to define mock.On call +// - username string +func (_e *MockPrivilegeCache_Expecter) GetUserRole(username interface{}) *MockPrivilegeCache_GetUserRole_Call { + return &MockPrivilegeCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)} +} + +func (_c *MockPrivilegeCache_GetUserRole_Call) Run(run func(username string)) *MockPrivilegeCache_GetUserRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockPrivilegeCache_GetUserRole_Call) Return(_a0 []string) *MockPrivilegeCache_GetUserRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPrivilegeCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockPrivilegeCache_GetUserRole_Call { + _c.Call.Return(run) + return _c +} + +// InitPolicyInfo provides a mock function with given fields: info, userRoles +func (_m *MockPrivilegeCache) InitPolicyInfo(info []string, userRoles []string) { + _m.Called(info, userRoles) +} + +// MockPrivilegeCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo' +type MockPrivilegeCache_InitPolicyInfo_Call struct { + *mock.Call +} + +// InitPolicyInfo is a helper method to define mock.On call +// - info []string +// - userRoles []string +func (_e *MockPrivilegeCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockPrivilegeCache_InitPolicyInfo_Call { + return &MockPrivilegeCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)} +} + +func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockPrivilegeCache_InitPolicyInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]string), args[1].([]string)) + }) + return _c +} + +func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Return() *MockPrivilegeCache_InitPolicyInfo_Call { + _c.Call.Return() + return _c +} + +func (_c *MockPrivilegeCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockPrivilegeCache_InitPolicyInfo_Call { + _c.Run(run) + return _c +} + +// RefreshPolicyInfo provides a mock function with given fields: op +func (_m *MockPrivilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) error { + ret := _m.Called(op) + + if len(ret) == 0 { + panic("no return value specified for RefreshPolicyInfo") + } + + var r0 error + if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok { + r0 = rf(op) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPrivilegeCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo' +type MockPrivilegeCache_RefreshPolicyInfo_Call struct { + *mock.Call +} + +// RefreshPolicyInfo is a helper method to define mock.On call +// - op typeutil.CacheOp +func (_e *MockPrivilegeCache_Expecter) RefreshPolicyInfo(op interface{}) *MockPrivilegeCache_RefreshPolicyInfo_Call { + return &MockPrivilegeCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)} +} + +func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockPrivilegeCache_RefreshPolicyInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(typeutil.CacheOp)) + }) + return _c +} + +func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockPrivilegeCache_RefreshPolicyInfo_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockPrivilegeCache_RefreshPolicyInfo_Call { + _c.Call.Return(run) + return _c +} + +// RemoveCredential provides a mock function with given fields: username +func (_m *MockPrivilegeCache) RemoveCredential(username string) { + _m.Called(username) +} + +// MockPrivilegeCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential' +type MockPrivilegeCache_RemoveCredential_Call struct { + *mock.Call +} + +// RemoveCredential is a helper method to define mock.On call +// - username string +func (_e *MockPrivilegeCache_Expecter) RemoveCredential(username interface{}) *MockPrivilegeCache_RemoveCredential_Call { + return &MockPrivilegeCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)} +} + +func (_c *MockPrivilegeCache_RemoveCredential_Call) Run(run func(username string)) *MockPrivilegeCache_RemoveCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockPrivilegeCache_RemoveCredential_Call) Return() *MockPrivilegeCache_RemoveCredential_Call { + _c.Call.Return() + return _c +} + +func (_c *MockPrivilegeCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockPrivilegeCache_RemoveCredential_Call { + _c.Run(run) + return _c +} + +// UpdateCredential provides a mock function with given fields: credInfo +func (_m *MockPrivilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) { + _m.Called(credInfo) +} + +// MockPrivilegeCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential' +type MockPrivilegeCache_UpdateCredential_Call struct { + *mock.Call +} + +// UpdateCredential is a helper method to define mock.On call +// - credInfo *internalpb.CredentialInfo +func (_e *MockPrivilegeCache_Expecter) UpdateCredential(credInfo interface{}) *MockPrivilegeCache_UpdateCredential_Call { + return &MockPrivilegeCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)} +} + +func (_c *MockPrivilegeCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*internalpb.CredentialInfo)) + }) + return _c +} + +func (_c *MockPrivilegeCache_UpdateCredential_Call) Return() *MockPrivilegeCache_UpdateCredential_Call { + _c.Call.Return() + return _c +} + +func (_c *MockPrivilegeCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call { + _c.Run(run) + return _c +} + +// NewMockPrivilegeCache creates a new instance of MockPrivilegeCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPrivilegeCache(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPrivilegeCache { + mock := &MockPrivilegeCache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proxy/privilege/model.go b/internal/proxy/privilege/model.go new file mode 100644 index 0000000000..3074e0b1d1 --- /dev/null +++ b/internal/proxy/privilege/model.go @@ -0,0 +1,139 @@ +package privilege + +import ( + "log" + "strings" + "sync" + + "github.com/casbin/casbin/v2" + "github.com/casbin/casbin/v2/model" + "github.com/samber/lo" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/v2/util" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +const ( + // sub -> role name, like admin, public + // obj -> contact object with object name, like Global-*, Collection-col1 + // act -> privilege, like CreateCollection, DescribeCollection + ModelStr = ` +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj)) +` +) + +var ( + enforcer *casbin.SyncedEnforcer + initOnce sync.Once + initPrivilegeGroupsOnce sync.Once +) + +func GetPolicyModel(modelString string) model.Model { + m, err := model.NewModelFromString(modelString) + if err != nil { + log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) + } + return m +} + +func InitPrivilegeGroups() { + initPrivilegeGroupsOnce.Do(func() { + roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings() + if len(roGroup) == 0 { + roGroup = util.ReadOnlyPrivilegeGroup + } + roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} }) + + rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings() + if len(rwGroup) == 0 { + rwGroup = util.ReadWritePrivilegeGroup + } + rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} }) + + adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings() + if len(adminGroup) == 0 { + adminGroup = util.AdminPrivilegeGroup + } + adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} }) + }) +} + +func GetEnforcer() *casbin.SyncedEnforcer { + initOnce.Do(func() { + e, err := casbin.NewSyncedEnforcer() + if err != nil { + log.Panic("failed to create casbin enforcer", zap.Error(err)) + } + casbinModel := GetPolicyModel(ModelStr) + adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return GetPrivilegeCache() }) + e.InitWithModelAndAdapter(casbinModel, adapter) + e.AddFunction("dbMatch", DBMatchFunc) + e.AddFunction("privilegeGroupContains", PrivilegeGroupContains) + enforcer = e + }) + return enforcer +} + +var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{} + +func DBMatchFunc(args ...interface{}) (interface{}, error) { + name1 := args[0].(string) + name2 := args[1].(string) + + db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:]) + db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:]) + + return db1 == db2, nil +} + +func PrivilegeGroupContains(args ...interface{}) (interface{}, error) { + requestPrivilege := args[0].(string) + policyPrivilege := args[1].(string) + requestObj := args[2].(string) + policyObj := args[3].(string) + + switch policyPrivilege { + case commonpb.ObjectPrivilege_PrivilegeAll.String(): + return true, nil + case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String(): + // read only belong to collection object + if !collMatch(requestObj, policyObj) { + return false, nil + } + _, ok := roPrivileges[requestPrivilege] + return ok, nil + case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String(): + // read write belong to collection object + if !collMatch(requestObj, policyObj) { + return false, nil + } + _, ok := rwPrivileges[requestPrivilege] + return ok, nil + case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String(): + // admin belong to global object + _, ok := adminPrivileges[requestPrivilege] + return ok, nil + default: + return false, nil + } +} + +func collMatch(requestObj, policyObj string) bool { + _, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:]) + _, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:]) + + return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2 +} diff --git a/internal/proxy/privilege_cache.go b/internal/proxy/privilege/result_cache.go similarity index 83% rename from internal/proxy/privilege_cache.go rename to internal/proxy/privilege/result_cache.go index a67c775eee..d50b495feb 100644 --- a/internal/proxy/privilege_cache.go +++ b/internal/proxy/privilege/result_cache.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package privilege import ( "fmt" @@ -28,11 +28,11 @@ import ( var ( priCacheInitOnce sync.Once priCacheMut sync.RWMutex - priCache *PrivilegeCache + priCache *resultCache ver atomic.Int64 ) -func getPriCache() *PrivilegeCache { +func getPriCache() *resultCache { priCacheMut.RLock() c := priCache priCacheMut.RUnlock() @@ -41,7 +41,7 @@ func getPriCache() *PrivilegeCache { priCacheInitOnce.Do(func() { priCacheMut.Lock() defer priCacheMut.Unlock() - priCache = &PrivilegeCache{ + priCache = &resultCache{ version: ver.Inc(), values: typeutil.ConcurrentMap[string, bool]{}, } @@ -57,20 +57,20 @@ func getPriCache() *PrivilegeCache { func CleanPrivilegeCache() { priCacheMut.Lock() defer priCacheMut.Unlock() - priCache = &PrivilegeCache{ + priCache = &resultCache{ version: ver.Inc(), values: typeutil.ConcurrentMap[string, bool]{}, } } -func GetPrivilegeCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) { +func GetResultCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) { key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege) c := getPriCache() isPermit, cached = c.values.Get(key) return isPermit, cached, c.version } -func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool, version int64) { +func SetResultCache(roleName, object, objectPrivilege string, isPermit bool, version int64) { key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege) c := getPriCache() if c.version == version { @@ -80,7 +80,7 @@ func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool, // PrivilegeCache is a cache for privilege enforce result // version provides version control when any policy updates -type PrivilegeCache struct { +type resultCache struct { values typeutil.ConcurrentMap[string, bool] version int64 } diff --git a/internal/proxy/privilege_cache_test.go b/internal/proxy/privilege/result_cache_test.go similarity index 80% rename from internal/proxy/privilege_cache_test.go rename to internal/proxy/privilege/result_cache_test.go index cf80b8c3e8..52baab166d 100644 --- a/internal/proxy/privilege_cache_test.go +++ b/internal/proxy/privilege/result_cache_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package privilege import ( "testing" @@ -32,9 +32,9 @@ func (s *PrivilegeCacheSuite) TearDownTest() { func (s *PrivilegeCacheSuite) TestGetPrivilege() { // get current version - _, _, version := GetPrivilegeCache("", "", "") - SetPrivilegeCache("test-role", "test-object", "read", true, version) - SetPrivilegeCache("test-role", "test-object", "delete", false, version) + _, _, version := GetResultCache("", "", "") + SetResultCache("test-role", "test-object", "read", true, version) + SetResultCache("test-role", "test-object", "delete", false, version) type testCase struct { tag string @@ -51,7 +51,7 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() { for _, tc := range testCases { s.Run(tc.tag, func() { - isPermit, exists, _ := GetPrivilegeCache(tc.input[0], tc.input[1], tc.input[2]) + isPermit, exists, _ := GetResultCache(tc.input[0], tc.input[1], tc.input[2]) s.Equal(tc.expectIsPermit, isPermit) s.Equal(tc.expectExists, exists) }) @@ -60,12 +60,12 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() { func (s *PrivilegeCacheSuite) TestSetPrivilegeVersion() { // get current version - _, _, version := GetPrivilegeCache("", "", "") + _, _, version := GetResultCache("", "", "") CleanPrivilegeCache() - SetPrivilegeCache("test-role", "test-object", "read", true, version) + SetResultCache("test-role", "test-object", "read", true, version) - isPermit, exists, nextVersion := GetPrivilegeCache("test-role", "test-object", "read") + isPermit, exists, nextVersion := GetResultCache("test-role", "test-object", "read") s.False(isPermit) s.False(exists) s.NotEqual(version, nextVersion) diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 7dc27fbb44..e2e6e08be0 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -4,12 +4,9 @@ import ( "context" "fmt" "reflect" - "strings" "sync" "github.com/casbin/casbin/v2" - "github.com/casbin/casbin/v2/model" - "github.com/samber/lo" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -17,36 +14,15 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" - "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error) -const ( - // sub -> role name, like admin, public - // obj -> contact object with object name, like Global-*, Collection-col1 - // act -> privilege, like CreateCollection, DescribeCollection - ModelStr = ` -[request_definition] -r = sub, obj, act - -[policy_definition] -p = sub, obj, act - -[policy_effect] -e = some(where (p.eft == allow)) - -[matchers] -m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj)) -` -) - -var templateModel = getPolicyModel(ModelStr) - var ( enforcer *casbin.SyncedEnforcer initOnce sync.Once @@ -55,55 +31,9 @@ var ( var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{} -func initPrivilegeGroups() { - initPrivilegeGroupsOnce.Do(func() { - roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings() - if len(roGroup) == 0 { - roGroup = util.ReadOnlyPrivilegeGroup - } - roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} }) - - rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings() - if len(rwGroup) == 0 { - rwGroup = util.ReadWritePrivilegeGroup - } - rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} }) - - adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings() - if len(adminGroup) == 0 { - adminGroup = util.AdminPrivilegeGroup - } - adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} }) - }) -} - -func getEnforcer() *casbin.SyncedEnforcer { - initOnce.Do(func() { - e, err := casbin.NewSyncedEnforcer() - if err != nil { - log.Panic("failed to create casbin enforcer", zap.Error(err)) - } - casbinModel := getPolicyModel(ModelStr) - adapter := NewMetaCacheCasbinAdapter(func() Cache { return globalMetaCache }) - e.InitWithModelAndAdapter(casbinModel, adapter) - e.AddFunction("dbMatch", DBMatchFunc) - e.AddFunction("privilegeGroupContains", PrivilegeGroupContains) - enforcer = e - }) - return enforcer -} - -func getPolicyModel(modelString string) model.Model { - m, err := model.NewModelFromString(modelString) - if err != nil { - log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) - } - return m -} - // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor { - initPrivilegeGroups() + privilege.InitPrivilegeGroups() return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { newCtx, err := privilegeFunc(ctx, req) if err != nil { @@ -160,11 +90,11 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName), zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames)) - e := getEnforcer() + e := privilege.GetEnforcer() for _, roleName := range roleNames { permitFunc := func(objectName string) (bool, error) { object := funcutil.PolicyForResource(dbName, objectType, objectName) - isPermit, cached, version := GetPrivilegeCache(roleName, object, objectPrivilege) + isPermit, cached, version := privilege.GetResultCache(roleName, object, objectPrivilege) if cached { return isPermit, nil } @@ -172,7 +102,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context if err != nil { return false, err } - SetPrivilegeCache(roleName, object, objectPrivilege, isPermit, version) + privilege.SetResultCache(roleName, object, objectPrivilege, isPermit, version) return isPermit, nil } @@ -237,52 +167,3 @@ func isSelectMyRoleGrants(req interface{}, roleNames []string) bool { roleName := filterGrantEntity.GetRole().GetName() return funcutil.SliceContain(roleNames, roleName) } - -func DBMatchFunc(args ...interface{}) (interface{}, error) { - name1 := args[0].(string) - name2 := args[1].(string) - - db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:]) - db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:]) - - return db1 == db2, nil -} - -func collMatch(requestObj, policyObj string) bool { - _, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:]) - _, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:]) - - return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2 -} - -func PrivilegeGroupContains(args ...interface{}) (interface{}, error) { - requestPrivilege := args[0].(string) - policyPrivilege := args[1].(string) - requestObj := args[2].(string) - policyObj := args[3].(string) - - switch policyPrivilege { - case commonpb.ObjectPrivilege_PrivilegeAll.String(): - return true, nil - case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String(): - // read only belong to collection object - if !collMatch(requestObj, policyObj) { - return false, nil - } - _, ok := roPrivileges[requestPrivilege] - return ok, nil - case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String(): - // read write belong to collection object - if !collMatch(requestObj, policyObj) { - return false, nil - } - _, ok := rwPrivileges[requestPrivilege] - return ok, nil - case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String(): - // admin belong to global object - _, ok := adminPrivileges[requestPrivilege] - return ok, nil - default: - return false, nil - } -} diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index 7a272bf82f..88895eee33 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" @@ -167,7 +168,7 @@ func TestPrivilegeInterceptor(t *testing.T) { g.Wait() assert.Panics(t, func() { - getPolicyModel("foo") + privilege.GetPolicyModel("foo") }) }) } @@ -268,7 +269,7 @@ func TestPrivilegeGroup(t *testing.T) { t.Run("grant ReadOnly to single collection", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") - initPrivilegeGroups() + privilege.InitPrivilegeGroups() var err error ctx = GetContext(context.Background(), "fooo:123456") @@ -287,7 +288,7 @@ func TestPrivilegeGroup(t *testing.T) { }, nil } InitMetaCache(ctx, client, mgr) - defer CleanPrivilegeCache() + defer privilege.CleanPrivilegeCache() _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ CollectionName: "coll1", @@ -325,7 +326,7 @@ func TestPrivilegeGroup(t *testing.T) { t.Run("grant ReadOnly to all collection", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") - initPrivilegeGroups() + privilege.InitPrivilegeGroups() var err error ctx = GetContext(context.Background(), "fooo:123456") @@ -344,7 +345,7 @@ func TestPrivilegeGroup(t *testing.T) { }, nil } InitMetaCache(ctx, client, mgr) - defer CleanPrivilegeCache() + defer privilege.CleanPrivilegeCache() _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ CollectionName: "coll1", @@ -382,7 +383,7 @@ func TestPrivilegeGroup(t *testing.T) { t.Run("grant ReadWrite to single collection", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") - initPrivilegeGroups() + privilege.InitPrivilegeGroups() var err error ctx = GetContext(context.Background(), "fooo:123456") @@ -401,7 +402,7 @@ func TestPrivilegeGroup(t *testing.T) { }, nil } InitMetaCache(ctx, client, mgr) - defer CleanPrivilegeCache() + defer privilege.CleanPrivilegeCache() _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ CollectionName: "coll1", @@ -485,7 +486,7 @@ func TestPrivilegeGroup(t *testing.T) { t.Run("grant ReadWrite to all collection", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") - initPrivilegeGroups() + privilege.InitPrivilegeGroups() var err error ctx = GetContext(context.Background(), "fooo:123456") @@ -504,7 +505,7 @@ func TestPrivilegeGroup(t *testing.T) { }, nil } InitMetaCache(ctx, client, mgr) - defer CleanPrivilegeCache() + defer privilege.CleanPrivilegeCache() _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ CollectionName: "coll1", @@ -552,7 +553,7 @@ func TestPrivilegeGroup(t *testing.T) { t.Run("Admin", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") - initPrivilegeGroups() + privilege.InitPrivilegeGroups() var err error ctx = GetContext(context.Background(), "fooo:123456") @@ -571,7 +572,7 @@ func TestPrivilegeGroup(t *testing.T) { }, nil } InitMetaCache(ctx, client, mgr) - defer CleanPrivilegeCache() + defer privilege.CleanPrivilegeCache() _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{}) assert.NoError(t, err) @@ -593,7 +594,7 @@ func TestPrivilegeGroup(t *testing.T) { func TestBuiltinPrivilegeGroup(t *testing.T) { t.Run("ClusterAdmin", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") - initPrivilegeGroups() + privilege.InitPrivilegeGroups() var err error ctx := GetContext(context.Background(), "fooo:123456") @@ -615,7 +616,7 @@ func TestBuiltinPrivilegeGroup(t *testing.T) { }, nil } InitMetaCache(ctx, client, mgr) - defer CleanPrivilegeCache() + defer privilege.CleanPrivilegeCache() _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{}) assert.NoError(t, err) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 64461dd922..2232dcdbc7 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -51,6 +51,7 @@ import ( grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -2837,7 +2838,7 @@ func TestProxy(t *testing.T) { getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode()) - assert.True(t, passwordVerify(ctx, username, newPassword, globalMetaCache)) + assert.True(t, passwordVerify(ctx, username, newPassword, privilege.GetPrivilegeCache())) getCredentialReq.Username = "(" getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 26a0ce3968..3916924341 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/analyzer" "github.com/milvus-io/milvus/internal/util/function/embedding" @@ -1464,14 +1465,15 @@ func AppendUserInfoForRPC(ctx context.Context) context.Context { } func GetRole(username string) ([]string, error) { - if globalMetaCache == nil { + privCache := privilege.GetPrivilegeCache() + if privCache == nil { return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") } - return globalMetaCache.GetUserRole(username), nil + return privCache.GetUserRole(username), nil } func PasswordVerify(ctx context.Context, username, rawPwd string) bool { - return passwordVerify(ctx, username, rawPwd, globalMetaCache) + return passwordVerify(ctx, username, rawPwd, privilege.GetPrivilegeCache()) } func VerifyAPIKey(rawToken string) (string, error) { @@ -1485,10 +1487,10 @@ func VerifyAPIKey(rawToken string) (string, error) { } // PasswordVerify verify password -func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool { +func passwordVerify(ctx context.Context, username, rawPwd string, privilegeCache privilege.PrivilegeCache) bool { // it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection. // meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache. - credInfo, err := globalMetaCache.GetCredentialInfo(ctx, username) + credInfo, err := privilege.GetPrivilegeCache().GetCredentialInfo(ctx, username) if err != nil { log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err)) return false @@ -1509,7 +1511,7 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach // update cache after miss cache credInfo.Sha256Password = sha256Pwd log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo)) - globalMetaCache.UpdateCredential(credInfo) + privilegeCache.UpdateCredential(credInfo) return true } diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index fc01696abb..28ab2e2c4a 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proxy/privilege" "github.com/milvus-io/milvus/internal/util/function/embedding" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" @@ -789,19 +790,21 @@ func TestGetCurDBNameFromContext(t *testing.T) { } func TestGetRole(t *testing.T) { - globalMetaCache = nil + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + privilege.ResetPrivilegeCacheForTest() _, err := GetRole("foo") assert.Error(t, err) - mockCache := NewMockCache(t) - mockCache.On("GetUserRole", - mock.AnythingOfType("string"), - ).Return(func(username string) []string { - if username == "root" { - return []string{"role1", "admin", "role2"} - } - return []string{"role1"} - }) - globalMetaCache = mockCache + + mixcoord := mocks.NewMockMixCoordClient(t) + mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{ + Status: merr.Success(), + UserRoles: []string{"root/role1", "root/admin", "root/role2", "foo/role1"}, + }, nil).Times(1) + + privilege.InitPrivilegeCache(ctx, mixcoord) + roles, err := GetRole("root") assert.NoError(t, err) assert.Equal(t, 3, len(roles)) @@ -812,11 +815,12 @@ func TestGetRole(t *testing.T) { } func TestPasswordVerify(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + username := "user-test00" password := "PasswordVerify" - // credential does not exist within cache - credCache := make(map[string]*internalpb.CredentialInfo, 0) invokedCount := 0 mockedRootCoord := NewMixCoordMock() @@ -825,34 +829,36 @@ func TestPasswordVerify(t *testing.T) { return nil, errors.New("get cred not found credential") } - metaCache := &MetaCache{ - credMap: credCache, - mixCoord: mockedRootCoord, - } - ret, ok := credCache[username] - assert.False(t, ok) - assert.Nil(t, ret) - assert.False(t, passwordVerify(context.TODO(), username, password, metaCache)) + privilege.InitPrivilegeCache(ctx, mockedRootCoord) + privilegeCache := privilege.GetPrivilegeCache() + assert.False(t, passwordVerify(ctx, username, password, privilegeCache)) assert.Equal(t, 1, invokedCount) // Sha256Password has not been filled into cache during establish connection firstly encryptedPwd, err := crypto.PasswordEncrypt(password) assert.NoError(t, err) - credCache[username] = &internalpb.CredentialInfo{ - Username: username, - EncryptedPassword: encryptedPwd, + privilegeCache.RemoveCredential(username) + mockedRootCoord.GetGetCredentialFunc = func(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { + invokedCount++ + return &rootcoordpb.GetCredentialResponse{ + Status: merr.Success(), + Username: username, + Password: encryptedPwd, + }, nil } - assert.True(t, passwordVerify(context.TODO(), username, password, metaCache)) - ret, ok = credCache[username] - assert.True(t, ok) + + assert.True(t, passwordVerify(ctx, username, password, privilegeCache)) + + ret, err := privilegeCache.GetCredentialInfo(ctx, username) + assert.NoError(t, err) assert.NotNil(t, ret) assert.Equal(t, username, ret.Username) assert.NotNil(t, ret.Sha256Password) - assert.Equal(t, 1, invokedCount) + assert.Equal(t, 2, invokedCount) // Sha256Password already exists within cache - assert.True(t, passwordVerify(context.TODO(), username, password, metaCache)) - assert.Equal(t, 1, invokedCount) + assert.True(t, passwordVerify(ctx, username, password, privilegeCache)) + assert.Equal(t, 2, invokedCount) } func Test_isCollectionIsLoaded(t *testing.T) {