enhance: [2.6] skip check source id (#45379)

pr: https://github.com/milvus-io/milvus/pull/45377
relate:https://github.com/milvus-io/milvus/issues/45381

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-11-07 15:47:34 +08:00 committed by GitHub
parent 81c2fd46a5
commit e938bacf20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 34 additions and 101 deletions

View File

@ -32,20 +32,6 @@ func parseMD(rawToken string) (username, password string) {
return return
} }
func validSourceID(ctx context.Context, authorization []string) bool {
if len(authorization) < 1 {
// log.Warn("key not found in header", zap.String("key", util.HeaderSourceID))
return false
}
// token format: base64<sourceID>
token := authorization[0]
sourceID, err := crypto.Base64Decode(token)
if err != nil {
return false
}
return sourceID == util.MemberCredID
}
func GrpcAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor { func GrpcAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var newCtx context.Context var newCtx context.Context
@ -76,48 +62,44 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
if globalMetaCache == nil { if globalMetaCache == nil {
return nil, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") return nil, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
} }
// check: // check rpc call from sdk
// 1. if rpc call from a member (like index/query/data component)
// 2. if rpc call from sdk
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() { if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) { authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
if len(authStrArr) < 1 { if len(authStrArr) < 1 {
log.Warn("key not found in header") log.Warn("key not found in header")
return nil, status.Error(codes.Unauthenticated, "missing authorization in header") return nil, status.Error(codes.Unauthenticated, "missing authorization in header")
} }
// token format: base64<username:password> // token format: base64<username:password>
// token := strings.TrimPrefix(authorization[0], "Bearer ") // token := strings.TrimPrefix(authorization[0], "Bearer ")
token := authStrArr[0] token := authStrArr[0]
rawToken, err := crypto.Base64Decode(token) rawToken, err := crypto.Base64Decode(token)
if err != nil {
log.Warn("fail to decode the token", zap.Error(err))
return nil, status.Error(codes.Unauthenticated, "invalid token format")
}
if !strings.Contains(rawToken, util.CredentialSeperator) {
user, err := VerifyAPIKey(rawToken)
if err != nil { if err != nil {
log.Warn("fail to decode the token", zap.Error(err)) log.Warn("fail to verify apikey", zap.Error(err))
return nil, status.Error(codes.Unauthenticated, "invalid token format") return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct")
} }
metrics.UserRPCCounter.WithLabelValues(user).Inc()
if !strings.Contains(rawToken, util.CredentialSeperator) { userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder)
user, err := VerifyAPIKey(rawToken) md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)}
if err != nil { md[util.HeaderToken] = []string{rawToken}
log.Warn("fail to verify apikey", zap.Error(err)) ctx = metadata.NewIncomingContext(ctx, md)
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct") } else {
} // username+password authentication
metrics.UserRPCCounter.WithLabelValues(user).Inc() username, password := parseMD(rawToken)
userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder) if !passwordVerify(ctx, username, password, globalMetaCache) {
md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)} log.Warn("fail to verify password", zap.String("username", username))
md[util.HeaderToken] = []string{rawToken} // NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
ctx = metadata.NewIncomingContext(ctx, md) return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
} else {
// username+password authentication
username, password := parseMD(rawToken)
if !passwordVerify(ctx, username, password, globalMetaCache) {
log.Warn("fail to verify password", zap.String("username", username))
// NOTE: don't use the merr, because it will cause the wrong retry behavior in the sdk
return nil, status.Error(codes.Unauthenticated, "auth check failure, please check username and password are correct")
}
metrics.UserRPCCounter.WithLabelValues(username).Inc()
} }
metrics.UserRPCCounter.WithLabelValues(username).Inc()
} }
} }
return ctx, nil return ctx, nil

View File

@ -49,19 +49,6 @@ func TestValidAuth(t *testing.T) {
assert.False(t, res) assert.False(t, res)
} }
func TestValidSourceID(t *testing.T) {
ctx := context.Background()
// no metadata
res := validSourceID(ctx, nil)
assert.False(t, res)
// illegal metadata
res = validSourceID(ctx, []string{"invalid_sourceid"})
assert.False(t, res)
// normal sourceId
res = validSourceID(ctx, []string{crypto.Base64Encode(util.MemberCredID)})
assert.True(t, res)
}
func TestAuthenticationInterceptor(t *testing.T) { func TestAuthenticationInterceptor(t *testing.T) {
ctx := context.Background() ctx := context.Background()
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on
@ -84,11 +71,6 @@ func TestAuthenticationInterceptor(t *testing.T) {
ctx = metadata.NewIncomingContext(ctx, md) ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx) _, err = AuthenticationInterceptor(ctx)
assert.NoError(t, err) assert.NoError(t, err)
// with valid sourceId
md = metadata.Pairs("sourceid", crypto.Base64Encode(util.MemberCredID))
ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx)
assert.NoError(t, err)
{ {
// wrong authorization style // wrong authorization style

View File

@ -1,19 +0,0 @@
package grpcclient
import (
"context"
"github.com/milvus-io/milvus/pkg/v2/util"
)
type Token struct {
Value string
}
func (t *Token) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{util.HeaderSourceID: t.Value}, nil
}
func (t *Token) RequireTransportSecurity() bool {
return false
}

View File

@ -40,8 +40,6 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/tracer" "github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/generic" "github.com/milvus-io/milvus/pkg/v2/util/generic"
"github.com/milvus-io/milvus/pkg/v2/util/interceptor" "github.com/milvus-io/milvus/pkg/v2/util/interceptor"
@ -310,7 +308,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
}, },
MinConnectTimeout: c.DialTimeout, MinConnectTimeout: c.DialTimeout,
}), }),
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(), grpc.WithReturnConnectionError(),
grpc.WithDisableRetry(), grpc.WithDisableRetry(),
@ -349,7 +346,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
}, },
MinConnectTimeout: c.DialTimeout, MinConnectTimeout: c.DialTimeout,
}), }),
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(), grpc.WithReturnConnectionError(),
grpc.WithDisableRetry(), grpc.WithDisableRetry(),

View File

@ -44,12 +44,8 @@ const (
SegmentIndexPrefix = "segment-index" SegmentIndexPrefix = "segment-index"
FieldIndexPrefix = "field-index" FieldIndexPrefix = "field-index"
HeaderAuthorize = "authorization" HeaderAuthorize = "authorization"
HeaderToken = "token" HeaderToken = "token"
// HeaderSourceID identify requests from Milvus members and client requests
HeaderSourceID = "sourceId"
// MemberCredID id for Milvus members (data/index/query node/coord component)
MemberCredID = "@@milvus-member@@"
CredentialSeperator = ":" CredentialSeperator = ":"
UserRoot = "root" UserRoot = "root"
PasswordHolder = "___" PasswordHolder = "___"

View File

@ -45,12 +45,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/util/grpcclient"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil" "github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/interceptor" "github.com/milvus-io/milvus/pkg/v2/util/interceptor"
"github.com/milvus-io/milvus/pkg/v2/util/lifetime" "github.com/milvus-io/milvus/pkg/v2/util/lifetime"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -499,7 +496,6 @@ func DailGRPClient(ctx context.Context, addr string, rootPath string, nodeID int
}, },
MinConnectTimeout: 5 * time.Second, MinConnectTimeout: 5 * time.Second,
}), }),
grpc.WithPerRPCCredentials(&grpcclient.Token{Value: crypto.Base64Encode(util.MemberCredID)}),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
grpc.WithReturnConnectionError(), grpc.WithReturnConnectionError(),
grpc.WithDisableRetry(), grpc.WithDisableRetry(),