enhance: Refactor privilege management by extracting privilege cache into separate package (#44762)

Related to #44761

This commit refactors the privilege management system in the proxy
component by:

1. **Separation of Concerns**: Extracts privilege-related functionality
from MetaCache into a dedicated `internal/proxy/privilege` package,
improving code organization and maintainability.

2. **New Package Structure**: Creates `internal/proxy/privilege/` with:
   - `cache.go`: Core privilege cache implementation (PrivilegeCache)
   - `result_cache.go`: Privilege enforcement result caching
   - `model.go`: Casbin model and policy enforcement functions
   - `meta_cache_adapter.go`: Casbin adapter for MetaCache integration
   - Corresponding test files and mock implementations

3. **MetaCache Simplification**: Removes privilege and credential
management methods from MetaCache interface and implementation:
   - Removed: GetCredentialInfo, RemoveCredential, UpdateCredential
- Removed: GetPrivilegeInfo, GetUserRole, RefreshPolicyInfo,
InitPolicyInfo
   - Deleted: meta_cache_adapter.go, privilege_cache.go and their tests

4. **Updated References**: Updates all callsites to use the new
privilegeCache global:
- Authentication interceptor now uses privilegeCache for password
verification
- Credential cache operations (InvalidateCredentialCache,
UpdateCredentialCache, UpdateCredential) now use privilegeCache
- Policy refresh operations (RefreshPolicyInfoCache) now use
privilegeCache
- Privilege interceptor uses new privilege.GetEnforcer() and privilege
result cache

5. **Improved API**: Renames cache functions for clarity:
   - GetPrivilegeCache → GetResultCache
   - SetPrivilegeCache → SetResultCache
   - CleanPrivilegeCache → CleanResultCache

This refactoring makes the codebase more modular, separates privilege
management concerns from general metadata caching, and provides a
clearer API for privilege enforcement operations.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2025-10-13 11:15:58 +08:00 committed by GitHub
parent 369c6eb206
commit f5f053f1d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 931 additions and 421 deletions

View File

@ -13,6 +13,7 @@ import (
"google.golang.org/grpc/status"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
@ -111,7 +112,7 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
} else {
// username+password authentication
username, password := parseMD(rawToken)
if !passwordVerify(ctx, username, password, globalMetaCache) {
if !passwordVerify(ctx, username, password, privilege.GetPrivilegeCache()) {
log.Warn("fail to verify password", zap.String("username", username))
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
@ -27,7 +28,7 @@ func TestValidAuth(t *testing.T) {
if username == "" || password == "" {
return false
}
return passwordVerify(ctx, username, password, globalMetaCache)
return passwordVerify(ctx, username, password, privilege.GetPrivilegeCache())
}
ctx := context.Background()

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/proxy/connection"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/proxy/replicate"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/analyzer"
@ -4689,8 +4690,9 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy
}
username := request.Username
if globalMetaCache != nil {
globalMetaCache.RemoveCredential(username) // no need to return error, though credential may be not cached
priCache := privilege.GetPrivilegeCache()
if priCache != nil {
priCache.RemoveCredential(username) // no need to return error, though credential may be not cached
}
log.Debug("complete to invalidate credential cache")
@ -4715,8 +4717,9 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U
Username: request.Username,
Sha256Password: request.Password,
}
if globalMetaCache != nil {
globalMetaCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached
priCache := privilege.GetPrivilegeCache()
if priCache != nil {
priCache.UpdateCredential(credInfo) // no need to return error, though credential may be not cached
}
log.Debug("complete to update credential cache")
@ -4820,7 +4823,7 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
}
}
if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, globalMetaCache) {
if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, privilege.GetPrivilegeCache()) {
err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername())
return merr.Status(err), nil
}
@ -5347,8 +5350,9 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr
return merr.Status(err), nil
}
if globalMetaCache != nil {
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{
priCache := privilege.GetPrivilegeCache()
if priCache != nil {
err := priCache.RefreshPolicyInfo(typeutil.CacheOp{
OpType: typeutil.CacheOpType(req.OpType),
OpKey: req.OpKey,
})

View File

@ -24,7 +24,6 @@ import (
"strings"
"sync"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -32,6 +31,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -39,11 +39,9 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/expr"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
@ -79,14 +77,14 @@ type Cache interface {
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string
// GetCredentialInfo operate credential cache
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
RemoveCredential(username string)
UpdateCredential(credInfo *internalpb.CredentialInfo)
// GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
// RemoveCredential(username string)
// UpdateCredential(credInfo *internalpb.CredentialInfo)
GetPrivilegeInfo(ctx context.Context) []string
GetUserRole(username string) []string
RefreshPolicyInfo(op typeutil.CacheOp) error
InitPolicyInfo(info []string, userRoles []string)
// GetPrivilegeInfo(ctx context.Context) []string
// GetUserRole(username string) []string
// RefreshPolicyInfo(op typeutil.CacheOp) error
// InitPolicyInfo(info []string, userRoles []string)
RemoveDatabase(ctx context.Context, database string)
HasDatabase(ctx context.Context, database string) bool
@ -382,14 +380,12 @@ func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr
}
expr.Register("cache", globalMetaCache)
// The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required.
resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{})
if err = merr.CheckRPCCall(resp, err); err != nil {
log.Error("fail to init meta cache", zap.Error(err))
err = privilege.InitPrivilegeCache(ctx, mixCoord)
if err != nil {
log.Error("failed to init privilege cache", zap.Error(err))
return err
}
globalMetaCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
log.Info("success to init meta cache", zap.Strings("policy_infos", resp.PolicyInfos))
return nil
}
@ -902,55 +898,6 @@ func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID Uniq
return collNames
}
// GetCredentialInfo returns the credential related to provided username
// If the cache missed, proxy will try to fetch from storage
func (m *MetaCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
m.credMut.RLock()
var credInfo *internalpb.CredentialInfo
credInfo, ok := m.credMap[username]
m.credMut.RUnlock()
if !ok {
req := &rootcoordpb.GetCredentialRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetCredential),
),
Username: username,
}
resp, err := m.mixCoord.GetCredential(ctx, req)
if err != nil {
return &internalpb.CredentialInfo{}, err
}
credInfo = &internalpb.CredentialInfo{
Username: resp.Username,
EncryptedPassword: resp.Password,
}
}
return credInfo, nil
}
func (m *MetaCache) RemoveCredential(username string) {
m.credMut.Lock()
defer m.credMut.Unlock()
// delete pair in credMap
delete(m.credMap, username)
}
func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
m.credMut.Lock()
defer m.credMut.Unlock()
username := credInfo.Username
_, ok := m.credMap[username]
if !ok {
m.credMap[username] = &internalpb.CredentialInfo{}
}
// Do not cache encrypted password content
m.credMap[username].Username = username
m.credMap[username].Sha256Password = credInfo.Sha256Password
}
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
method := "GetShard"
// check cache first
@ -1127,131 +1074,6 @@ func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
}
}
func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) {
defer func() {
err := getEnforcer().LoadPolicy()
if err != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err))
}
CleanPrivilegeCache()
}()
m.mu.Lock()
defer m.mu.Unlock()
m.unsafeInitPolicyInfo(info, userRoles)
}
func (m *MetaCache) unsafeInitPolicyInfo(info []string, userRoles []string) {
m.privilegeInfos = util.StringSet(info)
for _, userRole := range userRoles {
user, role, err := funcutil.DecodeUserRoleCache(userRole)
if err != nil {
log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err))
continue
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
}
}
func (m *MetaCache) GetPrivilegeInfo(ctx context.Context) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.privilegeInfos)
}
func (m *MetaCache) GetUserRole(user string) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.userToRoles[user])
}
func (m *MetaCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) {
defer func() {
if err == nil {
le := getEnforcer().LoadPolicy()
if le != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le))
}
CleanPrivilegeCache()
}
}()
if op.OpType != typeutil.CacheRefresh {
m.mu.Lock()
defer m.mu.Unlock()
if op.OpKey == "" {
return errors.New("empty op key")
}
}
switch op.OpType {
case typeutil.CacheGrantPrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
m.privilegeInfos[key] = struct{}{}
}
case typeutil.CacheRevokePrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
delete(m.privilegeInfos, key)
}
case typeutil.CacheAddUserToRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
case typeutil.CacheRemoveUserFromRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] != nil {
delete(m.userToRoles[user], role)
}
case typeutil.CacheDeleteUser:
delete(m.userToRoles, op.OpKey)
case typeutil.CacheDropRole:
for user := range m.userToRoles {
delete(m.userToRoles[user], op.OpKey)
}
for policy := range m.privilegeInfos {
if funcutil.PolicyCheckerWithRole(policy, op.OpKey) {
delete(m.privilegeInfos, policy)
}
}
case typeutil.CacheRefresh:
resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{})
if err != nil {
log.Error("fail to init meta cache", zap.Error(err))
return err
}
if !merr.Ok(resp.GetStatus()) {
log.Error("fail to init meta cache",
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
zap.String("reason", resp.GetStatus().GetReason()))
return merr.Error(resp.Status)
}
m.mu.Lock()
defer m.mu.Unlock()
m.userToRoles = make(map[string]map[string]struct{})
m.privilegeInfos = make(map[string]struct{})
m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
default:
return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
return nil
}
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
log.Ctx(ctx).Debug("remove database", zap.String("name", database))
m.mu.Lock()

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
@ -1510,9 +1511,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
}
err := InitMetaCache(context.Background(), client, mgr)
assert.NoError(t, err)
policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background())
policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 3, len(policyInfos))
roles := globalMetaCache.GetUserRole("foo")
roles := privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Equal(t, 2, len(roles))
})
@ -1527,29 +1528,29 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
err := InitMetaCache(context.Background(), client, mgr)
assert.NoError(t, err)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
assert.NoError(t, err)
policyInfos := globalMetaCache.GetPrivilegeInfo(context.Background())
policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 4, len(policyInfos))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRevokePrivilege, OpKey: "policyX"})
assert.NoError(t, err)
policyInfos = globalMetaCache.GetPrivilegeInfo(context.Background())
policyInfos = privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 3, len(policyInfos))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
assert.NoError(t, err)
roles := globalMetaCache.GetUserRole("foo")
roles := privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Equal(t, 3, len(roles))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("foo", "role3")})
assert.NoError(t, err)
roles = globalMetaCache.GetUserRole("foo")
roles = privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Equal(t, 2, len(roles))
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: ""})
assert.Error(t, err)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: 100, OpKey: "policyX"})
assert.Error(t, err)
})
@ -1568,18 +1569,18 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
err := InitMetaCache(context.Background(), client, mgr)
assert.NoError(t, err)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
assert.NoError(t, err)
roles := globalMetaCache.GetUserRole("foo")
roles := privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Len(t, roles, 0)
roles = globalMetaCache.GetUserRole("foo2")
roles = privilege.GetPrivilegeCache().GetUserRole("foo2")
assert.Len(t, roles, 2)
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDropRole, OpKey: "role2"})
assert.NoError(t, err)
roles = globalMetaCache.GetUserRole("foo2")
roles = privilege.GetPrivilegeCache().GetUserRole("foo2")
assert.Len(t, roles, 1)
assert.Equal(t, "role3", roles[0])
@ -1590,9 +1591,9 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")},
}, nil
}
err = globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh})
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRefresh})
assert.NoError(t, err)
roles = globalMetaCache.GetUserRole("foo")
roles = privilege.GetPrivilegeCache().GetUserRole("foo")
assert.Len(t, roles, 2)
})
}

View File

@ -22,24 +22,29 @@
package proxy
import (
"context"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func AddRootUserToAdminRole() {
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheAddUserToRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
if err != nil {
panic(err)
}
}
func RemoveRootUserFromAdminRole() {
err := globalMetaCache.RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
err := privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheRemoveUserFromRole, OpKey: funcutil.EncodeUserRoleCache("root", "admin")})
if err != nil {
panic(err)
}
@ -55,6 +60,8 @@ func InitEmptyGlobalCache() {
if err != nil {
panic(err)
}
mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
privilege.InitPrivilegeCache(context.Background(), mixcoord)
}
func SetGlobalMetaCache(metaCache *MetaCache) {

View File

@ -0,0 +1,7 @@
reviewers:
- congqixia
- czs007
- shaoting-huang
approvers:
- maintainers

View File

@ -0,0 +1,3 @@
# Summary
this package contains privilege related components for proxy.

View File

@ -0,0 +1,267 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package privilege
import (
"context"
"fmt"
"sync"
"sync/atomic"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var cacheInst atomic.Pointer[privilegeCache]
func GetPrivilegeCache() *privilegeCache {
return cacheInst.Load()
}
type PrivilegeCache interface {
// GetCredentialInfo operate credential cache
GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
RemoveCredential(username string)
UpdateCredential(credInfo *internalpb.CredentialInfo)
GetPrivilegeInfo(ctx context.Context) []string
GetUserRole(username string) []string
RefreshPolicyInfo(op typeutil.CacheOp) error
InitPolicyInfo(info []string, userRoles []string)
}
var _ PrivilegeCache = (*privilegeCache)(nil)
type privilegeCache struct {
mixCoord types.MixCoordClient
mu sync.RWMutex
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
credMut sync.RWMutex
credMap map[string]*internalpb.CredentialInfo
}
func InitPrivilegeCache(ctx context.Context, mixCoord types.MixCoordClient) error {
privilegeCache := NewPrivilegeCache(mixCoord)
// The privilege info is a little more. And to get this info, the query operation of involving multiple table queries is required.
cacheInst.Store(privilegeCache)
resp, err := mixCoord.ListPolicy(ctx, &internalpb.ListPolicyRequest{})
if err = merr.CheckRPCCall(resp, err); err != nil {
log.Error("fail to init meta cache", zap.Error(err))
return err
}
privilegeCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
log.Info("success to init privilege cache", zap.Strings("policy_infos", resp.PolicyInfos))
return nil
}
func NewPrivilegeCache(mixCoord types.MixCoordClient) *privilegeCache {
return &privilegeCache{
mixCoord: mixCoord,
privilegeInfos: make(map[string]struct{}),
userToRoles: make(map[string]map[string]struct{}),
credMap: make(map[string]*internalpb.CredentialInfo),
}
}
// GetCredentialInfo returns the credential related to provided username
// If the cache missed, proxy will try to fetch from storage
func (m *privilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
m.credMut.RLock()
var credInfo *internalpb.CredentialInfo
credInfo, ok := m.credMap[username]
m.credMut.RUnlock()
if !ok {
req := &rootcoordpb.GetCredentialRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetCredential),
),
Username: username,
}
resp, err := m.mixCoord.GetCredential(ctx, req)
if err != nil {
return &internalpb.CredentialInfo{}, err
}
credInfo = &internalpb.CredentialInfo{
Username: resp.Username,
EncryptedPassword: resp.Password,
}
}
return credInfo, nil
}
func (m *privilegeCache) RemoveCredential(username string) {
m.credMut.Lock()
defer m.credMut.Unlock()
// delete pair in credMap
delete(m.credMap, username)
}
func (m *privilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
m.credMut.Lock()
defer m.credMut.Unlock()
username := credInfo.Username
_, ok := m.credMap[username]
if !ok {
m.credMap[username] = &internalpb.CredentialInfo{}
}
// Do not cache encrypted password content
m.credMap[username].Username = username
m.credMap[username].Sha256Password = credInfo.Sha256Password
}
func (m *privilegeCache) InitPolicyInfo(info []string, userRoles []string) {
defer func() {
err := GetEnforcer().LoadPolicy()
if err != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(err))
}
CleanPrivilegeCache()
}()
m.mu.Lock()
defer m.mu.Unlock()
m.unsafeInitPolicyInfo(info, userRoles)
}
func (m *privilegeCache) unsafeInitPolicyInfo(info []string, userRoles []string) {
m.privilegeInfos = util.StringSet(info)
for _, userRole := range userRoles {
user, role, err := funcutil.DecodeUserRoleCache(userRole)
if err != nil {
log.Warn("invalid user-role key", zap.String("user-role", userRole), zap.Error(err))
continue
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
}
}
func (m *privilegeCache) GetPrivilegeInfo(ctx context.Context) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.privilegeInfos)
}
func (m *privilegeCache) GetUserRole(user string) []string {
m.mu.RLock()
defer m.mu.RUnlock()
return util.StringList(m.userToRoles[user])
}
func (m *privilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) (err error) {
defer func() {
if err == nil {
le := GetEnforcer().LoadPolicy()
if le != nil {
log.Error("failed to load policy after RefreshPolicyInfo", zap.Error(le))
}
CleanPrivilegeCache()
}
}()
if op.OpType != typeutil.CacheRefresh {
m.mu.Lock()
defer m.mu.Unlock()
if op.OpKey == "" {
return errors.New("empty op key")
}
}
switch op.OpType {
case typeutil.CacheGrantPrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
m.privilegeInfos[key] = struct{}{}
}
case typeutil.CacheRevokePrivilege:
keys := funcutil.PrivilegesForPolicy(op.OpKey)
for _, key := range keys {
delete(m.privilegeInfos, key)
}
case typeutil.CacheAddUserToRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] == nil {
m.userToRoles[user] = make(map[string]struct{})
}
m.userToRoles[user][role] = struct{}{}
case typeutil.CacheRemoveUserFromRole:
user, role, err := funcutil.DecodeUserRoleCache(op.OpKey)
if err != nil {
return fmt.Errorf("invalid opKey, fail to decode, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
if m.userToRoles[user] != nil {
delete(m.userToRoles[user], role)
}
case typeutil.CacheDeleteUser:
delete(m.userToRoles, op.OpKey)
case typeutil.CacheDropRole:
for user := range m.userToRoles {
delete(m.userToRoles[user], op.OpKey)
}
for policy := range m.privilegeInfos {
if funcutil.PolicyCheckerWithRole(policy, op.OpKey) {
delete(m.privilegeInfos, policy)
}
}
case typeutil.CacheRefresh:
resp, err := m.mixCoord.ListPolicy(context.Background(), &internalpb.ListPolicyRequest{})
if err != nil {
log.Error("fail to init meta cache", zap.Error(err))
return err
}
if !merr.Ok(resp.GetStatus()) {
log.Error("fail to init meta cache",
zap.String("error_code", resp.GetStatus().GetErrorCode().String()),
zap.String("reason", resp.GetStatus().GetReason()))
return merr.Error(resp.Status)
}
m.mu.Lock()
defer m.mu.Unlock()
m.userToRoles = make(map[string]map[string]struct{})
m.privilegeInfos = make(map[string]struct{})
m.unsafeInitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
default:
return fmt.Errorf("invalid opType, op_type: %d, op_key: %s", int(op.OpType), op.OpKey)
}
return nil
}

View File

@ -0,0 +1,27 @@
//go:build test
// +build test
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package privilege
// This file contains only functions used in tests to manipulate the privilege cache.
// ResetPrivilegeCacheForTest resets the privilege cache for testing purposes.
func ResetPrivilegeCacheForTest() {
cacheInst.Store(nil)
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package privilege
import (
"context"
@ -30,10 +30,10 @@ import (
// MetaCacheCasbinAdapter is the implementation of `persist.Adapter` with Cache
// Since the usage shall be read-only, it implements only `LoadPolicy` for now.
type MetaCacheCasbinAdapter struct {
cacheSource func() Cache
cacheSource func() PrivilegeCache
}
func NewMetaCacheCasbinAdapter(cacheSource func() Cache) *MetaCacheCasbinAdapter {
func NewMetaCacheCasbinAdapter(cacheSource func() PrivilegeCache) *MetaCacheCasbinAdapter {
return &MetaCacheCasbinAdapter{
cacheSource: cacheSource,
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package privilege
import (
"testing"
@ -26,36 +26,36 @@ import (
type MetaCacheCasbinAdapterSuite struct {
suite.Suite
cache *MockCache
cache *MockPrivilegeCache
adapter *MetaCacheCasbinAdapter
}
func (s *MetaCacheCasbinAdapterSuite) SetupTest() {
s.cache = NewMockCache(s.T())
s.cache = NewMockPrivilegeCache(s.T())
s.adapter = NewMetaCacheCasbinAdapter(func() Cache { return s.cache })
s.adapter = NewMetaCacheCasbinAdapter(func() PrivilegeCache { return s.cache })
}
func (s *MetaCacheCasbinAdapterSuite) TestLoadPolicy() {
s.Run("normal_load", func() {
s.cache.EXPECT().GetPrivilegeInfo(mock.Anything).Return([]string{})
m := getPolicyModel(ModelStr)
m := GetPolicyModel(ModelStr)
err := s.adapter.LoadPolicy(m)
s.NoError(err)
})
s.Run("source_return_nil", func() {
adapter := NewMetaCacheCasbinAdapter(func() Cache { return nil })
adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return nil })
m := getPolicyModel(ModelStr)
m := GetPolicyModel(ModelStr)
err := adapter.LoadPolicy(m)
s.Error(err)
})
}
func (s *MetaCacheCasbinAdapterSuite) TestSavePolicy() {
m := getPolicyModel(ModelStr)
m := GetPolicyModel(ModelStr)
s.Error(s.adapter.SavePolicy(m))
}

View File

@ -0,0 +1,340 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package privilege
import (
context "context"
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
mock "github.com/stretchr/testify/mock"
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// MockPrivilegeCache is an autogenerated mock type for the PrivilegeCache type
type MockPrivilegeCache struct {
mock.Mock
}
type MockPrivilegeCache_Expecter struct {
mock *mock.Mock
}
func (_m *MockPrivilegeCache) EXPECT() *MockPrivilegeCache_Expecter {
return &MockPrivilegeCache_Expecter{mock: &_m.Mock}
}
// GetCredentialInfo provides a mock function with given fields: ctx, username
func (_m *MockPrivilegeCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
ret := _m.Called(ctx, username)
if len(ret) == 0 {
panic("no return value specified for GetCredentialInfo")
}
var r0 *internalpb.CredentialInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
return rf(ctx, username)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
r0 = rf(ctx, username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.CredentialInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, username)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockPrivilegeCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo'
type MockPrivilegeCache_GetCredentialInfo_Call struct {
*mock.Call
}
// GetCredentialInfo is a helper method to define mock.On call
// - ctx context.Context
// - username string
func (_e *MockPrivilegeCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockPrivilegeCache_GetCredentialInfo_Call {
return &MockPrivilegeCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)}
}
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockPrivilegeCache_GetCredentialInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockPrivilegeCache_GetCredentialInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockPrivilegeCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockPrivilegeCache_GetCredentialInfo_Call {
_c.Call.Return(run)
return _c
}
// GetPrivilegeInfo provides a mock function with given fields: ctx
func (_m *MockPrivilegeCache) GetPrivilegeInfo(ctx context.Context) []string {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetPrivilegeInfo")
}
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context) []string); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockPrivilegeCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo'
type MockPrivilegeCache_GetPrivilegeInfo_Call struct {
*mock.Call
}
// GetPrivilegeInfo is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockPrivilegeCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockPrivilegeCache_GetPrivilegeInfo_Call {
return &MockPrivilegeCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)}
}
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockPrivilegeCache_GetPrivilegeInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockPrivilegeCache_GetPrivilegeInfo_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockPrivilegeCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockPrivilegeCache_GetPrivilegeInfo_Call {
_c.Call.Return(run)
return _c
}
// GetUserRole provides a mock function with given fields: username
func (_m *MockPrivilegeCache) GetUserRole(username string) []string {
ret := _m.Called(username)
if len(ret) == 0 {
panic("no return value specified for GetUserRole")
}
var r0 []string
if rf, ok := ret.Get(0).(func(string) []string); ok {
r0 = rf(username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockPrivilegeCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole'
type MockPrivilegeCache_GetUserRole_Call struct {
*mock.Call
}
// GetUserRole is a helper method to define mock.On call
// - username string
func (_e *MockPrivilegeCache_Expecter) GetUserRole(username interface{}) *MockPrivilegeCache_GetUserRole_Call {
return &MockPrivilegeCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)}
}
func (_c *MockPrivilegeCache_GetUserRole_Call) Run(run func(username string)) *MockPrivilegeCache_GetUserRole_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockPrivilegeCache_GetUserRole_Call) Return(_a0 []string) *MockPrivilegeCache_GetUserRole_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockPrivilegeCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockPrivilegeCache_GetUserRole_Call {
_c.Call.Return(run)
return _c
}
// InitPolicyInfo provides a mock function with given fields: info, userRoles
func (_m *MockPrivilegeCache) InitPolicyInfo(info []string, userRoles []string) {
_m.Called(info, userRoles)
}
// MockPrivilegeCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo'
type MockPrivilegeCache_InitPolicyInfo_Call struct {
*mock.Call
}
// InitPolicyInfo is a helper method to define mock.On call
// - info []string
// - userRoles []string
func (_e *MockPrivilegeCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockPrivilegeCache_InitPolicyInfo_Call {
return &MockPrivilegeCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)}
}
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockPrivilegeCache_InitPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].([]string))
})
return _c
}
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) Return() *MockPrivilegeCache_InitPolicyInfo_Call {
_c.Call.Return()
return _c
}
func (_c *MockPrivilegeCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockPrivilegeCache_InitPolicyInfo_Call {
_c.Run(run)
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockPrivilegeCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op)
if len(ret) == 0 {
panic("no return value specified for RefreshPolicyInfo")
}
var r0 error
if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok {
r0 = rf(op)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockPrivilegeCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo'
type MockPrivilegeCache_RefreshPolicyInfo_Call struct {
*mock.Call
}
// RefreshPolicyInfo is a helper method to define mock.On call
// - op typeutil.CacheOp
func (_e *MockPrivilegeCache_Expecter) RefreshPolicyInfo(op interface{}) *MockPrivilegeCache_RefreshPolicyInfo_Call {
return &MockPrivilegeCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)}
}
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockPrivilegeCache_RefreshPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(typeutil.CacheOp))
})
return _c
}
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockPrivilegeCache_RefreshPolicyInfo_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockPrivilegeCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockPrivilegeCache_RefreshPolicyInfo_Call {
_c.Call.Return(run)
return _c
}
// RemoveCredential provides a mock function with given fields: username
func (_m *MockPrivilegeCache) RemoveCredential(username string) {
_m.Called(username)
}
// MockPrivilegeCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential'
type MockPrivilegeCache_RemoveCredential_Call struct {
*mock.Call
}
// RemoveCredential is a helper method to define mock.On call
// - username string
func (_e *MockPrivilegeCache_Expecter) RemoveCredential(username interface{}) *MockPrivilegeCache_RemoveCredential_Call {
return &MockPrivilegeCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)}
}
func (_c *MockPrivilegeCache_RemoveCredential_Call) Run(run func(username string)) *MockPrivilegeCache_RemoveCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockPrivilegeCache_RemoveCredential_Call) Return() *MockPrivilegeCache_RemoveCredential_Call {
_c.Call.Return()
return _c
}
func (_c *MockPrivilegeCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockPrivilegeCache_RemoveCredential_Call {
_c.Run(run)
return _c
}
// UpdateCredential provides a mock function with given fields: credInfo
func (_m *MockPrivilegeCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
_m.Called(credInfo)
}
// MockPrivilegeCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential'
type MockPrivilegeCache_UpdateCredential_Call struct {
*mock.Call
}
// UpdateCredential is a helper method to define mock.On call
// - credInfo *internalpb.CredentialInfo
func (_e *MockPrivilegeCache_Expecter) UpdateCredential(credInfo interface{}) *MockPrivilegeCache_UpdateCredential_Call {
return &MockPrivilegeCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)}
}
func (_c *MockPrivilegeCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*internalpb.CredentialInfo))
})
return _c
}
func (_c *MockPrivilegeCache_UpdateCredential_Call) Return() *MockPrivilegeCache_UpdateCredential_Call {
_c.Call.Return()
return _c
}
func (_c *MockPrivilegeCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockPrivilegeCache_UpdateCredential_Call {
_c.Run(run)
return _c
}
// NewMockPrivilegeCache creates a new instance of MockPrivilegeCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockPrivilegeCache(t interface {
mock.TestingT
Cleanup(func())
}) *MockPrivilegeCache {
mock := &MockPrivilegeCache{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,139 @@
package privilege
import (
"log"
"strings"
"sync"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
const (
// sub -> role name, like admin, public
// obj -> contact object with object name, like Global-*, Collection-col1
// act -> privilege, like CreateCollection, DescribeCollection
ModelStr = `
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj))
`
)
var (
enforcer *casbin.SyncedEnforcer
initOnce sync.Once
initPrivilegeGroupsOnce sync.Once
)
func GetPolicyModel(modelString string) model.Model {
m, err := model.NewModelFromString(modelString)
if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
}
return m
}
func InitPrivilegeGroups() {
initPrivilegeGroupsOnce.Do(func() {
roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings()
if len(roGroup) == 0 {
roGroup = util.ReadOnlyPrivilegeGroup
}
roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} })
rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings()
if len(rwGroup) == 0 {
rwGroup = util.ReadWritePrivilegeGroup
}
rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} })
adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings()
if len(adminGroup) == 0 {
adminGroup = util.AdminPrivilegeGroup
}
adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} })
})
}
func GetEnforcer() *casbin.SyncedEnforcer {
initOnce.Do(func() {
e, err := casbin.NewSyncedEnforcer()
if err != nil {
log.Panic("failed to create casbin enforcer", zap.Error(err))
}
casbinModel := GetPolicyModel(ModelStr)
adapter := NewMetaCacheCasbinAdapter(func() PrivilegeCache { return GetPrivilegeCache() })
e.InitWithModelAndAdapter(casbinModel, adapter)
e.AddFunction("dbMatch", DBMatchFunc)
e.AddFunction("privilegeGroupContains", PrivilegeGroupContains)
enforcer = e
})
return enforcer
}
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
func DBMatchFunc(args ...interface{}) (interface{}, error) {
name1 := args[0].(string)
name2 := args[1].(string)
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
return db1 == db2, nil
}
func PrivilegeGroupContains(args ...interface{}) (interface{}, error) {
requestPrivilege := args[0].(string)
policyPrivilege := args[1].(string)
requestObj := args[2].(string)
policyObj := args[3].(string)
switch policyPrivilege {
case commonpb.ObjectPrivilege_PrivilegeAll.String():
return true, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String():
// read only belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := roPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String():
// read write belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := rwPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String():
// admin belong to global object
_, ok := adminPrivileges[requestPrivilege]
return ok, nil
default:
return false, nil
}
}
func collMatch(requestObj, policyObj string) bool {
_, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:])
_, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:])
return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package privilege
import (
"fmt"
@ -28,11 +28,11 @@ import (
var (
priCacheInitOnce sync.Once
priCacheMut sync.RWMutex
priCache *PrivilegeCache
priCache *resultCache
ver atomic.Int64
)
func getPriCache() *PrivilegeCache {
func getPriCache() *resultCache {
priCacheMut.RLock()
c := priCache
priCacheMut.RUnlock()
@ -41,7 +41,7 @@ func getPriCache() *PrivilegeCache {
priCacheInitOnce.Do(func() {
priCacheMut.Lock()
defer priCacheMut.Unlock()
priCache = &PrivilegeCache{
priCache = &resultCache{
version: ver.Inc(),
values: typeutil.ConcurrentMap[string, bool]{},
}
@ -57,20 +57,20 @@ func getPriCache() *PrivilegeCache {
func CleanPrivilegeCache() {
priCacheMut.Lock()
defer priCacheMut.Unlock()
priCache = &PrivilegeCache{
priCache = &resultCache{
version: ver.Inc(),
values: typeutil.ConcurrentMap[string, bool]{},
}
}
func GetPrivilegeCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) {
func GetResultCache(roleName, object, objectPrivilege string) (isPermit, cached bool, version int64) {
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
c := getPriCache()
isPermit, cached = c.values.Get(key)
return isPermit, cached, c.version
}
func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool, version int64) {
func SetResultCache(roleName, object, objectPrivilege string, isPermit bool, version int64) {
key := fmt.Sprintf("%s_%s_%s", roleName, object, objectPrivilege)
c := getPriCache()
if c.version == version {
@ -80,7 +80,7 @@ func SetPrivilegeCache(roleName, object, objectPrivilege string, isPermit bool,
// PrivilegeCache is a cache for privilege enforce result
// version provides version control when any policy updates
type PrivilegeCache struct {
type resultCache struct {
values typeutil.ConcurrentMap[string, bool]
version int64
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package privilege
import (
"testing"
@ -32,9 +32,9 @@ func (s *PrivilegeCacheSuite) TearDownTest() {
func (s *PrivilegeCacheSuite) TestGetPrivilege() {
// get current version
_, _, version := GetPrivilegeCache("", "", "")
SetPrivilegeCache("test-role", "test-object", "read", true, version)
SetPrivilegeCache("test-role", "test-object", "delete", false, version)
_, _, version := GetResultCache("", "", "")
SetResultCache("test-role", "test-object", "read", true, version)
SetResultCache("test-role", "test-object", "delete", false, version)
type testCase struct {
tag string
@ -51,7 +51,7 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() {
for _, tc := range testCases {
s.Run(tc.tag, func() {
isPermit, exists, _ := GetPrivilegeCache(tc.input[0], tc.input[1], tc.input[2])
isPermit, exists, _ := GetResultCache(tc.input[0], tc.input[1], tc.input[2])
s.Equal(tc.expectIsPermit, isPermit)
s.Equal(tc.expectExists, exists)
})
@ -60,12 +60,12 @@ func (s *PrivilegeCacheSuite) TestGetPrivilege() {
func (s *PrivilegeCacheSuite) TestSetPrivilegeVersion() {
// get current version
_, _, version := GetPrivilegeCache("", "", "")
_, _, version := GetResultCache("", "", "")
CleanPrivilegeCache()
SetPrivilegeCache("test-role", "test-object", "read", true, version)
SetResultCache("test-role", "test-object", "read", true, version)
isPermit, exists, nextVersion := GetPrivilegeCache("test-role", "test-object", "read")
isPermit, exists, nextVersion := GetResultCache("test-role", "test-object", "read")
s.False(isPermit)
s.False(exists)
s.NotEqual(version, nextVersion)

View File

@ -4,12 +4,9 @@ import (
"context"
"fmt"
"reflect"
"strings"
"sync"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
"github.com/samber/lo"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -17,36 +14,15 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
type PrivilegeFunc func(ctx context.Context, req interface{}) (context.Context, error)
const (
// sub -> role name, like admin, public
// obj -> contact object with object name, like Global-*, Collection-col1
// act -> privilege, like CreateCollection, DescribeCollection
ModelStr = `
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.sub == "admin" || (r.sub == p.sub && dbMatch(r.obj, p.obj) && privilegeGroupContains(r.act, p.act, r.obj, p.obj))
`
)
var templateModel = getPolicyModel(ModelStr)
var (
enforcer *casbin.SyncedEnforcer
initOnce sync.Once
@ -55,55 +31,9 @@ var (
var roPrivileges, rwPrivileges, adminPrivileges map[string]struct{}
func initPrivilegeGroups() {
initPrivilegeGroupsOnce.Do(func() {
roGroup := paramtable.Get().CommonCfg.ReadOnlyPrivileges.GetAsStrings()
if len(roGroup) == 0 {
roGroup = util.ReadOnlyPrivilegeGroup
}
roPrivileges = lo.SliceToMap(roGroup, func(item string) (string, struct{}) { return item, struct{}{} })
rwGroup := paramtable.Get().CommonCfg.ReadWritePrivileges.GetAsStrings()
if len(rwGroup) == 0 {
rwGroup = util.ReadWritePrivilegeGroup
}
rwPrivileges = lo.SliceToMap(rwGroup, func(item string) (string, struct{}) { return item, struct{}{} })
adminGroup := paramtable.Get().CommonCfg.AdminPrivileges.GetAsStrings()
if len(adminGroup) == 0 {
adminGroup = util.AdminPrivilegeGroup
}
adminPrivileges = lo.SliceToMap(adminGroup, func(item string) (string, struct{}) { return item, struct{}{} })
})
}
func getEnforcer() *casbin.SyncedEnforcer {
initOnce.Do(func() {
e, err := casbin.NewSyncedEnforcer()
if err != nil {
log.Panic("failed to create casbin enforcer", zap.Error(err))
}
casbinModel := getPolicyModel(ModelStr)
adapter := NewMetaCacheCasbinAdapter(func() Cache { return globalMetaCache })
e.InitWithModelAndAdapter(casbinModel, adapter)
e.AddFunction("dbMatch", DBMatchFunc)
e.AddFunction("privilegeGroupContains", PrivilegeGroupContains)
enforcer = e
})
return enforcer
}
func getPolicyModel(modelString string) model.Model {
m, err := model.NewModelFromString(modelString)
if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
}
return m
}
// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
func UnaryServerInterceptor(privilegeFunc PrivilegeFunc) grpc.UnaryServerInterceptor {
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newCtx, err := privilegeFunc(ctx, req)
if err != nil {
@ -160,11 +90,11 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
zap.Int32("object_index", objectNameIndex), zap.String("object_name", objectName),
zap.Int32("object_indexs", objectNameIndexs), zap.Strings("object_names", objectNames))
e := getEnforcer()
e := privilege.GetEnforcer()
for _, roleName := range roleNames {
permitFunc := func(objectName string) (bool, error) {
object := funcutil.PolicyForResource(dbName, objectType, objectName)
isPermit, cached, version := GetPrivilegeCache(roleName, object, objectPrivilege)
isPermit, cached, version := privilege.GetResultCache(roleName, object, objectPrivilege)
if cached {
return isPermit, nil
}
@ -172,7 +102,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
if err != nil {
return false, err
}
SetPrivilegeCache(roleName, object, objectPrivilege, isPermit, version)
privilege.SetResultCache(roleName, object, objectPrivilege, isPermit, version)
return isPermit, nil
}
@ -237,52 +167,3 @@ func isSelectMyRoleGrants(req interface{}, roleNames []string) bool {
roleName := filterGrantEntity.GetRole().GetName()
return funcutil.SliceContain(roleNames, roleName)
}
func DBMatchFunc(args ...interface{}) (interface{}, error) {
name1 := args[0].(string)
name2 := args[1].(string)
db1, _ := funcutil.SplitObjectName(name1[strings.Index(name1, "-")+1:])
db2, _ := funcutil.SplitObjectName(name2[strings.Index(name2, "-")+1:])
return db1 == db2, nil
}
func collMatch(requestObj, policyObj string) bool {
_, coll1 := funcutil.SplitObjectName(requestObj[strings.Index(requestObj, "-")+1:])
_, coll2 := funcutil.SplitObjectName(policyObj[strings.Index(policyObj, "-")+1:])
return coll1 == util.AnyWord || coll2 == util.AnyWord || coll1 == coll2
}
func PrivilegeGroupContains(args ...interface{}) (interface{}, error) {
requestPrivilege := args[0].(string)
policyPrivilege := args[1].(string)
requestObj := args[2].(string)
policyObj := args[3].(string)
switch policyPrivilege {
case commonpb.ObjectPrivilege_PrivilegeAll.String():
return true, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadOnly.String():
// read only belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := roPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupReadWrite.String():
// read write belong to collection object
if !collMatch(requestObj, policyObj) {
return false, nil
}
_, ok := rwPrivileges[requestPrivilege]
return ok, nil
case commonpb.ObjectPrivilege_PrivilegeGroupAdmin.String():
// admin belong to global object
_, ok := adminPrivileges[requestPrivilege]
return ok, nil
default:
return false, nil
}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -167,7 +168,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
g.Wait()
assert.Panics(t, func() {
getPolicyModel("foo")
privilege.GetPolicyModel("foo")
})
})
}
@ -268,7 +269,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadOnly to single collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
var err error
ctx = GetContext(context.Background(), "fooo:123456")
@ -287,7 +288,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil
}
InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache()
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1",
@ -325,7 +326,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadOnly to all collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
var err error
ctx = GetContext(context.Background(), "fooo:123456")
@ -344,7 +345,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil
}
InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache()
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1",
@ -382,7 +383,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadWrite to single collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
var err error
ctx = GetContext(context.Background(), "fooo:123456")
@ -401,7 +402,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil
}
InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache()
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1",
@ -485,7 +486,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("grant ReadWrite to all collection", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
var err error
ctx = GetContext(context.Background(), "fooo:123456")
@ -504,7 +505,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil
}
InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache()
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
CollectionName: "coll1",
@ -552,7 +553,7 @@ func TestPrivilegeGroup(t *testing.T) {
t.Run("Admin", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
var err error
ctx = GetContext(context.Background(), "fooo:123456")
@ -571,7 +572,7 @@ func TestPrivilegeGroup(t *testing.T) {
}, nil
}
InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache()
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{})
assert.NoError(t, err)
@ -593,7 +594,7 @@ func TestPrivilegeGroup(t *testing.T) {
func TestBuiltinPrivilegeGroup(t *testing.T) {
t.Run("ClusterAdmin", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true")
initPrivilegeGroups()
privilege.InitPrivilegeGroups()
var err error
ctx := GetContext(context.Background(), "fooo:123456")
@ -615,7 +616,7 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
}, nil
}
InitMetaCache(ctx, client, mgr)
defer CleanPrivilegeCache()
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{})
assert.NoError(t, err)

View File

@ -51,6 +51,7 @@ import (
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -2837,7 +2838,7 @@ func TestProxy(t *testing.T) {
getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode())
assert.True(t, passwordVerify(ctx, username, newPassword, globalMetaCache))
assert.True(t, passwordVerify(ctx, username, newPassword, privilege.GetPrivilegeCache()))
getCredentialReq.Username = "("
getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq)

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/analyzer"
"github.com/milvus-io/milvus/internal/util/function/embedding"
@ -1464,14 +1465,15 @@ func AppendUserInfoForRPC(ctx context.Context) context.Context {
}
func GetRole(username string) ([]string, error) {
if globalMetaCache == nil {
privCache := privilege.GetPrivilegeCache()
if privCache == nil {
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
}
return globalMetaCache.GetUserRole(username), nil
return privCache.GetUserRole(username), nil
}
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
return passwordVerify(ctx, username, rawPwd, globalMetaCache)
return passwordVerify(ctx, username, rawPwd, privilege.GetPrivilegeCache())
}
func VerifyAPIKey(rawToken string) (string, error) {
@ -1485,10 +1487,10 @@ func VerifyAPIKey(rawToken string) (string, error) {
}
// PasswordVerify verify password
func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool {
func passwordVerify(ctx context.Context, username, rawPwd string, privilegeCache privilege.PrivilegeCache) bool {
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
// meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache.
credInfo, err := globalMetaCache.GetCredentialInfo(ctx, username)
credInfo, err := privilege.GetPrivilegeCache().GetCredentialInfo(ctx, username)
if err != nil {
log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err))
return false
@ -1509,7 +1511,7 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach
// update cache after miss cache
credInfo.Sha256Password = sha256Pwd
log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo))
globalMetaCache.UpdateCredential(credInfo)
privilegeCache.UpdateCredential(credInfo)
return true
}

View File

@ -37,6 +37,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -789,19 +790,21 @@ func TestGetCurDBNameFromContext(t *testing.T) {
}
func TestGetRole(t *testing.T) {
globalMetaCache = nil
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
privilege.ResetPrivilegeCacheForTest()
_, err := GetRole("foo")
assert.Error(t, err)
mockCache := NewMockCache(t)
mockCache.On("GetUserRole",
mock.AnythingOfType("string"),
).Return(func(username string) []string {
if username == "root" {
return []string{"role1", "admin", "role2"}
}
return []string{"role1"}
})
globalMetaCache = mockCache
mixcoord := mocks.NewMockMixCoordClient(t)
mixcoord.EXPECT().ListPolicy(mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{
Status: merr.Success(),
UserRoles: []string{"root/role1", "root/admin", "root/role2", "foo/role1"},
}, nil).Times(1)
privilege.InitPrivilegeCache(ctx, mixcoord)
roles, err := GetRole("root")
assert.NoError(t, err)
assert.Equal(t, 3, len(roles))
@ -812,11 +815,12 @@ func TestGetRole(t *testing.T) {
}
func TestPasswordVerify(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
username := "user-test00"
password := "PasswordVerify"
// credential does not exist within cache
credCache := make(map[string]*internalpb.CredentialInfo, 0)
invokedCount := 0
mockedRootCoord := NewMixCoordMock()
@ -825,34 +829,36 @@ func TestPasswordVerify(t *testing.T) {
return nil, errors.New("get cred not found credential")
}
metaCache := &MetaCache{
credMap: credCache,
mixCoord: mockedRootCoord,
}
ret, ok := credCache[username]
assert.False(t, ok)
assert.Nil(t, ret)
assert.False(t, passwordVerify(context.TODO(), username, password, metaCache))
privilege.InitPrivilegeCache(ctx, mockedRootCoord)
privilegeCache := privilege.GetPrivilegeCache()
assert.False(t, passwordVerify(ctx, username, password, privilegeCache))
assert.Equal(t, 1, invokedCount)
// Sha256Password has not been filled into cache during establish connection firstly
encryptedPwd, err := crypto.PasswordEncrypt(password)
assert.NoError(t, err)
credCache[username] = &internalpb.CredentialInfo{
Username: username,
EncryptedPassword: encryptedPwd,
privilegeCache.RemoveCredential(username)
mockedRootCoord.GetGetCredentialFunc = func(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) {
invokedCount++
return &rootcoordpb.GetCredentialResponse{
Status: merr.Success(),
Username: username,
Password: encryptedPwd,
}, nil
}
assert.True(t, passwordVerify(context.TODO(), username, password, metaCache))
ret, ok = credCache[username]
assert.True(t, ok)
assert.True(t, passwordVerify(ctx, username, password, privilegeCache))
ret, err := privilegeCache.GetCredentialInfo(ctx, username)
assert.NoError(t, err)
assert.NotNil(t, ret)
assert.Equal(t, username, ret.Username)
assert.NotNil(t, ret.Sha256Password)
assert.Equal(t, 1, invokedCount)
assert.Equal(t, 2, invokedCount)
// Sha256Password already exists within cache
assert.True(t, passwordVerify(context.TODO(), username, password, metaCache))
assert.Equal(t, 1, invokedCount)
assert.True(t, passwordVerify(ctx, username, password, privilegeCache))
assert.Equal(t, 2, invokedCount)
}
func Test_isCollectionIsLoaded(t *testing.T) {