vitess.io/vitess@v0.16.2/go/vt/tlstest/tlstest.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package tlstest contains utility methods to create test certificates.
    18  // It is not meant to be used in production.
    19  package tlstest
    20  
    21  import (
    22  	"bytes"
    23  	"crypto"
    24  	"crypto/ecdsa"
    25  	"crypto/elliptic"
    26  	"crypto/rand"
    27  	"crypto/x509"
    28  	"crypto/x509/pkix"
    29  	"encoding/pem"
    30  	"errors"
    31  	"fmt"
    32  	"math/big"
    33  	"net"
    34  	"os"
    35  	"path"
    36  	"strconv"
    37  	"time"
    38  
    39  	"vitess.io/vitess/go/vt/log"
    40  )
    41  
    42  const (
    43  	// CA is the name of the CA toplevel cert.
    44  	CA          = "ca"
    45  	permissions = 0700
    46  )
    47  
    48  func loadCert(certPath string) (*x509.Certificate, error) {
    49  	certData, err := os.ReadFile(certPath)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	block, _ := pem.Decode(certData)
    55  	if block == nil {
    56  		return nil, errors.New("failed to parse certificate PEM")
    57  	}
    58  	return x509.ParseCertificate(block.Bytes)
    59  }
    60  
    61  func saveCert(certificate *x509.Certificate, certPath string) error {
    62  	out := &bytes.Buffer{}
    63  	err := pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw})
    64  	if err != nil {
    65  		return err
    66  	}
    67  	return os.WriteFile(certPath, out.Bytes(), permissions)
    68  }
    69  
    70  func generateKey() (crypto.PrivateKey, error) {
    71  	return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    72  }
    73  
    74  func loadKey(keyPath string) (crypto.PrivateKey, error) {
    75  	keyData, err := os.ReadFile(keyPath)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	block, _ := pem.Decode(keyData)
    81  	if block == nil {
    82  		return nil, errors.New("failed to parse key PEM")
    83  	}
    84  
    85  	switch block.Type {
    86  	case "PRIVATE KEY":
    87  		return x509.ParsePKCS8PrivateKey(block.Bytes)
    88  	case "RSA PRIVATE KEY":
    89  		return x509.ParsePKCS1PrivateKey(block.Bytes)
    90  	case "EC PRIVATE KEY":
    91  		return x509.ParseECPrivateKey(block.Bytes)
    92  	default:
    93  		return nil, fmt.Errorf("unknown private key format: %+v", block.Type)
    94  	}
    95  }
    96  
    97  func saveKey(key crypto.PrivateKey, keyPath string) error {
    98  	keyData, err := x509.MarshalPKCS8PrivateKey(key)
    99  	if err != nil {
   100  		return err
   101  	}
   102  	out := &bytes.Buffer{}
   103  	err = pem.Encode(out, &pem.Block{Type: "PRIVATE KEY", Bytes: keyData})
   104  	if err != nil {
   105  		return err
   106  	}
   107  	return os.WriteFile(keyPath, out.Bytes(), permissions)
   108  }
   109  
   110  // pubkey is an interface to get a public key from a private key
   111  // The Go specification for a private key defines that this always
   112  // exists, although there's no interface for it since it would break
   113  // backwards compatibility. See https://pkg.go.dev/crypto#PrivateKey
   114  type pubKey interface {
   115  	Public() crypto.PublicKey
   116  }
   117  
   118  func publicKey(priv crypto.PrivateKey) crypto.PublicKey {
   119  	return priv.(pubKey).Public()
   120  }
   121  
   122  func signCert(parent *x509.Certificate, parentPriv crypto.PrivateKey, certPub crypto.PublicKey, commonName string, serial int64, ca bool) (*x509.Certificate, error) {
   123  	keyUsage := x509.KeyUsageDigitalSignature
   124  	var extKeyUsage []x509.ExtKeyUsage
   125  	var dnsNames []string
   126  	var ipAddresses []net.IP
   127  
   128  	if ca {
   129  		keyUsage = keyUsage | x509.KeyUsageCRLSign | x509.KeyUsageCertSign
   130  	} else {
   131  		extKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}
   132  		dnsNames = []string{"localhost", commonName}
   133  		ipAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
   134  	}
   135  
   136  	template := x509.Certificate{
   137  		SerialNumber: big.NewInt(serial),
   138  		Subject: pkix.Name{
   139  			CommonName: commonName,
   140  		},
   141  		NotBefore:             time.Now().Add(-30 * time.Second),
   142  		NotAfter:              time.Now().Add(24 * time.Hour),
   143  		KeyUsage:              keyUsage,
   144  		ExtKeyUsage:           extKeyUsage,
   145  		BasicConstraintsValid: true,
   146  		IsCA:                  ca,
   147  		DNSNames:              dnsNames,
   148  		IPAddresses:           ipAddresses,
   149  	}
   150  
   151  	// No parent defined means we create a self signed one.
   152  	if parent == nil {
   153  		parent = &template
   154  	}
   155  
   156  	certificate, err := x509.CreateCertificate(rand.Reader, &template, parent, certPub, parentPriv)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	return x509.ParseCertificate(certificate)
   161  }
   162  
   163  // CreateCA creates the toplevel 'ca' certificate and key, and places it
   164  // in the provided directory. Temporary files are also created in that
   165  // directory.
   166  func CreateCA(root string) {
   167  	log.Infof("Creating test root CA in %v", root)
   168  	keyPath := path.Join(root, "ca-key.pem")
   169  	certPath := path.Join(root, "ca-cert.pem")
   170  
   171  	priv, err := generateKey()
   172  	if err != nil {
   173  		log.Fatal(err)
   174  	}
   175  
   176  	err = saveKey(priv, keyPath)
   177  	if err != nil {
   178  		log.Fatal(err)
   179  	}
   180  
   181  	ca, err := signCert(nil, priv, publicKey(priv), CA, 1, true)
   182  	if err != nil {
   183  		log.Fatal(err)
   184  	}
   185  
   186  	err = saveCert(ca, certPath)
   187  	if err != nil {
   188  		log.Fatal(err)
   189  	}
   190  }
   191  
   192  func CreateIntermediateCA(root, parent, serial, name, commonName string) {
   193  	caKeyPath := path.Join(root, parent+"-key.pem")
   194  	caCertPath := path.Join(root, parent+"-cert.pem")
   195  	keyPath := path.Join(root, name+"-key.pem")
   196  	certPath := path.Join(root, name+"-cert.pem")
   197  
   198  	caKey, err := loadKey(caKeyPath)
   199  	if err != nil {
   200  		log.Fatal(err)
   201  	}
   202  	caCert, err := loadCert(caCertPath)
   203  	if err != nil {
   204  		log.Fatal(err)
   205  	}
   206  
   207  	priv, err := generateKey()
   208  	if err != nil {
   209  		log.Fatal(err)
   210  	}
   211  
   212  	err = saveKey(priv, keyPath)
   213  	if err != nil {
   214  		log.Fatal(err)
   215  	}
   216  
   217  	serialNr, err := strconv.ParseInt(serial, 10, 64)
   218  	if err != nil {
   219  		log.Fatal(err)
   220  	}
   221  
   222  	intermediate, err := signCert(caCert, caKey, publicKey(priv), commonName, serialNr, true)
   223  	if err != nil {
   224  		log.Fatal(err)
   225  	}
   226  	err = saveCert(intermediate, certPath)
   227  	if err != nil {
   228  		log.Fatal(err)
   229  	}
   230  }
   231  
   232  // CreateSignedCert creates a new certificate signed by the provided parent,
   233  // with the provided serial number, name and common name.
   234  // name is the file name to use. Common Name is the certificate common name.
   235  func CreateSignedCert(root, parent, serial, name, commonName string) {
   236  	log.Infof("Creating signed cert and key %v", commonName)
   237  
   238  	caKeyPath := path.Join(root, parent+"-key.pem")
   239  	caCertPath := path.Join(root, parent+"-cert.pem")
   240  	keyPath := path.Join(root, name+"-key.pem")
   241  	certPath := path.Join(root, name+"-cert.pem")
   242  
   243  	caKey, err := loadKey(caKeyPath)
   244  	if err != nil {
   245  		log.Fatal(err)
   246  	}
   247  	caCert, err := loadCert(caCertPath)
   248  	if err != nil {
   249  		log.Fatal(err)
   250  	}
   251  
   252  	priv, err := generateKey()
   253  	if err != nil {
   254  		log.Fatal(err)
   255  	}
   256  
   257  	err = saveKey(priv, keyPath)
   258  	if err != nil {
   259  		log.Fatal(err)
   260  	}
   261  
   262  	serialNr, err := strconv.ParseInt(serial, 10, 64)
   263  	if err != nil {
   264  		log.Fatal(err)
   265  	}
   266  
   267  	leaf, err := signCert(caCert, caKey, publicKey(priv), commonName, serialNr, false)
   268  	if err != nil {
   269  		log.Fatal(err)
   270  	}
   271  
   272  	err = saveCert(leaf, certPath)
   273  	if err != nil {
   274  		log.Fatal(err)
   275  	}
   276  }
   277  
   278  // CreateCRL creates a new empty certificate revocation list
   279  // for the provided parent
   280  func CreateCRL(root, parent string) {
   281  	log.Infof("Creating CRL for root CA in %v", root)
   282  	caKeyPath := path.Join(root, parent+"-key.pem")
   283  	caCertPath := path.Join(root, parent+"-cert.pem")
   284  	crlPath := path.Join(root, parent+"-crl.pem")
   285  
   286  	caKey, err := loadKey(caKeyPath)
   287  	if err != nil {
   288  		log.Fatal(err)
   289  	}
   290  	caCert, err := loadCert(caCertPath)
   291  	if err != nil {
   292  		log.Fatal(err)
   293  	}
   294  
   295  	crlList, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{
   296  		RevokedCertificates: nil,
   297  		Number:              big.NewInt(1),
   298  	}, caCert, caKey.(crypto.Signer))
   299  	if err != nil {
   300  		log.Fatal(err)
   301  	}
   302  
   303  	out := &bytes.Buffer{}
   304  	err = pem.Encode(out, &pem.Block{Type: "X509 CRL", Bytes: crlList})
   305  	if err != nil {
   306  		log.Fatal(err)
   307  	}
   308  
   309  	err = os.WriteFile(crlPath, out.Bytes(), permissions)
   310  	if err != nil {
   311  		log.Fatal(err)
   312  	}
   313  }
   314  
   315  // RevokeCertAndRegenerateCRL revokes a provided certificate under the
   316  // provided parent CA and regenerates the CRL file for that parent
   317  func RevokeCertAndRegenerateCRL(root, parent, name string) {
   318  	log.Infof("Revoking certificate %s", name)
   319  	caKeyPath := path.Join(root, parent+"-key.pem")
   320  	caCertPath := path.Join(root, parent+"-cert.pem")
   321  	crlPath := path.Join(root, parent+"-crl.pem")
   322  	certPath := path.Join(root, name+"-cert.pem")
   323  
   324  	certificate, err := loadCert(certPath)
   325  	if err != nil {
   326  		log.Fatal(err)
   327  	}
   328  
   329  	// Check if CRL already exists. If it doesn't,
   330  	// create an empty CRL to start with.
   331  	_, err = os.Stat(crlPath)
   332  	if errors.Is(err, os.ErrNotExist) {
   333  		CreateCRL(root, parent)
   334  	}
   335  
   336  	data, err := os.ReadFile(crlPath)
   337  	if err != nil {
   338  		log.Fatal(err)
   339  	}
   340  
   341  	block, _ := pem.Decode(data)
   342  	if block == nil || block.Type != "X509 CRL" {
   343  		log.Fatal("failed to parse CRL PEM")
   344  	}
   345  
   346  	crlList, err := x509.ParseRevocationList(block.Bytes)
   347  	if err != nil {
   348  		log.Fatal(err)
   349  	}
   350  
   351  	revoked := crlList.RevokedCertificates
   352  	revoked = append(revoked, pkix.RevokedCertificate{
   353  		SerialNumber:   certificate.SerialNumber,
   354  		RevocationTime: time.Now(),
   355  	})
   356  
   357  	caKey, err := loadKey(caKeyPath)
   358  	if err != nil {
   359  		log.Fatal(err)
   360  	}
   361  	caCert, err := loadCert(caCertPath)
   362  	if err != nil {
   363  		log.Fatal(err)
   364  	}
   365  
   366  	var crlNumber big.Int
   367  	newCrl, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{
   368  		RevokedCertificates: revoked,
   369  		Number:              crlNumber.Add(crlList.Number, big.NewInt(1)),
   370  	}, caCert, caKey.(crypto.Signer))
   371  	if err != nil {
   372  		log.Fatal(err)
   373  	}
   374  
   375  	out := &bytes.Buffer{}
   376  	err = pem.Encode(out, &pem.Block{Type: "X509 CRL", Bytes: newCrl})
   377  	if err != nil {
   378  		log.Fatal(err)
   379  	}
   380  
   381  	err = os.WriteFile(crlPath, out.Bytes(), permissions)
   382  	if err != nil {
   383  		log.Fatal(err)
   384  	}
   385  }
   386  
   387  // ClientServerKeyPairs is used in tests
   388  type ClientServerKeyPairs struct {
   389  	ServerCert        string
   390  	ServerKey         string
   391  	ServerCA          string
   392  	ServerName        string
   393  	ServerCRL         string
   394  	RevokedServerCert string
   395  	RevokedServerKey  string
   396  	RevokedServerName string
   397  	ClientCert        string
   398  	ClientKey         string
   399  	ClientCA          string
   400  	ClientCRL         string
   401  	RevokedClientCert string
   402  	RevokedClientKey  string
   403  	RevokedClientName string
   404  	CombinedCRL       string
   405  }
   406  
   407  var serialCounter = 0
   408  
   409  // CreateClientServerCertPairs creates certificate pairs for use in test
   410  func CreateClientServerCertPairs(root string) ClientServerKeyPairs {
   411  	// Create the certs and configs.
   412  	CreateCA(root)
   413  
   414  	serverCASerial := fmt.Sprintf("%03d", serialCounter*2+1)
   415  	serverSerial := fmt.Sprintf("%03d", serialCounter*2+3)
   416  	revokedServerSerial := fmt.Sprintf("%03d", serialCounter*2+5)
   417  	clientCASerial := fmt.Sprintf("%03d", serialCounter*2+2)
   418  	clientCertSerial := fmt.Sprintf("%03d", serialCounter*2+4)
   419  	revokedClientSerial := fmt.Sprintf("%03d", serialCounter*2+6)
   420  
   421  	serialCounter = serialCounter + 3
   422  
   423  	serverCAName := fmt.Sprintf("servers-ca-%s", serverCASerial)
   424  	serverCACommonName := fmt.Sprintf("Servers %s CA", serverCASerial)
   425  	serverCertName := fmt.Sprintf("server-instance-%s", serverSerial)
   426  	serverCertCommonName := fmt.Sprintf("server%s.example.com", serverSerial)
   427  	revokedServerCertName := fmt.Sprintf("server-instance-%s", revokedServerSerial)
   428  	revokedServerCertCommonName := fmt.Sprintf("server%s.example.com", revokedServerSerial)
   429  
   430  	clientCAName := fmt.Sprintf("clients-ca-%s", clientCASerial)
   431  	clientCACommonName := fmt.Sprintf("Clients %s CA", clientCASerial)
   432  	clientCertName := fmt.Sprintf("client-instance-%s", clientCertSerial)
   433  	clientCertCommonName := fmt.Sprintf("client%s.example.com", clientCertSerial)
   434  	revokedClientCertName := fmt.Sprintf("client-instance-%s", revokedClientSerial)
   435  	revokedClientCertCommonName := fmt.Sprintf("client%s.example.com", revokedClientSerial)
   436  
   437  	CreateIntermediateCA(root, CA, serverCASerial, serverCAName, serverCACommonName)
   438  	CreateSignedCert(root, serverCAName, serverSerial, serverCertName, serverCertCommonName)
   439  	CreateSignedCert(root, serverCAName, revokedServerSerial, revokedServerCertName, revokedServerCertCommonName)
   440  	RevokeCertAndRegenerateCRL(root, serverCAName, revokedServerCertName)
   441  
   442  	CreateIntermediateCA(root, CA, clientCASerial, clientCAName, clientCACommonName)
   443  	CreateSignedCert(root, clientCAName, clientCertSerial, clientCertName, clientCertCommonName)
   444  	CreateSignedCert(root, clientCAName, revokedClientSerial, revokedClientCertName, revokedClientCertCommonName)
   445  	RevokeCertAndRegenerateCRL(root, clientCAName, revokedClientCertName)
   446  
   447  	serverCRLPath := path.Join(root, fmt.Sprintf("%s-crl.pem", serverCAName))
   448  	clientCRLPath := path.Join(root, fmt.Sprintf("%s-crl.pem", clientCAName))
   449  	combinedCRLPath := path.Join(root, fmt.Sprintf("%s-%s-combined-crl.pem", serverCAName, clientCAName))
   450  
   451  	serverCRLBytes, err := os.ReadFile(serverCRLPath)
   452  	if err != nil {
   453  		log.Fatalf("Could not read server CRL file")
   454  	}
   455  
   456  	clientCRLBytes, err := os.ReadFile(clientCRLPath)
   457  	if err != nil {
   458  		log.Fatalf("Could not read client CRL file")
   459  	}
   460  
   461  	err = os.WriteFile(combinedCRLPath, append(serverCRLBytes, clientCRLBytes...), permissions)
   462  	if err != nil {
   463  		log.Fatalf("Could not write combined CRL file")
   464  	}
   465  
   466  	return ClientServerKeyPairs{
   467  		ServerCert:        path.Join(root, fmt.Sprintf("%s-cert.pem", serverCertName)),
   468  		ServerKey:         path.Join(root, fmt.Sprintf("%s-key.pem", serverCertName)),
   469  		ServerCA:          path.Join(root, fmt.Sprintf("%s-cert.pem", serverCAName)),
   470  		ServerCRL:         serverCRLPath,
   471  		RevokedServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", revokedServerCertName)),
   472  		RevokedServerKey:  path.Join(root, fmt.Sprintf("%s-key.pem", revokedServerCertName)),
   473  		ClientCert:        path.Join(root, fmt.Sprintf("%s-cert.pem", clientCertName)),
   474  		ClientKey:         path.Join(root, fmt.Sprintf("%s-key.pem", clientCertName)),
   475  		ClientCA:          path.Join(root, fmt.Sprintf("%s-cert.pem", clientCAName)),
   476  		ClientCRL:         clientCRLPath,
   477  		RevokedClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", revokedClientCertName)),
   478  		RevokedClientKey:  path.Join(root, fmt.Sprintf("%s-key.pem", revokedClientCertName)),
   479  		CombinedCRL:       combinedCRLPath,
   480  		ServerName:        serverCertCommonName,
   481  		RevokedServerName: revokedServerCertCommonName,
   482  		RevokedClientName: revokedClientCertCommonName,
   483  	}
   484  }