mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: [2.6] skip check source id (#45379)
pr: https://github.com/milvus-io/milvus/pull/45377 relate:https://github.com/milvus-io/milvus/issues/45381 Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
81c2fd46a5
commit
e938bacf20
@ -32,20 +32,6 @@ func parseMD(rawToken string) (username, password string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func validSourceID(ctx context.Context, authorization []string) bool {
|
|
||||||
if len(authorization) < 1 {
|
|
||||||
// log.Warn("key not found in header", zap.String("key", util.HeaderSourceID))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// token format: base64<sourceID>
|
|
||||||
token := authorization[0]
|
|
||||||
sourceID, err := crypto.Base64Decode(token)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return sourceID == util.MemberCredID
|
|
||||||
}
|
|
||||||
|
|
||||||
func GrpcAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor {
|
func GrpcAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor {
|
||||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
var newCtx context.Context
|
var newCtx context.Context
|
||||||
@ -76,11 +62,8 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
|
|||||||
if globalMetaCache == nil {
|
if globalMetaCache == nil {
|
||||||
return nil, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
return nil, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
||||||
}
|
}
|
||||||
// check:
|
// check rpc call from sdk
|
||||||
// 1. if rpc call from a member (like index/query/data component)
|
|
||||||
// 2. if rpc call from sdk
|
|
||||||
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
if Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||||
if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) {
|
|
||||||
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
|
authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
|
||||||
|
|
||||||
if len(authStrArr) < 1 {
|
if len(authStrArr) < 1 {
|
||||||
@ -119,6 +102,5 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) {
|
|||||||
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
metrics.UserRPCCounter.WithLabelValues(username).Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -49,19 +49,6 @@ func TestValidAuth(t *testing.T) {
|
|||||||
assert.False(t, res)
|
assert.False(t, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidSourceID(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
// no metadata
|
|
||||||
res := validSourceID(ctx, nil)
|
|
||||||
assert.False(t, res)
|
|
||||||
// illegal metadata
|
|
||||||
res = validSourceID(ctx, []string{"invalid_sourceid"})
|
|
||||||
assert.False(t, res)
|
|
||||||
// normal sourceId
|
|
||||||
res = validSourceID(ctx, []string{crypto.Base64Encode(util.MemberCredID)})
|
|
||||||
assert.True(t, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthenticationInterceptor(t *testing.T) {
|
func TestAuthenticationInterceptor(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on
|
paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on
|
||||||
@ -84,11 +71,6 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
|||||||
ctx = metadata.NewIncomingContext(ctx, md)
|
ctx = metadata.NewIncomingContext(ctx, md)
|
||||||
_, err = AuthenticationInterceptor(ctx)
|
_, err = AuthenticationInterceptor(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
// with valid sourceId
|
|
||||||
md = metadata.Pairs("sourceid", crypto.Base64Encode(util.MemberCredID))
|
|
||||||
ctx = metadata.NewIncomingContext(ctx, md)
|
|
||||||
_, err = AuthenticationInterceptor(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// wrong authorization style
|
// wrong authorization style
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
package grpcclient
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Token struct {
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Token) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
|
||||||
return map[string]string{util.HeaderSourceID: t.Value}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Token) RequireTransportSecurity() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@ -40,8 +40,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/tracer"
|
"github.com/milvus-io/milvus/pkg/v2/tracer"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/generic"
|
"github.com/milvus-io/milvus/pkg/v2/util/generic"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
|
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
|
||||||
@ -310,7 +308,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
|
|||||||
},
|
},
|
||||||
MinConnectTimeout: c.DialTimeout,
|
MinConnectTimeout: c.DialTimeout,
|
||||||
}),
|
}),
|
||||||
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
|
|
||||||
grpc.FailOnNonTempDialError(true),
|
grpc.FailOnNonTempDialError(true),
|
||||||
grpc.WithReturnConnectionError(),
|
grpc.WithReturnConnectionError(),
|
||||||
grpc.WithDisableRetry(),
|
grpc.WithDisableRetry(),
|
||||||
@ -349,7 +346,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
|
|||||||
},
|
},
|
||||||
MinConnectTimeout: c.DialTimeout,
|
MinConnectTimeout: c.DialTimeout,
|
||||||
}),
|
}),
|
||||||
grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}),
|
|
||||||
grpc.FailOnNonTempDialError(true),
|
grpc.FailOnNonTempDialError(true),
|
||||||
grpc.WithReturnConnectionError(),
|
grpc.WithReturnConnectionError(),
|
||||||
grpc.WithDisableRetry(),
|
grpc.WithDisableRetry(),
|
||||||
|
|||||||
@ -46,10 +46,6 @@ const (
|
|||||||
|
|
||||||
HeaderAuthorize = "authorization"
|
HeaderAuthorize = "authorization"
|
||||||
HeaderToken = "token"
|
HeaderToken = "token"
|
||||||
// HeaderSourceID identify requests from Milvus members and client requests
|
|
||||||
HeaderSourceID = "sourceId"
|
|
||||||
// MemberCredID id for Milvus members (data/index/query node/coord component)
|
|
||||||
MemberCredID = "@@milvus-member@@"
|
|
||||||
CredentialSeperator = ":"
|
CredentialSeperator = ":"
|
||||||
UserRoot = "root"
|
UserRoot = "root"
|
||||||
PasswordHolder = "___"
|
PasswordHolder = "___"
|
||||||
|
|||||||
@ -45,12 +45,9 @@ import (
|
|||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus/internal/util/grpcclient"
|
|
||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
|
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/lifetime"
|
"github.com/milvus-io/milvus/pkg/v2/util/lifetime"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
@ -499,7 +496,6 @@ func DailGRPClient(ctx context.Context, addr string, rootPath string, nodeID int
|
|||||||
},
|
},
|
||||||
MinConnectTimeout: 5 * time.Second,
|
MinConnectTimeout: 5 * time.Second,
|
||||||
}),
|
}),
|
||||||
grpc.WithPerRPCCredentials(&grpcclient.Token{Value: crypto.Base64Encode(util.MemberCredID)}),
|
|
||||||
grpc.FailOnNonTempDialError(true),
|
grpc.FailOnNonTempDialError(true),
|
||||||
grpc.WithReturnConnectionError(),
|
grpc.WithReturnConnectionError(),
|
||||||
grpc.WithDisableRetry(),
|
grpc.WithDisableRetry(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user