diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index 5973b4543b..b70e7590b0 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -639,14 +639,7 @@ func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) ( } func (c *Client) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { - if !funcutil.CheckCtxValid(ctx) { - return nil, ctx.Err() - } + return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*commonpb.Status, error) { return client.ReportDataNodeTtMsgs(ctx, req) }) - if err != nil || ret == nil { - return nil, err - } - return ret.(*commonpb.Status), err } diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index eb5e5bd31f..ab71af526e 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" clientv3 "go.etcd.io/etcd/client/v3" @@ -151,10 +152,14 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - logutil.UnaryTraceLoggerInterceptor)), + logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), + )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( otelgrpc.StreamServerInterceptor(opts...), - logutil.StreamTraceLoggerInterceptor))) + logutil.StreamTraceLoggerInterceptor, + interceptor.ClusterValidationStreamServerInterceptor(), + ))) indexpb.RegisterIndexCoordServer(s.grpcServer, s) datapb.RegisterDataCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 400751ff16..3cf75e83b1 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" clientv3 "go.etcd.io/etcd/client/v3" @@ -135,10 +136,14 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - logutil.UnaryTraceLoggerInterceptor)), + logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), + )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( otelgrpc.StreamServerInterceptor(opts...), - logutil.StreamTraceLoggerInterceptor))) + logutil.StreamTraceLoggerInterceptor, + interceptor.ClusterValidationStreamServerInterceptor(), + ))) datapb.RegisterDataNodeServer(s.grpcServer, s) ctx, cancel := context.WithCancel(s.ctx) diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index 9cdce98ddf..50f6620b0b 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -42,6 +42,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -107,10 +108,14 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - logutil.UnaryTraceLoggerInterceptor)), + logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), + )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( otelgrpc.StreamServerInterceptor(opts...), - logutil.StreamTraceLoggerInterceptor))) + logutil.StreamTraceLoggerInterceptor, + interceptor.ClusterValidationStreamServerInterceptor(), + ))) indexpb.RegisterIndexNodeServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(lis); err != nil { diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index cba4a7a150..e30161cd3b 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -264,7 +265,9 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), )), + grpc.StreamInterceptor(interceptor.ClusterValidationStreamServerInterceptor()), ) proxypb.RegisterProxyServer(s.grpcInternalServer, s) grpc_health_v1.RegisterHealthServer(s.grpcInternalServer, s) diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index bb2945bbde..e0a199bea6 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/zap" @@ -225,10 +226,14 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - logutil.UnaryTraceLoggerInterceptor)), + logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), + )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( otelgrpc.StreamServerInterceptor(opts...), - logutil.StreamTraceLoggerInterceptor))) + logutil.StreamTraceLoggerInterceptor, + interceptor.ClusterValidationStreamServerInterceptor(), + ))) querypb.RegisterQueryCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 44f73717b2..933d5890ed 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -27,6 +27,7 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/zap" @@ -180,10 +181,14 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - logutil.UnaryTraceLoggerInterceptor)), + logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), + )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( otelgrpc.StreamServerInterceptor(opts...), - logutil.StreamTraceLoggerInterceptor))) + logutil.StreamTraceLoggerInterceptor, + interceptor.ClusterValidationStreamServerInterceptor(), + ))) querypb.RegisterQueryNodeServer(s.grpcServer, s) ctx, cancel := context.WithCancel(s.ctx) diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 946a7f6419..b125ca330d 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -26,6 +26,7 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/zap" @@ -240,10 +241,14 @@ func (s *Server) startGrpcLoop(port int) { grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), - logutil.UnaryTraceLoggerInterceptor)), + logutil.UnaryTraceLoggerInterceptor, + interceptor.ClusterValidationUnaryServerInterceptor(), + )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( otelgrpc.StreamServerInterceptor(opts...), - logutil.StreamTraceLoggerInterceptor))) + logutil.StreamTraceLoggerInterceptor, + interceptor.ClusterValidationStreamServerInterceptor(), + ))) rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 3450334ad4..8642b8ca1f 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -20,9 +20,11 @@ import ( "context" "crypto/tls" "fmt" + "strings" "sync" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" "go.uber.org/zap" @@ -39,6 +41,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/generic" + "github.com/milvus-io/milvus/pkg/util/interceptor" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -194,8 +198,14 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { grpc.MaxCallSendMsgSize(c.ClientMaxSendSize), grpc.UseCompressor(compress), ), - grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor(opts...)), - grpc.WithStreamInterceptor(otelgrpc.StreamClientInterceptor(opts...)), + grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( + otelgrpc.UnaryClientInterceptor(opts...), + interceptor.ClusterInjectionUnaryClientInterceptor(), + )), + grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( + otelgrpc.StreamClientInterceptor(opts...), + interceptor.ClusterInjectionStreamClientInterceptor(), + )), grpc.WithDefaultServiceConfig(retryPolicy), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: c.KeepAliveTime, @@ -225,8 +235,14 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { grpc.MaxCallSendMsgSize(c.ClientMaxSendSize), grpc.UseCompressor(compress), ), - grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor(opts...)), - grpc.WithStreamInterceptor(otelgrpc.StreamClientInterceptor(opts...)), + grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( + otelgrpc.UnaryClientInterceptor(opts...), + interceptor.ClusterInjectionUnaryClientInterceptor(), + )), + grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( + otelgrpc.StreamClientInterceptor(opts...), + interceptor.ClusterInjectionStreamClientInterceptor(), + )), grpc.WithDefaultServiceConfig(retryPolicy), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: c.KeepAliveTime, @@ -275,6 +291,14 @@ func (c *ClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any go c.bgHealthCheck(client) return generic.Zero[T](), err } + if IsCrossClusterRoutingErr(err) { + log.Warn("CrossClusterRoutingErr, start to reset connection", + zap.String("role", c.GetRole()), + zap.Error(err), + ) + c.resetConnection(client) + return ret, merr.ErrServiceUnavailable // For concealing ErrCrossClusterRouting from the client + } if !funcutil.IsGrpcErr(err) { log.Ctx(ctx).Warn("ClientBase:isNotGrpcErr", zap.Error(err)) return generic.Zero[T](), err @@ -365,3 +389,9 @@ func (c *ClientBase[T]) SetNodeID(nodeID int64) { func (c *ClientBase[T]) GetNodeID() int64 { return c.NodeID.Load() } + +func IsCrossClusterRoutingErr(err error) bool { + // GRPC utilizes `status.Status` to encapsulate errors, + // hence it is not viable to employ the `errors.Is` for assessment. + return strings.Contains(err.Error(), merr.ErrCrossClusterRouting.Error()) +} diff --git a/pkg/util/interceptor/cluster_interceptor.go b/pkg/util/interceptor/cluster_interceptor.go new file mode 100644 index 0000000000..e27f90f531 --- /dev/null +++ b/pkg/util/interceptor/cluster_interceptor.go @@ -0,0 +1,87 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interceptor + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +const ClusterKey = "Cluster" + +// ClusterValidationUnaryServerInterceptor returns a new unary server interceptor that +// rejects the request if the client's cluster differs from that of the server. +// It is chiefly employed to tackle the `Cross-Cluster Routing` issue. +func ClusterValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return handler(ctx, req) + } + clusters := md.Get(ClusterKey) + if len(clusters) == 0 { + return handler(ctx, req) + } + cluster := clusters[0] + if cluster != "" && cluster != paramtable.Get().CommonCfg.ClusterPrefix.GetValue() { + return nil, merr.WrapErrCrossClusterRouting(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), cluster) + } + return handler(ctx, req) + } +} + +// ClusterValidationStreamServerInterceptor returns a new streaming server interceptor that +// rejects the request if the client's cluster differs from that of the server. +// It is chiefly employed to tackle the `Cross-Cluster Routing` issue. +func ClusterValidationStreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + md, ok := metadata.FromIncomingContext(ss.Context()) + if !ok { + return handler(srv, ss) + } + clusters := md.Get(ClusterKey) + if len(clusters) == 0 { + return handler(srv, ss) + } + cluster := clusters[0] + if cluster != "" && cluster != paramtable.Get().CommonCfg.ClusterPrefix.GetValue() { + return merr.WrapErrCrossClusterRouting(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), cluster) + } + return handler(srv, ss) + } +} + +// ClusterInjectionUnaryClientInterceptor returns a new unary client interceptor that injects `cluster` into outgoing context. +func ClusterInjectionUnaryClientInterceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = metadata.AppendToOutgoingContext(ctx, ClusterKey, paramtable.Get().CommonCfg.ClusterPrefix.GetValue()) + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// ClusterInjectionStreamClientInterceptor returns a new streaming client interceptor that injects `cluster` into outgoing context. +func ClusterInjectionStreamClientInterceptor() grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ctx = metadata.AppendToOutgoingContext(ctx, ClusterKey, paramtable.Get().CommonCfg.ClusterPrefix.GetValue()) + return streamer(ctx, desc, cc, method, opts...) + } +} diff --git a/pkg/util/interceptor/cluster_interceptor_test.go b/pkg/util/interceptor/cluster_interceptor_test.go new file mode 100644 index 0000000000..3bd6f71c52 --- /dev/null +++ b/pkg/util/interceptor/cluster_interceptor_test.go @@ -0,0 +1,148 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interceptor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type mockSS struct { + grpc.ServerStream + ctx context.Context +} + +func newMockSS(ctx context.Context) grpc.ServerStream { + return &mockSS{ + ctx: ctx, + } +} + +func (m *mockSS) Context() context.Context { + return m.ctx +} + +func init() { + paramtable.Get().Init() +} + +func TestClusterInterceptor(t *testing.T) { + t.Run("test ClusterInjectionUnaryClientInterceptor", func(t *testing.T) { + method := "MockMethod" + req := &milvuspb.InsertRequest{} + + var incomingContext context.Context + invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + incomingContext = ctx + return nil + } + interceptor := ClusterInjectionUnaryClientInterceptor() + ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string))) + err := interceptor(ctx, method, req, nil, nil, invoker) + assert.NoError(t, err) + + md, ok := metadata.FromOutgoingContext(incomingContext) + assert.True(t, ok) + assert.Equal(t, paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), md.Get(ClusterKey)[0]) + }) + + t.Run("test ClusterInjectionStreamClientInterceptor", func(t *testing.T) { + method := "MockMethod" + + var incomingContext context.Context + streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + incomingContext = ctx + return nil, nil + } + interceptor := ClusterInjectionStreamClientInterceptor() + ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string))) + _, err := interceptor(ctx, nil, nil, method, streamer) + assert.NoError(t, err) + + md, ok := metadata.FromOutgoingContext(incomingContext) + assert.True(t, ok) + assert.Equal(t, paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), md.Get(ClusterKey)[0]) + }) + + t.Run("test ClusterValidationUnaryServerInterceptor", func(t *testing.T) { + method := "MockMethod" + req := &milvuspb.InsertRequest{} + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + } + serverInfo := &grpc.UnaryServerInfo{FullMethod: method} + interceptor := ClusterValidationUnaryServerInterceptor() + + // no md in context + _, err := interceptor(context.Background(), req, serverInfo, handler) + assert.NoError(t, err) + + // no cluster in md + ctx := metadata.NewIncomingContext(context.Background(), metadata.New(make(map[string]string))) + _, err = interceptor(ctx, req, serverInfo, handler) + assert.NoError(t, err) + + // with cross-cluster + md := metadata.Pairs(ClusterKey, "ins-1") + ctx = metadata.NewIncomingContext(context.Background(), md) + _, err = interceptor(ctx, req, serverInfo, handler) + assert.ErrorIs(t, err, merr.ErrCrossClusterRouting) + + // with same cluster + md = metadata.Pairs(ClusterKey, paramtable.Get().CommonCfg.ClusterPrefix.GetValue()) + ctx = metadata.NewIncomingContext(context.Background(), md) + _, err = interceptor(ctx, req, serverInfo, handler) + assert.NoError(t, err) + }) + + t.Run("test ClusterValidationUnaryServerInterceptor", func(t *testing.T) { + handler := func(srv interface{}, stream grpc.ServerStream) error { + return nil + } + interceptor := ClusterValidationStreamServerInterceptor() + + // no md in context + err := interceptor(nil, newMockSS(context.Background()), nil, handler) + assert.NoError(t, err) + + // no cluster in md + ctx := metadata.NewIncomingContext(context.Background(), metadata.New(make(map[string]string))) + err = interceptor(nil, newMockSS(ctx), nil, handler) + assert.NoError(t, err) + + // with cross-cluster + md := metadata.Pairs(ClusterKey, "ins-1") + ctx = metadata.NewIncomingContext(context.Background(), md) + err = interceptor(nil, newMockSS(ctx), nil, handler) + assert.ErrorIs(t, err, merr.ErrCrossClusterRouting) + + // with same cluster + md = metadata.Pairs(ClusterKey, paramtable.Get().CommonCfg.ClusterPrefix.GetValue()) + ctx = metadata.NewIncomingContext(context.Background(), md) + err = interceptor(nil, newMockSS(ctx), nil, handler) + assert.NoError(t, err) + }) +} diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 1250958b83..0b6fea62a7 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -53,6 +53,7 @@ var ( ErrServiceMemoryLimitExceeded = newMilvusError("memory limit exceeded", 3, false) ErrServiceRequestLimitExceeded = newMilvusError("request limit exceeded", 4, true) ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus + ErrCrossClusterRouting = newMilvusError("cross cluster routing", 6, false) // Collection related ErrCollectionNotFound = newMilvusError("collection not found", 100, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index 7b8027284b..3367a0278a 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -65,6 +65,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrServiceMemoryLimitExceeded(110, 100, "MLE"), ErrServiceMemoryLimitExceeded) s.ErrorIs(WrapErrServiceRequestLimitExceeded(100, "too many requests"), ErrServiceRequestLimitExceeded) s.ErrorIs(WrapErrServiceInternal("never throw out"), ErrServiceInternal) + s.ErrorIs(WrapErrCrossClusterRouting("ins-0", "ins-1"), ErrCrossClusterRouting) // Collection related s.ErrorIs(WrapErrCollectionNotFound("test_collection", "failed to get collection"), ErrCollectionNotFound) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index b4d33b1b00..9adf6a3f43 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -160,6 +160,14 @@ func WrapErrServiceInternal(msg string, others ...string) error { return err } +func WrapErrCrossClusterRouting(expectedCluster, actualCluster string, msg ...string) error { + err := errors.Wrapf(ErrCrossClusterRouting, "expectedCluster=%s, actualCluster=%s", expectedCluster, actualCluster) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + // Collection related func WrapErrCollectionNotFound(collection any, msg ...string) error { err := wrapWithField(ErrCollectionNotFound, "collection", collection) diff --git a/tests/integration/crossclusterrouting/cross_cluster_routing_test.go b/tests/integration/crossclusterrouting/cross_cluster_routing_test.go new file mode 100644 index 0000000000..9f3a796328 --- /dev/null +++ b/tests/integration/crossclusterrouting/cross_cluster_routing_test.go @@ -0,0 +1,283 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crossclusterrouting + +import ( + "context" + "fmt" + "math/rand" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + + grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" + grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" + grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" + grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" + grpcindexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" + grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy" + grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client" + grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" + grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode" + grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" + grpcrootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" +) + +type CrossClusterRoutingSuite struct { + suite.Suite + + ctx context.Context + cancel context.CancelFunc + + factory dependency.Factory + client *clientv3.Client + + // clients + rootCoordClient *grpcrootcoordclient.Client + proxyClient *grpcproxyclient.Client + dataCoordClient *grpcdatacoordclient.Client + queryCoordClient *grpcquerycoordclient.Client + dataNodeClient *grpcdatanodeclient.Client + queryNodeClient *grpcquerynodeclient.Client + indexNodeClient *grpcindexnodeclient.Client + + // servers + rootCoord *grpcrootcoord.Server + proxy *grpcproxy.Server + dataCoord *grpcdatacoord.Server + queryCoord *grpcquerycoord.Server + dataNode *grpcdatanode.Server + queryNode *grpcquerynode.Server + indexNode *grpcindexnode.Server +} + +func (s *CrossClusterRoutingSuite) SetupSuite() { + s.ctx, s.cancel = context.WithTimeout(context.Background(), time.Second*180) + rand.Seed(time.Now().UnixNano()) + + paramtable.Get().Init() + s.factory = dependency.NewDefaultFactory(true) +} + +func (s *CrossClusterRoutingSuite) TearDownSuite() { +} + +func (s *CrossClusterRoutingSuite) SetupTest() { + s.T().Logf("Setup test...") + var err error + + // setup etcd client + etcdConfig := ¶mtable.Get().EtcdCfg + s.client, err = etcd.GetEtcdClient( + etcdConfig.UseEmbedEtcd.GetAsBool(), + etcdConfig.EtcdUseSSL.GetAsBool(), + etcdConfig.Endpoints.GetAsStrings(), + etcdConfig.EtcdTLSCert.GetValue(), + etcdConfig.EtcdTLSKey.GetValue(), + etcdConfig.EtcdTLSCACert.GetValue(), + etcdConfig.EtcdTLSMinVersion.GetValue()) + s.NoError(err) + metaRoot := paramtable.Get().EtcdCfg.MetaRootPath.GetValue() + + // setup clients + s.rootCoordClient, err = grpcrootcoordclient.NewClient(s.ctx, metaRoot, s.client) + s.NoError(err) + s.dataCoordClient, err = grpcdatacoordclient.NewClient(s.ctx, metaRoot, s.client) + s.NoError(err) + s.queryCoordClient, err = grpcquerycoordclient.NewClient(s.ctx, metaRoot, s.client) + s.NoError(err) + s.proxyClient, err = grpcproxyclient.NewClient(s.ctx, paramtable.Get().ProxyGrpcClientCfg.GetInternalAddress()) + s.NoError(err) + s.dataNodeClient, err = grpcdatanodeclient.NewClient(s.ctx, paramtable.Get().DataNodeGrpcClientCfg.GetAddress()) + s.NoError(err) + s.queryNodeClient, err = grpcquerynodeclient.NewClient(s.ctx, paramtable.Get().QueryNodeGrpcClientCfg.GetAddress()) + s.NoError(err) + s.indexNodeClient, err = grpcindexnodeclient.NewClient(s.ctx, paramtable.Get().IndexNodeGrpcClientCfg.GetAddress(), false) + s.NoError(err) + + // setup servers + s.rootCoord, err = grpcrootcoord.NewServer(s.ctx, s.factory) + s.NoError(err) + err = s.rootCoord.Run() + s.NoError(err) + s.T().Logf("rootCoord server successfully started") + + s.dataCoord = grpcdatacoord.NewServer(s.ctx, s.factory) + s.NotNil(s.dataCoord) + err = s.dataCoord.Run() + s.NoError(err) + s.T().Logf("dataCoord server successfully started") + + s.queryCoord, err = grpcquerycoord.NewServer(s.ctx, s.factory) + s.NoError(err) + err = s.queryCoord.Run() + s.NoError(err) + s.T().Logf("queryCoord server successfully started") + + s.proxy, err = grpcproxy.NewServer(s.ctx, s.factory) + s.NoError(err) + err = s.proxy.Run() + s.NoError(err) + s.T().Logf("proxy server successfully started") + + s.dataNode, err = grpcdatanode.NewServer(s.ctx, s.factory) + s.NoError(err) + err = s.dataNode.Run() + s.NoError(err) + s.T().Logf("dataNode server successfully started") + + s.queryNode, err = grpcquerynode.NewServer(s.ctx, s.factory) + s.NoError(err) + err = s.queryNode.Run() + s.NoError(err) + s.T().Logf("queryNode server successfully started") + + s.indexNode, err = grpcindexnode.NewServer(s.ctx, s.factory) + s.NoError(err) + err = s.indexNode.Run() + s.NoError(err) + s.T().Logf("indexNode server successfully started") +} + +func (s *CrossClusterRoutingSuite) TearDownTest() { + err := s.rootCoord.Stop() + s.NoError(err) + err = s.proxy.Stop() + s.NoError(err) + err = s.dataCoord.Stop() + s.NoError(err) + err = s.queryCoord.Stop() + s.NoError(err) + err = s.dataNode.Stop() + s.NoError(err) + err = s.queryNode.Stop() + s.NoError(err) + err = s.indexNode.Stop() + s.NoError(err) + s.cancel() +} + +func (s *CrossClusterRoutingSuite) TestCrossClusterRoutingSuite() { + const ( + waitFor = time.Second * 10 + duration = time.Millisecond * 10 + ) + + go func() { + for { + select { + case <-s.ctx.Done(): + return + default: + err := paramtable.Get().Save(paramtable.Get().CommonCfg.ClusterPrefix.Key, fmt.Sprintf("%d", rand.Int())) + if err != nil { + panic(err) + } + } + } + }() + + // test rootCoord + s.Eventually(func() bool { + resp, err := s.rootCoordClient.ShowCollections(s.ctx, &milvuspb.ShowCollectionsRequest{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) + + // test dataCoord + s.Eventually(func() bool { + resp, err := s.dataCoordClient.GetRecoveryInfoV2(s.ctx, &datapb.GetRecoveryInfoRequestV2{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) + + // test queryCoord + s.Eventually(func() bool { + resp, err := s.queryCoordClient.LoadCollection(s.ctx, &querypb.LoadCollectionRequest{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) + + // test proxy + s.Eventually(func() bool { + resp, err := s.proxyClient.InvalidateCollectionMetaCache(s.ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) + + // test dataNode + s.Eventually(func() bool { + resp, err := s.dataNodeClient.FlushSegments(s.ctx, &datapb.FlushSegmentsRequest{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) + + // test queryNode + s.Eventually(func() bool { + resp, err := s.queryNodeClient.Search(s.ctx, &querypb.SearchRequest{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) + + // test indexNode + s.Eventually(func() bool { + resp, err := s.indexNodeClient.CreateJob(s.ctx, &indexpb.CreateJobRequest{}) + s.Suite.T().Logf("resp: %s, err: %s", resp, err) + if err != nil { + return strings.Contains(err.Error(), merr.ErrServiceUnavailable.Error()) + } + return false + }, waitFor, duration) +} + +func TestCrossClusterRoutingSuite(t *testing.T) { + suite.Run(t, new(CrossClusterRoutingSuite)) +}