diff --git a/internal/proxy/authentication_interceptor.go b/internal/proxy/authentication_interceptor.go index e2a0a24291..f5369dce4d 100644 --- a/internal/proxy/authentication_interceptor.go +++ b/internal/proxy/authentication_interceptor.go @@ -81,7 +81,7 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) { return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct") } metrics.UserRPCCounter.WithLabelValues(user).Inc() - userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, "___") + userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder) md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)} ctx = metadata.NewIncomingContext(ctx, md) } else { diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index e8319249b8..ad0496fd8a 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -18,6 +18,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/funcutil" ) @@ -94,7 +95,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context log.RatedInfo(60, "GetPrivilegeExtObj err", zap.Error(err)) return ctx, nil } - username, err := GetCurUserFromContext(ctx) + username, password, err := contextutil.GetAuthInfoFromContext(ctx) if err != nil { log.Warn("GetCurUserFromContext fail", zap.Error(err)) return ctx, err @@ -174,7 +175,13 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context } log.Info("permission deny", zap.Strings("roles", roleNames)) - return ctx, status.Error(codes.PermissionDenied, fmt.Sprintf("%s: permission deny to %s", objectPrivilege, username)) + + if password == util.PasswordHolder { + username = "apikey user" + } + + return ctx, status.Error(codes.PermissionDenied, + fmt.Sprintf("%s: permission deny to %s in the `%s` database", objectPrivilege, username, dbName)) } // isCurUserObject Determine whether it is an Object of type User that operates on its own user information, diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index e42c4df78b..5a6c554457 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "strings" "sync" "testing" @@ -11,6 +12,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -106,6 +108,14 @@ func TestPrivilegeInterceptor(t *testing.T) { CollectionName: "col1", }) assert.Error(t, err) + { + _, err = PrivilegeInterceptor(GetContext(context.Background(), "foo:"+util.PasswordHolder), &milvuspb.LoadCollectionRequest{ + DbName: "db_test", + CollectionName: "col1", + }) + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "apikey user")) + } _, err = PrivilegeInterceptor(ctx, &milvuspb.InsertRequest{ DbName: "db_test", CollectionName: "col1", diff --git a/pkg/util/constant.go b/pkg/util/constant.go index 1bc2883289..7bdad8a371 100644 --- a/pkg/util/constant.go +++ b/pkg/util/constant.go @@ -49,6 +49,7 @@ const ( CredentialSeperator = ":" UserRoot = "root" DefaultRootPassword = "Milvus" + PasswordHolder = "___" DefaultTenant = "" RoleAdmin = "admin" RolePublic = "public" diff --git a/pkg/util/contextutil/context_util.go b/pkg/util/contextutil/context_util.go index e5aa293b35..03d123f160 100644 --- a/pkg/util/contextutil/context_util.go +++ b/pkg/util/contextutil/context_util.go @@ -64,23 +64,29 @@ func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context } func GetCurUserFromContext(ctx context.Context) (string, error) { + username, _, err := GetAuthInfoFromContext(ctx) + return username, err +} + +func GetAuthInfoFromContext(ctx context.Context) (string, string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return "", fmt.Errorf("fail to get md from the context") + return "", "", fmt.Errorf("fail to get md from the context") } authorization, ok := md[strings.ToLower(util.HeaderAuthorize)] if !ok || len(authorization) < 1 { - return "", fmt.Errorf("fail to get authorization from the md, %s:[token]", strings.ToLower(util.HeaderAuthorize)) + return "", "", fmt.Errorf("fail to get authorization from the md, %s:[token]", strings.ToLower(util.HeaderAuthorize)) } token := authorization[0] rawToken, err := crypto.Base64Decode(token) if err != nil { - return "", fmt.Errorf("fail to decode the token, token: %s", token) + return "", "", fmt.Errorf("fail to decode the token, token: %s", token) } secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) if len(secrets) < 2 { - return "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken) + return "", "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken) } - username := secrets[0] - return username, nil + // username: secrets[0] + // password: secrets[1] + return secrets[0], secrets[1], nil } diff --git a/pkg/util/contextutil/context_util_test.go b/pkg/util/contextutil/context_util_test.go index 1114f71954..38442e6e39 100644 --- a/pkg/util/contextutil/context_util_test.go +++ b/pkg/util/contextutil/context_util_test.go @@ -62,7 +62,14 @@ func TestGetCurUserFromContext(t *testing.T) { password := "123456" username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) assert.NoError(t, err) - assert.Equal(t, "root", username) + assert.Equal(t, root, username) + + { + u, p, e := GetAuthInfoFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) + assert.NoError(t, e) + assert.Equal(t, "root", u) + assert.Equal(t, password, p) + } } func GetContext(ctx context.Context, originValue string) context.Context {