Add cluster validation interceptor to resolve the Cross-Cluster routing issue (#25001) (#25030)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-06-21 11:36:46 +08:00 committed by GitHub
parent 6d139d97d6
commit febeb371d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 667 additions and 24 deletions

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
@ -180,10 +181,14 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
datapb.RegisterDataCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(lis); err != nil {

View File

@ -45,6 +45,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
@ -138,10 +139,14 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
datapb.RegisterDataNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx)

View File

@ -44,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
@ -340,10 +341,14 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
indexpb.RegisterIndexCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
@ -111,10 +112,14 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
indexpb.RegisterIndexNodeServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(lis); err != nil {

View File

@ -59,6 +59,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
@ -271,7 +272,9 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(interceptor.ClusterValidationStreamServerInterceptor()),
)
proxypb.RegisterProxyServer(s.grpcInternalServer, s)
milvuspb.RegisterMilvusServiceServer(s.grpcInternalServer, s)

View File

@ -44,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
@ -267,10 +268,14 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
querypb.RegisterQueryCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
@ -190,10 +191,14 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
querypb.RegisterQueryNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx)

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace"
@ -262,10 +263,14 @@ func (s *Server) startGrpcLoop(port int) {
grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
)))
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -293,7 +293,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool {
}
}
req := &indexpb.CreateJobRequest{
ClusterID: Params.CommonCfg.ClusterPrefix,
ClusterID: Params.CommonCfg.GetClusterPrefix(),
IndexFilePrefix: path.Join(ib.ic.chunkManager.RootPath(), common.SegmentIndexPath),
BuildID: buildID,
DataPaths: binLogs,
@ -390,7 +390,7 @@ func (ib *indexBuilder) getTaskState(buildID, nodeID UniqueID) indexTaskState {
ctx1, cancel := context.WithTimeout(ib.ctx, reqTimeoutInterval)
defer cancel()
response, err := client.QueryJobs(ctx1, &indexpb.QueryJobsRequest{
ClusterID: Params.CommonCfg.ClusterPrefix,
ClusterID: Params.CommonCfg.GetClusterPrefix(),
BuildIDs: []int64{buildID},
})
if err != nil {
@ -439,7 +439,7 @@ func (ib *indexBuilder) dropIndexTask(buildID, nodeID UniqueID) bool {
ctx1, cancel := context.WithTimeout(ib.ctx, reqTimeoutInterval)
defer cancel()
status, err := client.DropJobs(ctx1, &indexpb.DropJobsRequest{
ClusterID: Params.CommonCfg.ClusterPrefix,
ClusterID: Params.CommonCfg.GetClusterPrefix(),
BuildIDs: []UniqueID{buildID},
})
if err != nil {

View File

@ -20,9 +20,11 @@ import (
"context"
"crypto/tls"
"fmt"
"strings"
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcopentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"go.uber.org/zap"
"golang.org/x/sync/singleflight"
@ -37,6 +39,7 @@ import (
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/generic"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
)
@ -188,8 +191,14 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize),
grpc.MaxCallSendMsgSize(c.ClientMaxSendSize),
),
grpc.WithUnaryInterceptor(grpcopentracing.UnaryClientInterceptor(opts...)),
grpc.WithStreamInterceptor(grpcopentracing.StreamClientInterceptor(opts...)),
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
grpcopentracing.UnaryClientInterceptor(opts...),
interceptor.ClusterInjectionUnaryClientInterceptor(),
)),
grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(
grpcopentracing.StreamClientInterceptor(opts...),
interceptor.ClusterInjectionStreamClientInterceptor(),
)),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: c.KeepAliveTime,
@ -218,8 +227,14 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize),
grpc.MaxCallSendMsgSize(c.ClientMaxSendSize),
),
grpc.WithUnaryInterceptor(grpcopentracing.UnaryClientInterceptor(opts...)),
grpc.WithStreamInterceptor(grpcopentracing.StreamClientInterceptor(opts...)),
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
grpcopentracing.UnaryClientInterceptor(opts...),
interceptor.ClusterInjectionUnaryClientInterceptor(),
)),
grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(
grpcopentracing.StreamClientInterceptor(opts...),
interceptor.ClusterInjectionStreamClientInterceptor(),
)),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: c.KeepAliveTime,
@ -273,6 +288,14 @@ func (c *ClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any
go c.bgHealthCheck(client)
return generic.Zero[T](), err
}
if IsCrossClusterRoutingErr(err) {
log.Warn("CrossClusterRoutingErr, start to reset connection",
zap.String("role", c.GetRole()),
zap.Error(err),
)
c.resetConnection(client)
return ret, interceptor.ErrServiceUnavailable // For concealing ErrCrossClusterRouting from the client
}
if !funcutil.IsGrpcErr(err) {
log.Warn("ClientBase:isNotGrpcErr", zap.Error(err))
return generic.Zero[T](), err
@ -367,3 +390,9 @@ func (c *ClientBase[T]) SetNodeID(nodeID int64) {
func (c *ClientBase[T]) GetNodeID() int64 {
return c.NodeID
}
func IsCrossClusterRoutingErr(err error) bool {
// GRPC utilizes `status.Status` to encapsulate errors,
// hence it is not viable to employ the `errors.Is` for assessment.
return strings.Contains(err.Error(), interceptor.ErrCrossClusterRouting.Error())
}

View File

@ -0,0 +1,108 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interceptor
import (
"context"
"fmt"
"strings"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/internal/util/paramtable"
)
const ClusterKey = "Cluster"
var Params paramtable.ComponentParam
var (
ErrCrossClusterRouting = fmt.Errorf("cross cluster routing")
ErrServiceUnavailable = fmt.Errorf("service unavailable") // For concealing ErrCrossClusterRouting from the client
)
func init() {
Params.Init()
}
func WrapErrCrossClusterRouting(expectedCluster, actualCluster string, msg ...string) error {
err := errors.Wrapf(ErrCrossClusterRouting, "expectedCluster=%s, actualCluster=%s", expectedCluster, actualCluster)
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "; "))
}
return err
}
// ClusterValidationUnaryServerInterceptor returns a new unary server interceptor that
// rejects the request if the client's cluster differs from that of the server.
// It is chiefly employed to tackle the `Cross-Cluster Routing` issue.
func ClusterValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return handler(ctx, req)
}
clusters := md.Get(ClusterKey)
if len(clusters) == 0 {
return handler(ctx, req)
}
cluster := clusters[0]
if cluster != "" && cluster != Params.CommonCfg.GetClusterPrefix() {
return nil, WrapErrCrossClusterRouting(Params.CommonCfg.GetClusterPrefix(), cluster)
}
return handler(ctx, req)
}
}
// ClusterValidationStreamServerInterceptor returns a new streaming server interceptor that
// rejects the request if the client's cluster differs from that of the server.
// It is chiefly employed to tackle the `Cross-Cluster Routing` issue.
func ClusterValidationStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
return handler(srv, ss)
}
clusters := md.Get(ClusterKey)
if len(clusters) == 0 {
return handler(srv, ss)
}
cluster := clusters[0]
if cluster != "" && cluster != Params.CommonCfg.GetClusterPrefix() {
return WrapErrCrossClusterRouting(Params.CommonCfg.GetClusterPrefix(), cluster)
}
return handler(srv, ss)
}
}
// ClusterInjectionUnaryClientInterceptor returns a new unary client interceptor that injects `cluster` into outgoing context.
func ClusterInjectionUnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = metadata.AppendToOutgoingContext(ctx, ClusterKey, Params.CommonCfg.GetClusterPrefix())
return invoker(ctx, method, req, reply, cc, opts...)
}
}
// ClusterInjectionStreamClientInterceptor returns a new streaming client interceptor that injects `cluster` into outgoing context.
func ClusterInjectionStreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx = metadata.AppendToOutgoingContext(ctx, ClusterKey, Params.CommonCfg.GetClusterPrefix())
return streamer(ctx, desc, cc, method, opts...)
}
}

View File

@ -0,0 +1,142 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interceptor
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
type mockSS struct {
grpc.ServerStream
ctx context.Context
}
func newMockSS(ctx context.Context) grpc.ServerStream {
return &mockSS{
ctx: ctx,
}
}
func (m *mockSS) Context() context.Context {
return m.ctx
}
func TestClusterInterceptor(t *testing.T) {
t.Run("test ClusterInjectionUnaryClientInterceptor", func(t *testing.T) {
method := "MockMethod"
req := &milvuspb.InsertRequest{}
var incomingContext context.Context
invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
incomingContext = ctx
return nil
}
interceptor := ClusterInjectionUnaryClientInterceptor()
ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))
err := interceptor(ctx, method, req, nil, nil, invoker)
assert.NoError(t, err)
md, ok := metadata.FromOutgoingContext(incomingContext)
assert.True(t, ok)
assert.Equal(t, Params.CommonCfg.GetClusterPrefix(), md.Get(ClusterKey)[0])
})
t.Run("test ClusterInjectionStreamClientInterceptor", func(t *testing.T) {
method := "MockMethod"
var incomingContext context.Context
streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
incomingContext = ctx
return nil, nil
}
interceptor := ClusterInjectionStreamClientInterceptor()
ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))
_, err := interceptor(ctx, nil, nil, method, streamer)
assert.NoError(t, err)
md, ok := metadata.FromOutgoingContext(incomingContext)
assert.True(t, ok)
assert.Equal(t, Params.CommonCfg.GetClusterPrefix(), md.Get(ClusterKey)[0])
})
t.Run("test ClusterValidationUnaryServerInterceptor", func(t *testing.T) {
method := "MockMethod"
req := &milvuspb.InsertRequest{}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}
serverInfo := &grpc.UnaryServerInfo{FullMethod: method}
interceptor := ClusterValidationUnaryServerInterceptor()
// no md in context
_, err := interceptor(context.Background(), req, serverInfo, handler)
assert.NoError(t, err)
// no cluster in md
ctx := metadata.NewIncomingContext(context.Background(), metadata.New(make(map[string]string)))
_, err = interceptor(ctx, req, serverInfo, handler)
assert.NoError(t, err)
// with cross-cluster
md := metadata.Pairs(ClusterKey, "ins-1")
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, req, serverInfo, handler)
assert.ErrorIs(t, err, ErrCrossClusterRouting)
// with same cluster
md = metadata.Pairs(ClusterKey, Params.CommonCfg.GetClusterPrefix())
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, req, serverInfo, handler)
assert.NoError(t, err)
})
t.Run("test ClusterValidationUnaryServerInterceptor", func(t *testing.T) {
handler := func(srv interface{}, stream grpc.ServerStream) error {
return nil
}
interceptor := ClusterValidationStreamServerInterceptor()
// no md in context
err := interceptor(nil, newMockSS(context.Background()), nil, handler)
assert.NoError(t, err)
// no cluster in md
ctx := metadata.NewIncomingContext(context.Background(), metadata.New(make(map[string]string)))
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.NoError(t, err)
// with cross-cluster
md := metadata.Pairs(ClusterKey, "ins-1")
ctx = metadata.NewIncomingContext(context.Background(), md)
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.ErrorIs(t, err, ErrCrossClusterRouting)
// with same cluster
md = metadata.Pairs(ClusterKey, Params.CommonCfg.GetClusterPrefix())
ctx = metadata.NewIncomingContext(context.Background(), md)
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.NoError(t, err)
})
}

View File

@ -128,7 +128,7 @@ func (p *ComponentParam) KafkaEnable() bool {
type commonConfig struct {
Base *BaseTable
ClusterPrefix string
clusterPrefix atomic.Value
ProxySubName string
@ -247,7 +247,19 @@ func (p *commonConfig) initClusterPrefix() {
if err != nil {
panic(err)
}
p.ClusterPrefix = str
p.clusterPrefix.Store(str)
}
func (p *commonConfig) SetClusterPrefix(cluster string) {
p.clusterPrefix.Store(cluster)
}
func (p *commonConfig) GetClusterPrefix() string {
val := p.clusterPrefix.Load()
if val != nil {
return val.(string)
}
return ""
}
func (p *commonConfig) initChanNamePrefix(keys []string) string {
@ -255,7 +267,7 @@ func (p *commonConfig) initChanNamePrefix(keys []string) string {
if err != nil {
panic(err)
}
s := []string{p.ClusterPrefix, value}
s := []string{p.GetClusterPrefix(), value}
return strings.Join(s, "-")
}

View File

@ -0,0 +1,314 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package crossclusterrouting
import (
"context"
"fmt"
"math/rand"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/interceptor"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord"
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcindexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy"
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
grpcrootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
)
type CrossClusterRoutingSuite struct {
suite.Suite
ctx context.Context
cancel context.CancelFunc
factory dependency.Factory
client *clientv3.Client
// clients
rootCoordClient *grpcrootcoordclient.Client
proxyClient *grpcproxyclient.Client
dataCoordClient *grpcdatacoordclient.Client
indexCoordClient *grpcindexcoordclient.Client
queryCoordClient *grpcquerycoordclient.Client
dataNodeClient *grpcdatanodeclient.Client
queryNodeClient *grpcquerynodeclient.Client
indexNodeClient *grpcindexnodeclient.Client
// servers
rootCoord *grpcrootcoord.Server
proxy *grpcproxy.Server
dataCoord *grpcdatacoord.Server
indexCoord *grpcindexcoord.Server
queryCoord *grpcquerycoord.Server
dataNode *grpcdatanode.Server
queryNode *grpcquerynode.Server
indexNode *grpcindexnode.Server
}
func (s *CrossClusterRoutingSuite) SetupSuite() {
s.ctx, s.cancel = context.WithTimeout(context.Background(), time.Second*180)
rand.Seed(time.Now().UnixNano())
s.factory = dependency.NewDefaultFactory(true)
}
func (s *CrossClusterRoutingSuite) TearDownSuite() {
}
func (s *CrossClusterRoutingSuite) SetupTest() {
s.T().Logf("Setup test...")
var err error
// setup etcd client
etcdConfig := interceptor.Params.EtcdCfg
s.client, err = etcd.GetEtcdClient(
etcdConfig.UseEmbedEtcd,
etcdConfig.EtcdUseSSL,
etcdConfig.Endpoints,
etcdConfig.EtcdTLSCert,
etcdConfig.EtcdTLSKey,
etcdConfig.EtcdTLSCACert,
etcdConfig.EtcdTLSMinVersion)
s.NoError(err)
// setup servers
s.rootCoord, err = grpcrootcoord.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.rootCoord.Run()
s.NoError(err)
s.T().Logf("rootCoord server successfully started")
s.dataCoord = grpcdatacoord.NewServer(s.ctx, s.factory)
s.NotNil(s.dataCoord)
err = s.dataCoord.Run()
s.NoError(err)
s.T().Logf("dataCoord server successfully started")
s.indexCoord, err = grpcindexcoord.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.indexCoord.Run()
s.NoError(err)
s.T().Logf("indexCoord server successfully started")
s.queryCoord, err = grpcquerycoord.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.queryCoord.Run()
s.NoError(err)
s.T().Logf("queryCoord server successfully started")
s.proxy, err = grpcproxy.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.proxy.Run()
s.NoError(err)
s.T().Logf("proxy server successfully started")
s.dataNode, err = grpcdatanode.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.dataNode.Run()
s.NoError(err)
s.T().Logf("dataNode server successfully started")
s.queryNode, err = grpcquerynode.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.queryNode.Run()
s.NoError(err)
s.T().Logf("queryNode server successfully started")
s.indexNode, err = grpcindexnode.NewServer(s.ctx, s.factory)
s.NoError(err)
err = s.indexNode.Run()
s.NoError(err)
s.T().Logf("indexNode server successfully started")
metaRoot := interceptor.Params.EtcdCfg.MetaRootPath
// setup clients
s.rootCoordClient, err = grpcrootcoordclient.NewClient(s.ctx, metaRoot, s.client)
s.NoError(err)
s.dataCoordClient, err = grpcdatacoordclient.NewClient(s.ctx, metaRoot, s.client)
s.NoError(err)
s.indexCoordClient, err = grpcindexcoordclient.NewClient(s.ctx, metaRoot, s.client)
s.NoError(err)
s.queryCoordClient, err = grpcquerycoordclient.NewClient(s.ctx, metaRoot, s.client)
s.NoError(err)
var proxyGrpcServerParam paramtable.GrpcServerConfig
proxyGrpcServerParam.InitOnce(typeutil.ProxyRole)
s.proxyClient, err = grpcproxyclient.NewClient(s.ctx, proxyGrpcServerParam.GetInternalAddress())
s.NoError(err)
var dataNodeGrpcServerParam paramtable.GrpcServerConfig
dataNodeGrpcServerParam.Init(typeutil.DataNodeRole)
s.dataNodeClient, err = grpcdatanodeclient.NewClient(s.ctx, dataNodeGrpcServerParam.GetAddress())
s.NoError(err)
var queryNodeServerParam paramtable.GrpcServerConfig
queryNodeServerParam.InitOnce(typeutil.QueryNodeRole)
s.queryNodeClient, err = grpcquerynodeclient.NewClient(s.ctx, queryNodeServerParam.GetAddress())
s.NoError(err)
var indexNodeGrpcServerParam paramtable.GrpcServerConfig
indexNodeGrpcServerParam.Init(typeutil.IndexNodeRole)
s.indexNodeClient, err = grpcindexnodeclient.NewClient(s.ctx, indexNodeGrpcServerParam.GetAddress(), false)
s.NoError(err)
}
func (s *CrossClusterRoutingSuite) TearDownTest() {
err := s.rootCoord.Stop()
s.NoError(err)
err = s.proxy.Stop()
s.NoError(err)
err = s.dataCoord.Stop()
s.NoError(err)
err = s.indexCoord.Stop()
s.NoError(err)
err = s.queryCoord.Stop()
s.NoError(err)
err = s.dataNode.Stop()
s.NoError(err)
err = s.queryNode.Stop()
s.NoError(err)
err = s.indexNode.Stop()
s.NoError(err)
s.cancel()
}
func (s *CrossClusterRoutingSuite) TestCrossClusterRoutingSuite() {
const (
waitFor = time.Second * 10
duration = time.Millisecond * 10
)
go func() {
for {
select {
case <-s.ctx.Done():
return
default:
interceptor.Params.CommonCfg.SetClusterPrefix(fmt.Sprintf("%d", rand.Int()))
}
}
}()
// test rootCoord
s.Eventually(func() bool {
resp, err := s.rootCoordClient.ShowCollections(s.ctx, &milvuspb.ShowCollectionsRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test dataCoord
s.Eventually(func() bool {
resp, err := s.dataCoordClient.GetRecoveryInfoV2(s.ctx, &datapb.GetRecoveryInfoRequestV2{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test indexCoord
s.Eventually(func() bool {
resp, err := s.indexCoordClient.CreateIndex(s.ctx, &indexpb.CreateIndexRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test queryCoord
s.Eventually(func() bool {
resp, err := s.queryCoordClient.LoadCollection(s.ctx, &querypb.LoadCollectionRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test proxy
s.Eventually(func() bool {
resp, err := s.proxyClient.InvalidateCollectionMetaCache(s.ctx, &proxypb.InvalidateCollMetaCacheRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test dataNode
s.Eventually(func() bool {
resp, err := s.dataNodeClient.FlushSegments(s.ctx, &datapb.FlushSegmentsRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test queryNode
s.Eventually(func() bool {
resp, err := s.queryNodeClient.Search(s.ctx, &querypb.SearchRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
// test indexNode
s.Eventually(func() bool {
resp, err := s.indexNodeClient.CreateJob(s.ctx, &indexpb.CreateJobRequest{})
s.Suite.T().Logf("resp: %s, err: %s", resp, err)
if err != nil {
return strings.Contains(err.Error(), interceptor.ErrServiceUnavailable.Error())
}
return false
}, waitFor, duration)
}
func TestCrossClusterRoutingSuite(t *testing.T) {
suite.Run(t, new(CrossClusterRoutingSuite))
}