vitess.io/vitess@v0.16.2/go/vt/vttls/vttls.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vttls 18 19 import ( 20 "crypto/tls" 21 "crypto/x509" 22 "os" 23 "strings" 24 "sync" 25 26 "vitess.io/vitess/go/vt/proto/vtrpc" 27 "vitess.io/vitess/go/vt/vterrors" 28 ) 29 30 // SslMode indicates the type of SSL mode to use. This matches 31 // the MySQL SSL modes as mentioned at: 32 // https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode 33 type SslMode string 34 35 // Disabled disables SSL and connects over plain text 36 const Disabled SslMode = "disabled" 37 38 // Preferred establishes an SSL connection if the server supports it. 39 // It does not validate the certificate provided by the server. 40 const Preferred SslMode = "preferred" 41 42 // Required requires an SSL connection to the server. 43 // It does not validate the certificate provided by the server. 44 const Required SslMode = "required" 45 46 // VerifyCA requires an SSL connection to the server. 47 // It validates the CA against the configured CA certificate(s). 48 const VerifyCA SslMode = "verify_ca" 49 50 // VerifyIdentity requires an SSL connection to the server. 51 // It validates the CA against the configured CA certificate(s) and 52 // also validates the certificate based on the hostname. 53 // This is the setting you want when you want to connect safely to 54 // a MySQL server and want to be protected against man-in-the-middle 55 // attacks. 56 const VerifyIdentity SslMode = "verify_identity" 57 58 // String returns the string representation, part of the Value interface 59 // for allowing this to be retrieved for a flag. 60 func (mode *SslMode) String() string { 61 return string(*mode) 62 } 63 64 // Type returns the value type, part of the pflag Value interface 65 // for allowing this to be used as a generic flag. 66 func (mode *SslMode) Type() string { 67 return "SslMode" 68 } 69 70 // Set updates the value of the SslMode pointer, part of the Value interface 71 // for allowing to update a flag. 72 func (mode *SslMode) Set(value string) error { 73 parsedMode := SslMode(strings.ToLower(value)) 74 switch parsedMode { 75 case "": 76 *mode = Preferred 77 return nil 78 case Disabled, Preferred, Required, VerifyCA, VerifyIdentity: 79 *mode = parsedMode 80 return nil 81 } 82 return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Invalid SSL mode specified: %s. Allowed options are disabled, preferred, required, verify_ca, verify_identity", value) 83 } 84 85 // TLSVersionToNumber converts a text description of the TLS protocol 86 // to the internal Go number representation. 87 func TLSVersionToNumber(tlsVersion string) (uint16, error) { 88 switch strings.ToLower(tlsVersion) { 89 case "tlsv1.3": 90 return tls.VersionTLS13, nil 91 case "", "tlsv1.2": 92 return tls.VersionTLS12, nil 93 case "tlsv1.1": 94 return tls.VersionTLS11, nil 95 case "tlsv1.0": 96 return tls.VersionTLS10, nil 97 default: 98 return tls.VersionTLS12, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Invalid TLS version specified: %s. Allowed options are TLSv1.0, TLSv1.1, TLSv1.2 & TLSv1.3", tlsVersion) 99 } 100 } 101 102 var onceByKeys = sync.Map{} 103 104 // ClientConfig returns the TLS config to use for a client to 105 // connect to a server with the provided parameters. 106 func ClientConfig(mode SslMode, cert, key, ca, crl, name string, minTLSVersion uint16) (*tls.Config, error) { 107 config := &tls.Config{ 108 MinVersion: minTLSVersion, 109 } 110 111 // Load the client-side cert & key if any. 112 if cert != "" && key != "" { 113 certificates, err := loadTLSCertificate(cert, key) 114 115 if err != nil { 116 return nil, err 117 } 118 119 config.Certificates = *certificates 120 } 121 122 // Load the server CA if any. 123 if ca != "" { 124 certificatePool, err := loadx509CertPool(ca) 125 126 if err != nil { 127 return nil, err 128 } 129 130 config.RootCAs = certificatePool 131 } 132 133 // Set the server name if any. 134 if name != "" { 135 config.ServerName = name 136 } 137 138 switch mode { 139 case Disabled: 140 return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "can't create config for disabled mode") 141 case Preferred, Required: 142 config.InsecureSkipVerify = true 143 case VerifyCA: 144 config.InsecureSkipVerify = true 145 config.VerifyConnection = func(cs tls.ConnectionState) error { 146 caRoots := config.RootCAs 147 if caRoots == nil { 148 var err error 149 caRoots, err = x509.SystemCertPool() 150 if err != nil { 151 return err 152 } 153 } 154 opts := x509.VerifyOptions{ 155 Roots: caRoots, 156 Intermediates: x509.NewCertPool(), 157 } 158 for _, cert := range cs.PeerCertificates[1:] { 159 opts.Intermediates.AddCert(cert) 160 } 161 _, err := cs.PeerCertificates[0].Verify(opts) 162 return err 163 } 164 case VerifyIdentity: 165 // Nothing to do here, default config is the strictest and correct. 166 default: 167 return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid mode: %s", mode) 168 } 169 170 if crl != "" { 171 crlFunc, err := verifyPeerCertificateAgainstCRL(crl) 172 if err != nil { 173 return nil, err 174 } 175 config.VerifyPeerCertificate = crlFunc 176 } 177 178 return config, nil 179 } 180 181 // ServerConfig returns the TLS config to use for a server to 182 // accept client connections. 183 func ServerConfig(cert, key, ca, crl, serverCA string, minTLSVersion uint16) (*tls.Config, error) { 184 config := &tls.Config{ 185 MinVersion: minTLSVersion, 186 } 187 188 var certificates *[]tls.Certificate 189 var err error 190 191 if serverCA != "" { 192 certificates, err = combineAndLoadTLSCertificates(serverCA, cert, key) 193 } else { 194 certificates, err = loadTLSCertificate(cert, key) 195 } 196 197 if err != nil { 198 return nil, err 199 } 200 config.Certificates = *certificates 201 202 // if specified, load ca to validate client, 203 // and enforce clients present valid certs. 204 if ca != "" { 205 certificatePool, err := loadx509CertPool(ca) 206 207 if err != nil { 208 return nil, err 209 } 210 211 config.ClientCAs = certificatePool 212 config.ClientAuth = tls.RequireAndVerifyClientCert 213 } 214 215 if crl != "" { 216 crlFunc, err := verifyPeerCertificateAgainstCRL(crl) 217 if err != nil { 218 return nil, err 219 } 220 config.VerifyPeerCertificate = crlFunc 221 } 222 223 return config, nil 224 } 225 226 var certPools = sync.Map{} 227 228 func loadx509CertPool(ca string) (*x509.CertPool, error) { 229 once, _ := onceByKeys.LoadOrStore(ca, &sync.Once{}) 230 231 var err error 232 once.(*sync.Once).Do(func() { 233 err = doLoadx509CertPool(ca) 234 }) 235 if err != nil { 236 return nil, err 237 } 238 239 result, ok := certPools.Load(ca) 240 241 if !ok { 242 return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded x509 cert pool for ca: %s", ca) 243 } 244 245 return result.(*x509.CertPool), nil 246 } 247 248 func doLoadx509CertPool(ca string) error { 249 b, err := os.ReadFile(ca) 250 if err != nil { 251 return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca) 252 } 253 254 cp := x509.NewCertPool() 255 if !cp.AppendCertsFromPEM(b) { 256 return vterrors.Errorf(vtrpc.Code_UNKNOWN, "failed to append certificates") 257 } 258 259 certPools.Store(ca, cp) 260 261 return nil 262 } 263 264 var tlsCertificates = sync.Map{} 265 266 func tlsCertificatesIdentifier(tokens ...string) string { 267 return strings.Join(tokens, ";") 268 } 269 270 func loadTLSCertificate(cert, key string) (*[]tls.Certificate, error) { 271 tlsIdentifier := tlsCertificatesIdentifier(cert, key) 272 once, _ := onceByKeys.LoadOrStore(tlsIdentifier, &sync.Once{}) 273 274 var err error 275 once.(*sync.Once).Do(func() { 276 err = doLoadTLSCertificate(cert, key) 277 }) 278 279 if err != nil { 280 return nil, err 281 } 282 283 result, ok := tlsCertificates.Load(tlsIdentifier) 284 285 if !ok { 286 return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate with cert: %s, key%s", cert, key) 287 } 288 289 return result.(*[]tls.Certificate), nil 290 } 291 292 func doLoadTLSCertificate(cert, key string) error { 293 tlsIdentifier := tlsCertificatesIdentifier(cert, key) 294 295 var certificate []tls.Certificate 296 // Load the server cert and key. 297 crt, err := tls.LoadX509KeyPair(cert, key) 298 if err != nil { 299 return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load tls certificate, cert %s, key: %s", cert, key) 300 } 301 302 certificate = []tls.Certificate{crt} 303 304 tlsCertificates.Store(tlsIdentifier, &certificate) 305 306 return nil 307 } 308 309 var combinedTLSCertificates = sync.Map{} 310 311 func combineAndLoadTLSCertificates(ca, cert, key string) (*[]tls.Certificate, error) { 312 combinedTLSIdentifier := tlsCertificatesIdentifier(ca, cert, key) 313 once, _ := onceByKeys.LoadOrStore(combinedTLSIdentifier, &sync.Once{}) 314 315 var err error 316 once.(*sync.Once).Do(func() { 317 err = doLoadAndCombineTLSCertificates(ca, cert, key) 318 }) 319 320 if err != nil { 321 return nil, err 322 } 323 324 result, ok := combinedTLSCertificates.Load(combinedTLSIdentifier) 325 326 if !ok { 327 return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate chain with ca: %s, cert: %s, key: %s", ca, cert, key) 328 } 329 330 return result.(*[]tls.Certificate), nil 331 } 332 333 func doLoadAndCombineTLSCertificates(ca, cert, key string) error { 334 combinedTLSIdentifier := tlsCertificatesIdentifier(ca, cert, key) 335 336 // Read CA certificates chain 337 caB, err := os.ReadFile(ca) 338 if err != nil { 339 return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca) 340 } 341 342 // Read server certificate 343 certB, err := os.ReadFile(cert) 344 if err != nil { 345 return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read server cert file: %s", cert) 346 } 347 348 // Read server key file 349 keyB, err := os.ReadFile(key) 350 if err != nil { 351 return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read key file: %s", key) 352 } 353 354 // Load CA, server cert and key. 355 var certificate []tls.Certificate 356 crt, err := tls.X509KeyPair(append(certB, caB...), keyB) 357 if err != nil { 358 return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load and merge tls certificate with CA, ca %s, cert %s, key: %s", ca, cert, key) 359 } 360 361 certificate = []tls.Certificate{crt} 362 363 combinedTLSCertificates.Store(combinedTLSIdentifier, &certificate) 364 365 return nil 366 }