From c5f455dc6b69ae902acf85bb3678cbdcd335f3a9 Mon Sep 17 00:00:00 2001 From: jaime Date: Mon, 27 Nov 2023 16:28:34 +0800 Subject: [PATCH] fix: cmux graceful shutdown on proxy service (#28383) issue #28305 Signed-off-by: jaime --- internal/distributed/proxy/service.go | 52 +++++++++++---- internal/distributed/proxy/service_test.go | 73 ++++++++++++++++++++++ 2 files changed, 112 insertions(+), 13 deletions(-) diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index bea362fe41..05bbb277ec 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -30,6 +30,7 @@ import ( "sync" "time" + "github.com/cockroachdb/errors" "github.com/gin-gonic/gin" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" @@ -80,6 +81,10 @@ var ( errInvalidToken = status.Errorf(codes.Unauthenticated, "invalid token") // registerHTTPHandlerOnce avoid register http handler multiple times registerHTTPHandlerOnce sync.Once + // only for test + enableCustomInterceptor = true + // only for test, register internal interface to external service + enableRegisterProxyServer = false ) const apiPathPrefix = "/api/v1" @@ -232,12 +237,10 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { log.Debug("Get proxy rate limiter done", zap.Int("port", grpcPort)) opts := tracer.GetInterceptorOpts() - grpcOpts := []grpc.ServerOption{ - grpc.KeepaliveEnforcementPolicy(kaep), - grpc.KeepaliveParams(kasp), - grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), - grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), - grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + + var unaryServerOption grpc.ServerOption + if enableCustomInterceptor { + unaryServerOption = grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( otelgrpc.UnaryServerInterceptor(opts...), grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor), proxy.DatabaseInterceptor(), @@ -248,7 +251,17 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { accesslog.UnaryAccessLoggerInterceptor, proxy.TraceLogInterceptor, proxy.KeepActiveInterceptor, - )), + )) + } else { + unaryServerOption = grpc.EmptyServerOption{} + } + + grpcOpts := []grpc.ServerOption{ + grpc.KeepaliveEnforcementPolicy(kaep), + grpc.KeepaliveParams(kasp), + grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), + grpc.MaxSendMsgSize(Params.ServerMaxSendSize.GetAsInt()), + unaryServerOption, } if Params.TLSMode.GetAsInt() == 1 { @@ -290,6 +303,11 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(tlsConf))) } s.grpcExternalServer = grpc.NewServer(grpcOpts...) + + if enableRegisterProxyServer { + proxypb.RegisterProxyServer(s.grpcExternalServer, s) + } + milvuspb.RegisterMilvusServiceServer(s.grpcExternalServer, s) grpc_health_v1.RegisterHealthServer(s.grpcExternalServer, s) errChan <- nil @@ -390,7 +408,7 @@ func (s *Server) Run() error { s.wg.Add(1) go func() { defer s.wg.Done() - if err := s.tcpServer.Serve(); err != nil && err != cmux.ErrServerClosed { + if err := s.tcpServer.Serve(); err != nil && !errors.Is(err, net.ErrClosed) { log.Warn("Proxy server for tcp port failed", zap.Error(err)) return } @@ -651,11 +669,8 @@ func (s *Server) Stop() error { go func() { defer gracefulWg.Done() - if s.tcpServer != nil { - log.Info("Proxy stop tcp server...") - s.tcpServer.Close() - } - + // try to close grpc server firstly, it has the same root listener with cmux server and + // http listener that tls has not been enabled. if s.grpcExternalServer != nil { log.Info("Proxy stop external grpc server") utils.GracefulStopGRPCServer(s.grpcExternalServer) @@ -666,6 +681,17 @@ func (s *Server) Stop() error { s.httpServer.Close() } + // close cmux server, it isn't a synchronized operation. + // Note that: + // 1. all listeners can be closed after closing cmux server that has the root listener, it will automatically + // propagate the closure to all the listeners derived from it, but it doesn't provide a graceful shutdown + // grpc server ideally. + // 2. avoid resource leak also need to close cmux after grpc and http listener closed. + if s.tcpServer != nil { + log.Info("Proxy stop tcp server...") + s.tcpServer.Close() + } + if s.grpcInternalServer != nil { log.Info("Proxy stop internal grpc server") utils.GracefulStopGRPCServer(s.grpcInternalServer) diff --git a/internal/distributed/proxy/service_test.go b/internal/distributed/proxy/service_test.go index 34d506d348..993340fa58 100644 --- a/internal/distributed/proxy/service_test.go +++ b/internal/distributed/proxy/service_test.go @@ -25,6 +25,7 @@ import ( "net/http/httptest" "os" "strconv" + "sync/atomic" "testing" "time" @@ -34,6 +35,7 @@ import ( "github.com/stretchr/testify/mock" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -42,6 +44,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/federpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client" "github.com/milvus-io/milvus/internal/distributed/proxy/httpserver" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -50,6 +53,7 @@ import ( "github.com/milvus-io/milvus/internal/types" milvusmock "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/uniquegenerator" @@ -1380,3 +1384,72 @@ func TestHttpAuthenticate(t *testing.T) { assert.Equal(t, "foo", ctxName) } } + +func Test_Service_GracefulStop(t *testing.T) { + mockedProxy := mocks.NewMockProxy(t) + var count int32 + + mockedProxy.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Run(func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) { + fmt.Println("rpc start") + time.Sleep(10 * time.Second) + atomic.AddInt32(&count, 1) + fmt.Println("rpc done") + }).Return(&milvuspb.ComponentStates{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + + mockedProxy.EXPECT().Init().Return(nil) + mockedProxy.EXPECT().Start().Return(nil) + mockedProxy.EXPECT().Stop().Return(nil) + mockedProxy.EXPECT().Register().Return(nil) + mockedProxy.EXPECT().SetEtcdClient(mock.Anything).Return() + mockedProxy.EXPECT().GetRateLimiter().Return(nil, nil) + mockedProxy.EXPECT().SetDataCoordClient(mock.Anything).Return() + mockedProxy.EXPECT().SetRootCoordClient(mock.Anything).Return() + mockedProxy.EXPECT().SetQueryCoordClient(mock.Anything).Return() + mockedProxy.EXPECT().UpdateStateCode(mock.Anything).Return() + mockedProxy.EXPECT().SetAddress(mock.Anything).Return() + + Params := ¶mtable.Get().ProxyGrpcServerCfg + + paramtable.Get().Save(Params.TLSMode.Key, "0") + paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort())) + paramtable.Get().Save(Params.InternalPort.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort())) + paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem") + paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key") + paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true") + paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "") + + ctx := context.Background() + enableCustomInterceptor = false + enableRegisterProxyServer = true + defer func() { + enableCustomInterceptor = true + enableRegisterProxyServer = false + }() + + server := getServer(t) + assert.NotNil(t, server) + server.proxy = mockedProxy + + err := server.Run() + assert.Nil(t, err) + + proxyClient, err := grpcproxyclient.NewClient(ctx, fmt.Sprintf("localhost:%s", Params.Port.GetValue()), 0) + assert.Nil(t, err) + + group := &errgroup.Group{} + for i := 0; i < 3; i++ { + group.Go(func() error { + _, err := proxyClient.GetComponentStates(context.TODO(), &milvuspb.GetComponentStatesRequest{}) + return err + }) + } + + // waiting for all requests have been launched + time.Sleep(1 * time.Second) + + server.Stop() + + err = group.Wait() + assert.Nil(t, err) + assert.Equal(t, count, int32(3)) +}