github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/x/tls_helper.go (about) 1 /* 2 * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package x 18 19 import ( 20 "crypto/tls" 21 "crypto/x509" 22 "io/ioutil" 23 "path" 24 "strings" 25 26 "github.com/pkg/errors" 27 "github.com/spf13/pflag" 28 "github.com/spf13/viper" 29 ) 30 31 const ( 32 tlsRootCert = "ca.crt" 33 ) 34 35 // TLSHelperConfig define params used to create a tls.Config 36 type TLSHelperConfig struct { 37 CertDir string 38 CertRequired bool 39 Cert string 40 Key string 41 ServerName string 42 RootCACert string 43 ClientAuth string 44 UseSystemCACerts bool 45 } 46 47 // RegisterClientTLSFlags registers the required flags to set up a TLS client. 48 func RegisterClientTLSFlags(flag *pflag.FlagSet) { 49 flag.String("tls_cacert", "", 50 "The CA Cert file used to verify server certificates. Required for enabling TLS.") 51 flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.") 52 flag.String("tls_server_name", "", "Used to verify the server hostname.") 53 flag.String("tls_cert", "", "(optional) The Cert file provided by the client to the server.") 54 flag.String("tls_key", "", "(optional) The private key file "+ 55 "provided by the client to the server.") 56 } 57 58 // LoadServerTLSConfig loads the TLS config into the server with the given parameters. 59 func LoadServerTLSConfig(v *viper.Viper, tlsCertFile string, tlsKeyFile string) (*tls.Config, 60 error) { 61 conf := TLSHelperConfig{} 62 conf.CertDir = v.GetString("tls_dir") 63 if conf.CertDir != "" { 64 conf.CertRequired = true 65 conf.RootCACert = path.Join(conf.CertDir, tlsRootCert) 66 conf.Cert = path.Join(conf.CertDir, tlsCertFile) 67 conf.Key = path.Join(conf.CertDir, tlsKeyFile) 68 conf.ClientAuth = v.GetString("tls_client_auth") 69 } 70 conf.UseSystemCACerts = v.GetBool("tls_use_system_ca") 71 72 return GenerateServerTLSConfig(&conf) 73 } 74 75 // LoadClientTLSConfig loads the TLS config into the client with the given parameters. 76 func LoadClientTLSConfig(v *viper.Viper) (*tls.Config, error) { 77 // When the --tls_cacert option is pecified, the connection will be set up using TLS instead of 78 // plaintext. However the client cert files are optional, depending on whether the server 79 // requires a client certificate. 80 caCert := v.GetString("tls_cacert") 81 if caCert != "" { 82 tlsCfg := tls.Config{} 83 84 // 1. set up the root CA 85 pool, err := generateCertPool(caCert, v.GetBool("tls_use_system_ca")) 86 if err != nil { 87 return nil, err 88 } 89 tlsCfg.RootCAs = pool 90 91 // 2. set up the server name for verification 92 tlsCfg.ServerName = v.GetString("tls_server_name") 93 94 // 3. optionally load the client cert files 95 certFile := v.GetString("tls_cert") 96 keyFile := v.GetString("tls_key") 97 if certFile != "" && keyFile != "" { 98 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 99 if err != nil { 100 return nil, err 101 } 102 tlsCfg.Certificates = []tls.Certificate{cert} 103 } 104 105 return &tlsCfg, nil 106 } else 107 // Attempt to determine if user specified *any* TLS option. Unfortunately and contrary to 108 // Viper's own documentation, there's no way to tell whether an option value came from a 109 // command-line option or a built-it default. 110 if v.GetString("tls_server_name") != "" || 111 v.GetString("tls_cert") != "" || 112 v.GetString("tls_key") != "" { 113 return nil, errors.Errorf("--tls_cacert is required for enabling TLS") 114 } 115 return nil, nil 116 } 117 118 func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) { 119 var pool *x509.CertPool 120 if useSystemCA { 121 var err error 122 if pool, err = x509.SystemCertPool(); err != nil { 123 return nil, err 124 } 125 } else { 126 pool = x509.NewCertPool() 127 } 128 129 if len(certPath) > 0 { 130 caFile, err := ioutil.ReadFile(certPath) 131 if err != nil { 132 return nil, err 133 } 134 if !pool.AppendCertsFromPEM(caFile) { 135 return nil, errors.Errorf("error reading CA file %q", certPath) 136 } 137 } 138 139 return pool, nil 140 } 141 142 func setupClientAuth(authType string) (tls.ClientAuthType, error) { 143 auth := map[string]tls.ClientAuthType{ 144 "REQUEST": tls.RequestClientCert, 145 "REQUIREANY": tls.RequireAnyClientCert, 146 "VERIFYIFGIVEN": tls.VerifyClientCertIfGiven, 147 "REQUIREANDVERIFY": tls.RequireAndVerifyClientCert, 148 } 149 150 if len(authType) > 0 { 151 if v, has := auth[strings.ToUpper(authType)]; has { 152 return v, nil 153 } 154 return tls.NoClientCert, errors.Errorf("Invalid client auth. Valid values " + 155 "[REQUEST, REQUIREANY, VERIFYIFGIVEN, REQUIREANDVERIFY]") 156 } 157 158 return tls.NoClientCert, nil 159 } 160 161 // GenerateServerTLSConfig creates and returns a new *tls.Config with the 162 // configuration provided. 163 func GenerateServerTLSConfig(config *TLSHelperConfig) (tlsCfg *tls.Config, err error) { 164 if config.CertRequired { 165 tlsCfg = new(tls.Config) 166 cert, err := tls.LoadX509KeyPair(config.Cert, config.Key) 167 if err != nil { 168 return nil, err 169 } 170 tlsCfg.Certificates = []tls.Certificate{cert} 171 172 pool, err := generateCertPool(config.RootCACert, config.UseSystemCACerts) 173 if err != nil { 174 return nil, err 175 } 176 tlsCfg.ClientCAs = pool 177 178 auth, err := setupClientAuth(config.ClientAuth) 179 if err != nil { 180 return nil, err 181 } 182 tlsCfg.ClientAuth = auth 183 184 tlsCfg.MinVersion = tls.VersionTLS11 185 tlsCfg.MaxVersion = tls.VersionTLS12 186 187 return tlsCfg, nil 188 } 189 return nil, nil 190 } 191 192 // GenerateClientTLSConfig creates and returns a new client side *tls.Config with the 193 // configuration provided. 194 func GenerateClientTLSConfig(config *TLSHelperConfig) (tlsCfg *tls.Config, err error) { 195 pool, err := generateCertPool(config.RootCACert, config.UseSystemCACerts) 196 if err != nil { 197 return nil, err 198 } 199 200 return &tls.Config{RootCAs: pool, ServerName: config.ServerName}, nil 201 }