github.com/blend/go-sdk@v1.20220411.3/certutil/cert_manager.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package certutil 9 10 import ( 11 "crypto/tls" 12 "crypto/x509" 13 "sync" 14 15 "github.com/blend/go-sdk/ex" 16 ) 17 18 // NewCertManagerWithKeyPairs returns a new cert pool from key pairs. 19 func NewCertManagerWithKeyPairs(server KeyPair, certificateAuthorities []KeyPair, clients ...KeyPair) (*CertManager, error) { 20 serverCert, err := server.CertBytes() 21 if err != nil { 22 return nil, err 23 } 24 serverKey, err := server.KeyBytes() 25 if err != nil { 26 return nil, err 27 } 28 29 serverCertificate, err := tls.X509KeyPair(serverCert, serverKey) 30 if err != nil { 31 return nil, err 32 } 33 caCertPool, err := ExtendSystemCertPool(certificateAuthorities...) 34 if err != nil { 35 return nil, err 36 } 37 38 clientCerts := map[string][]byte{} 39 for _, client := range clients { 40 certPEM, err := client.CertBytes() 41 if err != nil { 42 return nil, err 43 } 44 commonNames, err := CommonNamesForCertPEM(certPEM) 45 if err != nil { 46 return nil, err 47 } 48 if len(commonNames) == 0 { 49 return nil, ex.New(ErrInvalidCertPEM) 50 } 51 clientCerts[commonNames[0]] = certPEM 52 } 53 54 cm := NewCertManager(OptCertManagerServerCerts(serverCertificate), OptCertManagerRootCAs(caCertPool)) 55 return cm, cm.UpdateClientCerts(clientCerts) 56 } 57 58 // NewCertManager returns a new cert manager. 59 func NewCertManager(options ...CertManagerOption) *CertManager { 60 certManager := &CertManager{ 61 TLSConfig: &tls.Config{ 62 ClientAuth: tls.RequireAndVerifyClientCert, 63 }, 64 ClientCerts: map[string][]byte{}, 65 } 66 certManager.TLSConfig.GetConfigForClient = certManager.GetConfigForClient 67 68 for _, option := range options { 69 option(certManager) 70 } 71 return certManager 72 } 73 74 // CertManagerOption is an option for a cert manager. 75 type CertManagerOption func(*CertManager) 76 77 // OptCertManagerRootCAs sets a field on the cert manager. 78 func OptCertManagerRootCAs(pool *x509.CertPool) CertManagerOption { 79 return func(cm *CertManager) { cm.TLSConfig.RootCAs = pool } 80 } 81 82 // OptCertManagerServerCerts sets a field on the cert manager. 83 func OptCertManagerServerCerts(server ...tls.Certificate) CertManagerOption { 84 return func(cm *CertManager) { cm.TLSConfig.Certificates = server } 85 } 86 87 // OptCertManagerClientCerts sets a field on the cert manager. 88 func OptCertManagerClientCerts(client *x509.CertPool) CertManagerOption { 89 return func(cm *CertManager) { cm.TLSConfig.ClientCAs = client } 90 } 91 92 // CertManager is a pool of client certs. 93 type CertManager struct { 94 sync.RWMutex 95 TLSConfig *tls.Config 96 ClientCerts map[string][]byte 97 } 98 99 // ClientCertUIDs returns all the client cert uids. 100 func (cm *CertManager) ClientCertUIDs() (output []string) { 101 for uid := range cm.ClientCerts { 102 output = append(output, uid) 103 } 104 return 105 } 106 107 // HasClientCert returns if the manager has a client cert. 108 func (cm *CertManager) HasClientCert(uid string) (has bool) { 109 cm.RLock() 110 _, has = cm.ClientCerts[uid] 111 cm.RUnlock() 112 return 113 } 114 115 // AddClientCert adds a client cert to the bunde and refreshes the bundle. 116 func (cm *CertManager) AddClientCert(clientCert []byte) error { 117 cm.Lock() 118 defer cm.Unlock() 119 120 commonNames, err := ParseCertPEM(clientCert) 121 if err != nil { 122 return err 123 } 124 if len(commonNames) == 0 { 125 return ex.New(ErrInvalidCertPEM) 126 } 127 cm.ClientCerts[commonNames[0].Subject.CommonName] = clientCert 128 return cm.RefreshClientCerts() 129 } 130 131 // RemoveClientCert removes a client cert by uid. 132 func (cm *CertManager) RemoveClientCert(uid string) error { 133 cm.Lock() 134 defer cm.Unlock() 135 delete(cm.ClientCerts, uid) 136 return cm.RefreshClientCerts() 137 } 138 139 // UpdateClientCerts sets the client cert bundle fully. 140 func (cm *CertManager) UpdateClientCerts(clientCerts map[string][]byte) error { 141 cm.Lock() 142 defer cm.Unlock() 143 cm.ClientCerts = clientCerts 144 return cm.RefreshClientCerts() 145 } 146 147 // RefreshClientCerts reloads the client cert bundle. 148 func (cm *CertManager) RefreshClientCerts() error { 149 pool := x509.NewCertPool() 150 for uid, cert := range cm.ClientCerts { 151 if ok := pool.AppendCertsFromPEM(cert); !ok { 152 return ex.New("invalid ca cert for client cert pool", ex.OptMessagef("cert uid: %s", uid)) 153 } 154 } 155 cm.TLSConfig.ClientCAs = pool 156 // cm.TLSConfig.BuildNameToCertificate() 157 return nil 158 } 159 160 // GetConfigForClient gets a tls config for a given client hello. 161 func (cm *CertManager) GetConfigForClient(sni *tls.ClientHelloInfo) (config *tls.Config, _ error) { 162 cm.RLock() 163 config = cm.TLSConfig.Clone() 164 cm.RUnlock() 165 return 166 }