mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: Refactor privilege management by extracting privilege cache into separate package (#44762)
Related to #44761 This commit refactors the privilege management system in the proxy component by: 1. **Separation of Concerns**: Extracts privilege-related functionality from MetaCache into a dedicated `internal/proxy/privilege` package, improving code organization and maintainability. 2. **New Package Structure**: Creates `internal/proxy/privilege/` with: - `cache.go`: Core privilege cache implementation (PrivilegeCache) - `result_cache.go`: Privilege enforcement result caching - `model.go`: Casbin model and policy enforcement functions - `meta_cache_adapter.go`: Casbin adapter for MetaCache integration - Corresponding test files and mock implementations 3. **MetaCache Simplification**: Removes privilege and credential management methods from MetaCache interface and implementation: - Removed: GetCredentialInfo, RemoveCredential, UpdateCredential - Removed: GetPrivilegeInfo, GetUserRole, RefreshPolicyInfo, InitPolicyInfo - Deleted: meta_cache_adapter.go, privilege_cache.go and their tests 4. **Updated References**: Updates all callsites to use the new privilegeCache global: - Authentication interceptor now uses privilegeCache for password verification - Credential cache operations (InvalidateCredentialCache, UpdateCredentialCache, UpdateCredential) now use privilegeCache - Policy refresh operations (RefreshPolicyInfoCache) now use privilegeCache - Privilege interceptor uses new privilege.GetEnforcer() and privilege result cache 5. **Improved API**: Renames cache functions for clarity: - GetPrivilegeCache → GetResultCache - SetPrivilegeCache → SetResultCache - CleanPrivilegeCache → CleanResultCache This refactoring makes the codebase more modular, separates privilege management concerns from general metadata caching, and provides a clearer API for privilege enforcement operations. --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
369c6eb206
commit
f5f053f1d2
@ -13,6 +13,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||||
@ -111,7 +112,7 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
|
|||||||
} else {
|
} else {
|
||||||
// username+password authentication
|
// username+password authentication
|
||||||
username, password := parseMD(rawToken)
|
username, password := parseMD(rawToken)
|
||||||
if !passwordVerify(ctx, username, password, globalMetaCache) {
|
if !passwordVerify(ctx, username, password, privilege.GetPrivilegeCache()) {
|
||||||
log.Warn("fail to verify password", zap.String("username", username))
|
log.Warn("fail to verify password", zap.String("username", username))
|
||||||
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
|
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
|
||||||
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
|
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
||||||
@ -27,7 +28,7 @@ func TestValidAuth(t *testing.T) {
|
|||||||
if username == "" || password == "" {
|
if username == "" || password == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return passwordVerify(ctx, username, password, globalMetaCache)
|
return passwordVerify(ctx, username, password, privilege.GetPrivilegeCache())
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@ -43,6 +43,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||||
"github.com/milvus-io/milvus/internal/http"
|
"github.com/milvus-io/milvus/internal/http"
|
||||||
"github.com/milvus-io/milvus/internal/proxy/connection"
|
"github.com/milvus-io/milvus/internal/proxy/connection"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/proxy/replicate"
|
"github.com/milvus-io/milvus/internal/proxy/replicate"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/analyzer"
|
"github.com/milvus-io/milvus/internal/util/analyzer"
|
||||||
@ -4689,8 +4690,9 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy
|
|||||||
}
|
}
|
||||||
|
|
||||||
username := request.Username
|
username := request.Username
|
||||||
if globalMetaCache != nil {
|
priCache := privilege.GetPrivilegeCache()
|
||||||
globalMetaCache.RemoveCredential(username) // no need to return error, though credential may be not cached
|
if priCache != nil {
|
||||||
|
priCache.RemoveCredential(username) // no need to return error, though credential may be not cached
|
||||||
}
|
}
|
||||||
log.Debug("complete to invalidate credential cache")
|
log.Debug("complete to invalidate credential cache")
|
||||||
|
|
||||||
@ -4715,8 +4717,9 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U
|
|||||||
Username: request.Username,
|
Username: request.Username,
|
||||||
Sha256Password: request.Password,
|
Sha256Password: request.Password,
|
||||||
}
|
}
|
||||||
if globalMetaCache != nil {
|
priCache := privilege.GetPrivilegeCache()
|
||||||
globalMetaCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached
|
if priCache != nil {
|
||||||
|
priCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached
|
||||||
}
|
}
|
||||||
log.Debug("complete to update credential cache")
|
log.Debug("complete to update credential cache")
|
||||||
|
|
||||||
@ -4820,7 +4823,7 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, globalMetaCache) {
|
if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, privilege.GetPrivilegeCache()) {
|
||||||
err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername())
|
err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername())
|
||||||
return merr.Status(err), nil
|
return merr.Status(err), nil
|
||||||
}
|
}
|
||||||
@ -5347,8 +5350,9 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr
|
|||||||
return merr.Status(err), nil
|
return merr.Status(err), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if globalMetaCache != nil {
|
priCache := privilege.GetPrivilegeCache()
|
||||||
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{
|
if priCache != nil {
|
||||||
|
err := priCache.RefreshPolicyInfo(typeutil.CacheOp{
|
||||||
OpType: typeutil.CacheOpType(req.OpType),
|
OpType: typeutil.CacheOpType(req.OpType),
|
||||||
OpKey: req.OpKey,
|
OpKey: req.OpKey,
|
||||||
})
|
})
|
||||||
|
|||||||
@ -24,7 +24,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@ -32,6 +31,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
@ -39,11 +39,9 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/expr"
|
"github.com/milvus-io/milvus/pkg/v2/util/expr"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
||||||
@ -79,14 +77,14 @@ type Cache interface {
|
|||||||
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string
|
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string
|
||||||
|
|
||||||
// GetCredentialInfo operate credential cache
|
// GetCredentialInfo operate credential cache
|
||||||
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
|
// GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
|
||||||
RemoveCredential(username string)
|
// RemoveCredential(username string)
|
||||||
UpdateCredential(credInfo *internalpb.CredentialInfo)
|
// UpdateCredential(credInfo *internalpb.CredentialInfo)
|
||||||
|
|
||||||
GetPrivilegeInfo(ctx context.Context) []string
|
// GetPrivilegeInfo(ctx context.Context) []string
|
||||||
GetUserRole(username string) []string
|
// GetUserRole(username string) []string
|
||||||
RefreshPolicyInfo(op typeutil.CacheOp) error
|
// RefreshPolicyInfo(op typeutil.CacheOp) error
|
||||||
InitPolicyInfo(info []string, userRoles []string)
|
// InitPolicyInfo(info []string, userRoles []string)
|
||||||
|
|
||||||
RemoveDatabase(ctx context.Context, database string)
|
RemoveDatabase(ctx context.Context, database string)
|
||||||
HasDatabase(ctx context.Context, database string) bool
|
HasDatabase(ctx context.Context, database string) bool
|
||||||
@ -382,14 +380,12 @@ func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr
|
|||||||
}
|
}
|
||||||
expr.Register("cache", globalMetaCache)
|
expr.Register("cache", globalMetaCache)
|
||||||
|
|
||||||
// The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required.
|
err = privilege.InitPrivilegeCache(ctx, mixCoord)
|
||||||
resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{})
|
if err != nil {
|
||||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
log.Error("failed to init privilege cache", zap.Error(err))
|
||||||
log.Error("fail to init meta cache", zap.Error(err))
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
globalMetaCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
|
|
||||||
log.Info("success to init meta cache", zap.Strings("policy_infos", resp.PolicyInfos))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -902,55 +898,6 @@ func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID Uniq
|
|||||||
return collNames
|
return collNames
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCredentialInfo returns the credential related to provided username
|
|
||||||
// If the cache missed, proxy will try to fetch from storage
|
|
||||||
func (m *MetaCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
|
|
||||||
m.credMut.RLock()
|
|
||||||
var credInfo *internalpb.CredentialInfo
|
|
||||||
credInfo, ok := m.credMap[username]
|
|
||||||
m.credMut.RUnlock()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
req := &rootcoordpb.GetCredentialRequest{
|
|
||||||
Base: commonpbutil.NewMsgBase(
|
|
||||||
commonpbutil.WithMsgType(commonpb.MsgType_GetCredential),
|
|
||||||
),
|
|
||||||
Username: username,
|
|
||||||
}
|
|
||||||
resp, err := m.mixCoord.GetCredential(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return &internalpb.CredentialInfo{}, err
|
|
||||||
}
|
|
||||||
credInfo = &internalpb.CredentialInfo{
|
|
||||||
Username: resp.Username,
|
|
||||||
EncryptedPassword: resp.Password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return credInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) RemoveCredential(username string) {
|
|
||||||
m.credMut.Lock()
|
|
||||||
defer m.credMut.Unlock()
|
|
||||||
// delete pair in credMap
|
|
||||||
delete(m.credMap, username)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
|
|
||||||
m.credMut.Lock()
|
|
||||||
defer m.credMut.Unlock()
|
|
||||||
username := credInfo.Username
|
|
||||||
_, ok := m.credMap[username]
|
|
||||||
if !ok {
|
|
||||||
m.credMap[username] = &internalpb.CredentialInfo{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do not cache encrypted password content
|
|
||||||
m.credMap[username].Username = username
|
|
||||||
m.credMap[username].Sha256Password = credInfo.Sha256Password
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
|
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
|
||||||
method := "GetShard"
|
method := "GetShard"
|
||||||
// check cache first
|
// check cache first
|
||||||
@ -1127,131 +1074,6 @@ func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) {
|
|
||||||
defer func() {
|
|
||||||
err := getEnforcer().LoadPolicy()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err))
|
|
||||||
}
|
|
||||||
CleanPrivilegeCache()
|
|
||||||
}()
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
m.unsafeInitPolicyInfo(info, userRoles)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) unsafeInitPolicyInfo(info []string, userRoles []string) {
|
|
||||||
m.privilegeInfos = util.StringSet(info)
|
|
||||||
for _, userRole := range userRoles {
|
|
||||||
user, role, err := funcutil.DecodeUserRoleCache(userRole)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if m.userToRoles[user] == nil {
|
|
||||||
m.userToRoles[user] = make(map[string]struct{})
|
|
||||||
}
|
|
||||||
m.userToRoles[user][role] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) GetPrivilegeInfo(ctx context.Context) []string {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
|
|
||||||
return util.StringList(m.privilegeInfos)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) GetUserRole(user string) []string {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
|
|
||||||
return util.StringList(m.userToRoles[user])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) {
|
|
||||||
defer func() {
|
|
||||||
if err == nil {
|
|
||||||
le := getEnforcer().LoadPolicy()
|
|
||||||
if le != nil {
|
|
||||||
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le))
|
|
||||||
}
|
|
||||||
CleanPrivilegeCache()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if op.OpType != typeutil.CacheRefresh {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
if op.OpKey == "" {
|
|
||||||
return errors.New("empty op key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch op.OpType {
|
|
||||||
case typeutil.CacheGrantPrivilege:
|
|
||||||
keys := funcutil.PrivilegesForPolicy(op.OpKey)
|
|
||||||
for _, key := range keys {
|
|
||||||
m.privilegeInfos[key] = struct{}{}
|
|
||||||
}
|
|
||||||
case typeutil.CacheRevokePrivilege:
|
|
||||||
keys := funcutil.PrivilegesForPolicy(op.OpKey)
|
|
||||||
for _, key := range keys {
|
|
||||||
delete(m.privilegeInfos, key)
|
|
||||||
}
|
|
||||||
case typeutil.CacheAddUserToRole:
|
|
||||||
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
|
|
||||||
}
|
|
||||||
if m.userToRoles[user] == nil {
|
|
||||||
m.userToRoles[user] = make(map[string]struct{})
|
|
||||||
}
|
|
||||||
m.userToRoles[user][role] = struct{}{}
|
|
||||||
case typeutil.CacheRemoveUserFromRole:
|
|
||||||
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
|
|
||||||
}
|
|
||||||
if m.userToRoles[user] != nil {
|
|
||||||
delete(m.userToRoles[user], role)
|
|
||||||
}
|
|
||||||
case typeutil.CacheDeleteUser:
|
|
||||||
delete(m.userToRoles, op.OpKey)
|
|
||||||
case typeutil.CacheDropRole:
|
|
||||||
for user := range m.userToRoles {
|
|
||||||
delete(m.userToRoles[user], op.OpKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
for policy := range m.privilegeInfos {
|
|
||||||
if funcutil.PolicyCheckerWithRole(policy, op.OpKey) {
|
|
||||||
delete(m.privilegeInfos, policy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case typeutil.CacheRefresh:
|
|
||||||
resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{})
|
|
||||||
if err != nil {
|
|
||||||
log.Error("fail to init meta cache", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !merr.Ok(resp.GetStatus()) {
|
|
||||||
log.Error("fail to init meta cache",
|
|
||||||
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
|
|
||||||
zap.String("reason", resp.GetStatus().GetReason()))
|
|
||||||
return merr.Error(resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
m.userToRoles = make(map[string]map[string]struct{})
|
|
||||||
m.privilegeInfos = make(map[string]struct{})
|
|
||||||
m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
|
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
|
||||||
log.Ctx(ctx).Debug("remove database", zap.String("name", database))
|
log.Ctx(ctx).Debug("remove database", zap.String("name", database))
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
|||||||
@ -35,6 +35,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/mocks"
|
"github.com/milvus-io/milvus/internal/mocks"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||||
@ -1510,9 +1511,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
err := InitMetaCache(context.Background(), client, mgr)
|
err := InitMetaCache(context.Background(), client, mgr)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background())
|
policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
|
||||||
assert.Equal(t, 3, len(policyInfos))
|
assert.Equal(t, 3, len(policyInfos))
|
||||||
roles := globalMetaCache.GetUserRole("foo")
|
roles := privilege.GetPrivilegeCache().GetUserRole("foo")
|
||||||
assert.Equal(t, 2, len(roles))
|
assert.Equal(t, 2, len(roles))
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -1527,29 +1528,29 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
|||||||
err := InitMetaCache(context.Background(), client, mgr)
|
err := InitMetaCache(context.Background(), client, mgr)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background())
|
policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
|
||||||
assert.Equal(t, 4, len(policyInfos))
|
assert.Equal(t, 4, len(policyInfos))
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
policyInfos = globalMetaCache.GetPrivilegeInfo(context.Background())
|
policyInfos = privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
|
||||||
assert.Equal(t, 3, len(policyInfos))
|
assert.Equal(t, 3, len(policyInfos))
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
roles := globalMetaCache.GetUserRole("foo")
|
roles := privilege.GetPrivilegeCache().GetUserRole("foo")
|
||||||
assert.Equal(t, 3, len(roles))
|
assert.Equal(t, 3, len(roles))
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
roles = globalMetaCache.GetUserRole("foo")
|
roles = privilege.GetPrivilegeCache().GetUserRole("foo")
|
||||||
assert.Equal(t, 2, len(roles))
|
assert.Equal(t, 2, len(roles))
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -1568,18 +1569,18 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
|||||||
err := InitMetaCache(context.Background(), client, mgr)
|
err := InitMetaCache(context.Background(), client, mgr)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
roles := globalMetaCache.GetUserRole("foo")
|
roles := privilege.GetPrivilegeCache().GetUserRole("foo")
|
||||||
assert.Len(t, roles, 0)
|
assert.Len(t, roles, 0)
|
||||||
|
|
||||||
roles = globalMetaCache.GetUserRole("foo2")
|
roles = privilege.GetPrivilegeCache().GetUserRole("foo2")
|
||||||
assert.Len(t, roles, 2)
|
assert.Len(t, roles, 2)
|
||||||
|
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
roles = globalMetaCache.GetUserRole("foo2")
|
roles = privilege.GetPrivilegeCache().GetUserRole("foo2")
|
||||||
assert.Len(t, roles, 1)
|
assert.Len(t, roles, 1)
|
||||||
assert.Equal(t, "role3", roles[0])
|
assert.Equal(t, "role3", roles[0])
|
||||||
|
|
||||||
@ -1590,9 +1591,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
|||||||
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")},
|
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh})
|
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
roles = globalMetaCache.GetUserRole("foo")
|
roles = privilege.GetPrivilegeCache().GetUserRole("foo")
|
||||||
assert.Len(t, roles, 2)
|
assert.Len(t, roles, 2)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,24 +22,29 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/mocks"
|
"github.com/milvus-io/milvus/internal/mocks"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddRootUserToAdminRole() {
|
func AddRootUserToAdminRole() {
|
||||||
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
|
err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoveRootUserFromAdminRole() {
|
func RemoveRootUserFromAdminRole() {
|
||||||
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
|
err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -55,6 +60,8 @@ func InitEmptyGlobalCache() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
|
||||||
|
privilege.InitPrivilegeCache(context.Background(), mixcoord)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetGlobalMetaCache(metaCache *MetaCache) {
|
func SetGlobalMetaCache(metaCache *MetaCache) {
|
||||||
|
|||||||
7
internal/proxy/privilege/OWNERS
Normal file
7
internal/proxy/privilege/OWNERS
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
reviewers:
|
||||||
|
- congqixia
|
||||||
|
- czs007
|
||||||
|
- shaoting-huang
|
||||||
|
|
||||||
|
approvers:
|
||||||
|
- maintainers
|
||||||
3
internal/proxy/privilege/README.md
Normal file
3
internal/proxy/privilege/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Summary
|
||||||
|
|
||||||
|
this package contains privilege related components for proxy.
|
||||||
267
internal/proxy/privilege/cache.go
Normal file
267
internal/proxy/privilege/cache.go
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package privilege
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/cockroachdb/errors"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cacheInst atomic.Pointer[privilegeCache]
|
||||||
|
|
||||||
|
func GetPrivilegeCache() *privilegeCache {
|
||||||
|
return cacheInst.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
type PrivilegeCache interface {
|
||||||
|
// GetCredentialInfo operate credential cache
|
||||||
|
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
|
||||||
|
RemoveCredential(username string)
|
||||||
|
UpdateCredential(credInfo *internalpb.CredentialInfo)
|
||||||
|
|
||||||
|
GetPrivilegeInfo(ctx context.Context) []string
|
||||||
|
GetUserRole(username string) []string
|
||||||
|
RefreshPolicyInfo(op typeutil.CacheOp) error
|
||||||
|
InitPolicyInfo(info []string, userRoles []string)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ PrivilegeCache = (*privilegeCache)(nil)
|
||||||
|
|
||||||
|
type privilegeCache struct {
|
||||||
|
mixCoord types.MixCoordClient
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
privilegeInfos map[string]struct{} // privileges cache
|
||||||
|
userToRoles map[string]map[string]struct{} // user to role cache
|
||||||
|
|
||||||
|
credMut sync.RWMutex
|
||||||
|
credMap map[string]*internalpb.CredentialInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitPrivilegeCache(ctx context.Context, mixCoord types.MixCoordClient) error {
|
||||||
|
privilegeCache := NewPrivilegeCache(mixCoord)
|
||||||
|
// The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required.
|
||||||
|
cacheInst.Store(privilegeCache)
|
||||||
|
resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{})
|
||||||
|
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||||
|
log.Error("fail to init meta cache", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
privilegeCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
|
||||||
|
log.Info("success to init privilege cache", zap.Strings("policy_infos", resp.PolicyInfos))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPrivilegeCache(mixCoord types.MixCoordClient) *privilegeCache {
|
||||||
|
return &privilegeCache{
|
||||||
|
mixCoord: mixCoord,
|
||||||
|
privilegeInfos: make(map[string]struct{}),
|
||||||
|
userToRoles: make(map[string]map[string]struct{}),
|
||||||
|
|
||||||
|
credMap: make(map[string]*internalpb.CredentialInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCredentialInfo returns the credential related to provided username
|
||||||
|
// If the cache missed, proxy will try to fetch from storage
|
||||||
|
func (m *privilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
|
||||||
|
m.credMut.RLock()
|
||||||
|
var credInfo *internalpb.CredentialInfo
|
||||||
|
credInfo, ok := m.credMap[username]
|
||||||
|
m.credMut.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
req := &rootcoordpb.GetCredentialRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(
|
||||||
|
commonpbutil.WithMsgType(commonpb.MsgType_GetCredential),
|
||||||
|
),
|
||||||
|
Username: username,
|
||||||
|
}
|
||||||
|
resp, err := m.mixCoord.GetCredential(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return &internalpb.CredentialInfo{}, err
|
||||||
|
}
|
||||||
|
credInfo = &internalpb.CredentialInfo{
|
||||||
|
Username: resp.Username,
|
||||||
|
EncryptedPassword: resp.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return credInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) RemoveCredential(username string) {
|
||||||
|
m.credMut.Lock()
|
||||||
|
defer m.credMut.Unlock()
|
||||||
|
// delete pair in credMap
|
||||||
|
delete(m.credMap, username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
|
||||||
|
m.credMut.Lock()
|
||||||
|
defer m.credMut.Unlock()
|
||||||
|
username := credInfo.Username
|
||||||
|
_, ok := m.credMap[username]
|
||||||
|
if !ok {
|
||||||
|
m.credMap[username] = &internalpb.CredentialInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not cache encrypted password content
|
||||||
|
m.credMap[username].Username = username
|
||||||
|
m.credMap[username].Sha256Password = credInfo.Sha256Password
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) InitPolicyInfo(info []string, userRoles []string) {
|
||||||
|
defer func() {
|
||||||
|
err := GetEnforcer().LoadPolicy()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err))
|
||||||
|
}
|
||||||
|
CleanPrivilegeCache()
|
||||||
|
}()
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.unsafeInitPolicyInfo(info, userRoles)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) unsafeInitPolicyInfo(info []string, userRoles []string) {
|
||||||
|
m.privilegeInfos = util.StringSet(info)
|
||||||
|
for _, userRole := range userRoles {
|
||||||
|
user, role, err := funcutil.DecodeUserRoleCache(userRole)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.userToRoles[user] == nil {
|
||||||
|
m.userToRoles[user] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
m.userToRoles[user][role] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) GetPrivilegeInfo(ctx context.Context) []string {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return util.StringList(m.privilegeInfos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) GetUserRole(user string) []string {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return util.StringList(m.userToRoles[user])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *privilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
le := GetEnforcer().LoadPolicy()
|
||||||
|
if le != nil {
|
||||||
|
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le))
|
||||||
|
}
|
||||||
|
CleanPrivilegeCache()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if op.OpType != typeutil.CacheRefresh {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if op.OpKey == "" {
|
||||||
|
return errors.New("empty op key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch op.OpType {
|
||||||
|
case typeutil.CacheGrantPrivilege:
|
||||||
|
keys := funcutil.PrivilegesForPolicy(op.OpKey)
|
||||||
|
for _, key := range keys {
|
||||||
|
m.privilegeInfos[key] = struct{}{}
|
||||||
|
}
|
||||||
|
case typeutil.CacheRevokePrivilege:
|
||||||
|
keys := funcutil.PrivilegesForPolicy(op.OpKey)
|
||||||
|
for _, key := range keys {
|
||||||
|
delete(m.privilegeInfos, key)
|
||||||
|
}
|
||||||
|
case typeutil.CacheAddUserToRole:
|
||||||
|
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
|
||||||
|
}
|
||||||
|
if m.userToRoles[user] == nil {
|
||||||
|
m.userToRoles[user] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
m.userToRoles[user][role] = struct{}{}
|
||||||
|
case typeutil.CacheRemoveUserFromRole:
|
||||||
|
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
|
||||||
|
}
|
||||||
|
if m.userToRoles[user] != nil {
|
||||||
|
delete(m.userToRoles[user], role)
|
||||||
|
}
|
||||||
|
case typeutil.CacheDeleteUser:
|
||||||
|
delete(m.userToRoles, op.OpKey)
|
||||||
|
case typeutil.CacheDropRole:
|
||||||
|
for user := range m.userToRoles {
|
||||||
|
delete(m.userToRoles[user], op.OpKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
for policy := range m.privilegeInfos {
|
||||||
|
if funcutil.PolicyCheckerWithRole(policy, op.OpKey) {
|
||||||
|
delete(m.privilegeInfos, policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case typeutil.CacheRefresh:
|
||||||
|
resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{})
|
||||||
|
if err != nil {
|
||||||
|
log.Error("fail to init meta cache", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !merr.Ok(resp.GetStatus()) {
|
||||||
|
log.Error("fail to init meta cache",
|
||||||
|
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
|
||||||
|
zap.String("reason", resp.GetStatus().GetReason()))
|
||||||
|
return merr.Error(resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.userToRoles = make(map[string]map[string]struct{})
|
||||||
|
m.privilegeInfos = make(map[string]struct{})
|
||||||
|
m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
27
internal/proxy/privilege/cache_testonly.go
Normal file
27
internal/proxy/privilege/cache_testonly.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
//go:build test
|
||||||
|
// +build test
|
||||||
|
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package privilege
|
||||||
|
|
||||||
|
// This file contains only functions used in tests to manipulate the privilege cache.
|
||||||
|
|
||||||
|
// ResetPrivilegeCacheForTest resets the privilege cache for testing purposes.
|
||||||
|
func ResetPrivilegeCacheForTest() {
|
||||||
|
cacheInst.Store(nil)
|
||||||
|
}
|
||||||
@ -14,7 +14,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package proxy
|
package privilege
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -30,10 +30,10 @@ import (
|
|||||||
// MetaCacheCasbinAdapter is the implementation of `persist.Adapter` with Cache
|
// MetaCacheCasbinAdapter is the implementation of `persist.Adapter` with Cache
|
||||||
// Since the usage shall be read-only, it implements only `LoadPolicy` for now.
|
// Since the usage shall be read-only, it implements only `LoadPolicy` for now.
|
||||||
type MetaCacheCasbinAdapter struct {
|
type MetaCacheCasbinAdapter struct {
|
||||||
cacheSource func() Cache
|
cacheSource func() PrivilegeCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetaCacheCasbinAdapter(cacheSource func() Cache) *MetaCacheCasbinAdapter {
|
func NewMetaCacheCasbinAdapter(cacheSource func() PrivilegeCache) *MetaCacheCasbinAdapter {
|
||||||
return &MetaCacheCasbinAdapter{
|
return &MetaCacheCasbinAdapter{
|
||||||
cacheSource: cacheSource,
|
cacheSource: cacheSource,
|
||||||
}
|
}
|
||||||
@ -14,7 +14,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package proxy
|
package privilege
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@ -26,36 +26,36 @@ import (
|
|||||||
type MetaCacheCasbinAdapterSuite struct {
|
type MetaCacheCasbinAdapterSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
|
|
||||||
cache *MockCache
|
cache *MockPrivilegeCache
|
||||||
adapter *MetaCacheCasbinAdapter
|
adapter *MetaCacheCasbinAdapter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MetaCacheCasbinAdapterSuite) SetupTest() {
|
func (s *MetaCacheCasbinAdapterSuite) SetupTest() {
|
||||||
s.cache = NewMockCache(s.T())
|
s.cache = NewMockPrivilegeCache(s.T())
|
||||||
|
|
||||||
s.adapter = NewMetaCacheCasbinAdapter(func() Cache { return s.cache })
|
s.adapter = NewMetaCacheCasbinAdapter(func() PrivilegeCache { return s.cache })
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MetaCacheCasbinAdapterSuite) TestLoadPolicy() {
|
func (s *MetaCacheCasbinAdapterSuite) TestLoadPolicy() {
|
||||||
s.Run("normal_load", func() {
|
s.Run("normal_load", func() {
|
||||||
s.cache.EXPECT().GetPrivilegeInfo(mock.Anything).Return([]string{})
|
s.cache.EXPECT().GetPrivilegeInfo(mock.Anything).Return([]string{})
|
||||||
|
|
||||||
m := getPolicyModel(ModelStr)
|
m := GetPolicyModel(ModelStr)
|
||||||
err := s.adapter.LoadPolicy(m)
|
err := s.adapter.LoadPolicy(m)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Run("source_return_nil", func() {
|
s.Run("source_return_nil", func() {
|
||||||
adapter := NewMetaCacheCasbinAdapter(func() Cache { return nil })
|
adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return nil })
|
||||||
|
|
||||||
m := getPolicyModel(ModelStr)
|
m := GetPolicyModel(ModelStr)
|
||||||
err := adapter.LoadPolicy(m)
|
err := adapter.LoadPolicy(m)
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MetaCacheCasbinAdapterSuite) TestSavePolicy() {
|
func (s *MetaCacheCasbinAdapterSuite) TestSavePolicy() {
|
||||||
m := getPolicyModel(ModelStr)
|
m := GetPolicyModel(ModelStr)
|
||||||
s.Error(s.adapter.SavePolicy(m))
|
s.Error(s.adapter.SavePolicy(m))
|
||||||
}
|
}
|
||||||
|
|
||||||
340
internal/proxy/privilege/mock_cache.go
Normal file
340
internal/proxy/privilege/mock_cache.go
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||||
|
|
||||||
|
package privilege
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
|
||||||
|
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
|
mock "github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockPrivilegeCache is an autogenerated mock type for the PrivilegeCache type
|
||||||
|
type MockPrivilegeCache struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockPrivilegeCache_Expecter struct {
|
||||||
|
mock *mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_m *MockPrivilegeCache) EXPECT() *MockPrivilegeCache_Expecter {
|
||||||
|
return &MockPrivilegeCache_Expecter{mock: &_m.Mock}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCredentialInfo provides a mock function with given fields: ctx, username
|
||||||
|
func (_m *MockPrivilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
|
||||||
|
ret := _m.Called(ctx, username)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for GetCredentialInfo")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 *internalpb.CredentialInfo
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
|
||||||
|
return rf(ctx, username)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
|
||||||
|
r0 = rf(ctx, username)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*internalpb.CredentialInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||||
|
r1 = rf(ctx, username)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo'
|
||||||
|
type MockPrivilegeCache_GetCredentialInfo_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCredentialInfo is a helper method to define mock.On call
|
||||||
|
// - ctx context.Context
|
||||||
|
// - username string
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockPrivilegeCache_GetCredentialInfo_Call {
|
||||||
|
return &MockPrivilegeCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockPrivilegeCache_GetCredentialInfo_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(context.Context), args[1].(string))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockPrivilegeCache_GetCredentialInfo_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockPrivilegeCache_GetCredentialInfo_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrivilegeInfo provides a mock function with given fields: ctx
|
||||||
|
func (_m *MockPrivilegeCache) GetPrivilegeInfo(ctx context.Context) []string {
|
||||||
|
ret := _m.Called(ctx)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for GetPrivilegeInfo")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 []string
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context) []string); ok {
|
||||||
|
r0 = rf(ctx)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).([]string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo'
|
||||||
|
type MockPrivilegeCache_GetPrivilegeInfo_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrivilegeInfo is a helper method to define mock.On call
|
||||||
|
// - ctx context.Context
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockPrivilegeCache_GetPrivilegeInfo_Call {
|
||||||
|
return &MockPrivilegeCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockPrivilegeCache_GetPrivilegeInfo_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(context.Context))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockPrivilegeCache_GetPrivilegeInfo_Call {
|
||||||
|
_c.Call.Return(_a0)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockPrivilegeCache_GetPrivilegeInfo_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserRole provides a mock function with given fields: username
|
||||||
|
func (_m *MockPrivilegeCache) GetUserRole(username string) []string {
|
||||||
|
ret := _m.Called(username)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for GetUserRole")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 []string
|
||||||
|
if rf, ok := ret.Get(0).(func(string) []string); ok {
|
||||||
|
r0 = rf(username)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).([]string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole'
|
||||||
|
type MockPrivilegeCache_GetUserRole_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserRole is a helper method to define mock.On call
|
||||||
|
// - username string
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) GetUserRole(username interface{}) *MockPrivilegeCache_GetUserRole_Call {
|
||||||
|
return &MockPrivilegeCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetUserRole_Call) Run(run func(username string)) *MockPrivilegeCache_GetUserRole_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(string))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetUserRole_Call) Return(_a0 []string) *MockPrivilegeCache_GetUserRole_Call {
|
||||||
|
_c.Call.Return(_a0)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockPrivilegeCache_GetUserRole_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitPolicyInfo provides a mock function with given fields: info, userRoles
|
||||||
|
func (_m *MockPrivilegeCache) InitPolicyInfo(info []string, userRoles []string) {
|
||||||
|
_m.Called(info, userRoles)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo'
|
||||||
|
type MockPrivilegeCache_InitPolicyInfo_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitPolicyInfo is a helper method to define mock.On call
|
||||||
|
// - info []string
|
||||||
|
// - userRoles []string
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockPrivilegeCache_InitPolicyInfo_Call {
|
||||||
|
return &MockPrivilegeCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockPrivilegeCache_InitPolicyInfo_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].([]string), args[1].([]string))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Return() *MockPrivilegeCache_InitPolicyInfo_Call {
|
||||||
|
_c.Call.Return()
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockPrivilegeCache_InitPolicyInfo_Call {
|
||||||
|
_c.Run(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshPolicyInfo provides a mock function with given fields: op
|
||||||
|
func (_m *MockPrivilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
|
||||||
|
ret := _m.Called(op)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for RefreshPolicyInfo")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok {
|
||||||
|
r0 = rf(op)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo'
|
||||||
|
type MockPrivilegeCache_RefreshPolicyInfo_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshPolicyInfo is a helper method to define mock.On call
|
||||||
|
// - op typeutil.CacheOp
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) RefreshPolicyInfo(op interface{}) *MockPrivilegeCache_RefreshPolicyInfo_Call {
|
||||||
|
return &MockPrivilegeCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockPrivilegeCache_RefreshPolicyInfo_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(typeutil.CacheOp))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockPrivilegeCache_RefreshPolicyInfo_Call {
|
||||||
|
_c.Call.Return(_a0)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockPrivilegeCache_RefreshPolicyInfo_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveCredential provides a mock function with given fields: username
|
||||||
|
func (_m *MockPrivilegeCache) RemoveCredential(username string) {
|
||||||
|
_m.Called(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential'
|
||||||
|
type MockPrivilegeCache_RemoveCredential_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveCredential is a helper method to define mock.On call
|
||||||
|
// - username string
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) RemoveCredential(username interface{}) *MockPrivilegeCache_RemoveCredential_Call {
|
||||||
|
return &MockPrivilegeCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_RemoveCredential_Call) Run(run func(username string)) *MockPrivilegeCache_RemoveCredential_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(string))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_RemoveCredential_Call) Return() *MockPrivilegeCache_RemoveCredential_Call {
|
||||||
|
_c.Call.Return()
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockPrivilegeCache_RemoveCredential_Call {
|
||||||
|
_c.Run(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCredential provides a mock function with given fields: credInfo
|
||||||
|
func (_m *MockPrivilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
|
||||||
|
_m.Called(credInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrivilegeCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential'
|
||||||
|
type MockPrivilegeCache_UpdateCredential_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCredential is a helper method to define mock.On call
|
||||||
|
// - credInfo *internalpb.CredentialInfo
|
||||||
|
func (_e *MockPrivilegeCache_Expecter) UpdateCredential(credInfo interface{}) *MockPrivilegeCache_UpdateCredential_Call {
|
||||||
|
return &MockPrivilegeCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(*internalpb.CredentialInfo))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_UpdateCredential_Call) Return() *MockPrivilegeCache_UpdateCredential_Call {
|
||||||
|
_c.Call.Return()
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockPrivilegeCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call {
|
||||||
|
_c.Run(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockPrivilegeCache creates a new instance of MockPrivilegeCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
// The first argument is typically a *testing.T value.
|
||||||
|
func NewMockPrivilegeCache(t interface {
|
||||||
|
mock.TestingT
|
||||||
|
Cleanup(func())
|
||||||
|
}) *MockPrivilegeCache {
|
||||||
|
mock := &MockPrivilegeCache{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
||||||
139
internal/proxy/privilege/model.go
Normal file
139
internal/proxy/privilege/model.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package privilege
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/casbin/casbin/v2"
|
||||||
|
"github.com/casbin/casbin/v2/model"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// sub -> role name, like admin, public
|
||||||
|
// obj -> contact object with object name, like Global-*, Collection-col1
|
||||||
|
// act -> privilege, like CreateCollection, DescribeCollection
|
||||||
|
ModelStr = `
|
||||||
|
[request_definition]
|
||||||
|
r = sub, obj, act
|
||||||
|
|
||||||
|
[policy_definition]
|
||||||
|
p = sub, obj, act
|
||||||
|
|
||||||
|
[policy_effect]
|
||||||
|
e = some(where (p.eft == allow))
|
||||||
|
|
||||||
|
[matchers]
|
||||||
|
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj))
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
enforcer *casbin.SyncedEnforcer
|
||||||
|
initOnce sync.Once
|
||||||
|
initPrivilegeGroupsOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetPolicyModel(modelString string) model.Model {
|
||||||
|
m, err := model.NewModelFromString(modelString)
|
||||||
|
if err != nil {
|
||||||
|
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitPrivilegeGroups() {
|
||||||
|
initPrivilegeGroupsOnce.Do(func() {
|
||||||
|
roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings()
|
||||||
|
if len(roGroup) == 0 {
|
||||||
|
roGroup = util.ReadOnlyPrivilegeGroup
|
||||||
|
}
|
||||||
|
roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} })
|
||||||
|
|
||||||
|
rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings()
|
||||||
|
if len(rwGroup) == 0 {
|
||||||
|
rwGroup = util.ReadWritePrivilegeGroup
|
||||||
|
}
|
||||||
|
rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} })
|
||||||
|
|
||||||
|
adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings()
|
||||||
|
if len(adminGroup) == 0 {
|
||||||
|
adminGroup = util.AdminPrivilegeGroup
|
||||||
|
}
|
||||||
|
adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEnforcer() *casbin.SyncedEnforcer {
|
||||||
|
initOnce.Do(func() {
|
||||||
|
e, err := casbin.NewSyncedEnforcer()
|
||||||
|
if err != nil {
|
||||||
|
log.Panic("failed to create casbin enforcer", zap.Error(err))
|
||||||
|
}
|
||||||
|
casbinModel := GetPolicyModel(ModelStr)
|
||||||
|
adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return GetPrivilegeCache() })
|
||||||
|
e.InitWithModelAndAdapter(casbinModel, adapter)
|
||||||
|
e.AddFunction("dbMatch", DBMatchFunc)
|
||||||
|
e.AddFunction("privilegeGroupContains", PrivilegeGroupContains)
|
||||||
|
enforcer = e
|
||||||
|
})
|
||||||
|
return enforcer
|
||||||
|
}
|
||||||
|
|
||||||
|
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
|
||||||
|
|
||||||
|
func DBMatchFunc(args ...interface{}) (interface{}, error) {
|
||||||
|
name1 := args[0].(string)
|
||||||
|
name2 := args[1].(string)
|
||||||
|
|
||||||
|
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
|
||||||
|
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
|
||||||
|
|
||||||
|
return db1 == db2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrivilegeGroupContains(args ...interface{}) (interface{}, error) {
|
||||||
|
requestPrivilege := args[0].(string)
|
||||||
|
policyPrivilege := args[1].(string)
|
||||||
|
requestObj := args[2].(string)
|
||||||
|
policyObj := args[3].(string)
|
||||||
|
|
||||||
|
switch policyPrivilege {
|
||||||
|
case commonpb.ObjectPrivilege_PrivilegeAll.String():
|
||||||
|
return true, nil
|
||||||
|
case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String():
|
||||||
|
// read only belong to collection object
|
||||||
|
if !collMatch(requestObj, policyObj) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
_, ok := roPrivileges[requestPrivilege]
|
||||||
|
return ok, nil
|
||||||
|
case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String():
|
||||||
|
// read write belong to collection object
|
||||||
|
if !collMatch(requestObj, policyObj) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
_, ok := rwPrivileges[requestPrivilege]
|
||||||
|
return ok, nil
|
||||||
|
case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String():
|
||||||
|
// admin belong to global object
|
||||||
|
_, ok := adminPrivileges[requestPrivilege]
|
||||||
|
return ok, nil
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collMatch(requestObj, policyObj string) bool {
|
||||||
|
_, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:])
|
||||||
|
_, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:])
|
||||||
|
|
||||||
|
return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2
|
||||||
|
}
|
||||||
@ -14,7 +14,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package proxy
|
package privilege
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -28,11 +28,11 @@ import (
|
|||||||
var (
|
var (
|
||||||
priCacheInitOnce sync.Once
|
priCacheInitOnce sync.Once
|
||||||
priCacheMut sync.RWMutex
|
priCacheMut sync.RWMutex
|
||||||
priCache *PrivilegeCache
|
priCache *resultCache
|
||||||
ver atomic.Int64
|
ver atomic.Int64
|
||||||
)
|
)
|
||||||
|
|
||||||
func getPriCache() *PrivilegeCache {
|
func getPriCache() *resultCache {
|
||||||
priCacheMut.RLock()
|
priCacheMut.RLock()
|
||||||
c := priCache
|
c := priCache
|
||||||
priCacheMut.RUnlock()
|
priCacheMut.RUnlock()
|
||||||
@ -41,7 +41,7 @@ func getPriCache() *PrivilegeCache {
|
|||||||
priCacheInitOnce.Do(func() {
|
priCacheInitOnce.Do(func() {
|
||||||
priCacheMut.Lock()
|
priCacheMut.Lock()
|
||||||
defer priCacheMut.Unlock()
|
defer priCacheMut.Unlock()
|
||||||
priCache = &PrivilegeCache{
|
priCache = &resultCache{
|
||||||
version: ver.Inc(),
|
version: ver.Inc(),
|
||||||
values: typeutil.ConcurrentMap[string, bool]{},
|
values: typeutil.ConcurrentMap[string, bool]{},
|
||||||
}
|
}
|
||||||
@ -57,20 +57,20 @@ func getPriCache() *PrivilegeCache {
|
|||||||
func CleanPrivilegeCache() {
|
func CleanPrivilegeCache() {
|
||||||
priCacheMut.Lock()
|
priCacheMut.Lock()
|
||||||
defer priCacheMut.Unlock()
|
defer priCacheMut.Unlock()
|
||||||
priCache = &PrivilegeCache{
|
priCache = &resultCache{
|
||||||
version: ver.Inc(),
|
version: ver.Inc(),
|
||||||
values: typeutil.ConcurrentMap[string, bool]{},
|
values: typeutil.ConcurrentMap[string, bool]{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPrivilegeCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) {
|
func GetResultCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) {
|
||||||
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
|
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
|
||||||
c := getPriCache()
|
c := getPriCache()
|
||||||
isPermit, cached = c.values.Get(key)
|
isPermit, cached = c.values.Get(key)
|
||||||
return isPermit, cached, c.version
|
return isPermit, cached, c.version
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool, version int64) {
|
func SetResultCache(roleName, object, objectPrivilege string, isPermit bool, version int64) {
|
||||||
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
|
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
|
||||||
c := getPriCache()
|
c := getPriCache()
|
||||||
if c.version == version {
|
if c.version == version {
|
||||||
@ -80,7 +80,7 @@ func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool,
|
|||||||
|
|
||||||
// PrivilegeCache is a cache for privilege enforce result
|
// PrivilegeCache is a cache for privilege enforce result
|
||||||
// version provides version control when any policy updates
|
// version provides version control when any policy updates
|
||||||
type PrivilegeCache struct {
|
type resultCache struct {
|
||||||
values typeutil.ConcurrentMap[string, bool]
|
values typeutil.ConcurrentMap[string, bool]
|
||||||
version int64
|
version int64
|
||||||
}
|
}
|
||||||
@ -14,7 +14,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package proxy
|
package privilege
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@ -32,9 +32,9 @@ func (s *PrivilegeCacheSuite) TearDownTest() {
|
|||||||
|
|
||||||
func (s *PrivilegeCacheSuite) TestGetPrivilege() {
|
func (s *PrivilegeCacheSuite) TestGetPrivilege() {
|
||||||
// get current version
|
// get current version
|
||||||
_, _, version := GetPrivilegeCache("", "", "")
|
_, _, version := GetResultCache("", "", "")
|
||||||
SetPrivilegeCache("test-role", "test-object", "read", true, version)
|
SetResultCache("test-role", "test-object", "read", true, version)
|
||||||
SetPrivilegeCache("test-role", "test-object", "delete", false, version)
|
SetResultCache("test-role", "test-object", "delete", false, version)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
tag string
|
tag string
|
||||||
@ -51,7 +51,7 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
s.Run(tc.tag, func() {
|
s.Run(tc.tag, func() {
|
||||||
isPermit, exists, _ := GetPrivilegeCache(tc.input[0], tc.input[1], tc.input[2])
|
isPermit, exists, _ := GetResultCache(tc.input[0], tc.input[1], tc.input[2])
|
||||||
s.Equal(tc.expectIsPermit, isPermit)
|
s.Equal(tc.expectIsPermit, isPermit)
|
||||||
s.Equal(tc.expectExists, exists)
|
s.Equal(tc.expectExists, exists)
|
||||||
})
|
})
|
||||||
@ -60,12 +60,12 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() {
|
|||||||
|
|
||||||
func (s *PrivilegeCacheSuite) TestSetPrivilegeVersion() {
|
func (s *PrivilegeCacheSuite) TestSetPrivilegeVersion() {
|
||||||
// get current version
|
// get current version
|
||||||
_, _, version := GetPrivilegeCache("", "", "")
|
_, _, version := GetResultCache("", "", "")
|
||||||
CleanPrivilegeCache()
|
CleanPrivilegeCache()
|
||||||
|
|
||||||
SetPrivilegeCache("test-role", "test-object", "read", true, version)
|
SetResultCache("test-role", "test-object", "read", true, version)
|
||||||
|
|
||||||
isPermit, exists, nextVersion := GetPrivilegeCache("test-role", "test-object", "read")
|
isPermit, exists, nextVersion := GetResultCache("test-role", "test-object", "read")
|
||||||
s.False(isPermit)
|
s.False(isPermit)
|
||||||
s.False(exists)
|
s.False(exists)
|
||||||
s.NotEqual(version, nextVersion)
|
s.NotEqual(version, nextVersion)
|
||||||
@ -4,12 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/casbin/casbin/v2"
|
"github.com/casbin/casbin/v2"
|
||||||
"github.com/casbin/casbin/v2/model"
|
|
||||||
"github.com/samber/lo"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@ -17,36 +14,15 @@ import (
|
|||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error)
|
type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error)
|
||||||
|
|
||||||
const (
|
|
||||||
// sub -> role name, like admin, public
|
|
||||||
// obj -> contact object with object name, like Global-*, Collection-col1
|
|
||||||
// act -> privilege, like CreateCollection, DescribeCollection
|
|
||||||
ModelStr = `
|
|
||||||
[request_definition]
|
|
||||||
r = sub, obj, act
|
|
||||||
|
|
||||||
[policy_definition]
|
|
||||||
p = sub, obj, act
|
|
||||||
|
|
||||||
[policy_effect]
|
|
||||||
e = some(where (p.eft == allow))
|
|
||||||
|
|
||||||
[matchers]
|
|
||||||
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj))
|
|
||||||
`
|
|
||||||
)
|
|
||||||
|
|
||||||
var templateModel = getPolicyModel(ModelStr)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
enforcer *casbin.SyncedEnforcer
|
enforcer *casbin.SyncedEnforcer
|
||||||
initOnce sync.Once
|
initOnce sync.Once
|
||||||
@ -55,55 +31,9 @@ var (
|
|||||||
|
|
||||||
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
|
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
|
||||||
|
|
||||||
func initPrivilegeGroups() {
|
|
||||||
initPrivilegeGroupsOnce.Do(func() {
|
|
||||||
roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings()
|
|
||||||
if len(roGroup) == 0 {
|
|
||||||
roGroup = util.ReadOnlyPrivilegeGroup
|
|
||||||
}
|
|
||||||
roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} })
|
|
||||||
|
|
||||||
rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings()
|
|
||||||
if len(rwGroup) == 0 {
|
|
||||||
rwGroup = util.ReadWritePrivilegeGroup
|
|
||||||
}
|
|
||||||
rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} })
|
|
||||||
|
|
||||||
adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings()
|
|
||||||
if len(adminGroup) == 0 {
|
|
||||||
adminGroup = util.AdminPrivilegeGroup
|
|
||||||
}
|
|
||||||
adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} })
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnforcer() *casbin.SyncedEnforcer {
|
|
||||||
initOnce.Do(func() {
|
|
||||||
e, err := casbin.NewSyncedEnforcer()
|
|
||||||
if err != nil {
|
|
||||||
log.Panic("failed to create casbin enforcer", zap.Error(err))
|
|
||||||
}
|
|
||||||
casbinModel := getPolicyModel(ModelStr)
|
|
||||||
adapter := NewMetaCacheCasbinAdapter(func() Cache { return globalMetaCache })
|
|
||||||
e.InitWithModelAndAdapter(casbinModel, adapter)
|
|
||||||
e.AddFunction("dbMatch", DBMatchFunc)
|
|
||||||
e.AddFunction("privilegeGroupContains", PrivilegeGroupContains)
|
|
||||||
enforcer = e
|
|
||||||
})
|
|
||||||
return enforcer
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPolicyModel(modelString string) model.Model {
|
|
||||||
m, err := model.NewModelFromString(modelString)
|
|
||||||
if err != nil {
|
|
||||||
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
|
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
|
||||||
func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor {
|
func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor {
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
newCtx, err := privilegeFunc(ctx, req)
|
newCtx, err := privilegeFunc(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -160,11 +90,11 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
|
|||||||
zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName),
|
zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName),
|
||||||
zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames))
|
zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames))
|
||||||
|
|
||||||
e := getEnforcer()
|
e := privilege.GetEnforcer()
|
||||||
for _, roleName := range roleNames {
|
for _, roleName := range roleNames {
|
||||||
permitFunc := func(objectName string) (bool, error) {
|
permitFunc := func(objectName string) (bool, error) {
|
||||||
object := funcutil.PolicyForResource(dbName, objectType, objectName)
|
object := funcutil.PolicyForResource(dbName, objectType, objectName)
|
||||||
isPermit, cached, version := GetPrivilegeCache(roleName, object, objectPrivilege)
|
isPermit, cached, version := privilege.GetResultCache(roleName, object, objectPrivilege)
|
||||||
if cached {
|
if cached {
|
||||||
return isPermit, nil
|
return isPermit, nil
|
||||||
}
|
}
|
||||||
@ -172,7 +102,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
SetPrivilegeCache(roleName, object, objectPrivilege, isPermit, version)
|
privilege.SetResultCache(roleName, object, objectPrivilege, isPermit, version)
|
||||||
return isPermit, nil
|
return isPermit, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,52 +167,3 @@ func isSelectMyRoleGrants(req interface{}, roleNames []string) bool {
|
|||||||
roleName := filterGrantEntity.GetRole().GetName()
|
roleName := filterGrantEntity.GetRole().GetName()
|
||||||
return funcutil.SliceContain(roleNames, roleName)
|
return funcutil.SliceContain(roleNames, roleName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DBMatchFunc(args ...interface{}) (interface{}, error) {
|
|
||||||
name1 := args[0].(string)
|
|
||||||
name2 := args[1].(string)
|
|
||||||
|
|
||||||
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
|
|
||||||
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
|
|
||||||
|
|
||||||
return db1 == db2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func collMatch(requestObj, policyObj string) bool {
|
|
||||||
_, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:])
|
|
||||||
_, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:])
|
|
||||||
|
|
||||||
return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrivilegeGroupContains(args ...interface{}) (interface{}, error) {
|
|
||||||
requestPrivilege := args[0].(string)
|
|
||||||
policyPrivilege := args[1].(string)
|
|
||||||
requestObj := args[2].(string)
|
|
||||||
policyObj := args[3].(string)
|
|
||||||
|
|
||||||
switch policyPrivilege {
|
|
||||||
case commonpb.ObjectPrivilege_PrivilegeAll.String():
|
|
||||||
return true, nil
|
|
||||||
case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String():
|
|
||||||
// read only belong to collection object
|
|
||||||
if !collMatch(requestObj, policyObj) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
_, ok := roPrivileges[requestPrivilege]
|
|
||||||
return ok, nil
|
|
||||||
case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String():
|
|
||||||
// read write belong to collection object
|
|
||||||
if !collMatch(requestObj, policyObj) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
_, ok := rwPrivileges[requestPrivilege]
|
|
||||||
return ok, nil
|
|
||||||
case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String():
|
|
||||||
// admin belong to global object
|
|
||||||
_, ok := adminPrivileges[requestPrivilege]
|
|
||||||
return ok, nil
|
|
||||||
default:
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
@ -167,7 +168,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
|||||||
g.Wait()
|
g.Wait()
|
||||||
|
|
||||||
assert.Panics(t, func() {
|
assert.Panics(t, func() {
|
||||||
getPolicyModel("foo")
|
privilege.GetPolicyModel("foo")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -268,7 +269,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("grant ReadOnly to single collection", func(t *testing.T) {
|
t.Run("grant ReadOnly to single collection", func(t *testing.T) {
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ctx = GetContext(context.Background(), "fooo:123456")
|
ctx = GetContext(context.Background(), "fooo:123456")
|
||||||
@ -287,7 +288,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
InitMetaCache(ctx, client, mgr)
|
InitMetaCache(ctx, client, mgr)
|
||||||
defer CleanPrivilegeCache()
|
defer privilege.CleanPrivilegeCache()
|
||||||
|
|
||||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||||
CollectionName: "coll1",
|
CollectionName: "coll1",
|
||||||
@ -325,7 +326,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("grant ReadOnly to all collection", func(t *testing.T) {
|
t.Run("grant ReadOnly to all collection", func(t *testing.T) {
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ctx = GetContext(context.Background(), "fooo:123456")
|
ctx = GetContext(context.Background(), "fooo:123456")
|
||||||
@ -344,7 +345,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
InitMetaCache(ctx, client, mgr)
|
InitMetaCache(ctx, client, mgr)
|
||||||
defer CleanPrivilegeCache()
|
defer privilege.CleanPrivilegeCache()
|
||||||
|
|
||||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||||
CollectionName: "coll1",
|
CollectionName: "coll1",
|
||||||
@ -382,7 +383,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("grant ReadWrite to single collection", func(t *testing.T) {
|
t.Run("grant ReadWrite to single collection", func(t *testing.T) {
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ctx = GetContext(context.Background(), "fooo:123456")
|
ctx = GetContext(context.Background(), "fooo:123456")
|
||||||
@ -401,7 +402,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
InitMetaCache(ctx, client, mgr)
|
InitMetaCache(ctx, client, mgr)
|
||||||
defer CleanPrivilegeCache()
|
defer privilege.CleanPrivilegeCache()
|
||||||
|
|
||||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||||
CollectionName: "coll1",
|
CollectionName: "coll1",
|
||||||
@ -485,7 +486,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("grant ReadWrite to all collection", func(t *testing.T) {
|
t.Run("grant ReadWrite to all collection", func(t *testing.T) {
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ctx = GetContext(context.Background(), "fooo:123456")
|
ctx = GetContext(context.Background(), "fooo:123456")
|
||||||
@ -504,7 +505,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
InitMetaCache(ctx, client, mgr)
|
InitMetaCache(ctx, client, mgr)
|
||||||
defer CleanPrivilegeCache()
|
defer privilege.CleanPrivilegeCache()
|
||||||
|
|
||||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||||
CollectionName: "coll1",
|
CollectionName: "coll1",
|
||||||
@ -552,7 +553,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Admin", func(t *testing.T) {
|
t.Run("Admin", func(t *testing.T) {
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ctx = GetContext(context.Background(), "fooo:123456")
|
ctx = GetContext(context.Background(), "fooo:123456")
|
||||||
@ -571,7 +572,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
InitMetaCache(ctx, client, mgr)
|
InitMetaCache(ctx, client, mgr)
|
||||||
defer CleanPrivilegeCache()
|
defer privilege.CleanPrivilegeCache()
|
||||||
|
|
||||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{})
|
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -593,7 +594,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
|||||||
func TestBuiltinPrivilegeGroup(t *testing.T) {
|
func TestBuiltinPrivilegeGroup(t *testing.T) {
|
||||||
t.Run("ClusterAdmin", func(t *testing.T) {
|
t.Run("ClusterAdmin", func(t *testing.T) {
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
|
||||||
initPrivilegeGroups()
|
privilege.InitPrivilegeGroups()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ctx := GetContext(context.Background(), "fooo:123456")
|
ctx := GetContext(context.Background(), "fooo:123456")
|
||||||
@ -615,7 +616,7 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
InitMetaCache(ctx, client, mgr)
|
InitMetaCache(ctx, client, mgr)
|
||||||
defer CleanPrivilegeCache()
|
defer privilege.CleanPrivilegeCache()
|
||||||
|
|
||||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{})
|
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@ -51,6 +51,7 @@ import (
|
|||||||
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
|
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
|
||||||
"github.com/milvus-io/milvus/internal/json"
|
"github.com/milvus-io/milvus/internal/json"
|
||||||
"github.com/milvus-io/milvus/internal/mocks"
|
"github.com/milvus-io/milvus/internal/mocks"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/util/componentutil"
|
"github.com/milvus-io/milvus/internal/util/componentutil"
|
||||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
@ -2837,7 +2838,7 @@ func TestProxy(t *testing.T) {
|
|||||||
getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq)
|
getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode())
|
assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode())
|
||||||
assert.True(t, passwordVerify(ctx, username, newPassword, globalMetaCache))
|
assert.True(t, passwordVerify(ctx, username, newPassword, privilege.GetPrivilegeCache()))
|
||||||
|
|
||||||
getCredentialReq.Username = "("
|
getCredentialReq.Username = "("
|
||||||
getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq)
|
getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq)
|
||||||
|
|||||||
@ -38,6 +38,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/json"
|
"github.com/milvus-io/milvus/internal/json"
|
||||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/analyzer"
|
"github.com/milvus-io/milvus/internal/util/analyzer"
|
||||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||||
@ -1464,14 +1465,15 @@ func AppendUserInfoForRPC(ctx context.Context) context.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetRole(username string) ([]string, error) {
|
func GetRole(username string) ([]string, error) {
|
||||||
if globalMetaCache == nil {
|
privCache := privilege.GetPrivilegeCache()
|
||||||
|
if privCache == nil {
|
||||||
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
||||||
}
|
}
|
||||||
return globalMetaCache.GetUserRole(username), nil
|
return privCache.GetUserRole(username), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
|
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
|
||||||
return passwordVerify(ctx, username, rawPwd, globalMetaCache)
|
return passwordVerify(ctx, username, rawPwd, privilege.GetPrivilegeCache())
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyAPIKey(rawToken string) (string, error) {
|
func VerifyAPIKey(rawToken string) (string, error) {
|
||||||
@ -1485,10 +1487,10 @@ func VerifyAPIKey(rawToken string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PasswordVerify verify password
|
// PasswordVerify verify password
|
||||||
func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool {
|
func passwordVerify(ctx context.Context, username, rawPwd string, privilegeCache privilege.PrivilegeCache) bool {
|
||||||
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
|
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
|
||||||
// meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache.
|
// meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache.
|
||||||
credInfo, err := globalMetaCache.GetCredentialInfo(ctx, username)
|
credInfo, err := privilege.GetPrivilegeCache().GetCredentialInfo(ctx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err))
|
log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err))
|
||||||
return false
|
return false
|
||||||
@ -1509,7 +1511,7 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach
|
|||||||
// update cache after miss cache
|
// update cache after miss cache
|
||||||
credInfo.Sha256Password = sha256Pwd
|
credInfo.Sha256Password = sha256Pwd
|
||||||
log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo))
|
log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo))
|
||||||
globalMetaCache.UpdateCredential(credInfo)
|
privilegeCache.UpdateCredential(credInfo)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/json"
|
"github.com/milvus-io/milvus/internal/json"
|
||||||
"github.com/milvus-io/milvus/internal/mocks"
|
"github.com/milvus-io/milvus/internal/mocks"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
@ -789,19 +790,21 @@ func TestGetCurDBNameFromContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRole(t *testing.T) {
|
func TestGetRole(t *testing.T) {
|
||||||
globalMetaCache = nil
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
privilege.ResetPrivilegeCacheForTest()
|
||||||
_, err := GetRole("foo")
|
_, err := GetRole("foo")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
mockCache := NewMockCache(t)
|
|
||||||
mockCache.On("GetUserRole",
|
mixcoord := mocks.NewMockMixCoordClient(t)
|
||||||
mock.AnythingOfType("string"),
|
mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{
|
||||||
).Return(func(username string) []string {
|
Status: merr.Success(),
|
||||||
if username == "root" {
|
UserRoles: []string{"root/role1", "root/admin", "root/role2", "foo/role1"},
|
||||||
return []string{"role1", "admin", "role2"}
|
}, nil).Times(1)
|
||||||
}
|
|
||||||
return []string{"role1"}
|
privilege.InitPrivilegeCache(ctx, mixcoord)
|
||||||
})
|
|
||||||
globalMetaCache = mockCache
|
|
||||||
roles, err := GetRole("root")
|
roles, err := GetRole("root")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 3, len(roles))
|
assert.Equal(t, 3, len(roles))
|
||||||
@ -812,11 +815,12 @@ func TestGetRole(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPasswordVerify(t *testing.T) {
|
func TestPasswordVerify(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
username := "user-test00"
|
username := "user-test00"
|
||||||
password := "PasswordVerify"
|
password := "PasswordVerify"
|
||||||
|
|
||||||
// credential does not exist within cache
|
|
||||||
credCache := make(map[string]*internalpb.CredentialInfo, 0)
|
|
||||||
invokedCount := 0
|
invokedCount := 0
|
||||||
|
|
||||||
mockedRootCoord := NewMixCoordMock()
|
mockedRootCoord := NewMixCoordMock()
|
||||||
@ -825,34 +829,36 @@ func TestPasswordVerify(t *testing.T) {
|
|||||||
return nil, errors.New("get cred not found credential")
|
return nil, errors.New("get cred not found credential")
|
||||||
}
|
}
|
||||||
|
|
||||||
metaCache := &MetaCache{
|
privilege.InitPrivilegeCache(ctx, mockedRootCoord)
|
||||||
credMap: credCache,
|
privilegeCache := privilege.GetPrivilegeCache()
|
||||||
mixCoord: mockedRootCoord,
|
assert.False(t, passwordVerify(ctx, username, password, privilegeCache))
|
||||||
}
|
|
||||||
ret, ok := credCache[username]
|
|
||||||
assert.False(t, ok)
|
|
||||||
assert.Nil(t, ret)
|
|
||||||
assert.False(t, passwordVerify(context.TODO(), username, password, metaCache))
|
|
||||||
assert.Equal(t, 1, invokedCount)
|
assert.Equal(t, 1, invokedCount)
|
||||||
|
|
||||||
// Sha256Password has not been filled into cache during establish connection firstly
|
// Sha256Password has not been filled into cache during establish connection firstly
|
||||||
encryptedPwd, err := crypto.PasswordEncrypt(password)
|
encryptedPwd, err := crypto.PasswordEncrypt(password)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
credCache[username] = &internalpb.CredentialInfo{
|
privilegeCache.RemoveCredential(username)
|
||||||
Username: username,
|
mockedRootCoord.GetGetCredentialFunc = func(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) {
|
||||||
EncryptedPassword: encryptedPwd,
|
invokedCount++
|
||||||
|
return &rootcoordpb.GetCredentialResponse{
|
||||||
|
Status: merr.Success(),
|
||||||
|
Username: username,
|
||||||
|
Password: encryptedPwd,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
assert.True(t, passwordVerify(context.TODO(), username, password, metaCache))
|
|
||||||
ret, ok = credCache[username]
|
assert.True(t, passwordVerify(ctx, username, password, privilegeCache))
|
||||||
assert.True(t, ok)
|
|
||||||
|
ret, err := privilegeCache.GetCredentialInfo(ctx, username)
|
||||||
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, ret)
|
assert.NotNil(t, ret)
|
||||||
assert.Equal(t, username, ret.Username)
|
assert.Equal(t, username, ret.Username)
|
||||||
assert.NotNil(t, ret.Sha256Password)
|
assert.NotNil(t, ret.Sha256Password)
|
||||||
assert.Equal(t, 1, invokedCount)
|
assert.Equal(t, 2, invokedCount)
|
||||||
|
|
||||||
// Sha256Password already exists within cache
|
// Sha256Password already exists within cache
|
||||||
assert.True(t, passwordVerify(context.TODO(), username, password, metaCache))
|
assert.True(t, passwordVerify(ctx, username, password, privilegeCache))
|
||||||
assert.Equal(t, 1, invokedCount)
|
assert.Equal(t, 2, invokedCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_isCollectionIsLoaded(t *testing.T) {
|
func Test_isCollectionIsLoaded(t *testing.T) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user