vitess.io/vitess@v0.16.2/go/vt/vttls/vttls.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vttls
    18  
    19  import (
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"os"
    23  	"strings"
    24  	"sync"
    25  
    26  	"vitess.io/vitess/go/vt/proto/vtrpc"
    27  	"vitess.io/vitess/go/vt/vterrors"
    28  )
    29  
    30  // SslMode indicates the type of SSL mode to use. This matches
    31  // the MySQL SSL modes as mentioned at:
    32  // https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode
    33  type SslMode string
    34  
    35  // Disabled disables SSL and connects over plain text
    36  const Disabled SslMode = "disabled"
    37  
    38  // Preferred establishes an SSL connection if the server supports it.
    39  // It does not validate the certificate provided by the server.
    40  const Preferred SslMode = "preferred"
    41  
    42  // Required requires an SSL connection to the server.
    43  // It does not validate the certificate provided by the server.
    44  const Required SslMode = "required"
    45  
    46  // VerifyCA requires an SSL connection to the server.
    47  // It validates the CA against the configured CA certificate(s).
    48  const VerifyCA SslMode = "verify_ca"
    49  
    50  // VerifyIdentity requires an SSL connection to the server.
    51  // It validates the CA against the configured CA certificate(s) and
    52  // also validates the certificate based on the hostname.
    53  // This is the setting you want when you want to connect safely to
    54  // a MySQL server and want to be protected against man-in-the-middle
    55  // attacks.
    56  const VerifyIdentity SslMode = "verify_identity"
    57  
    58  // String returns the string representation, part of the Value interface
    59  // for allowing this to be retrieved for a flag.
    60  func (mode *SslMode) String() string {
    61  	return string(*mode)
    62  }
    63  
    64  // Type returns the value type, part of the pflag Value interface
    65  // for allowing this to be used as a generic flag.
    66  func (mode *SslMode) Type() string {
    67  	return "SslMode"
    68  }
    69  
    70  // Set updates the value of the SslMode pointer, part of the Value interface
    71  // for allowing to update a flag.
    72  func (mode *SslMode) Set(value string) error {
    73  	parsedMode := SslMode(strings.ToLower(value))
    74  	switch parsedMode {
    75  	case "":
    76  		*mode = Preferred
    77  		return nil
    78  	case Disabled, Preferred, Required, VerifyCA, VerifyIdentity:
    79  		*mode = parsedMode
    80  		return nil
    81  	}
    82  	return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Invalid SSL mode specified: %s. Allowed options are disabled, preferred, required, verify_ca, verify_identity", value)
    83  }
    84  
    85  // TLSVersionToNumber converts a text description of the TLS protocol
    86  // to the internal Go number representation.
    87  func TLSVersionToNumber(tlsVersion string) (uint16, error) {
    88  	switch strings.ToLower(tlsVersion) {
    89  	case "tlsv1.3":
    90  		return tls.VersionTLS13, nil
    91  	case "", "tlsv1.2":
    92  		return tls.VersionTLS12, nil
    93  	case "tlsv1.1":
    94  		return tls.VersionTLS11, nil
    95  	case "tlsv1.0":
    96  		return tls.VersionTLS10, nil
    97  	default:
    98  		return tls.VersionTLS12, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Invalid TLS version specified: %s. Allowed options are TLSv1.0, TLSv1.1, TLSv1.2 & TLSv1.3", tlsVersion)
    99  	}
   100  }
   101  
   102  var onceByKeys = sync.Map{}
   103  
   104  // ClientConfig returns the TLS config to use for a client to
   105  // connect to a server with the provided parameters.
   106  func ClientConfig(mode SslMode, cert, key, ca, crl, name string, minTLSVersion uint16) (*tls.Config, error) {
   107  	config := &tls.Config{
   108  		MinVersion: minTLSVersion,
   109  	}
   110  
   111  	// Load the client-side cert & key if any.
   112  	if cert != "" && key != "" {
   113  		certificates, err := loadTLSCertificate(cert, key)
   114  
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  
   119  		config.Certificates = *certificates
   120  	}
   121  
   122  	// Load the server CA if any.
   123  	if ca != "" {
   124  		certificatePool, err := loadx509CertPool(ca)
   125  
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  
   130  		config.RootCAs = certificatePool
   131  	}
   132  
   133  	// Set the server name if any.
   134  	if name != "" {
   135  		config.ServerName = name
   136  	}
   137  
   138  	switch mode {
   139  	case Disabled:
   140  		return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "can't create config for disabled mode")
   141  	case Preferred, Required:
   142  		config.InsecureSkipVerify = true
   143  	case VerifyCA:
   144  		config.InsecureSkipVerify = true
   145  		config.VerifyConnection = func(cs tls.ConnectionState) error {
   146  			caRoots := config.RootCAs
   147  			if caRoots == nil {
   148  				var err error
   149  				caRoots, err = x509.SystemCertPool()
   150  				if err != nil {
   151  					return err
   152  				}
   153  			}
   154  			opts := x509.VerifyOptions{
   155  				Roots:         caRoots,
   156  				Intermediates: x509.NewCertPool(),
   157  			}
   158  			for _, cert := range cs.PeerCertificates[1:] {
   159  				opts.Intermediates.AddCert(cert)
   160  			}
   161  			_, err := cs.PeerCertificates[0].Verify(opts)
   162  			return err
   163  		}
   164  	case VerifyIdentity:
   165  		// Nothing to do here, default config is the strictest and correct.
   166  	default:
   167  		return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid mode: %s", mode)
   168  	}
   169  
   170  	if crl != "" {
   171  		crlFunc, err := verifyPeerCertificateAgainstCRL(crl)
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		config.VerifyPeerCertificate = crlFunc
   176  	}
   177  
   178  	return config, nil
   179  }
   180  
   181  // ServerConfig returns the TLS config to use for a server to
   182  // accept client connections.
   183  func ServerConfig(cert, key, ca, crl, serverCA string, minTLSVersion uint16) (*tls.Config, error) {
   184  	config := &tls.Config{
   185  		MinVersion: minTLSVersion,
   186  	}
   187  
   188  	var certificates *[]tls.Certificate
   189  	var err error
   190  
   191  	if serverCA != "" {
   192  		certificates, err = combineAndLoadTLSCertificates(serverCA, cert, key)
   193  	} else {
   194  		certificates, err = loadTLSCertificate(cert, key)
   195  	}
   196  
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	config.Certificates = *certificates
   201  
   202  	// if specified, load ca to validate client,
   203  	// and enforce clients present valid certs.
   204  	if ca != "" {
   205  		certificatePool, err := loadx509CertPool(ca)
   206  
   207  		if err != nil {
   208  			return nil, err
   209  		}
   210  
   211  		config.ClientCAs = certificatePool
   212  		config.ClientAuth = tls.RequireAndVerifyClientCert
   213  	}
   214  
   215  	if crl != "" {
   216  		crlFunc, err := verifyPeerCertificateAgainstCRL(crl)
   217  		if err != nil {
   218  			return nil, err
   219  		}
   220  		config.VerifyPeerCertificate = crlFunc
   221  	}
   222  
   223  	return config, nil
   224  }
   225  
   226  var certPools = sync.Map{}
   227  
   228  func loadx509CertPool(ca string) (*x509.CertPool, error) {
   229  	once, _ := onceByKeys.LoadOrStore(ca, &sync.Once{})
   230  
   231  	var err error
   232  	once.(*sync.Once).Do(func() {
   233  		err = doLoadx509CertPool(ca)
   234  	})
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	result, ok := certPools.Load(ca)
   240  
   241  	if !ok {
   242  		return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded x509 cert pool for ca: %s", ca)
   243  	}
   244  
   245  	return result.(*x509.CertPool), nil
   246  }
   247  
   248  func doLoadx509CertPool(ca string) error {
   249  	b, err := os.ReadFile(ca)
   250  	if err != nil {
   251  		return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca)
   252  	}
   253  
   254  	cp := x509.NewCertPool()
   255  	if !cp.AppendCertsFromPEM(b) {
   256  		return vterrors.Errorf(vtrpc.Code_UNKNOWN, "failed to append certificates")
   257  	}
   258  
   259  	certPools.Store(ca, cp)
   260  
   261  	return nil
   262  }
   263  
   264  var tlsCertificates = sync.Map{}
   265  
   266  func tlsCertificatesIdentifier(tokens ...string) string {
   267  	return strings.Join(tokens, ";")
   268  }
   269  
   270  func loadTLSCertificate(cert, key string) (*[]tls.Certificate, error) {
   271  	tlsIdentifier := tlsCertificatesIdentifier(cert, key)
   272  	once, _ := onceByKeys.LoadOrStore(tlsIdentifier, &sync.Once{})
   273  
   274  	var err error
   275  	once.(*sync.Once).Do(func() {
   276  		err = doLoadTLSCertificate(cert, key)
   277  	})
   278  
   279  	if err != nil {
   280  		return nil, err
   281  	}
   282  
   283  	result, ok := tlsCertificates.Load(tlsIdentifier)
   284  
   285  	if !ok {
   286  		return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate with cert: %s, key%s", cert, key)
   287  	}
   288  
   289  	return result.(*[]tls.Certificate), nil
   290  }
   291  
   292  func doLoadTLSCertificate(cert, key string) error {
   293  	tlsIdentifier := tlsCertificatesIdentifier(cert, key)
   294  
   295  	var certificate []tls.Certificate
   296  	// Load the server cert and key.
   297  	crt, err := tls.LoadX509KeyPair(cert, key)
   298  	if err != nil {
   299  		return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load tls certificate, cert %s, key: %s", cert, key)
   300  	}
   301  
   302  	certificate = []tls.Certificate{crt}
   303  
   304  	tlsCertificates.Store(tlsIdentifier, &certificate)
   305  
   306  	return nil
   307  }
   308  
   309  var combinedTLSCertificates = sync.Map{}
   310  
   311  func combineAndLoadTLSCertificates(ca, cert, key string) (*[]tls.Certificate, error) {
   312  	combinedTLSIdentifier := tlsCertificatesIdentifier(ca, cert, key)
   313  	once, _ := onceByKeys.LoadOrStore(combinedTLSIdentifier, &sync.Once{})
   314  
   315  	var err error
   316  	once.(*sync.Once).Do(func() {
   317  		err = doLoadAndCombineTLSCertificates(ca, cert, key)
   318  	})
   319  
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  
   324  	result, ok := combinedTLSCertificates.Load(combinedTLSIdentifier)
   325  
   326  	if !ok {
   327  		return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate chain with ca: %s, cert: %s, key: %s", ca, cert, key)
   328  	}
   329  
   330  	return result.(*[]tls.Certificate), nil
   331  }
   332  
   333  func doLoadAndCombineTLSCertificates(ca, cert, key string) error {
   334  	combinedTLSIdentifier := tlsCertificatesIdentifier(ca, cert, key)
   335  
   336  	// Read CA certificates chain
   337  	caB, err := os.ReadFile(ca)
   338  	if err != nil {
   339  		return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca)
   340  	}
   341  
   342  	// Read server certificate
   343  	certB, err := os.ReadFile(cert)
   344  	if err != nil {
   345  		return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read server cert file: %s", cert)
   346  	}
   347  
   348  	// Read server key file
   349  	keyB, err := os.ReadFile(key)
   350  	if err != nil {
   351  		return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read key file: %s", key)
   352  	}
   353  
   354  	// Load CA, server cert and key.
   355  	var certificate []tls.Certificate
   356  	crt, err := tls.X509KeyPair(append(certB, caB...), keyB)
   357  	if err != nil {
   358  		return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load and merge tls certificate with CA, ca %s, cert %s, key: %s", ca, cert, key)
   359  	}
   360  
   361  	certificate = []tls.Certificate{crt}
   362  
   363  	combinedTLSCertificates.Store(combinedTLSIdentifier, &certificate)
   364  
   365  	return nil
   366  }