mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: [2.5] skip check source id (#45383)
pr: https://github.com/milvus-io/milvus/pull/45377 relate:https://github.com/milvus-io/milvus/issues/45381 Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
6dd7b7c197
commit
7ad68910d9
@ -32,20 +32,6 @@ func parseMD(rawToken string) (username, password string) {
|
||||
return
|
||||
}
|
||||
|
||||
func validSourceID(ctx context.Context, authorization []string) bool {
|
||||
if len(authorization) < 1 {
|
||||
// log.Warn("key not found in header", zap.String("key", util.HeaderSourceID))
|
||||
return false
|
||||
}
|
||||
// token format: base64<sourceID>
|
||||
token := authorization[0]
|
||||
sourceID, err := crypto.Base64Decode(token)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return sourceID == util.MemberCredID
|
||||
}
|
||||
|
||||
func GrpcAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
var newCtx context.Context
|
||||
@ -76,48 +62,44 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
|
||||
if globalMetaCache == nil {
|
||||
return nil, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
||||
}
|
||||
// check:
|
||||
// 1. if rpc call from a member (like index/query/data component)
|
||||
// 2. if rpc call from sdk
|
||||
// check if rpc call from sdk
|
||||
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) {
|
||||
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
|
||||
if len(authStrArr) < 1 {
|
||||
log.Warn("key not found in header")
|
||||
return nil, status.Error(codes.Unauthenticated, "missing authorization in header")
|
||||
}
|
||||
if len(authStrArr) < 1 {
|
||||
log.Warn("key not found in header")
|
||||
return nil, status.Error(codes.Unauthenticated, "missing authorization in header")
|
||||
}
|
||||
|
||||
// token format: base64<username:password>
|
||||
// token := strings.TrimPrefix(authorization[0], "Bearer ")
|
||||
token := authStrArr[0]
|
||||
rawToken, err := crypto.Base64Decode(token)
|
||||
// token format: base64<username:password>
|
||||
// token := strings.TrimPrefix(authorization[0], "Bearer ")
|
||||
token := authStrArr[0]
|
||||
rawToken, err := crypto.Base64Decode(token)
|
||||
if err != nil {
|
||||
log.Warn("fail to decode the token", zap.Error(err))
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid token format")
|
||||
}
|
||||
|
||||
if !strings.Contains(rawToken, util.CredentialSeperator) {
|
||||
user, err := VerifyAPIKey(rawToken)
|
||||
if err != nil {
|
||||
log.Warn("fail to decode the token", zap.Error(err))
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid token format")
|
||||
log.Warn("fail to verify apikey", zap.Error(err))
|
||||
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct")
|
||||
}
|
||||
|
||||
if !strings.Contains(rawToken, util.CredentialSeperator) {
|
||||
user, err := VerifyAPIKey(rawToken)
|
||||
if err != nil {
|
||||
log.Warn("fail to verify apikey", zap.Error(err))
|
||||
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct")
|
||||
}
|
||||
metrics.UserRPCCounter.WithLabelValues(user).Inc()
|
||||
userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder)
|
||||
md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)}
|
||||
md[util.HeaderToken] = []string{rawToken}
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
} else {
|
||||
// username+password authentication
|
||||
username, password := parseMD(rawToken)
|
||||
if !passwordVerify(ctx, username, password, globalMetaCache) {
|
||||
log.Warn("fail to verify password", zap.String("username", username))
|
||||
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
|
||||
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
|
||||
}
|
||||
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
||||
metrics.UserRPCCounter.WithLabelValues(user).Inc()
|
||||
userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder)
|
||||
md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)}
|
||||
md[util.HeaderToken] = []string{rawToken}
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
} else {
|
||||
// username+password authentication
|
||||
username, password := parseMD(rawToken)
|
||||
if !passwordVerify(ctx, username, password, globalMetaCache) {
|
||||
log.Warn("fail to verify password", zap.String("username", username))
|
||||
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
|
||||
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
|
||||
}
|
||||
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
|
||||
@ -51,19 +51,6 @@ func TestValidAuth(t *testing.T) {
|
||||
assert.False(t, res)
|
||||
}
|
||||
|
||||
func TestValidSourceID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// no metadata
|
||||
res := validSourceID(ctx, nil)
|
||||
assert.False(t, res)
|
||||
// illegal metadata
|
||||
res = validSourceID(ctx, []string{"invalid_sourceid"})
|
||||
assert.False(t, res)
|
||||
// normal sourceId
|
||||
res = validSourceID(ctx, []string{crypto.Base64Encode(util.MemberCredID)})
|
||||
assert.True(t, res)
|
||||
}
|
||||
|
||||
func TestAuthenticationInterceptor(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on
|
||||
@ -87,11 +74,6 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
_, err = AuthenticationInterceptor(ctx)
|
||||
assert.NoError(t, err)
|
||||
// with valid sourceId
|
||||
md = metadata.Pairs("sourceid", crypto.Base64Encode(util.MemberCredID))
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
_, err = AuthenticationInterceptor(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
{
|
||||
// wrong authorization style
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
package grpcclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
func (t *Token) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
return map[string]string{util.HeaderSourceID: t.Value}, nil
|
||||
}
|
||||
|
||||
func (t *Token) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
@ -40,8 +40,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/tracer"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/generic"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
|
||||
@ -310,7 +308,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
|
||||
},
|
||||
MinConnectTimeout: c.DialTimeout,
|
||||
}),
|
||||
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
|
||||
grpc.FailOnNonTempDialError(true),
|
||||
grpc.WithReturnConnectionError(),
|
||||
grpc.WithDisableRetry(),
|
||||
@ -349,7 +346,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
|
||||
},
|
||||
MinConnectTimeout: c.DialTimeout,
|
||||
}),
|
||||
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
|
||||
grpc.FailOnNonTempDialError(true),
|
||||
grpc.WithReturnConnectionError(),
|
||||
grpc.WithDisableRetry(),
|
||||
|
||||
@ -45,12 +45,8 @@ const (
|
||||
SegmentIndexPrefix = "segment-index"
|
||||
FieldIndexPrefix = "field-index"
|
||||
|
||||
HeaderAuthorize = "authorization"
|
||||
HeaderToken = "token"
|
||||
// HeaderSourceID identify requests from Milvus members and client requests
|
||||
HeaderSourceID = "sourceId"
|
||||
// MemberCredID id for Milvus members (data/index/query node/coord component)
|
||||
MemberCredID = "@@milvus-member@@"
|
||||
HeaderAuthorize = "authorization"
|
||||
HeaderToken = "token"
|
||||
CredentialSeperator = ":"
|
||||
UserRoot = "root"
|
||||
PasswordHolder = "___"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user