github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/internal/cert/gen_certs.go (about)

     1  package cert
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"crypto/x509/pkix"
    10  	"encoding/pem"
    11  	"errors"
    12  	"fmt"
    13  	"math/big"
    14  	"net"
    15  	"os"
    16  	"path/filepath"
    17  	"strings"
    18  	"time"
    19  )
    20  
    21  // Options for generating a self-signed certificate.
    22  type Options struct {
    23  	KeyFile   string        // path to the server private key file
    24  	CertFile  string        // path to the fullchain certificate file
    25  	Hosts     string        // comma-separated hostnames and IPs to generate a certificate for.
    26  	ValidFrom time.Time     // creation date (default duration is 1 year)
    27  	ValidFor  time.Duration // for how long the certificate is valid.
    28  	KeyType   string        // default ECDSA curve P256
    29  }
    30  
    31  // GenerateSelfSignedCert generates a self-signed X.509 certificate for testing purposes.
    32  // It supports ECDSA curve P256 or RSA 2048 bits to generate the key.
    33  // based on: https://golang.org/src/crypto/tls/generate_cert.go
    34  func GenerateSelfSignedCert(opts Options) error {
    35  	if opts.Hosts == "" {
    36  		return errors.New("at least one hostname must be specified")
    37  	}
    38  	path := filepath.Dir(opts.KeyFile)
    39  	if err := os.MkdirAll(path, 0o700); err != nil {
    40  		return err
    41  	}
    42  	caKey, serverKey, err := generateKeys(opts)
    43  	if err != nil {
    44  		return err
    45  	}
    46  	notBefore, notAfter, err := certPeriod(opts)
    47  	if err != nil {
    48  		return err
    49  	}
    50  
    51  	caTemplate, err := caCertificateTemplate(opts.Hosts, notBefore, notAfter)
    52  	if err != nil {
    53  		return err
    54  	}
    55  	caCert, caCertBytes, err := makeCertificate(caTemplate, caTemplate, publicKey(caKey), caKey)
    56  	if err != nil {
    57  		return err
    58  	}
    59  
    60  	serverTemplate, err := serverCertificateTemplate(serverKey, opts.Hosts, notBefore, notAfter)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	_, serverCertBytes, err := makeCertificate(serverTemplate, caCert, publicKey(serverKey), caKey)
    65  	if err != nil {
    66  		return err
    67  	}
    68  
    69  	serverKeyBytes, err := x509.MarshalPKCS8PrivateKey(serverKey)
    70  	if err != nil {
    71  		return fmt.Errorf("unable to marshal server private key: %w", err)
    72  	}
    73  
    74  	// save server private key
    75  	if err = savePEM(opts.KeyFile, []*pem.Block{
    76  		{Type: "PRIVATE KEY", Bytes: serverKeyBytes},
    77  	}); err != nil {
    78  		return err
    79  	}
    80  
    81  	// save fullchain (server certificate and CA certificate)
    82  	return savePEM(opts.CertFile, []*pem.Block{
    83  		{Type: "CERTIFICATE", Bytes: serverCertBytes},
    84  		{Type: "CERTIFICATE", Bytes: caCertBytes},
    85  	})
    86  }
    87  
    88  func generateKeys(opts Options) (caKey, serverKey any, err error) {
    89  	switch opts.KeyType {
    90  	case "rsa":
    91  		caKey, err = rsa.GenerateKey(rand.Reader, 2048)
    92  		if err != nil {
    93  			return
    94  		}
    95  		serverKey, err = rsa.GenerateKey(rand.Reader, 2048)
    96  	default:
    97  		caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    98  		if err != nil {
    99  			return
   100  		}
   101  		serverKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   102  	}
   103  	return
   104  }
   105  
   106  func certPeriod(opts Options) (notBefore time.Time, notAfter time.Time, err error) {
   107  	if opts.ValidFrom.IsZero() {
   108  		notBefore = time.Now()
   109  	} else {
   110  		notBefore = opts.ValidFrom
   111  	}
   112  
   113  	if opts.ValidFor == 0 {
   114  		notAfter = notBefore.Add(365 * 24 * time.Hour)
   115  	} else {
   116  		notAfter = notBefore.Add(opts.ValidFor)
   117  	}
   118  
   119  	if notBefore.After(notAfter) {
   120  		return notBefore, notAfter, errors.New("wrong certificate validity")
   121  	}
   122  	return notBefore, notAfter, nil
   123  }
   124  
   125  func serverCertificateTemplate(privKey any, hostList string, notBefore time.Time, notAfter time.Time) (*x509.Certificate, error) {
   126  	serialNumber, err := serialNumber()
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	// https://go-review.googlesource.com/c/go/+/214337/
   131  	// If is RSA set KeyEncipherment KeyUsage bits.
   132  	keyUsage := x509.KeyUsageDigitalSignature
   133  	if _, isRSA := privKey.(*rsa.PrivateKey); isRSA {
   134  		keyUsage |= x509.KeyUsageKeyEncipherment
   135  	}
   136  	template := &x509.Certificate{
   137  		SerialNumber:          serialNumber,
   138  		NotBefore:             notBefore,
   139  		NotAfter:              notAfter,
   140  		KeyUsage:              keyUsage,
   141  		IsCA:                  false,
   142  		BasicConstraintsValid: true,
   143  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   144  	}
   145  	setHosts(template, hostList)
   146  	return template, err
   147  }
   148  
   149  func caCertificateTemplate(hostList string, notBefore time.Time, notAfter time.Time) (*x509.Certificate, error) {
   150  	serialNumber, err := serialNumber()
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	caSubject := &pkix.Name{
   155  		Country:      []string{"NO"},
   156  		Organization: []string{"QuickFeed Corp."},
   157  		CommonName:   "127.0.0.1",
   158  	}
   159  	template := &x509.Certificate{
   160  		SerialNumber:          serialNumber,
   161  		Subject:               *caSubject,
   162  		NotBefore:             notBefore,
   163  		NotAfter:              notAfter,
   164  		KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
   165  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   166  		BasicConstraintsValid: true,
   167  		IsCA:                  true,
   168  		MaxPathLenZero:        true,
   169  	}
   170  	setHosts(template, hostList)
   171  	return template, nil
   172  }
   173  
   174  func serialNumber() (*big.Int, error) {
   175  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
   176  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
   177  	if err != nil {
   178  		return nil, fmt.Errorf("serial number generation failed: %w", err)
   179  	}
   180  	return serialNumber, nil
   181  }
   182  
   183  func setHosts(template *x509.Certificate, hostList string) {
   184  	for _, host := range strings.Split(hostList, ",") {
   185  		if ip := net.ParseIP(host); ip != nil {
   186  			template.IPAddresses = append(template.IPAddresses, ip)
   187  		} else {
   188  			template.DNSNames = append(template.DNSNames, host)
   189  		}
   190  	}
   191  }
   192  
   193  func makeCertificate(template, parent *x509.Certificate, publicKey any, privateKey any) (*x509.Certificate, []byte, error) {
   194  	derCertBytes, err := x509.CreateCertificate(rand.Reader, template, parent, publicKey, privateKey)
   195  	if err != nil {
   196  		return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
   197  	}
   198  	cert, err := x509.ParseCertificate(derCertBytes)
   199  	if err != nil {
   200  		return nil, nil, fmt.Errorf("failed to parse certificate: %w", err)
   201  	}
   202  	return cert, derCertBytes, nil
   203  }
   204  
   205  const defaultFileFlags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
   206  
   207  func savePEM(filename string, block []*pem.Block) error {
   208  	out, err := os.OpenFile(filename, defaultFileFlags, 0o600)
   209  	if err != nil {
   210  		return fmt.Errorf("failed to open %s for writing: %w", filename, err)
   211  	}
   212  
   213  	for _, b := range block {
   214  		if err := pem.Encode(out, b); err != nil {
   215  			return fmt.Errorf("failed to write data to %s: %w", filename, err)
   216  		}
   217  	}
   218  
   219  	if err := out.Close(); err != nil {
   220  		return fmt.Errorf("error closing %s: %w", filename, err)
   221  	}
   222  	return nil
   223  }
   224  
   225  func publicKey(priv any) any {
   226  	switch k := priv.(type) {
   227  	case *rsa.PrivateKey:
   228  		return &k.PublicKey
   229  	case *ecdsa.PrivateKey:
   230  		return &k.PublicKey
   231  	default:
   232  		return nil
   233  	}
   234  }