github.com/blend/go-sdk@v1.20220411.3/certutil/cert_manager.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package certutil
     9  
    10  import (
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"sync"
    14  
    15  	"github.com/blend/go-sdk/ex"
    16  )
    17  
    18  // NewCertManagerWithKeyPairs returns a new cert pool from key pairs.
    19  func NewCertManagerWithKeyPairs(server KeyPair, certificateAuthorities []KeyPair, clients ...KeyPair) (*CertManager, error) {
    20  	serverCert, err := server.CertBytes()
    21  	if err != nil {
    22  		return nil, err
    23  	}
    24  	serverKey, err := server.KeyBytes()
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  
    29  	serverCertificate, err := tls.X509KeyPair(serverCert, serverKey)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	caCertPool, err := ExtendSystemCertPool(certificateAuthorities...)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	clientCerts := map[string][]byte{}
    39  	for _, client := range clients {
    40  		certPEM, err := client.CertBytes()
    41  		if err != nil {
    42  			return nil, err
    43  		}
    44  		commonNames, err := CommonNamesForCertPEM(certPEM)
    45  		if err != nil {
    46  			return nil, err
    47  		}
    48  		if len(commonNames) == 0 {
    49  			return nil, ex.New(ErrInvalidCertPEM)
    50  		}
    51  		clientCerts[commonNames[0]] = certPEM
    52  	}
    53  
    54  	cm := NewCertManager(OptCertManagerServerCerts(serverCertificate), OptCertManagerRootCAs(caCertPool))
    55  	return cm, cm.UpdateClientCerts(clientCerts)
    56  }
    57  
    58  // NewCertManager returns a new cert manager.
    59  func NewCertManager(options ...CertManagerOption) *CertManager {
    60  	certManager := &CertManager{
    61  		TLSConfig: &tls.Config{
    62  			ClientAuth: tls.RequireAndVerifyClientCert,
    63  		},
    64  		ClientCerts: map[string][]byte{},
    65  	}
    66  	certManager.TLSConfig.GetConfigForClient = certManager.GetConfigForClient
    67  
    68  	for _, option := range options {
    69  		option(certManager)
    70  	}
    71  	return certManager
    72  }
    73  
    74  // CertManagerOption is an option for a cert manager.
    75  type CertManagerOption func(*CertManager)
    76  
    77  // OptCertManagerRootCAs sets a field on the cert manager.
    78  func OptCertManagerRootCAs(pool *x509.CertPool) CertManagerOption {
    79  	return func(cm *CertManager) { cm.TLSConfig.RootCAs = pool }
    80  }
    81  
    82  // OptCertManagerServerCerts sets a field on the cert manager.
    83  func OptCertManagerServerCerts(server ...tls.Certificate) CertManagerOption {
    84  	return func(cm *CertManager) { cm.TLSConfig.Certificates = server }
    85  }
    86  
    87  // OptCertManagerClientCerts sets a field on the cert manager.
    88  func OptCertManagerClientCerts(client *x509.CertPool) CertManagerOption {
    89  	return func(cm *CertManager) { cm.TLSConfig.ClientCAs = client }
    90  }
    91  
    92  // CertManager is a pool of client certs.
    93  type CertManager struct {
    94  	sync.RWMutex
    95  	TLSConfig   *tls.Config
    96  	ClientCerts map[string][]byte
    97  }
    98  
    99  // ClientCertUIDs returns all the client cert uids.
   100  func (cm *CertManager) ClientCertUIDs() (output []string) {
   101  	for uid := range cm.ClientCerts {
   102  		output = append(output, uid)
   103  	}
   104  	return
   105  }
   106  
   107  // HasClientCert returns if the manager has a client cert.
   108  func (cm *CertManager) HasClientCert(uid string) (has bool) {
   109  	cm.RLock()
   110  	_, has = cm.ClientCerts[uid]
   111  	cm.RUnlock()
   112  	return
   113  }
   114  
   115  // AddClientCert adds a client cert to the bunde and refreshes the bundle.
   116  func (cm *CertManager) AddClientCert(clientCert []byte) error {
   117  	cm.Lock()
   118  	defer cm.Unlock()
   119  
   120  	commonNames, err := ParseCertPEM(clientCert)
   121  	if err != nil {
   122  		return err
   123  	}
   124  	if len(commonNames) == 0 {
   125  		return ex.New(ErrInvalidCertPEM)
   126  	}
   127  	cm.ClientCerts[commonNames[0].Subject.CommonName] = clientCert
   128  	return cm.RefreshClientCerts()
   129  }
   130  
   131  // RemoveClientCert removes a client cert by uid.
   132  func (cm *CertManager) RemoveClientCert(uid string) error {
   133  	cm.Lock()
   134  	defer cm.Unlock()
   135  	delete(cm.ClientCerts, uid)
   136  	return cm.RefreshClientCerts()
   137  }
   138  
   139  // UpdateClientCerts sets the client cert bundle fully.
   140  func (cm *CertManager) UpdateClientCerts(clientCerts map[string][]byte) error {
   141  	cm.Lock()
   142  	defer cm.Unlock()
   143  	cm.ClientCerts = clientCerts
   144  	return cm.RefreshClientCerts()
   145  }
   146  
   147  // RefreshClientCerts reloads the client cert bundle.
   148  func (cm *CertManager) RefreshClientCerts() error {
   149  	pool := x509.NewCertPool()
   150  	for uid, cert := range cm.ClientCerts {
   151  		if ok := pool.AppendCertsFromPEM(cert); !ok {
   152  			return ex.New("invalid ca cert for client cert pool", ex.OptMessagef("cert uid: %s", uid))
   153  		}
   154  	}
   155  	cm.TLSConfig.ClientCAs = pool
   156  	// cm.TLSConfig.BuildNameToCertificate()
   157  	return nil
   158  }
   159  
   160  // GetConfigForClient gets a tls config for a given client hello.
   161  func (cm *CertManager) GetConfigForClient(sni *tls.ClientHelloInfo) (config *tls.Config, _ error) {
   162  	cm.RLock()
   163  	config = cm.TLSConfig.Clone()
   164  	cm.RUnlock()
   165  	return
   166  }