go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/certutil/cert_bundle.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package certutil
     9  
    10  import (
    11  	"bytes"
    12  	"crypto/rsa"
    13  	"crypto/tls"
    14  	"crypto/x509"
    15  	"encoding/pem"
    16  	"io"
    17  	"os"
    18  
    19  	"go.charczuk.com/sdk/errutil"
    20  )
    21  
    22  // NewCertBundle returns a new cert bundle from a given key pair, which can denote the raw PEM encoded
    23  // contents of the public and private key portions of the cert, or paths to files.
    24  // The CertBundle itself is the parsed public key, private key, and individual certificates for the pair.
    25  func NewCertBundle(keyPair KeyPair) (*CertBundle, error) {
    26  	certPEM, err := keyPair.CertBytes()
    27  	if err != nil {
    28  		return nil, errutil.New(err)
    29  	}
    30  	if len(certPEM) == 0 {
    31  		return nil, errutil.New("empty cert contents")
    32  	}
    33  
    34  	keyPEM, err := keyPair.KeyBytes()
    35  	if err != nil {
    36  		return nil, errutil.New(err)
    37  	}
    38  	if len(keyPEM) == 0 {
    39  		return nil, errutil.New("empty key contents")
    40  	}
    41  
    42  	certData, err := tls.X509KeyPair(certPEM, keyPEM)
    43  	if err != nil {
    44  		return nil, errutil.New(err)
    45  	}
    46  	if len(certData.Certificate) == 0 {
    47  		return nil, errutil.New("no certificates")
    48  	}
    49  
    50  	var certs []x509.Certificate
    51  	var ders [][]byte
    52  	for _, certDataPortion := range certData.Certificate {
    53  		cert, err := x509.ParseCertificate(certDataPortion)
    54  		if err != nil {
    55  			return nil, errutil.New(err)
    56  		}
    57  
    58  		certs = append(certs, *cert)
    59  		ders = append(ders, cert.Raw)
    60  	}
    61  
    62  	var privateKey *rsa.PrivateKey
    63  	if typed, ok := certData.PrivateKey.(*rsa.PrivateKey); ok {
    64  		privateKey = typed
    65  	} else {
    66  		return nil, errutil.New("invalid private key type", errutil.OptMessagef("%T", certData.PrivateKey))
    67  	}
    68  
    69  	return &CertBundle{
    70  		PrivateKey:      privateKey,
    71  		PublicKey:       &privateKey.PublicKey,
    72  		Certificates:    certs,
    73  		CertificateDERs: ders,
    74  	}, nil
    75  }
    76  
    77  // CertBundle is the packet of information for a certificate.
    78  type CertBundle struct {
    79  	PrivateKey      *rsa.PrivateKey
    80  	PublicKey       *rsa.PublicKey
    81  	Certificates    []x509.Certificate
    82  	CertificateDERs [][]byte
    83  }
    84  
    85  // MustGenerateKeyPair returns a serialized version of the bundle as a key pair
    86  // and panics if there is an error.
    87  func (cb *CertBundle) MustGenerateKeyPair() KeyPair {
    88  	pair, err := cb.GenerateKeyPair()
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  	return pair
    93  }
    94  
    95  // GenerateKeyPair returns a serialized key pair for the cert bundle.
    96  func (cb *CertBundle) GenerateKeyPair() (output KeyPair, err error) {
    97  	private := bytes.NewBuffer(nil)
    98  	if err = cb.WriteKeyPem(private); err != nil {
    99  		return
   100  	}
   101  	public := bytes.NewBuffer(nil)
   102  	if err = cb.WriteCertPem(public); err != nil {
   103  		return
   104  	}
   105  	output = KeyPair{
   106  		Cert: public.String(),
   107  		Key:  private.String(),
   108  	}
   109  	return
   110  }
   111  
   112  // WithParent adds a parent certificate to the certificate chain.
   113  // It is used typically to add the certificate authority.
   114  func (cb *CertBundle) WithParent(parent *CertBundle) {
   115  	cb.Certificates = append(cb.Certificates, parent.Certificates...)
   116  	cb.CertificateDERs = append(cb.CertificateDERs, parent.CertificateDERs...)
   117  }
   118  
   119  // WriteCertPem writes the public key portion of the cert to a given writer.
   120  func (cb CertBundle) WriteCertPem(w io.Writer) error {
   121  	for _, der := range cb.CertificateDERs {
   122  		if err := pem.Encode(w, &pem.Block{Type: BlockTypeCertificate, Bytes: der}); err != nil {
   123  			return errutil.New(err)
   124  		}
   125  	}
   126  	return nil
   127  }
   128  
   129  // WriteCertPemPath writes the public key portion of the cert to a given path.
   130  func (cb CertBundle) WriteCertPemPath(path string) error {
   131  	w, err := os.Create(path)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	return cb.WriteCertPem(w)
   136  }
   137  
   138  // WriteCertChainPem writes the public key portion of the cert to a given writer.
   139  func (cb CertBundle) WriteCertChainPem(w io.Writer) error {
   140  	if len(cb.CertificateDERs) < 2 {
   141  		return nil
   142  	}
   143  	for _, der := range cb.CertificateDERs[1:] {
   144  		if err := pem.Encode(w, &pem.Block{Type: BlockTypeCertificate, Bytes: der}); err != nil {
   145  			return errutil.New(err)
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  // WriteCertPartialPem writes the public key portion of the cert to a given writer.
   152  func (cb CertBundle) WriteCertPartialPem(w io.Writer) error {
   153  	if len(cb.CertificateDERs) == 0 {
   154  		return nil
   155  	}
   156  	return pem.Encode(w, &pem.Block{Type: BlockTypeCertificate, Bytes: cb.CertificateDERs[0]})
   157  }
   158  
   159  // CertPEM returns the cert portion of the certificate DERs as a byte array.
   160  func (cb CertBundle) CertPEM() ([]byte, error) {
   161  	buffer := new(bytes.Buffer)
   162  	if err := cb.WriteCertPem(buffer); err != nil {
   163  		return nil, err
   164  	}
   165  	return buffer.Bytes(), nil
   166  }
   167  
   168  // WriteKeyPem writes the certificate key as a pem to a given writer.
   169  func (cb CertBundle) WriteKeyPem(w io.Writer) error {
   170  	return pem.Encode(w, &pem.Block{Type: BlockTypeRSAPrivateKey, Bytes: x509.MarshalPKCS1PrivateKey(cb.PrivateKey)})
   171  }
   172  
   173  // WriteKeyPemPath writes the certificate key as a pem to a given path.
   174  func (cb CertBundle) WriteKeyPemPath(path string) error {
   175  	w, err := os.Create(path)
   176  	if err != nil {
   177  		return err
   178  	}
   179  	return cb.WriteKeyPem(w)
   180  }
   181  
   182  // KeyPEM returns the cert portion of the certificate DERs as a byte array.
   183  func (cb CertBundle) KeyPEM() ([]byte, error) {
   184  	buffer := new(bytes.Buffer)
   185  	if err := cb.WriteKeyPem(buffer); err != nil {
   186  		return nil, err
   187  	}
   188  	return buffer.Bytes(), nil
   189  }
   190  
   191  // CommonNames returns the cert bundle common name(s).
   192  func (cb CertBundle) CommonNames() ([]string, error) {
   193  	if len(cb.Certificates) == 0 {
   194  		return nil, errutil.New("no certificates returned")
   195  	}
   196  	var output []string
   197  	for _, cert := range cb.Certificates {
   198  		if cert.Subject.CommonName != "" {
   199  			output = append(output, cert.Subject.CommonName)
   200  		}
   201  	}
   202  	return output, nil
   203  }
   204  
   205  // CertPool returns the bundle as a cert pool.
   206  func (cb CertBundle) CertPool() (*x509.CertPool, error) {
   207  	systemPool, err := x509.SystemCertPool()
   208  	if err != nil {
   209  		return nil, errutil.New(err)
   210  	}
   211  	for index := range cb.Certificates {
   212  		systemPool.AddCert(&cb.Certificates[index])
   213  	}
   214  	return systemPool, nil
   215  }
   216  
   217  // TLSConfig returns a tls.Config for this bundle as a server certificate.
   218  func (cb CertBundle) TLSConfig() (*tls.Config, error) {
   219  	keyPair, err := cb.GenerateKeyPair()
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	serverCert, err := keyPair.CertBytes()
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	serverKey, err := keyPair.KeyBytes()
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	serverCertificate, err := tls.X509KeyPair(serverCert, serverKey)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	certPool, err := cb.CertPool()
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	config := new(tls.Config)
   244  	config.Certificates = []tls.Certificate{serverCertificate}
   245  	config.RootCAs = certPool
   246  	return config, nil
   247  }