enhance: [2.6] skip check source id (#45379)

pr: https://github.com/milvus-io/milvus/pull/45377
relate:https://github.com/milvus-io/milvus/issues/45381

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-11-07 15:47:34 +08:00 committed by GitHub
parent 81c2fd46a5
commit e938bacf20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 34 additions and 101 deletions

View File

@ -32,20 +32,6 @@ func parseMD(rawToken string) (username, password string) {
return
}
func validSourceID(ctx context.Context, authorization []string) bool {
if len(authorization) < 1 {
// log.Warn("key not found in header", zap.String("key", util.HeaderSourceID))
return false
}
// token format: base64<sourceID>
token := authorization[0]
sourceID, err := crypto.Base64Decode(token)
if err != nil {
return false
}
return sourceID == util.MemberCredID
}
func GrpcAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var newCtx context.Context
@ -76,48 +62,44 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
if globalMetaCache == nil {
return nil, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
}
// check:
// 1. if rpc call from a member (like index/query/data component)
// 2. if rpc call from sdk
// check rpc call from sdk
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) {
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
if len(authStrArr) < 1 {
log.Warn("key not found in header")
return nil, status.Error(codes.Unauthenticated, "missing authorization in header")
}
if len(authStrArr) < 1 {
log.Warn("key not found in header")
return nil, status.Error(codes.Unauthenticated, "missing authorization in header")
}
// token format: base64<username:password>
// token := strings.TrimPrefix(authorization[0], "Bearer ")
token := authStrArr[0]
rawToken, err := crypto.Base64Decode(token)
// token format: base64<username:password>
// token := strings.TrimPrefix(authorization[0], "Bearer ")
token := authStrArr[0]
rawToken, err := crypto.Base64Decode(token)
if err != nil {
log.Warn("fail to decode the token", zap.Error(err))
return nil, status.Error(codes.Unauthenticated, "invalid token format")
}
if !strings.Contains(rawToken, util.CredentialSeperator) {
user, err := VerifyAPIKey(rawToken)
if err != nil {
log.Warn("fail to decode the token", zap.Error(err))
return nil, status.Error(codes.Unauthenticated, "invalid token format")
log.Warn("fail to verify apikey", zap.Error(err))
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct")
}
if !strings.Contains(rawToken, util.CredentialSeperator) {
user, err := VerifyAPIKey(rawToken)
if err != nil {
log.Warn("fail to verify apikey", zap.Error(err))
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct")
}
metrics.UserRPCCounter.WithLabelValues(user).Inc()
userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder)
md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)}
md[util.HeaderToken] = []string{rawToken}
ctx = metadata.NewIncomingContext(ctx, md)
} else {
// username+password authentication
username, password := parseMD(rawToken)
if !passwordVerify(ctx, username, password, globalMetaCache) {
log.Warn("fail to verify password", zap.String("username", username))
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
}
metrics.UserRPCCounter.WithLabelValues(username).Inc()
metrics.UserRPCCounter.WithLabelValues(user).Inc()
userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder)
md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)}
md[util.HeaderToken] = []string{rawToken}
ctx = metadata.NewIncomingContext(ctx, md)
} else {
// username+password authentication
username, password := parseMD(rawToken)
if !passwordVerify(ctx, username, password, globalMetaCache) {
log.Warn("fail to verify password", zap.String("username", username))
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
}
metrics.UserRPCCounter.WithLabelValues(username).Inc()
}
}
return ctx, nil

View File

@ -49,19 +49,6 @@ func TestValidAuth(t *testing.T) {
assert.False(t, res)
}
func TestValidSourceID(t *testing.T) {
ctx := context.Background()
// no metadata
res := validSourceID(ctx, nil)
assert.False(t, res)
// illegal metadata
res = validSourceID(ctx, []string{"invalid_sourceid"})
assert.False(t, res)
// normal sourceId
res = validSourceID(ctx, []string{crypto.Base64Encode(util.MemberCredID)})
assert.True(t, res)
}
func TestAuthenticationInterceptor(t *testing.T) {
ctx := context.Background()
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on
@ -84,11 +71,6 @@ func TestAuthenticationInterceptor(t *testing.T) {
ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx)
assert.NoError(t, err)
// with valid sourceId
md = metadata.Pairs("sourceid", crypto.Base64Encode(util.MemberCredID))
ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx)
assert.NoError(t, err)
{
// wrong authorization style

View File

@ -1,19 +0,0 @@
package grpcclient
import (
"context"
"github.com/milvus-io/milvus/pkg/v2/util"
)
type Token struct {
Value string
}
func (t *Token) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{util.HeaderSourceID: t.Value}, nil
}
func (t *Token) RequireTransportSecurity() bool {
return false
}

View File

@ -40,8 +40,6 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/generic"
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
@ -310,7 +308,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
},
MinConnectTimeout: c.DialTimeout,
}),
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(),
grpc.WithDisableRetry(),
@ -349,7 +346,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
},
MinConnectTimeout: c.DialTimeout,
}),
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(),
grpc.WithDisableRetry(),

View File

@ -44,12 +44,8 @@ const (
SegmentIndexPrefix = "segment-index"
FieldIndexPrefix = "field-index"
HeaderAuthorize = "authorization"
HeaderToken = "token"
// HeaderSourceID identify requests from Milvus members and client requests
HeaderSourceID = "sourceId"
// MemberCredID id for Milvus members (data/index/query node/coord component)
MemberCredID = "@@milvus-member@@"
HeaderAuthorize = "authorization"
HeaderToken = "token"
CredentialSeperator = ":"
UserRoot = "root"
PasswordHolder = "___"

View File

@ -45,12 +45,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/util/grpcclient"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
"github.com/milvus-io/milvus/pkg/v2/util/lifetime"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -499,7 +496,6 @@ func DailGRPClient(ctx context.Context, addr string, rootPath string, nodeID int
},
MinConnectTimeout: 5 * time.Second,
}),
grpc.WithPerRPCCredentials(&grpcclient.Token{Value: crypto.Base64Encode(util.MemberCredID)}),
grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(),
grpc.WithDisableRetry(),