github.com/blend/go-sdk@v1.20220411.3/certutil/cert_file_watcher.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package certutil
     9  
    10  import (
    11  	"crypto/tls"
    12  	"os"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/blend/go-sdk/async"
    17  	"github.com/blend/go-sdk/ex"
    18  )
    19  
    20  // Error constants.
    21  const (
    22  	ErrTLSPathsUnset ex.Class = "tls cert or key path unset; cannot continue"
    23  )
    24  
    25  const (
    26  	// DefaultCertficicateFileWatcherPollInterval is the default poll interval when re-reading certs
    27  	DefaultCertficicateFileWatcherPollInterval = 500 * time.Millisecond
    28  )
    29  
    30  // NewCertFileWatcher creates a new CertReloader object with a reload delay
    31  func NewCertFileWatcher(keyPair KeyPair, opts ...CertFileWatcherOption) (*CertFileWatcher, error) {
    32  	if keyPair.CertPath == "" || keyPair.KeyPath == "" {
    33  		return nil, ex.New(ErrTLSPathsUnset)
    34  	}
    35  	cw := &CertFileWatcher{
    36  		latch:   async.NewLatch(),
    37  		keyPair: keyPair,
    38  	}
    39  	for _, opt := range opts {
    40  		if err := opt(cw); err != nil {
    41  			return nil, err
    42  		}
    43  	}
    44  	cert, err := tls.LoadX509KeyPair(cw.keyPair.CertPath, cw.keyPair.KeyPath)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	cw.certificate = &cert
    49  	return cw, nil
    50  }
    51  
    52  // CertFileWatcherOption is an option for a cert watcher.
    53  type CertFileWatcherOption func(*CertFileWatcher) error
    54  
    55  // CertFileWatcherOnReloadAction is the on reload action for a cert file watcher.
    56  type CertFileWatcherOnReloadAction func(*CertFileWatcher) error
    57  
    58  // OptCertFileWatcherOnReload sets the on reload handler.
    59  // If you need to capture *every* reload of the cert, including the initial one in the constructor
    60  // you must use this option.
    61  func OptCertFileWatcherOnReload(handler CertFileWatcherOnReloadAction) CertFileWatcherOption {
    62  	return func(cfw *CertFileWatcher) error {
    63  		cfw.onReload = handler
    64  		return nil
    65  	}
    66  }
    67  
    68  // OptCertFileWatcherNotifyReload sets the notify reload channel.
    69  func OptCertFileWatcherNotifyReload(notifyReload chan struct{}) CertFileWatcherOption {
    70  	return func(cfw *CertFileWatcher) error {
    71  		cfw.notifyReload = notifyReload
    72  		return nil
    73  	}
    74  }
    75  
    76  // OptCertFileWatcherPollInterval sets the poll interval .
    77  func OptCertFileWatcherPollInterval(d time.Duration) CertFileWatcherOption {
    78  	return func(cfw *CertFileWatcher) error {
    79  		cfw.pollInterval = d
    80  		return nil
    81  	}
    82  }
    83  
    84  // CertFileWatcher reloads a cert key pair when there is a change, e.g. cert renewal
    85  type CertFileWatcher struct {
    86  	latch         *async.Latch
    87  	certificateMu sync.RWMutex
    88  	certificate   *tls.Certificate
    89  	keyPair       KeyPair
    90  	pollInterval  time.Duration
    91  	notifyReload  chan struct{}
    92  	onReload      CertFileWatcherOnReloadAction
    93  }
    94  
    95  // CertPath returns the cert path.
    96  func (cw *CertFileWatcher) CertPath() string { return cw.keyPair.CertPath }
    97  
    98  // KeyPath returns the cert path.
    99  func (cw *CertFileWatcher) KeyPath() string { return cw.keyPair.KeyPath }
   100  
   101  // PollIntervalOrDefault returns the polling interval or a default.
   102  func (cw *CertFileWatcher) PollIntervalOrDefault() time.Duration {
   103  	if cw.pollInterval > 0 {
   104  		return cw.pollInterval
   105  	}
   106  	return DefaultCertficicateFileWatcherPollInterval
   107  }
   108  
   109  // Reload forces the reload of the underlying certificate.
   110  func (cw *CertFileWatcher) Reload() (err error) {
   111  	defer func() {
   112  		if cw.notifyReload != nil {
   113  			cw.notifyReload <- struct{}{}
   114  		}
   115  		if cw.onReload != nil && err == nil {
   116  			err = cw.onReload(cw)
   117  		}
   118  	}()
   119  
   120  	cert, loadErr := tls.LoadX509KeyPair(cw.keyPair.CertPath, cw.keyPair.KeyPath)
   121  	if loadErr != nil {
   122  		err = ex.New(loadErr)
   123  		return
   124  	}
   125  	cw.certificateMu.Lock()
   126  	cw.certificate = &cert
   127  	cw.certificateMu.Unlock()
   128  	return
   129  }
   130  
   131  // Certificate gets the underlying certificate, it blocks when the `cert` field is being updated
   132  func (cw *CertFileWatcher) Certificate() *tls.Certificate {
   133  	cw.certificateMu.RLock()
   134  	defer cw.certificateMu.RUnlock()
   135  	return cw.certificate
   136  }
   137  
   138  // GetCertificate gets the underlying certificate in the form that tls config expects.
   139  func (cw *CertFileWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
   140  	cw.certificateMu.RLock()
   141  	defer cw.certificateMu.RUnlock()
   142  	return cw.certificate, nil
   143  }
   144  
   145  // IsStarted returns if the underlying latch is started.
   146  func (cw *CertFileWatcher) IsStarted() bool { return cw.latch.IsStarted() }
   147  
   148  // IsStopped returns if the underlying latch is stopped.
   149  func (cw *CertFileWatcher) IsStopped() bool { return cw.latch.IsStopped() }
   150  
   151  // NotifyStarted returns the notify started channel.
   152  func (cw *CertFileWatcher) NotifyStarted() <-chan struct{} {
   153  	return cw.latch.NotifyStarted()
   154  }
   155  
   156  // NotifyStopped returns the notify stopped channel.
   157  func (cw *CertFileWatcher) NotifyStopped() <-chan struct{} {
   158  	return cw.latch.NotifyStopped()
   159  }
   160  
   161  // NotifyReload the notify reload channel.
   162  //
   163  // You must supply this channel as an option in the constructor.
   164  func (cw *CertFileWatcher) NotifyReload() <-chan struct{} {
   165  	return cw.notifyReload
   166  }
   167  
   168  // Start watches the cert and triggers a reload on change
   169  func (cw *CertFileWatcher) Start() error {
   170  	cw.latch.Starting()
   171  
   172  	certLastMod, keyLastMod, err := cw.keyPairLastModified()
   173  	if err != nil {
   174  		cw.latch.Stopped()
   175  		return err
   176  	}
   177  
   178  	ticker := time.NewTicker(cw.PollIntervalOrDefault())
   179  	defer ticker.Stop()
   180  
   181  	cw.latch.Started()
   182  	var certMod, keyMod time.Time
   183  	for {
   184  		select {
   185  		case <-ticker.C:
   186  			certMod, keyMod, err = cw.keyPairLastModified()
   187  			if err != nil {
   188  				return err
   189  			}
   190  			// wait for both to update
   191  			if keyMod.After(keyLastMod) && certMod.After(certLastMod) {
   192  				if err = cw.Reload(); err != nil {
   193  					return err
   194  				}
   195  				keyLastMod = keyMod
   196  				certLastMod = certMod
   197  			}
   198  		case <-cw.latch.NotifyStopping():
   199  			cw.latch.Stopped()
   200  			return nil
   201  		}
   202  	}
   203  }
   204  
   205  // Stop stops the watcher.
   206  func (cw *CertFileWatcher) Stop() error {
   207  	if !cw.latch.CanStop() {
   208  		return async.ErrCannotStop
   209  	}
   210  	cw.latch.WaitStopped()
   211  	cw.latch.Reset()
   212  	return nil
   213  }
   214  
   215  func (cw *CertFileWatcher) keyPairLastModified() (cert time.Time, key time.Time, err error) {
   216  	var certStat, keyStat os.FileInfo
   217  	certStat, err = os.Stat(cw.keyPair.CertPath)
   218  	if err != nil {
   219  		return
   220  	}
   221  	keyStat, err = os.Stat(cw.keyPair.KeyPath)
   222  	if err != nil {
   223  		return
   224  	}
   225  	cert = certStat.ModTime()
   226  	key = keyStat.ModTime()
   227  	return
   228  }