vitess.io/vitess@v0.16.2/go/vt/vtgr/ssl/ssl.go (about)

     1  package ssl
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"encoding/pem"
     7  	"errors"
     8  	"fmt"
     9  	nethttp "net/http"
    10  	"os"
    11  	"strings"
    12  
    13  	"vitess.io/vitess/go/vt/log"
    14  
    15  	"github.com/go-martini/martini"
    16  	"github.com/howeyc/gopass"
    17  
    18  	"vitess.io/vitess/go/vt/vtgr/config"
    19  )
    20  
    21  /*
    22  	This file has been copied over from VTOrc package
    23  */
    24  
    25  // Determine if a string element is in a string array
    26  func HasString(elem string, arr []string) bool {
    27  	for _, s := range arr {
    28  		if s == elem {
    29  			return true
    30  		}
    31  	}
    32  	return false
    33  }
    34  
    35  // NewTLSConfig returns an initialized TLS configuration suitable for client
    36  // authentication. If caFile is non-empty, it will be loaded.
    37  func NewTLSConfig(caFile string, verifyCert bool) (*tls.Config, error) {
    38  	var c tls.Config
    39  
    40  	// Set to TLS 1.2 as a minimum.  This is overridden for mysql communication
    41  	c.MinVersion = tls.VersionTLS12
    42  
    43  	if verifyCert {
    44  		log.Info("verifyCert requested, client certificates will be verified")
    45  		c.ClientAuth = tls.VerifyClientCertIfGiven
    46  	}
    47  	caPool, err := ReadCAFile(caFile)
    48  	if err != nil {
    49  		return &c, err
    50  	}
    51  	c.ClientCAs = caPool
    52  	return &c, nil
    53  }
    54  
    55  // Returns CA certificate. If caFile is non-empty, it will be loaded.
    56  func ReadCAFile(caFile string) (*x509.CertPool, error) {
    57  	var caCertPool *x509.CertPool
    58  	if caFile != "" {
    59  		data, err := os.ReadFile(caFile)
    60  		if err != nil {
    61  			return nil, err
    62  		}
    63  		caCertPool = x509.NewCertPool()
    64  		if !caCertPool.AppendCertsFromPEM(data) {
    65  			return nil, errors.New("No certificates parsed")
    66  		}
    67  		log.Infof("Read in CA file: %v", caFile)
    68  	}
    69  	return caCertPool, nil
    70  }
    71  
    72  // Verify that the OU of the presented client certificate matches the list
    73  // of Valid OUs
    74  func Verify(r *nethttp.Request, validOUs []string) error {
    75  	if strings.Contains(r.URL.String(), config.Config.StatusEndpoint) && !config.Config.StatusOUVerify {
    76  		return nil
    77  	}
    78  	if r.TLS == nil {
    79  		return errors.New("No TLS")
    80  	}
    81  	for _, chain := range r.TLS.VerifiedChains {
    82  		s := chain[0].Subject.OrganizationalUnit
    83  		log.Infof("All OUs:", strings.Join(s, " "))
    84  		for _, ou := range s {
    85  			log.Infof("Client presented OU:", ou)
    86  			if HasString(ou, validOUs) {
    87  				log.Infof("Found valid OU:", ou)
    88  				return nil
    89  			}
    90  		}
    91  	}
    92  	log.Error("No valid OUs found")
    93  	return errors.New("Invalid OU")
    94  }
    95  
    96  // TODO: make this testable?
    97  func VerifyOUs(validOUs []string) martini.Handler {
    98  	return func(res nethttp.ResponseWriter, req *nethttp.Request, c martini.Context) {
    99  		log.Infof("Verifying client OU")
   100  		if err := Verify(req, validOUs); err != nil {
   101  			nethttp.Error(res, err.Error(), nethttp.StatusUnauthorized)
   102  		}
   103  	}
   104  }
   105  
   106  // AppendKeyPair loads the given TLS key pair and appends it to
   107  // tlsConfig.Certificates.
   108  func AppendKeyPair(tlsConfig *tls.Config, certFile string, keyFile string) error {
   109  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
   114  	return nil
   115  }
   116  
   117  // Read in a keypair where the key is password protected
   118  func AppendKeyPairWithPassword(tlsConfig *tls.Config, certFile string, keyFile string, pemPass []byte) error {
   119  
   120  	// Certificates aren't usually password protected, but we're kicking the password
   121  	// along just in case.  It won't be used if the file isn't encrypted
   122  	certData, err := ReadPEMData(certFile, pemPass)
   123  	if err != nil {
   124  		return err
   125  	}
   126  	keyData, err := ReadPEMData(keyFile, pemPass)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	cert, err := tls.X509KeyPair(certData, keyData)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
   135  	return nil
   136  }
   137  
   138  // Read a PEM file and ask for a password to decrypt it if needed
   139  func ReadPEMData(pemFile string, pemPass []byte) ([]byte, error) {
   140  	pemData, err := os.ReadFile(pemFile)
   141  	if err != nil {
   142  		return pemData, err
   143  	}
   144  
   145  	// We should really just get the pem.Block back here, if there's other
   146  	// junk on the end, warn about it.
   147  	pemBlock, rest := pem.Decode(pemData)
   148  	if len(rest) > 0 {
   149  		log.Warning("Didn't parse all of", pemFile)
   150  	}
   151  
   152  	if x509.IsEncryptedPEMBlock(pemBlock) { //nolint SA1019
   153  		// Decrypt and get the ASN.1 DER bytes here
   154  		pemData, err = x509.DecryptPEMBlock(pemBlock, pemPass) //nolint SA1019
   155  		if err != nil {
   156  			return pemData, err
   157  		}
   158  		log.Infof("Decrypted %v successfully", pemFile)
   159  		// Shove the decrypted DER bytes into a new pem Block with blank headers
   160  		var newBlock pem.Block
   161  		newBlock.Type = pemBlock.Type
   162  		newBlock.Bytes = pemData
   163  		// This is now like reading in an uncrypted key from a file and stuffing it
   164  		// into a byte stream
   165  		pemData = pem.EncodeToMemory(&newBlock)
   166  	}
   167  	return pemData, nil
   168  }
   169  
   170  // Print a password prompt on the terminal and collect a password
   171  func GetPEMPassword(pemFile string) []byte {
   172  	fmt.Printf("Password for %s: ", pemFile)
   173  	pass, err := gopass.GetPasswd()
   174  	if err != nil {
   175  		// We'll error with an incorrect password at DecryptPEMBlock
   176  		return []byte("")
   177  	}
   178  	return pass
   179  }
   180  
   181  // Determine if PEM file is encrypted
   182  func IsEncryptedPEM(pemFile string) bool {
   183  	pemData, err := os.ReadFile(pemFile)
   184  	if err != nil {
   185  		return false
   186  	}
   187  	pemBlock, _ := pem.Decode(pemData)
   188  	if len(pemBlock.Bytes) == 0 {
   189  		return false
   190  	}
   191  	return x509.IsEncryptedPEMBlock(pemBlock) //nolint SA1019
   192  }
   193  
   194  // ListenAndServeTLS acts identically to http.ListenAndServeTLS, except that it
   195  // expects TLS configuration.
   196  // TODO: refactor so this is testable?
   197  func ListenAndServeTLS(addr string, handler nethttp.Handler, tlsConfig *tls.Config) error {
   198  	if addr == "" {
   199  		// On unix Listen calls getaddrinfo to parse the port, so named ports are fine as long
   200  		// as they exist in /etc/services
   201  		addr = ":https"
   202  	}
   203  	l, err := tls.Listen("tcp", addr, tlsConfig)
   204  	if err != nil {
   205  		return err
   206  	}
   207  	return nethttp.Serve(l, handler)
   208  }