From 5763fbd5abe421ae0347ce9e1a397ca5dae818bf Mon Sep 17 00:00:00 2001 From: SimFG Date: Fri, 9 Dec 2022 21:37:37 +0800 Subject: [PATCH] Fix the unsafe casbin `Model` (#21115) Signed-off-by: SimFG Signed-off-by: SimFG --- internal/distributed/proxy/service.go | 1 - internal/proxy/privilege_interceptor.go | 28 +++++--------------- internal/proxy/privilege_interceptor_test.go | 19 ++++++++++--- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index c86684fd8a..5636491940 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -331,7 +331,6 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.proxy.SetEtcdClient(s.etcdCli) - proxy.InitPolicyModel() errChan := make(chan error, 1) { diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 8deeaba3d0..eb528b9aee 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -2,11 +2,9 @@ package proxy import ( "context" - "errors" "fmt" "reflect" "strings" - "sync" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -46,19 +44,12 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su ModelKey = "casbin" ) -var ( - casbinModel model.Model - initOnce sync.Once -) - -func InitPolicyModel() { - initOnce.Do(func() { - var err error - casbinModel, err = model.NewModelFromString(ModelStr) - if err != nil { - log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) - } - }) +func GetPolicyModel() model.Model { + model, err := model.NewModelFromString(ModelStr) + if err != nil { + log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) + } + return model } // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. @@ -116,12 +107,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context policy := fmt.Sprintf("[%s]", policyInfo) b := []byte(policy) a := jsonadapter.NewAdapter(&b) - if casbinModel == nil { - errStr := "fail to get policy model" - err = errors.New(errStr) - log.Panic(errStr, zap.Error(err)) - return ctx, err - } + casbinModel := GetPolicyModel() e, err := casbin.NewEnforcer(casbinModel, a) if err != nil { log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err)) diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index 4453174746..2799581432 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -2,13 +2,13 @@ package proxy import ( "context" + "sync" "testing" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/stretchr/testify/assert" ) @@ -19,7 +19,7 @@ func TestUnaryServerInterceptor(t *testing.T) { func TestPrivilegeInterceptor(t *testing.T) { ctx := context.Background() - InitPolicyModel() + t.Run("Authorization Disabled", func(t *testing.T) { Params.CommonCfg.AuthorizationEnabled = false _, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{ @@ -110,6 +110,19 @@ func TestPrivilegeInterceptor(t *testing.T) { CollectionName: "col1", }) assert.Nil(t, err) + + g := sync.WaitGroup{} + for i := 0; i < 20; i++ { + g.Add(1) + go func() { + defer g.Done() + PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ + DbName: "db_test", + CollectionName: "col1", + }) + }() + } + g.Wait() }) }