github.com/blend/go-sdk@v1.20220411.3/certutil/cert_file_watcher.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package certutil 9 10 import ( 11 "crypto/tls" 12 "os" 13 "sync" 14 "time" 15 16 "github.com/blend/go-sdk/async" 17 "github.com/blend/go-sdk/ex" 18 ) 19 20 // Error constants. 21 const ( 22 ErrTLSPathsUnset ex.Class = "tls cert or key path unset; cannot continue" 23 ) 24 25 const ( 26 // DefaultCertficicateFileWatcherPollInterval is the default poll interval when re-reading certs 27 DefaultCertficicateFileWatcherPollInterval = 500 * time.Millisecond 28 ) 29 30 // NewCertFileWatcher creates a new CertReloader object with a reload delay 31 func NewCertFileWatcher(keyPair KeyPair, opts ...CertFileWatcherOption) (*CertFileWatcher, error) { 32 if keyPair.CertPath == "" || keyPair.KeyPath == "" { 33 return nil, ex.New(ErrTLSPathsUnset) 34 } 35 cw := &CertFileWatcher{ 36 latch: async.NewLatch(), 37 keyPair: keyPair, 38 } 39 for _, opt := range opts { 40 if err := opt(cw); err != nil { 41 return nil, err 42 } 43 } 44 cert, err := tls.LoadX509KeyPair(cw.keyPair.CertPath, cw.keyPair.KeyPath) 45 if err != nil { 46 return nil, err 47 } 48 cw.certificate = &cert 49 return cw, nil 50 } 51 52 // CertFileWatcherOption is an option for a cert watcher. 53 type CertFileWatcherOption func(*CertFileWatcher) error 54 55 // CertFileWatcherOnReloadAction is the on reload action for a cert file watcher. 56 type CertFileWatcherOnReloadAction func(*CertFileWatcher) error 57 58 // OptCertFileWatcherOnReload sets the on reload handler. 59 // If you need to capture *every* reload of the cert, including the initial one in the constructor 60 // you must use this option. 61 func OptCertFileWatcherOnReload(handler CertFileWatcherOnReloadAction) CertFileWatcherOption { 62 return func(cfw *CertFileWatcher) error { 63 cfw.onReload = handler 64 return nil 65 } 66 } 67 68 // OptCertFileWatcherNotifyReload sets the notify reload channel. 69 func OptCertFileWatcherNotifyReload(notifyReload chan struct{}) CertFileWatcherOption { 70 return func(cfw *CertFileWatcher) error { 71 cfw.notifyReload = notifyReload 72 return nil 73 } 74 } 75 76 // OptCertFileWatcherPollInterval sets the poll interval . 77 func OptCertFileWatcherPollInterval(d time.Duration) CertFileWatcherOption { 78 return func(cfw *CertFileWatcher) error { 79 cfw.pollInterval = d 80 return nil 81 } 82 } 83 84 // CertFileWatcher reloads a cert key pair when there is a change, e.g. cert renewal 85 type CertFileWatcher struct { 86 latch *async.Latch 87 certificateMu sync.RWMutex 88 certificate *tls.Certificate 89 keyPair KeyPair 90 pollInterval time.Duration 91 notifyReload chan struct{} 92 onReload CertFileWatcherOnReloadAction 93 } 94 95 // CertPath returns the cert path. 96 func (cw *CertFileWatcher) CertPath() string { return cw.keyPair.CertPath } 97 98 // KeyPath returns the cert path. 99 func (cw *CertFileWatcher) KeyPath() string { return cw.keyPair.KeyPath } 100 101 // PollIntervalOrDefault returns the polling interval or a default. 102 func (cw *CertFileWatcher) PollIntervalOrDefault() time.Duration { 103 if cw.pollInterval > 0 { 104 return cw.pollInterval 105 } 106 return DefaultCertficicateFileWatcherPollInterval 107 } 108 109 // Reload forces the reload of the underlying certificate. 110 func (cw *CertFileWatcher) Reload() (err error) { 111 defer func() { 112 if cw.notifyReload != nil { 113 cw.notifyReload <- struct{}{} 114 } 115 if cw.onReload != nil && err == nil { 116 err = cw.onReload(cw) 117 } 118 }() 119 120 cert, loadErr := tls.LoadX509KeyPair(cw.keyPair.CertPath, cw.keyPair.KeyPath) 121 if loadErr != nil { 122 err = ex.New(loadErr) 123 return 124 } 125 cw.certificateMu.Lock() 126 cw.certificate = &cert 127 cw.certificateMu.Unlock() 128 return 129 } 130 131 // Certificate gets the underlying certificate, it blocks when the `cert` field is being updated 132 func (cw *CertFileWatcher) Certificate() *tls.Certificate { 133 cw.certificateMu.RLock() 134 defer cw.certificateMu.RUnlock() 135 return cw.certificate 136 } 137 138 // GetCertificate gets the underlying certificate in the form that tls config expects. 139 func (cw *CertFileWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { 140 cw.certificateMu.RLock() 141 defer cw.certificateMu.RUnlock() 142 return cw.certificate, nil 143 } 144 145 // IsStarted returns if the underlying latch is started. 146 func (cw *CertFileWatcher) IsStarted() bool { return cw.latch.IsStarted() } 147 148 // IsStopped returns if the underlying latch is stopped. 149 func (cw *CertFileWatcher) IsStopped() bool { return cw.latch.IsStopped() } 150 151 // NotifyStarted returns the notify started channel. 152 func (cw *CertFileWatcher) NotifyStarted() <-chan struct{} { 153 return cw.latch.NotifyStarted() 154 } 155 156 // NotifyStopped returns the notify stopped channel. 157 func (cw *CertFileWatcher) NotifyStopped() <-chan struct{} { 158 return cw.latch.NotifyStopped() 159 } 160 161 // NotifyReload the notify reload channel. 162 // 163 // You must supply this channel as an option in the constructor. 164 func (cw *CertFileWatcher) NotifyReload() <-chan struct{} { 165 return cw.notifyReload 166 } 167 168 // Start watches the cert and triggers a reload on change 169 func (cw *CertFileWatcher) Start() error { 170 cw.latch.Starting() 171 172 certLastMod, keyLastMod, err := cw.keyPairLastModified() 173 if err != nil { 174 cw.latch.Stopped() 175 return err 176 } 177 178 ticker := time.NewTicker(cw.PollIntervalOrDefault()) 179 defer ticker.Stop() 180 181 cw.latch.Started() 182 var certMod, keyMod time.Time 183 for { 184 select { 185 case <-ticker.C: 186 certMod, keyMod, err = cw.keyPairLastModified() 187 if err != nil { 188 return err 189 } 190 // wait for both to update 191 if keyMod.After(keyLastMod) && certMod.After(certLastMod) { 192 if err = cw.Reload(); err != nil { 193 return err 194 } 195 keyLastMod = keyMod 196 certLastMod = certMod 197 } 198 case <-cw.latch.NotifyStopping(): 199 cw.latch.Stopped() 200 return nil 201 } 202 } 203 } 204 205 // Stop stops the watcher. 206 func (cw *CertFileWatcher) Stop() error { 207 if !cw.latch.CanStop() { 208 return async.ErrCannotStop 209 } 210 cw.latch.WaitStopped() 211 cw.latch.Reset() 212 return nil 213 } 214 215 func (cw *CertFileWatcher) keyPairLastModified() (cert time.Time, key time.Time, err error) { 216 var certStat, keyStat os.FileInfo 217 certStat, err = os.Stat(cw.keyPair.CertPath) 218 if err != nil { 219 return 220 } 221 keyStat, err = os.Stat(cw.keyPair.KeyPath) 222 if err != nil { 223 return 224 } 225 cert = certStat.ModTime() 226 key = keyStat.ModTime() 227 return 228 }