github.com/pachyderm/pachyderm@v1.13.4/src/client/pkg/tls/cert_loader.go (about)

     1  package tls
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"sync/atomic"
     7  	"time"
     8  	"unsafe"
     9  
    10  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
    11  
    12  	log "github.com/sirupsen/logrus"
    13  )
    14  
    15  // CertLoader provides simple hot TLS certificate reloading by checking for a renewed certificate at a configurable interval
    16  type CertLoader struct {
    17  	certPath        string
    18  	keyPath         string
    19  	refreshInterval time.Duration
    20  
    21  	// cert is the current cached *tls.Certificate. It should only be accessed with atomic methods because it may be updated by the cert reloading routine.
    22  	cert     unsafe.Pointer
    23  	stopChan chan interface{}
    24  	stopped  bool
    25  }
    26  
    27  // NewCertLoader creates a new CertLoader to refresh the specified TLS key at a fixed interval
    28  func NewCertLoader(certPath, keyPath string, refreshInterval time.Duration) *CertLoader {
    29  	return &CertLoader{
    30  		certPath:        certPath,
    31  		keyPath:         keyPath,
    32  		refreshInterval: refreshInterval,
    33  	}
    34  }
    35  
    36  // LoadAndStart ensures the current TLS certificate is loaded and starts the reload routine to poll for renewed certificates
    37  func (l *CertLoader) LoadAndStart() error {
    38  	if err := l.loadCertificate(); err != nil {
    39  		return err
    40  	}
    41  	go l.reloadRoutine()
    42  	return nil
    43  }
    44  
    45  // Stop signals the reloading routine to stop
    46  func (l *CertLoader) Stop() {
    47  	if l.stopped {
    48  		return
    49  	}
    50  	l.stopped = true
    51  	close(l.stopChan)
    52  }
    53  
    54  // GetCertificate gets the currently cached certificate and fulfills
    55  func (l *CertLoader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
    56  	certPtr := atomic.LoadPointer(&l.cert)
    57  	cert := (*tls.Certificate)(certPtr)
    58  	if cert == nil {
    59  		return nil, fmt.Errorf("no cached TLS certificate available")
    60  	}
    61  	return cert, nil
    62  }
    63  
    64  func (l *CertLoader) reloadRoutine() {
    65  	t := time.NewTicker(l.refreshInterval)
    66  	for {
    67  		select {
    68  		case <-t.C:
    69  			err := l.loadCertificate()
    70  			if err != nil {
    71  				log.Error("Unable to load TLS certificate", err)
    72  			}
    73  		case <-l.stopChan:
    74  			return
    75  		}
    76  	}
    77  }
    78  
    79  func (l *CertLoader) loadCertificate() error {
    80  	log.Debugf("Reloading TLS keypair - %q %q", l.certPath, l.keyPath)
    81  	cert, err := tls.LoadX509KeyPair(l.certPath, l.keyPath)
    82  	if err != nil {
    83  		return errors.Wrapf(err, "unable to load keypair")
    84  	}
    85  	atomic.StorePointer(&l.cert, unsafe.Pointer(&cert))
    86  	return nil
    87  }