enhance: Refactor privilege management by extracting privilege cache into separate package (#44762)

Related to #44761

This commit refactors the privilege management system in the proxy
component by:

1. **Separation of Concerns**: Extracts privilege-related functionality
from MetaCache into a dedicated `internal/proxy/privilege` package,
improving code organization and maintainability.

2. **New Package Structure**: Creates `internal/proxy/privilege/` with:
   - `cache.go`: Core privilege cache implementation (PrivilegeCache)
   - `result_cache.go`: Privilege enforcement result caching
   - `model.go`: Casbin model and policy enforcement functions
   - `meta_cache_adapter.go`: Casbin adapter for MetaCache integration
   - Corresponding test files and mock implementations

3. **MetaCache Simplification**: Removes privilege and credential
management methods from MetaCache interface and implementation:
   - Removed: GetCredentialInfo, RemoveCredential, UpdateCredential
- Removed: GetPrivilegeInfo, GetUserRole, RefreshPolicyInfo,
InitPolicyInfo
   - Deleted: meta_cache_adapter.go, privilege_cache.go and their tests

4. **Updated References**: Updates all callsites to use the new
privilegeCache global:
- Authentication interceptor now uses privilegeCache for password
verification
- Credential cache operations (InvalidateCredentialCache,
UpdateCredentialCache, UpdateCredential) now use privilegeCache
- Policy refresh operations (RefreshPolicyInfoCache) now use
privilegeCache
- Privilege interceptor uses new privilege.GetEnforcer() and privilege
result cache

5. **Improved API**: Renames cache functions for clarity:
   - GetPrivilegeCache → GetResultCache
   - SetPrivilegeCache → SetResultCache
   - CleanPrivilegeCache → CleanResultCache

This refactoring makes the codebase more modular, separates privilege
management concerns from general metadata caching, and provides a
clearer API for privilege enforcement operations.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2025-10-13 11:15:58 +08:00 committed by GitHub
parent 369c6eb206
commit f5f053f1d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 931 additions and 421 deletions

View File

@ -13,6 +13,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
@ -111,7 +112,7 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
} else { } else {
// username+password authentication // username+password authentication
username, password := parseMD(rawToken) username, password := parseMD(rawToken)
if !passwordVerify(ctx, username, password, globalMetaCache) { if !passwordVerify(ctx, username, password, privilege.GetPrivilegeCache()) {
log.Warn("fail to verify password", zap.String("username", username)) log.Warn("fail to verify password", zap.String("username", username))
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk // NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct") return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/crypto"
@ -27,7 +28,7 @@ func TestValidAuth(t *testing.T) {
if username == "" || password == "" { if username == "" || password == "" {
return false return false
} }
return passwordVerify(ctx, username, password, globalMetaCache) return passwordVerify(ctx, username, password, privilege.GetPrivilegeCache())
} }
ctx := context.Background() ctx := context.Background()

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/proxy/connection" "github.com/milvus-io/milvus/internal/proxy/connection"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/proxy/replicate" "github.com/milvus-io/milvus/internal/proxy/replicate"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/analyzer" "github.com/milvus-io/milvus/internal/util/analyzer"
@ -4689,8 +4690,9 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy
} }
username := request.Username username := request.Username
if globalMetaCache != nil { priCache := privilege.GetPrivilegeCache()
globalMetaCache.RemoveCredential(username) // no need to return error, though credential may be not cached if priCache != nil {
priCache.RemoveCredential(username) // no need to return error, though credential may be not cached
} }
log.Debug("complete to invalidate credential cache") log.Debug("complete to invalidate credential cache")
@ -4715,8 +4717,9 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U
Username: request.Username, Username: request.Username,
Sha256Password: request.Password, Sha256Password: request.Password,
} }
if globalMetaCache != nil { priCache := privilege.GetPrivilegeCache()
globalMetaCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached if priCache != nil {
priCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached
} }
log.Debug("complete to update credential cache") log.Debug("complete to update credential cache")
@ -4820,7 +4823,7 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
} }
} }
if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, globalMetaCache) { if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, privilege.GetPrivilegeCache()) {
err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername()) err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername())
return merr.Status(err), nil return merr.Status(err), nil
} }
@ -5347,8 +5350,9 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr
return merr.Status(err), nil return merr.Status(err), nil
} }
if globalMetaCache != nil { priCache := privilege.GetPrivilegeCache()
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{ if priCache != nil {
err := priCache.RefreshPolicyInfo(typeutil.CacheOp{
OpType: typeutil.CacheOpType(req.OpType), OpType: typeutil.CacheOpType(req.OpType),
OpKey: req.OpKey, OpKey: req.OpKey,
}) })

View File

@ -24,7 +24,6 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/cockroachdb/errors"
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
@ -32,6 +31,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
@ -39,11 +39,9 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/expr" "github.com/milvus-io/milvus/pkg/v2/util/expr"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord" "github.com/milvus-io/milvus/pkg/v2/util/timerecord"
@ -79,14 +77,14 @@ type Cache interface {
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string
// GetCredentialInfo operate credential cache // GetCredentialInfo operate credential cache
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) // GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
RemoveCredential(username string) // RemoveCredential(username string)
UpdateCredential(credInfo *internalpb.CredentialInfo) // UpdateCredential(credInfo *internalpb.CredentialInfo)
GetPrivilegeInfo(ctx context.Context) []string // GetPrivilegeInfo(ctx context.Context) []string
GetUserRole(username string) []string // GetUserRole(username string) []string
RefreshPolicyInfo(op typeutil.CacheOp) error // RefreshPolicyInfo(op typeutil.CacheOp) error
InitPolicyInfo(info []string, userRoles []string) // InitPolicyInfo(info []string, userRoles []string)
RemoveDatabase(ctx context.Context, database string) RemoveDatabase(ctx context.Context, database string)
HasDatabase(ctx context.Context, database string) bool HasDatabase(ctx context.Context, database string) bool
@ -382,14 +380,12 @@ func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr
} }
expr.Register("cache", globalMetaCache) expr.Register("cache", globalMetaCache)
// The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required. err = privilege.InitPrivilegeCache(ctx, mixCoord)
resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{}) if err != nil {
if err = merr.CheckRPCCall(resp, err); err != nil { log.Error("failed to init privilege cache", zap.Error(err))
log.Error("fail to init meta cache", zap.Error(err))
return err return err
} }
globalMetaCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
log.Info("success to init meta cache", zap.Strings("policy_infos", resp.PolicyInfos))
return nil return nil
} }
@ -902,55 +898,6 @@ func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID Uniq
return collNames return collNames
} }
// GetCredentialInfo returns the credential related to provided username
// If the cache missed, proxy will try to fetch from storage
func (m *MetaCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
m.credMut.RLock()
var credInfo *internalpb.CredentialInfo
credInfo, ok := m.credMap[username]
m.credMut.RUnlock()
if !ok {
req := &rootcoordpb.GetCredentialRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetCredential),
),
Username: username,
}
resp, err := m.mixCoord.GetCredential(ctx, req)
if err != nil {
return &internalpb.CredentialInfo{}, err
}
credInfo = &internalpb.CredentialInfo{
Username: resp.Username,
EncryptedPassword: resp.Password,
}
}
return credInfo, nil
}
func (m *MetaCache) RemoveCredential(username string) {
m.credMut.Lock()
defer m.credMut.Unlock()
// delete pair in credMap
delete(m.credMap, username)
}
func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
m.credMut.Lock()
defer m.credMut.Unlock()
username := credInfo.Username
_, ok := m.credMap[username]
if !ok {
m.credMap[username] = &internalpb.CredentialInfo{}
}
// Do not cache encrypted password content
m.credMap[username].Username = username
m.credMap[username].Sha256Password = credInfo.Sha256Password
}
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) { func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
method := "GetShard" method := "GetShard"
// check cache first // check cache first
@ -1127,131 +1074,6 @@ func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
} }
} }
func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) {
defer func() {
err := getEnforcer().LoadPolicy()
if err != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err))
}
CleanPrivilegeCache()
}()
m.mu.Lock()
defer m.mu.Unlock()
m.unsafeInitPolicyInfo(info, userRoles)
}
func (m *MetaCache) unsafeInitPolicyInfo(info []string, userRoles []string) {
m.privilegeInfos = util.StringSet(info)
for _, userRole := range userRoles {
user, role, err := funcutil.DecodeUserRoleCache(userRole)
if err != nil {
log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err))
continue
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
}
}
func (m *MetaCache) GetPrivilegeInfo(ctx context.Context) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.privilegeInfos)
}
func (m *MetaCache) GetUserRole(user string) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.userToRoles[user])
}
func (m *MetaCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) {
defer func() {
if err == nil {
le := getEnforcer().LoadPolicy()
if le != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le))
}
CleanPrivilegeCache()
}
}()
if op.OpType != typeutil.CacheRefresh {
m.mu.Lock()
defer m.mu.Unlock()
if op.OpKey == "" {
return errors.New("empty op key")
}
}
switch op.OpType {
case typeutil.CacheGrantPrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
m.privilegeInfos[key] = struct{}{}
}
case typeutil.CacheRevokePrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
delete(m.privilegeInfos, key)
}
case typeutil.CacheAddUserToRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
case typeutil.CacheRemoveUserFromRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] != nil {
delete(m.userToRoles[user], role)
}
case typeutil.CacheDeleteUser:
delete(m.userToRoles, op.OpKey)
case typeutil.CacheDropRole:
for user := range m.userToRoles {
delete(m.userToRoles[user], op.OpKey)
}
for policy := range m.privilegeInfos {
if funcutil.PolicyCheckerWithRole(policy, op.OpKey) {
delete(m.privilegeInfos, policy)
}
}
case typeutil.CacheRefresh:
resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{})
if err != nil {
log.Error("fail to init meta cache", zap.Error(err))
return err
}
if !merr.Ok(resp.GetStatus()) {
log.Error("fail to init meta cache",
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
zap.String("reason", resp.GetStatus().GetReason()))
return merr.Error(resp.Status)
}
m.mu.Lock()
defer m.mu.Unlock()
m.userToRoles = make(map[string]map[string]struct{})
m.privilegeInfos = make(map[string]struct{})
m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
default:
return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
return nil
}
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) { func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
log.Ctx(ctx).Debug("remove database", zap.String("name", database)) log.Ctx(ctx).Debug("remove database", zap.String("name", database))
m.mu.Lock() m.mu.Lock()

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
@ -1510,9 +1511,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
} }
err := InitMetaCache(context.Background(), client, mgr) err := InitMetaCache(context.Background(), client, mgr)
assert.NoError(t, err) assert.NoError(t, err)
policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background()) policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 3, len(policyInfos)) assert.Equal(t, 3, len(policyInfos))
roles := globalMetaCache.GetUserRole("foo") roles := privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Equal(t, 2, len(roles)) assert.Equal(t, 2, len(roles))
}) })
@ -1527,29 +1528,29 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
err := InitMetaCache(context.Background(), client, mgr) err := InitMetaCache(context.Background(), client, mgr)
assert.NoError(t, err) assert.NoError(t, err)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
assert.NoError(t, err) assert.NoError(t, err)
policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background()) policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 4, len(policyInfos)) assert.Equal(t, 4, len(policyInfos))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"})
assert.NoError(t, err) assert.NoError(t, err)
policyInfos = globalMetaCache.GetPrivilegeInfo(context.Background()) policyInfos = privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 3, len(policyInfos)) assert.Equal(t, 3, len(policyInfos))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
assert.NoError(t, err) assert.NoError(t, err)
roles := globalMetaCache.GetUserRole("foo") roles := privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Equal(t, 3, len(roles)) assert.Equal(t, 3, len(roles))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
assert.NoError(t, err) assert.NoError(t, err)
roles = globalMetaCache.GetUserRole("foo") roles = privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Equal(t, 2, len(roles)) assert.Equal(t, 2, len(roles))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""})
assert.Error(t, err) assert.Error(t, err)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -1568,18 +1569,18 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
err := InitMetaCache(context.Background(), client, mgr) err := InitMetaCache(context.Background(), client, mgr)
assert.NoError(t, err) assert.NoError(t, err)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
assert.NoError(t, err) assert.NoError(t, err)
roles := globalMetaCache.GetUserRole("foo") roles := privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Len(t, roles, 0) assert.Len(t, roles, 0)
roles = globalMetaCache.GetUserRole("foo2") roles = privilege.GetPrivilegeCache().GetUserRole("foo2")
assert.Len(t, roles, 2) assert.Len(t, roles, 2)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"})
assert.NoError(t, err) assert.NoError(t, err)
roles = globalMetaCache.GetUserRole("foo2") roles = privilege.GetPrivilegeCache().GetUserRole("foo2")
assert.Len(t, roles, 1) assert.Len(t, roles, 1)
assert.Equal(t, "role3", roles[0]) assert.Equal(t, "role3", roles[0])
@ -1590,9 +1591,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")}, UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")},
}, nil }, nil
} }
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh}) err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh})
assert.NoError(t, err) assert.NoError(t, err)
roles = globalMetaCache.GetUserRole("foo") roles = privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Len(t, roles, 2) assert.Len(t, roles, 2)
}) })
} }

View File

@ -22,24 +22,29 @@
package proxy package proxy
import ( import (
"context"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
func AddRootUserToAdminRole() { func AddRootUserToAdminRole() {
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")}) err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
func RemoveRootUserFromAdminRole() { func RemoveRootUserFromAdminRole() {
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")}) err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -55,6 +60,8 @@ func InitEmptyGlobalCache() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
privilege.InitPrivilegeCache(context.Background(), mixcoord)
} }
func SetGlobalMetaCache(metaCache *MetaCache) { func SetGlobalMetaCache(metaCache *MetaCache) {

View File

@ -0,0 +1,7 @@
reviewers:
- congqixia
- czs007
- shaoting-huang
approvers:
- maintainers

View File

@ -0,0 +1,3 @@
# Summary
this package contains privilege related components for proxy.

View File

@ -0,0 +1,267 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package privilege
import (
"context"
"fmt"
"sync"
"sync/atomic"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var cacheInst atomic.Pointer[privilegeCache]
func GetPrivilegeCache() *privilegeCache {
return cacheInst.Load()
}
type PrivilegeCache interface {
// GetCredentialInfo operate credential cache
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
RemoveCredential(username string)
UpdateCredential(credInfo *internalpb.CredentialInfo)
GetPrivilegeInfo(ctx context.Context) []string
GetUserRole(username string) []string
RefreshPolicyInfo(op typeutil.CacheOp) error
InitPolicyInfo(info []string, userRoles []string)
}
var _ PrivilegeCache = (*privilegeCache)(nil)
type privilegeCache struct {
mixCoord types.MixCoordClient
mu sync.RWMutex
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
credMut sync.RWMutex
credMap map[string]*internalpb.CredentialInfo
}
func InitPrivilegeCache(ctx context.Context, mixCoord types.MixCoordClient) error {
privilegeCache := NewPrivilegeCache(mixCoord)
// The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required.
cacheInst.Store(privilegeCache)
resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{})
if err = merr.CheckRPCCall(resp, err); err != nil {
log.Error("fail to init meta cache", zap.Error(err))
return err
}
privilegeCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
log.Info("success to init privilege cache", zap.Strings("policy_infos", resp.PolicyInfos))
return nil
}
func NewPrivilegeCache(mixCoord types.MixCoordClient) *privilegeCache {
return &privilegeCache{
mixCoord: mixCoord,
privilegeInfos: make(map[string]struct{}),
userToRoles: make(map[string]map[string]struct{}),
credMap: make(map[string]*internalpb.CredentialInfo),
}
}
// GetCredentialInfo returns the credential related to provided username
// If the cache missed, proxy will try to fetch from storage
func (m *privilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
m.credMut.RLock()
var credInfo *internalpb.CredentialInfo
credInfo, ok := m.credMap[username]
m.credMut.RUnlock()
if !ok {
req := &rootcoordpb.GetCredentialRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetCredential),
),
Username: username,
}
resp, err := m.mixCoord.GetCredential(ctx, req)
if err != nil {
return &internalpb.CredentialInfo{}, err
}
credInfo = &internalpb.CredentialInfo{
Username: resp.Username,
EncryptedPassword: resp.Password,
}
}
return credInfo, nil
}
func (m *privilegeCache) RemoveCredential(username string) {
m.credMut.Lock()
defer m.credMut.Unlock()
// delete pair in credMap
delete(m.credMap, username)
}
func (m *privilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
m.credMut.Lock()
defer m.credMut.Unlock()
username := credInfo.Username
_, ok := m.credMap[username]
if !ok {
m.credMap[username] = &internalpb.CredentialInfo{}
}
// Do not cache encrypted password content
m.credMap[username].Username = username
m.credMap[username].Sha256Password = credInfo.Sha256Password
}
func (m *privilegeCache) InitPolicyInfo(info []string, userRoles []string) {
defer func() {
err := GetEnforcer().LoadPolicy()
if err != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err))
}
CleanPrivilegeCache()
}()
m.mu.Lock()
defer m.mu.Unlock()
m.unsafeInitPolicyInfo(info, userRoles)
}
func (m *privilegeCache) unsafeInitPolicyInfo(info []string, userRoles []string) {
m.privilegeInfos = util.StringSet(info)
for _, userRole := range userRoles {
user, role, err := funcutil.DecodeUserRoleCache(userRole)
if err != nil {
log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err))
continue
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
}
}
func (m *privilegeCache) GetPrivilegeInfo(ctx context.Context) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.privilegeInfos)
}
func (m *privilegeCache) GetUserRole(user string) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.userToRoles[user])
}
func (m *privilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) {
defer func() {
if err == nil {
le := GetEnforcer().LoadPolicy()
if le != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le))
}
CleanPrivilegeCache()
}
}()
if op.OpType != typeutil.CacheRefresh {
m.mu.Lock()
defer m.mu.Unlock()
if op.OpKey == "" {
return errors.New("empty op key")
}
}
switch op.OpType {
case typeutil.CacheGrantPrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
m.privilegeInfos[key] = struct{}{}
}
case typeutil.CacheRevokePrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
delete(m.privilegeInfos, key)
}
case typeutil.CacheAddUserToRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
case typeutil.CacheRemoveUserFromRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] != nil {
delete(m.userToRoles[user], role)
}
case typeutil.CacheDeleteUser:
delete(m.userToRoles, op.OpKey)
case typeutil.CacheDropRole:
for user := range m.userToRoles {
delete(m.userToRoles[user], op.OpKey)
}
for policy := range m.privilegeInfos {
if funcutil.PolicyCheckerWithRole(policy, op.OpKey) {
delete(m.privilegeInfos, policy)
}
}
case typeutil.CacheRefresh:
resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{})
if err != nil {
log.Error("fail to init meta cache", zap.Error(err))
return err
}
if !merr.Ok(resp.GetStatus()) {
log.Error("fail to init meta cache",
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
zap.String("reason", resp.GetStatus().GetReason()))
return merr.Error(resp.Status)
}
m.mu.Lock()
defer m.mu.Unlock()
m.userToRoles = make(map[string]map[string]struct{})
m.privilegeInfos = make(map[string]struct{})
m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
default:
return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
return nil
}

View File

@ -0,0 +1,27 @@
//go:build test
// +build test
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package privilege
// This file contains only functions used in tests to manipulate the privilege cache.
// ResetPrivilegeCacheForTest resets the privilege cache for testing purposes.
func ResetPrivilegeCacheForTest() {
cacheInst.Store(nil)
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package proxy package privilege
import ( import (
"context" "context"
@ -30,10 +30,10 @@ import (
// MetaCacheCasbinAdapter is the implementation of `persist.Adapter` with Cache // MetaCacheCasbinAdapter is the implementation of `persist.Adapter` with Cache
// Since the usage shall be read-only, it implements only `LoadPolicy` for now. // Since the usage shall be read-only, it implements only `LoadPolicy` for now.
type MetaCacheCasbinAdapter struct { type MetaCacheCasbinAdapter struct {
cacheSource func() Cache cacheSource func() PrivilegeCache
} }
func NewMetaCacheCasbinAdapter(cacheSource func() Cache) *MetaCacheCasbinAdapter { func NewMetaCacheCasbinAdapter(cacheSource func() PrivilegeCache) *MetaCacheCasbinAdapter {
return &MetaCacheCasbinAdapter{ return &MetaCacheCasbinAdapter{
cacheSource: cacheSource, cacheSource: cacheSource,
} }

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package proxy package privilege
import ( import (
"testing" "testing"
@ -26,36 +26,36 @@ import (
type MetaCacheCasbinAdapterSuite struct { type MetaCacheCasbinAdapterSuite struct {
suite.Suite suite.Suite
cache *MockCache cache *MockPrivilegeCache
adapter *MetaCacheCasbinAdapter adapter *MetaCacheCasbinAdapter
} }
func (s *MetaCacheCasbinAdapterSuite) SetupTest() { func (s *MetaCacheCasbinAdapterSuite) SetupTest() {
s.cache = NewMockCache(s.T()) s.cache = NewMockPrivilegeCache(s.T())
s.adapter = NewMetaCacheCasbinAdapter(func() Cache { return s.cache }) s.adapter = NewMetaCacheCasbinAdapter(func() PrivilegeCache { return s.cache })
} }
func (s *MetaCacheCasbinAdapterSuite) TestLoadPolicy() { func (s *MetaCacheCasbinAdapterSuite) TestLoadPolicy() {
s.Run("normal_load", func() { s.Run("normal_load", func() {
s.cache.EXPECT().GetPrivilegeInfo(mock.Anything).Return([]string{}) s.cache.EXPECT().GetPrivilegeInfo(mock.Anything).Return([]string{})
m := getPolicyModel(ModelStr) m := GetPolicyModel(ModelStr)
err := s.adapter.LoadPolicy(m) err := s.adapter.LoadPolicy(m)
s.NoError(err) s.NoError(err)
}) })
s.Run("source_return_nil", func() { s.Run("source_return_nil", func() {
adapter := NewMetaCacheCasbinAdapter(func() Cache { return nil }) adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return nil })
m := getPolicyModel(ModelStr) m := GetPolicyModel(ModelStr)
err := adapter.LoadPolicy(m) err := adapter.LoadPolicy(m)
s.Error(err) s.Error(err)
}) })
} }
func (s *MetaCacheCasbinAdapterSuite) TestSavePolicy() { func (s *MetaCacheCasbinAdapterSuite) TestSavePolicy() {
m := getPolicyModel(ModelStr) m := GetPolicyModel(ModelStr)
s.Error(s.adapter.SavePolicy(m)) s.Error(s.adapter.SavePolicy(m))
} }

View File

@ -0,0 +1,340 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package privilege
import (
context "context"
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
mock "github.com/stretchr/testify/mock"
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// MockPrivilegeCache is an autogenerated mock type for the PrivilegeCache type
type MockPrivilegeCache struct {
mock.Mock
}
type MockPrivilegeCache_Expecter struct {
mock *mock.Mock
}
func (_m *MockPrivilegeCache) EXPECT() *MockPrivilegeCache_Expecter {
return &MockPrivilegeCache_Expecter{mock: &_m.Mock}
}
// GetCredentialInfo provides a mock function with given fields: ctx, username
func (_m *MockPrivilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
ret := _m.Called(ctx, username)
if len(ret) == 0 {
panic("no return value specified for GetCredentialInfo")
}
var r0 *internalpb.CredentialInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
return rf(ctx, username)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
r0 = rf(ctx, username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.CredentialInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, username)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockPrivilegeCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo'
type MockPrivilegeCache_GetCredentialInfo_Call struct {
*mock.Call
}
// GetCredentialInfo is a helper method to define mock.On call
// - ctx context.Context
// - username string
func (_e *MockPrivilegeCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockPrivilegeCache_GetCredentialInfo_Call {
return &MockPrivilegeCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)}
}
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockPrivilegeCache_GetCredentialInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockPrivilegeCache_GetCredentialInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockPrivilegeCache_GetCredentialInfo_Call {
_c.Call.Return(run)
return _c
}
// GetPrivilegeInfo provides a mock function with given fields: ctx
func (_m *MockPrivilegeCache) GetPrivilegeInfo(ctx context.Context) []string {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetPrivilegeInfo")
}
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context) []string); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockPrivilegeCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo'
type MockPrivilegeCache_GetPrivilegeInfo_Call struct {
*mock.Call
}
// GetPrivilegeInfo is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockPrivilegeCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockPrivilegeCache_GetPrivilegeInfo_Call {
return &MockPrivilegeCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)}
}
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockPrivilegeCache_GetPrivilegeInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockPrivilegeCache_GetPrivilegeInfo_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockPrivilegeCache_GetPrivilegeInfo_Call {
_c.Call.Return(run)
return _c
}
// GetUserRole provides a mock function with given fields: username
func (_m *MockPrivilegeCache) GetUserRole(username string) []string {
ret := _m.Called(username)
if len(ret) == 0 {
panic("no return value specified for GetUserRole")
}
var r0 []string
if rf, ok := ret.Get(0).(func(string) []string); ok {
r0 = rf(username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockPrivilegeCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole'
type MockPrivilegeCache_GetUserRole_Call struct {
*mock.Call
}
// GetUserRole is a helper method to define mock.On call
// - username string
func (_e *MockPrivilegeCache_Expecter) GetUserRole(username interface{}) *MockPrivilegeCache_GetUserRole_Call {
return &MockPrivilegeCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)}
}
func (_c *MockPrivilegeCache_GetUserRole_Call) Run(run func(username string)) *MockPrivilegeCache_GetUserRole_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockPrivilegeCache_GetUserRole_Call) Return(_a0 []string) *MockPrivilegeCache_GetUserRole_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockPrivilegeCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockPrivilegeCache_GetUserRole_Call {
_c.Call.Return(run)
return _c
}
// InitPolicyInfo provides a mock function with given fields: info, userRoles
func (_m *MockPrivilegeCache) InitPolicyInfo(info []string, userRoles []string) {
_m.Called(info, userRoles)
}
// MockPrivilegeCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo'
type MockPrivilegeCache_InitPolicyInfo_Call struct {
*mock.Call
}
// InitPolicyInfo is a helper method to define mock.On call
// - info []string
// - userRoles []string
func (_e *MockPrivilegeCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockPrivilegeCache_InitPolicyInfo_Call {
return &MockPrivilegeCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)}
}
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockPrivilegeCache_InitPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].([]string))
})
return _c
}
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Return() *MockPrivilegeCache_InitPolicyInfo_Call {
_c.Call.Return()
return _c
}
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockPrivilegeCache_InitPolicyInfo_Call {
_c.Run(run)
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockPrivilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op)
if len(ret) == 0 {
panic("no return value specified for RefreshPolicyInfo")
}
var r0 error
if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok {
r0 = rf(op)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockPrivilegeCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo'
type MockPrivilegeCache_RefreshPolicyInfo_Call struct {
*mock.Call
}
// RefreshPolicyInfo is a helper method to define mock.On call
// - op typeutil.CacheOp
func (_e *MockPrivilegeCache_Expecter) RefreshPolicyInfo(op interface{}) *MockPrivilegeCache_RefreshPolicyInfo_Call {
return &MockPrivilegeCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)}
}
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockPrivilegeCache_RefreshPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(typeutil.CacheOp))
})
return _c
}
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockPrivilegeCache_RefreshPolicyInfo_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockPrivilegeCache_RefreshPolicyInfo_Call {
_c.Call.Return(run)
return _c
}
// RemoveCredential provides a mock function with given fields: username
func (_m *MockPrivilegeCache) RemoveCredential(username string) {
_m.Called(username)
}
// MockPrivilegeCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential'
type MockPrivilegeCache_RemoveCredential_Call struct {
*mock.Call
}
// RemoveCredential is a helper method to define mock.On call
// - username string
func (_e *MockPrivilegeCache_Expecter) RemoveCredential(username interface{}) *MockPrivilegeCache_RemoveCredential_Call {
return &MockPrivilegeCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)}
}
func (_c *MockPrivilegeCache_RemoveCredential_Call) Run(run func(username string)) *MockPrivilegeCache_RemoveCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockPrivilegeCache_RemoveCredential_Call) Return() *MockPrivilegeCache_RemoveCredential_Call {
_c.Call.Return()
return _c
}
func (_c *MockPrivilegeCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockPrivilegeCache_RemoveCredential_Call {
_c.Run(run)
return _c
}
// UpdateCredential provides a mock function with given fields: credInfo
func (_m *MockPrivilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
_m.Called(credInfo)
}
// MockPrivilegeCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential'
type MockPrivilegeCache_UpdateCredential_Call struct {
*mock.Call
}
// UpdateCredential is a helper method to define mock.On call
// - credInfo *internalpb.CredentialInfo
func (_e *MockPrivilegeCache_Expecter) UpdateCredential(credInfo interface{}) *MockPrivilegeCache_UpdateCredential_Call {
return &MockPrivilegeCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)}
}
func (_c *MockPrivilegeCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*internalpb.CredentialInfo))
})
return _c
}
func (_c *MockPrivilegeCache_UpdateCredential_Call) Return() *MockPrivilegeCache_UpdateCredential_Call {
_c.Call.Return()
return _c
}
func (_c *MockPrivilegeCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call {
_c.Run(run)
return _c
}
// NewMockPrivilegeCache creates a new instance of MockPrivilegeCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockPrivilegeCache(t interface {
mock.TestingT
Cleanup(func())
}) *MockPrivilegeCache {
mock := &MockPrivilegeCache{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,139 @@
package privilege
import (
"log"
"strings"
"sync"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
const (
// sub -> role name, like admin, public
// obj -> contact object with object name, like Global-*, Collection-col1
// act -> privilege, like CreateCollection, DescribeCollection
ModelStr = `
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj))
`
)
var (
enforcer *casbin.SyncedEnforcer
initOnce sync.Once
initPrivilegeGroupsOnce sync.Once
)
func GetPolicyModel(modelString string) model.Model {
m, err := model.NewModelFromString(modelString)
if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
}
return m
}
func InitPrivilegeGroups() {
initPrivilegeGroupsOnce.Do(func() {
roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings()
if len(roGroup) == 0 {
roGroup = util.ReadOnlyPrivilegeGroup
}
roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} })
rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings()
if len(rwGroup) == 0 {
rwGroup = util.ReadWritePrivilegeGroup
}
rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} })
adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings()
if len(adminGroup) == 0 {
adminGroup = util.AdminPrivilegeGroup
}
adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} })
})
}
func GetEnforcer() *casbin.SyncedEnforcer {
initOnce.Do(func() {
e, err := casbin.NewSyncedEnforcer()
if err != nil {
log.Panic("failed to create casbin enforcer", zap.Error(err))
}
casbinModel := GetPolicyModel(ModelStr)
adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return GetPrivilegeCache() })
e.InitWithModelAndAdapter(casbinModel, adapter)
e.AddFunction("dbMatch", DBMatchFunc)
e.AddFunction("privilegeGroupContains", PrivilegeGroupContains)
enforcer = e
})
return enforcer
}
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
func DBMatchFunc(args ...interface{}) (interface{}, error) {
name1 := args[0].(string)
name2 := args[1].(string)
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
return db1 == db2, nil
}
func PrivilegeGroupContains(args ...interface{}) (interface{}, error) {
requestPrivilege := args[0].(string)
policyPrivilege := args[1].(string)
requestObj := args[2].(string)
policyObj := args[3].(string)
switch policyPrivilege {
case commonpb.ObjectPrivilege_PrivilegeAll.String():
return true, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String():
// read only belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := roPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String():
// read write belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := rwPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String():
// admin belong to global object
_, ok := adminPrivileges[requestPrivilege]
return ok, nil
default:
return false, nil
}
}
func collMatch(requestObj, policyObj string) bool {
_, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:])
_, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:])
return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package proxy package privilege
import ( import (
"fmt" "fmt"
@ -28,11 +28,11 @@ import (
var ( var (
priCacheInitOnce sync.Once priCacheInitOnce sync.Once
priCacheMut sync.RWMutex priCacheMut sync.RWMutex
priCache *PrivilegeCache priCache *resultCache
ver atomic.Int64 ver atomic.Int64
) )
func getPriCache() *PrivilegeCache { func getPriCache() *resultCache {
priCacheMut.RLock() priCacheMut.RLock()
c := priCache c := priCache
priCacheMut.RUnlock() priCacheMut.RUnlock()
@ -41,7 +41,7 @@ func getPriCache() *PrivilegeCache {
priCacheInitOnce.Do(func() { priCacheInitOnce.Do(func() {
priCacheMut.Lock() priCacheMut.Lock()
defer priCacheMut.Unlock() defer priCacheMut.Unlock()
priCache = &PrivilegeCache{ priCache = &resultCache{
version: ver.Inc(), version: ver.Inc(),
values: typeutil.ConcurrentMap[string, bool]{}, values: typeutil.ConcurrentMap[string, bool]{},
} }
@ -57,20 +57,20 @@ func getPriCache() *PrivilegeCache {
func CleanPrivilegeCache() { func CleanPrivilegeCache() {
priCacheMut.Lock() priCacheMut.Lock()
defer priCacheMut.Unlock() defer priCacheMut.Unlock()
priCache = &PrivilegeCache{ priCache = &resultCache{
version: ver.Inc(), version: ver.Inc(),
values: typeutil.ConcurrentMap[string, bool]{}, values: typeutil.ConcurrentMap[string, bool]{},
} }
} }
func GetPrivilegeCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) { func GetResultCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) {
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege) key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
c := getPriCache() c := getPriCache()
isPermit, cached = c.values.Get(key) isPermit, cached = c.values.Get(key)
return isPermit, cached, c.version return isPermit, cached, c.version
} }
func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool, version int64) { func SetResultCache(roleName, object, objectPrivilege string, isPermit bool, version int64) {
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege) key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
c := getPriCache() c := getPriCache()
if c.version == version { if c.version == version {
@ -80,7 +80,7 @@ func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool,
// PrivilegeCache is a cache for privilege enforce result // PrivilegeCache is a cache for privilege enforce result
// version provides version control when any policy updates // version provides version control when any policy updates
type PrivilegeCache struct { type resultCache struct {
values typeutil.ConcurrentMap[string, bool] values typeutil.ConcurrentMap[string, bool]
version int64 version int64
} }

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package proxy package privilege
import ( import (
"testing" "testing"
@ -32,9 +32,9 @@ func (s *PrivilegeCacheSuite) TearDownTest() {
func (s *PrivilegeCacheSuite) TestGetPrivilege() { func (s *PrivilegeCacheSuite) TestGetPrivilege() {
// get current version // get current version
_, _, version := GetPrivilegeCache("", "", "") _, _, version := GetResultCache("", "", "")
SetPrivilegeCache("test-role", "test-object", "read", true, version) SetResultCache("test-role", "test-object", "read", true, version)
SetPrivilegeCache("test-role", "test-object", "delete", false, version) SetResultCache("test-role", "test-object", "delete", false, version)
type testCase struct { type testCase struct {
tag string tag string
@ -51,7 +51,7 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() {
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.tag, func() { s.Run(tc.tag, func() {
isPermit, exists, _ := GetPrivilegeCache(tc.input[0], tc.input[1], tc.input[2]) isPermit, exists, _ := GetResultCache(tc.input[0], tc.input[1], tc.input[2])
s.Equal(tc.expectIsPermit, isPermit) s.Equal(tc.expectIsPermit, isPermit)
s.Equal(tc.expectExists, exists) s.Equal(tc.expectExists, exists)
}) })
@ -60,12 +60,12 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() {
func (s *PrivilegeCacheSuite) TestSetPrivilegeVersion() { func (s *PrivilegeCacheSuite) TestSetPrivilegeVersion() {
// get current version // get current version
_, _, version := GetPrivilegeCache("", "", "") _, _, version := GetResultCache("", "", "")
CleanPrivilegeCache() CleanPrivilegeCache()
SetPrivilegeCache("test-role", "test-object", "read", true, version) SetResultCache("test-role", "test-object", "read", true, version)
isPermit, exists, nextVersion := GetPrivilegeCache("test-role", "test-object", "read") isPermit, exists, nextVersion := GetResultCache("test-role", "test-object", "read")
s.False(isPermit) s.False(isPermit)
s.False(exists) s.False(exists)
s.NotEqual(version, nextVersion) s.NotEqual(version, nextVersion)

View File

@ -4,12 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"sync" "sync"
"github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
"github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -17,36 +14,15 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error) type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error)
const (
// sub -> role name, like admin, public
// obj -> contact object with object name, like Global-*, Collection-col1
// act -> privilege, like CreateCollection, DescribeCollection
ModelStr = `
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj))
`
)
var templateModel = getPolicyModel(ModelStr)
var ( var (
enforcer *casbin.SyncedEnforcer enforcer *casbin.SyncedEnforcer
initOnce sync.Once initOnce sync.Once
@ -55,55 +31,9 @@ var (
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{} var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
func initPrivilegeGroups() {
initPrivilegeGroupsOnce.Do(func() {
roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings()
if len(roGroup) == 0 {
roGroup = util.ReadOnlyPrivilegeGroup
}
roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} })
rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings()
if len(rwGroup) == 0 {
rwGroup = util.ReadWritePrivilegeGroup
}
rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} })
adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings()
if len(adminGroup) == 0 {
adminGroup = util.AdminPrivilegeGroup
}
adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} })
})
}
func getEnforcer() *casbin.SyncedEnforcer {
initOnce.Do(func() {
e, err := casbin.NewSyncedEnforcer()
if err != nil {
log.Panic("failed to create casbin enforcer", zap.Error(err))
}
casbinModel := getPolicyModel(ModelStr)
adapter := NewMetaCacheCasbinAdapter(func() Cache { return globalMetaCache })
e.InitWithModelAndAdapter(casbinModel, adapter)
e.AddFunction("dbMatch", DBMatchFunc)
e.AddFunction("privilegeGroupContains", PrivilegeGroupContains)
enforcer = e
})
return enforcer
}
func getPolicyModel(modelString string) model.Model {
m, err := model.NewModelFromString(modelString)
if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
}
return m
}
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor { func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor {
initPrivilegeGroups() privilege.InitPrivilegeGroups()
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newCtx, err := privilegeFunc(ctx, req) newCtx, err := privilegeFunc(ctx, req)
if err != nil { if err != nil {
@ -160,11 +90,11 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName), zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName),
zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames)) zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames))
e := getEnforcer() e := privilege.GetEnforcer()
for _, roleName := range roleNames { for _, roleName := range roleNames {
permitFunc := func(objectName string) (bool, error) { permitFunc := func(objectName string) (bool, error) {
object := funcutil.PolicyForResource(dbName, objectType, objectName) object := funcutil.PolicyForResource(dbName, objectType, objectName)
isPermit, cached, version := GetPrivilegeCache(roleName, object, objectPrivilege) isPermit, cached, version := privilege.GetResultCache(roleName, object, objectPrivilege)
if cached { if cached {
return isPermit, nil return isPermit, nil
} }
@ -172,7 +102,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
if err != nil { if err != nil {
return false, err return false, err
} }
SetPrivilegeCache(roleName, object, objectPrivilege, isPermit, version) privilege.SetResultCache(roleName, object, objectPrivilege, isPermit, version)
return isPermit, nil return isPermit, nil
} }
@ -237,52 +167,3 @@ func isSelectMyRoleGrants(req interface{}, roleNames []string) bool {
roleName := filterGrantEntity.GetRole().GetName() roleName := filterGrantEntity.GetRole().GetName()
return funcutil.SliceContain(roleNames, roleName) return funcutil.SliceContain(roleNames, roleName)
} }
func DBMatchFunc(args ...interface{}) (interface{}, error) {
name1 := args[0].(string)
name2 := args[1].(string)
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
return db1 == db2, nil
}
func collMatch(requestObj, policyObj string) bool {
_, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:])
_, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:])
return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2
}
func PrivilegeGroupContains(args ...interface{}) (interface{}, error) {
requestPrivilege := args[0].(string)
policyPrivilege := args[1].(string)
requestObj := args[2].(string)
policyObj := args[3].(string)
switch policyPrivilege {
case commonpb.ObjectPrivilege_PrivilegeAll.String():
return true, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String():
// read only belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := roPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String():
// read write belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := rwPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String():
// admin belong to global object
_, ok := adminPrivileges[requestPrivilege]
return ok, nil
default:
return false, nil
}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -167,7 +168,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
g.Wait() g.Wait()
assert.Panics(t, func() { assert.Panics(t, func() {
getPolicyModel("foo") privilege.GetPolicyModel("foo")
}) })
}) })
} }
@ -268,7 +269,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadOnly to single collection", func(t *testing.T) { t.Run("grant ReadOnly to single collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups() privilege.InitPrivilegeGroups()
var err error var err error
ctx = GetContext(context.Background(), "fooo:123456") ctx = GetContext(context.Background(), "fooo:123456")
@ -287,7 +288,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil }, nil
} }
InitMetaCache(ctx, client, mgr) InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache() defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1", CollectionName: "coll1",
@ -325,7 +326,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadOnly to all collection", func(t *testing.T) { t.Run("grant ReadOnly to all collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups() privilege.InitPrivilegeGroups()
var err error var err error
ctx = GetContext(context.Background(), "fooo:123456") ctx = GetContext(context.Background(), "fooo:123456")
@ -344,7 +345,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil }, nil
} }
InitMetaCache(ctx, client, mgr) InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache() defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1", CollectionName: "coll1",
@ -382,7 +383,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadWrite to single collection", func(t *testing.T) { t.Run("grant ReadWrite to single collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups() privilege.InitPrivilegeGroups()
var err error var err error
ctx = GetContext(context.Background(), "fooo:123456") ctx = GetContext(context.Background(), "fooo:123456")
@ -401,7 +402,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil }, nil
} }
InitMetaCache(ctx, client, mgr) InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache() defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1", CollectionName: "coll1",
@ -485,7 +486,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadWrite to all collection", func(t *testing.T) { t.Run("grant ReadWrite to all collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups() privilege.InitPrivilegeGroups()
var err error var err error
ctx = GetContext(context.Background(), "fooo:123456") ctx = GetContext(context.Background(), "fooo:123456")
@ -504,7 +505,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil }, nil
} }
InitMetaCache(ctx, client, mgr) InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache() defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{ _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1", CollectionName: "coll1",
@ -552,7 +553,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("Admin", func(t *testing.T) { t.Run("Admin", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups() privilege.InitPrivilegeGroups()
var err error var err error
ctx = GetContext(context.Background(), "fooo:123456") ctx = GetContext(context.Background(), "fooo:123456")
@ -571,7 +572,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil }, nil
} }
InitMetaCache(ctx, client, mgr) InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache() defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{}) _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{})
assert.NoError(t, err) assert.NoError(t, err)
@ -593,7 +594,7 @@ func TestPrivilegeGroup(t *testing.T) {
func TestBuiltinPrivilegeGroup(t *testing.T) { func TestBuiltinPrivilegeGroup(t *testing.T) {
t.Run("ClusterAdmin", func(t *testing.T) { t.Run("ClusterAdmin", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups() privilege.InitPrivilegeGroups()
var err error var err error
ctx := GetContext(context.Background(), "fooo:123456") ctx := GetContext(context.Background(), "fooo:123456")
@ -615,7 +616,7 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
}, nil }, nil
} }
InitMetaCache(ctx, client, mgr) InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache() defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{}) _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{})
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -51,6 +51,7 @@ import (
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode" grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
"github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
@ -2837,7 +2838,7 @@ func TestProxy(t *testing.T) {
getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq) getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode())
assert.True(t, passwordVerify(ctx, username, newPassword, globalMetaCache)) assert.True(t, passwordVerify(ctx, username, newPassword, privilege.GetPrivilegeCache()))
getCredentialReq.Username = "(" getCredentialReq.Username = "("
getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq) getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq)

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/analyzer" "github.com/milvus-io/milvus/internal/util/analyzer"
"github.com/milvus-io/milvus/internal/util/function/embedding" "github.com/milvus-io/milvus/internal/util/function/embedding"
@ -1464,14 +1465,15 @@ func AppendUserInfoForRPC(ctx context.Context) context.Context {
} }
func GetRole(username string) ([]string, error) { func GetRole(username string) ([]string, error) {
if globalMetaCache == nil { privCache := privilege.GetPrivilegeCache()
if privCache == nil {
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
} }
return globalMetaCache.GetUserRole(username), nil return privCache.GetUserRole(username), nil
} }
func PasswordVerify(ctx context.Context, username, rawPwd string) bool { func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
return passwordVerify(ctx, username, rawPwd, globalMetaCache) return passwordVerify(ctx, username, rawPwd, privilege.GetPrivilegeCache())
} }
func VerifyAPIKey(rawToken string) (string, error) { func VerifyAPIKey(rawToken string) (string, error) {
@ -1485,10 +1487,10 @@ func VerifyAPIKey(rawToken string) (string, error) {
} }
// PasswordVerify verify password // PasswordVerify verify password
func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool { func passwordVerify(ctx context.Context, username, rawPwd string, privilegeCache privilege.PrivilegeCache) bool {
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection. // it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
// meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache. // meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache.
credInfo, err := globalMetaCache.GetCredentialInfo(ctx, username) credInfo, err := privilege.GetPrivilegeCache().GetCredentialInfo(ctx, username)
if err != nil { if err != nil {
log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err)) log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err))
return false return false
@ -1509,7 +1511,7 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach
// update cache after miss cache // update cache after miss cache
credInfo.Sha256Password = sha256Pwd credInfo.Sha256Password = sha256Pwd
log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo)) log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo))
globalMetaCache.UpdateCredential(credInfo) privilegeCache.UpdateCredential(credInfo)
return true return true
} }

View File

@ -37,6 +37,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/function/embedding" "github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
@ -789,19 +790,21 @@ func TestGetCurDBNameFromContext(t *testing.T) {
} }
func TestGetRole(t *testing.T) { func TestGetRole(t *testing.T) {
globalMetaCache = nil ctx, cancel := context.WithCancel(context.Background())
defer cancel()
privilege.ResetPrivilegeCacheForTest()
_, err := GetRole("foo") _, err := GetRole("foo")
assert.Error(t, err) assert.Error(t, err)
mockCache := NewMockCache(t)
mockCache.On("GetUserRole", mixcoord := mocks.NewMockMixCoordClient(t)
mock.AnythingOfType("string"), mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{
).Return(func(username string) []string { Status: merr.Success(),
if username == "root" { UserRoles: []string{"root/role1", "root/admin", "root/role2", "foo/role1"},
return []string{"role1", "admin", "role2"} }, nil).Times(1)
}
return []string{"role1"} privilege.InitPrivilegeCache(ctx, mixcoord)
})
globalMetaCache = mockCache
roles, err := GetRole("root") roles, err := GetRole("root")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 3, len(roles)) assert.Equal(t, 3, len(roles))
@ -812,11 +815,12 @@ func TestGetRole(t *testing.T) {
} }
func TestPasswordVerify(t *testing.T) { func TestPasswordVerify(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
username := "user-test00" username := "user-test00"
password := "PasswordVerify" password := "PasswordVerify"
// credential does not exist within cache
credCache := make(map[string]*internalpb.CredentialInfo, 0)
invokedCount := 0 invokedCount := 0
mockedRootCoord := NewMixCoordMock() mockedRootCoord := NewMixCoordMock()
@ -825,34 +829,36 @@ func TestPasswordVerify(t *testing.T) {
return nil, errors.New("get cred not found credential") return nil, errors.New("get cred not found credential")
} }
metaCache := &MetaCache{ privilege.InitPrivilegeCache(ctx, mockedRootCoord)
credMap: credCache, privilegeCache := privilege.GetPrivilegeCache()
mixCoord: mockedRootCoord, assert.False(t, passwordVerify(ctx, username, password, privilegeCache))
}
ret, ok := credCache[username]
assert.False(t, ok)
assert.Nil(t, ret)
assert.False(t, passwordVerify(context.TODO(), username, password, metaCache))
assert.Equal(t, 1, invokedCount) assert.Equal(t, 1, invokedCount)
// Sha256Password has not been filled into cache during establish connection firstly // Sha256Password has not been filled into cache during establish connection firstly
encryptedPwd, err := crypto.PasswordEncrypt(password) encryptedPwd, err := crypto.PasswordEncrypt(password)
assert.NoError(t, err) assert.NoError(t, err)
credCache[username] = &internalpb.CredentialInfo{ privilegeCache.RemoveCredential(username)
Username: username, mockedRootCoord.GetGetCredentialFunc = func(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) {
EncryptedPassword: encryptedPwd, invokedCount++
return &rootcoordpb.GetCredentialResponse{
Status: merr.Success(),
Username: username,
Password: encryptedPwd,
}, nil
} }
assert.True(t, passwordVerify(context.TODO(), username, password, metaCache))
ret, ok = credCache[username] assert.True(t, passwordVerify(ctx, username, password, privilegeCache))
assert.True(t, ok)
ret, err := privilegeCache.GetCredentialInfo(ctx, username)
assert.NoError(t, err)
assert.NotNil(t, ret) assert.NotNil(t, ret)
assert.Equal(t, username, ret.Username) assert.Equal(t, username, ret.Username)
assert.NotNil(t, ret.Sha256Password) assert.NotNil(t, ret.Sha256Password)
assert.Equal(t, 1, invokedCount) assert.Equal(t, 2, invokedCount)
// Sha256Password already exists within cache // Sha256Password already exists within cache
assert.True(t, passwordVerify(context.TODO(), username, password, metaCache)) assert.True(t, passwordVerify(ctx, username, password, privilegeCache))
assert.Equal(t, 1, invokedCount) assert.Equal(t, 2, invokedCount)
} }
func Test_isCollectionIsLoaded(t *testing.T) { func Test_isCollectionIsLoaded(t *testing.T) {