diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 8b518ee52b..ee1ede7535 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -37,6 +37,15 @@ etcd: # Embedded Etcd only. # please adjust in embedded Milvus: /tmp/milvus/etcdData/ dir: default.etcd + ssl: + enabled: false # Whether to support ETCD secure connection mode + tlsCert: /path/to/etcd-client.pem # path to your cert file + tlsKey: /path/to/etcd-client-key.pem # path to your key file + tlsCACert: /path/to/ca.pem # path to your CACert file + # TLS min version + # Optional values: 1.0, 1.1, 1.2, 1.3。 + # We recommend using version 1.2 and above + tlsMinVersion: 1.3 # please adjust in embedded Milvus: /tmp/milvus/data/ localStorage: diff --git a/internal/util/etcd/etcd_util.go b/internal/util/etcd/etcd_util.go index cda73c1b3e..c6d029a266 100644 --- a/internal/util/etcd/etcd_util.go +++ b/internal/util/etcd/etcd_util.go @@ -17,8 +17,13 @@ package etcd import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" "time" + "github.com/pkg/errors" + "github.com/milvus-io/milvus/internal/util/paramtable" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -28,6 +33,9 @@ func GetEtcdClient(cfg *paramtable.EtcdConfig) (*clientv3.Client, error) { if cfg.UseEmbedEtcd { return GetEmbedEtcdClient() } + if cfg.EtcdUseSSL { + return GetRemoteEtcdSSLClient(cfg.Endpoints, cfg.EtcdTLSCert, cfg.EtcdTLSKey, cfg.EtcdTLSCACert, cfg.EtcdTLSMinVersion) + } return GetRemoteEtcdClient(cfg.Endpoints) } @@ -38,3 +46,45 @@ func GetRemoteEtcdClient(endpoints []string) (*clientv3.Client, error) { DialTimeout: 5 * time.Second, }) } + +func GetRemoteEtcdSSLClient(endpoints []string, certFile string, keyFile string, caCertFile string, minVersion string) (*clientv3.Client, error) { + var cfg clientv3.Config + cfg.Endpoints = endpoints + cfg.DialTimeout = 5 * time.Second + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, errors.Wrap(err, "load etcd cert key pair error") + } + caCert, err := ioutil.ReadFile(caCertFile) + if err != nil { + return nil, errors.Wrapf(err, "load etcd CACert file error, filename = %s", caCertFile) + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + cfg.TLS = &tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{ + cert, + }, + RootCAs: caCertPool, + } + switch minVersion { + case "1.0": + cfg.TLS.MinVersion = tls.VersionTLS10 + case "1.1": + cfg.TLS.MinVersion = tls.VersionTLS11 + case "1.2": + cfg.TLS.MinVersion = tls.VersionTLS12 + case "1.3": + cfg.TLS.MinVersion = tls.VersionTLS13 + default: + cfg.TLS.MinVersion = 0 + } + + if cfg.TLS.MinVersion == 0 { + return nil, errors.Errorf("unknown TLS version,%s", minVersion) + } + + return clientv3.New(cfg) +} diff --git a/internal/util/etcd/etcd_util_test.go b/internal/util/etcd/etcd_util_test.go index e350d771ab..6b72c144dd 100644 --- a/internal/util/etcd/etcd_util_test.go +++ b/internal/util/etcd/etcd_util_test.go @@ -48,4 +48,27 @@ func TestEtcd(t *testing.T) { assert.NoError(t, err) assert.False(t, resp.Count < 1) assert.Equal(t, string(resp.Kvs[0].Value), "value") + + Params.EtcdCfg.UseEmbedEtcd = false + Params.EtcdCfg.EtcdUseSSL = true + Params.EtcdCfg.EtcdTLSMinVersion = "1.3" + Params.EtcdCfg.EtcdTLSCACert = "../../../configs/cert/ca.pem" + Params.EtcdCfg.EtcdTLSCert = "../../../configs/cert/client.pem" + Params.EtcdCfg.EtcdTLSKey = "../../../configs/cert/client.key" + etcdCli, err = GetEtcdClient(&Params.EtcdCfg) + assert.NoError(t, err) + + Params.EtcdCfg.EtcdTLSMinVersion = "some not right word" + etcdCli, err = GetEtcdClient(&Params.EtcdCfg) + assert.NotNil(t, err) + + Params.EtcdCfg.EtcdTLSMinVersion = "1.2" + Params.EtcdCfg.EtcdTLSCACert = "wrong/file" + etcdCli, err = GetEtcdClient(&Params.EtcdCfg) + assert.NotNil(t, err) + + Params.EtcdCfg.EtcdTLSCACert = "../../../configs/cert/ca.pem" + Params.EtcdCfg.EtcdTLSCert = "wrong/file" + assert.NotNil(t, err) + } diff --git a/internal/util/paramtable/service_param.go b/internal/util/paramtable/service_param.go index de63b06b00..804c798361 100644 --- a/internal/util/paramtable/service_param.go +++ b/internal/util/paramtable/service_param.go @@ -61,11 +61,16 @@ type EtcdConfig struct { Base *BaseTable // --- ETCD --- - Endpoints []string - MetaRootPath string - KvRootPath string - EtcdLogLevel string - EtcdLogPath string + Endpoints []string + MetaRootPath string + KvRootPath string + EtcdLogLevel string + EtcdLogPath string + EtcdUseSSL bool + EtcdTLSCert string + EtcdTLSKey string + EtcdTLSCACert string + EtcdTLSMinVersion string // --- Embed ETCD --- UseEmbedEtcd bool @@ -90,6 +95,11 @@ func (p *EtcdConfig) LoadCfgToMemory() { p.initKvRootPath() p.initEtcdLogLevel() p.initEtcdLogPath() + p.initEtcdUseSSL() + p.initEtcdTLSCert() + p.initEtcdTLSKey() + p.initEtcdTLSCACert() + p.initEtcdTLSMinVersion() } func (p *EtcdConfig) initUseEmbedEtcd() { @@ -149,6 +159,26 @@ func (p *EtcdConfig) initEtcdLogPath() { p.EtcdLogPath = p.Base.LoadWithDefault("etcd.log.path", defaultEtcdLogPath) } +func (p *EtcdConfig) initEtcdUseSSL() { + p.EtcdUseSSL = p.Base.ParseBool("etcd.ssl.enabled", false) +} + +func (p *EtcdConfig) initEtcdTLSCert() { + p.EtcdTLSCert = p.Base.LoadWithDefault("etcd.ssl.tlsCert", "") +} + +func (p *EtcdConfig) initEtcdTLSKey() { + p.EtcdTLSKey = p.Base.LoadWithDefault("etcd.ssl.tlsKey", "") +} + +func (p *EtcdConfig) initEtcdTLSCACert() { + p.EtcdTLSCACert = p.Base.LoadWithDefault("etcd.ssl.tlsCACert", "") +} + +func (p *EtcdConfig) initEtcdTLSMinVersion() { + p.EtcdTLSMinVersion = p.Base.LoadWithDefault("etcd.ssl.tlsMinVersion", "1.3") +} + type LocalStorageConfig struct { Base *BaseTable diff --git a/internal/util/paramtable/service_param_test.go b/internal/util/paramtable/service_param_test.go index a4162c4894..3b5c16a8ce 100644 --- a/internal/util/paramtable/service_param_test.go +++ b/internal/util/paramtable/service_param_test.go @@ -35,6 +35,21 @@ func TestServiceParam(t *testing.T) { assert.NotEqual(t, Params.KvRootPath, "") t.Logf("kv root path = %s", Params.KvRootPath) + assert.NotNil(t, Params.EtcdUseSSL) + t.Logf("use ssl = %t", Params.EtcdUseSSL) + + assert.NotEmpty(t, Params.EtcdTLSKey) + t.Logf("tls key = %s", Params.EtcdTLSKey) + + assert.NotEmpty(t, Params.EtcdTLSCACert) + t.Logf("tls CACert = %s", Params.EtcdTLSCACert) + + assert.NotEmpty(t, Params.EtcdTLSCert) + t.Logf("tls cert = %s", Params.EtcdTLSCert) + + assert.NotEmpty(t, Params.EtcdTLSMinVersion) + t.Logf("tls minVersion = %s", Params.EtcdTLSMinVersion) + // test UseEmbedEtcd Params.Base.Save("etcd.use.embed", "true") assert.Nil(t, os.Setenv(metricsinfo.DeployModeEnvKey, metricsinfo.ClusterDeployMode))