go.temporal.io/server@v1.23.0/common/rpc/encryption/local_store_cert_provider.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package encryption
    26  
    27  import (
    28  	"bytes"
    29  	"crypto/md5"
    30  	"crypto/tls"
    31  	"crypto/x509"
    32  	"encoding/base64"
    33  	"encoding/pem"
    34  	"errors"
    35  	"fmt"
    36  	"os"
    37  	"sync"
    38  	"time"
    39  
    40  	"go.temporal.io/server/common/config"
    41  	"go.temporal.io/server/common/log"
    42  	"go.temporal.io/server/common/log/tag"
    43  )
    44  
    45  var _ CertProvider = (*localStoreCertProvider)(nil)
    46  var _ CertExpirationChecker = (*localStoreCertProvider)(nil)
    47  
    48  type certCache struct {
    49  	serverCert          *tls.Certificate
    50  	workerCert          *tls.Certificate
    51  	clientCAPool        *x509.CertPool
    52  	serverCAPool        *x509.CertPool
    53  	serverCAsWorkerPool *x509.CertPool
    54  	clientCACerts       []*x509.Certificate // copies of certs in the clientCAPool CertPool for expiration checks
    55  	serverCACerts       []*x509.Certificate // copies of certs in the serverCAPool CertPool for expiration checks
    56  	serverCACertsWorker []*x509.Certificate // copies of certs in the serverCAsWorkerPool CertPool for expiration checks
    57  }
    58  
    59  type localStoreCertProvider struct {
    60  	sync.RWMutex
    61  
    62  	tlsSettings          *config.GroupTLS
    63  	workerTLSSettings    *config.WorkerTLS
    64  	isLegacyWorkerConfig bool
    65  	legacyWorkerSettings *config.ClientTLS
    66  
    67  	certs           *certCache
    68  	refreshInterval time.Duration
    69  
    70  	ticker *time.Ticker
    71  	stop   chan bool
    72  	logger log.Logger
    73  }
    74  
    75  type loadOrDecodeDataFunc func(item string) ([]byte, error)
    76  
    77  type tlsCertFetcher func() (*tls.Certificate, error)
    78  
    79  func (s *localStoreCertProvider) initialize() {
    80  
    81  	if s.refreshInterval != 0 {
    82  		s.stop = make(chan bool)
    83  		s.ticker = time.NewTicker(s.refreshInterval)
    84  		go s.refreshCerts()
    85  	}
    86  }
    87  
    88  func NewLocalStoreCertProvider(
    89  	tlsSettings *config.GroupTLS,
    90  	workerTlsSettings *config.WorkerTLS,
    91  	legacyWorkerSettings *config.ClientTLS,
    92  	refreshInterval time.Duration,
    93  	logger log.Logger) CertProvider {
    94  
    95  	provider := &localStoreCertProvider{
    96  		tlsSettings:          tlsSettings,
    97  		workerTLSSettings:    workerTlsSettings,
    98  		legacyWorkerSettings: legacyWorkerSettings,
    99  		isLegacyWorkerConfig: legacyWorkerSettings != nil,
   100  		logger:               logger,
   101  		refreshInterval:      refreshInterval,
   102  	}
   103  	provider.initialize()
   104  	return provider
   105  }
   106  
   107  func (s *localStoreCertProvider) Close() {
   108  
   109  	if s.ticker != nil {
   110  		s.ticker.Stop()
   111  	}
   112  	if s.stop != nil {
   113  		s.stop <- true
   114  		close(s.stop)
   115  	}
   116  }
   117  
   118  func (s *localStoreCertProvider) FetchServerCertificate() (*tls.Certificate, error) {
   119  
   120  	if s.tlsSettings == nil {
   121  		return nil, nil
   122  	}
   123  	certs, err := s.getCerts()
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	return certs.serverCert, nil
   128  }
   129  
   130  func (s *localStoreCertProvider) FetchClientCAs() (*x509.CertPool, error) {
   131  
   132  	if s.tlsSettings == nil {
   133  		return nil, nil
   134  	}
   135  	certs, err := s.getCerts()
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	return certs.clientCAPool, nil
   140  }
   141  
   142  func (s *localStoreCertProvider) FetchServerRootCAsForClient(isWorker bool) (*x509.CertPool, error) {
   143  
   144  	clientSettings := s.getClientTLSSettings(isWorker)
   145  	if clientSettings == nil {
   146  		return nil, nil
   147  	}
   148  	certs, err := s.getCerts()
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	if isWorker {
   154  		return certs.serverCAsWorkerPool, nil
   155  	}
   156  
   157  	return certs.serverCAPool, nil
   158  }
   159  
   160  func (s *localStoreCertProvider) FetchClientCertificate(isWorker bool) (*tls.Certificate, error) {
   161  
   162  	if !s.isTLSEnabled() {
   163  		return nil, nil
   164  	}
   165  	certs, err := s.getCerts()
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	if isWorker {
   170  		return certs.workerCert, nil
   171  	}
   172  	return certs.serverCert, nil
   173  }
   174  
   175  func (s *localStoreCertProvider) GetExpiringCerts(timeWindow time.Duration,
   176  ) (CertExpirationMap, CertExpirationMap, error) {
   177  
   178  	expiring := make(CertExpirationMap)
   179  	expired := make(CertExpirationMap)
   180  	when := time.Now().UTC().Add(timeWindow)
   181  
   182  	certs, err := s.getCerts()
   183  	if err != nil {
   184  		return nil, nil, err
   185  	}
   186  
   187  	checkError := checkTLSCertForExpiration(certs.serverCert, when, expiring, expired)
   188  	err = appendError(err, checkError)
   189  	checkError = checkTLSCertForExpiration(certs.workerCert, when, expiring, expired)
   190  	err = appendError(err, checkError)
   191  
   192  	checkCertsForExpiration(certs.clientCACerts, when, expiring, expired)
   193  	checkCertsForExpiration(certs.serverCACerts, when, expiring, expired)
   194  	checkCertsForExpiration(certs.serverCACertsWorker, when, expiring, expired)
   195  
   196  	return expiring, expired, err
   197  }
   198  
   199  func (s *localStoreCertProvider) getCerts() (*certCache, error) {
   200  
   201  	s.RLock()
   202  	if s.certs != nil {
   203  		defer s.RUnlock()
   204  		return s.certs, nil
   205  	}
   206  	s.RUnlock()
   207  	s.Lock()
   208  	defer s.Unlock()
   209  
   210  	if s.certs != nil {
   211  		return s.certs, nil
   212  	}
   213  
   214  	newCerts, err := s.loadCerts()
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	if newCerts == nil {
   220  		s.certs = &certCache{}
   221  	} else {
   222  		s.certs = newCerts
   223  	}
   224  	return s.certs, nil
   225  }
   226  
   227  func (s *localStoreCertProvider) loadCerts() (*certCache, error) {
   228  
   229  	if !s.isTLSEnabled() {
   230  		return nil, nil
   231  	}
   232  
   233  	newCerts := certCache{}
   234  	var err error
   235  
   236  	if s.tlsSettings != nil {
   237  		newCerts.serverCert, err = s.fetchCertificate(s.tlsSettings.Server.CertFile, s.tlsSettings.Server.CertData,
   238  			s.tlsSettings.Server.KeyFile, s.tlsSettings.Server.KeyData)
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  
   243  		certPool, certs, err := s.fetchCAs(s.tlsSettings.Server.ClientCAFiles, s.tlsSettings.Server.ClientCAData,
   244  			"cannot specify both clientCAFiles and clientCAData properties")
   245  		if err != nil {
   246  			return nil, err
   247  		}
   248  		newCerts.clientCAPool = certPool
   249  		newCerts.clientCACerts = certs
   250  	}
   251  
   252  	if s.isLegacyWorkerConfig {
   253  		newCerts.workerCert = newCerts.serverCert
   254  	} else {
   255  		if s.workerTLSSettings != nil {
   256  			newCerts.workerCert, err = s.fetchCertificate(s.workerTLSSettings.CertFile, s.workerTLSSettings.CertData,
   257  				s.workerTLSSettings.KeyFile, s.workerTLSSettings.KeyData)
   258  			if err != nil {
   259  				return nil, err
   260  			}
   261  		}
   262  	}
   263  
   264  	nonWorkerPool, nonWorkerCerts, err := s.loadServerCACerts(false)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  	newCerts.serverCAPool = nonWorkerPool
   269  	newCerts.serverCACerts = nonWorkerCerts
   270  
   271  	workerPool, workerCerts, err := s.loadServerCACerts(true)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  	newCerts.serverCAsWorkerPool = workerPool
   276  	newCerts.serverCACertsWorker = workerCerts
   277  
   278  	return &newCerts, nil
   279  }
   280  
   281  func (s *localStoreCertProvider) fetchCertificate(
   282  	certFile string, certData string,
   283  	keyFile string, keyData string) (*tls.Certificate, error) {
   284  	if certFile == "" && certData == "" {
   285  		return nil, nil
   286  	}
   287  
   288  	if certFile != "" && certData != "" {
   289  		return nil, errors.New("only one of certFile or certData properties should be spcified")
   290  	}
   291  
   292  	var certBytes []byte
   293  	var keyBytes []byte
   294  	var err error
   295  
   296  	if certFile != "" {
   297  		s.logger.Info("loading certificate from file", tag.TLSCertFile(certFile))
   298  		certBytes, err = os.ReadFile(certFile)
   299  		if err != nil {
   300  			return nil, err
   301  		}
   302  	} else if certData != "" {
   303  		certBytes, err = base64.StdEncoding.DecodeString(certData)
   304  		if err != nil {
   305  			return nil, fmt.Errorf("TLS public certificate could not be decoded: %w", err)
   306  		}
   307  	}
   308  
   309  	if keyFile != "" {
   310  		s.logger.Info("loading private key from file", tag.TLSKeyFile(keyFile))
   311  		keyBytes, err = os.ReadFile(keyFile)
   312  		if err != nil {
   313  			return nil, err
   314  		}
   315  	} else if keyData != "" {
   316  		keyBytes, err = base64.StdEncoding.DecodeString(keyData)
   317  		if err != nil {
   318  			return nil, fmt.Errorf("TLS private key could not be decoded: %w", err)
   319  		}
   320  	}
   321  
   322  	cert, err := tls.X509KeyPair(certBytes, keyBytes)
   323  	if err != nil {
   324  		return nil, fmt.Errorf("loading tls certificate failed: %v", err)
   325  	}
   326  
   327  	return &cert, nil
   328  }
   329  
   330  func (s *localStoreCertProvider) getClientTLSSettings(isWorker bool) *config.ClientTLS {
   331  	if isWorker && s.workerTLSSettings != nil {
   332  		return &s.workerTLSSettings.Client // explicit system worker case
   333  	} else if isWorker {
   334  		return s.legacyWorkerSettings // legacy config case when we use Frontend.Client settings
   335  	} else {
   336  		if s.tlsSettings == nil {
   337  			return nil
   338  		}
   339  		return &s.tlsSettings.Client // internode client case
   340  	}
   341  }
   342  
   343  func (s *localStoreCertProvider) loadServerCACerts(isWorker bool) (*x509.CertPool, []*x509.Certificate, error) {
   344  
   345  	clientSettings := s.getClientTLSSettings(isWorker)
   346  	if clientSettings == nil {
   347  		return nil, nil, nil
   348  	}
   349  
   350  	return s.fetchCAs(clientSettings.RootCAFiles, clientSettings.RootCAData,
   351  		"cannot specify both rootCAFiles and rootCAData properties")
   352  }
   353  
   354  func (s *localStoreCertProvider) fetchCAs(
   355  	files []string,
   356  	data []string,
   357  	duplicateErrorMessage string) (*x509.CertPool, []*x509.Certificate, error) {
   358  	if len(files) == 0 && len(data) == 0 {
   359  		return nil, nil, nil
   360  	}
   361  
   362  	caPoolFromFiles, caCertsFromFiles, err := s.buildCAPoolFromFiles(files)
   363  	if err != nil {
   364  		return nil, nil, err
   365  	}
   366  
   367  	caPoolFromData, caCertsFromData, err := buildCAPoolFromData(data)
   368  	if err != nil {
   369  		return nil, nil, err
   370  	}
   371  
   372  	if caPoolFromFiles != nil && caPoolFromData != nil {
   373  		return nil, nil, errors.New(duplicateErrorMessage)
   374  	}
   375  
   376  	var certPool *x509.CertPool
   377  	var certs []*x509.Certificate
   378  
   379  	if caPoolFromData != nil {
   380  		certPool = caPoolFromData
   381  		certs = caCertsFromData
   382  	} else {
   383  		certPool = caPoolFromFiles
   384  		certs = caCertsFromFiles
   385  	}
   386  
   387  	return certPool, certs, nil
   388  }
   389  
   390  func checkTLSCertForExpiration(
   391  	cert *tls.Certificate,
   392  	when time.Time,
   393  	expiring CertExpirationMap,
   394  	expired CertExpirationMap,
   395  ) error {
   396  
   397  	if cert == nil {
   398  		return nil
   399  	}
   400  
   401  	x509cert, err := x509.ParseCertificate(cert.Certificate[0])
   402  	if err != nil {
   403  		return err
   404  	}
   405  	checkCertForExpiration(x509cert, when, expiring, expired)
   406  	return nil
   407  }
   408  
   409  func checkCertsForExpiration(
   410  	certs []*x509.Certificate,
   411  	time time.Time,
   412  	expiring CertExpirationMap,
   413  	expired CertExpirationMap,
   414  ) {
   415  
   416  	for _, cert := range certs {
   417  		checkCertForExpiration(cert, time, expiring, expired)
   418  	}
   419  }
   420  
   421  func checkCertForExpiration(
   422  	cert *x509.Certificate,
   423  	pointInTime time.Time,
   424  	expiring CertExpirationMap,
   425  	expired CertExpirationMap,
   426  ) {
   427  
   428  	if cert != nil && expiresBefore(cert, pointInTime) {
   429  		record := CertExpirationData{
   430  			Thumbprint: md5.Sum(cert.Raw),
   431  			IsCA:       cert.IsCA,
   432  			DNSNames:   cert.DNSNames,
   433  			Expiration: cert.NotAfter,
   434  		}
   435  		if record.Expiration.Before(time.Now().UTC()) {
   436  			expired[record.Thumbprint] = record
   437  		} else {
   438  			expiring[record.Thumbprint] = record
   439  		}
   440  	}
   441  }
   442  
   443  func expiresBefore(cert *x509.Certificate, pointInTime time.Time) bool {
   444  	return cert.NotAfter.Before(pointInTime)
   445  }
   446  
   447  func buildCAPoolFromData(caData []string) (*x509.CertPool, []*x509.Certificate, error) {
   448  
   449  	return buildCAPool(caData, base64.StdEncoding.DecodeString)
   450  }
   451  
   452  func (s *localStoreCertProvider) buildCAPoolFromFiles(caFiles []string) (*x509.CertPool, []*x509.Certificate, error) {
   453  	if len(caFiles) == 0 {
   454  		return nil, nil, nil
   455  	}
   456  
   457  	s.logger.Info("loading CA certs from", tag.TLSCertFiles(caFiles))
   458  	return buildCAPool(caFiles, os.ReadFile)
   459  }
   460  
   461  func buildCAPool(cas []string, getBytes loadOrDecodeDataFunc) (*x509.CertPool, []*x509.Certificate, error) {
   462  
   463  	var caPool *x509.CertPool
   464  	var certs []*x509.Certificate
   465  
   466  	for _, ca := range cas {
   467  		if ca == "" {
   468  			continue
   469  		}
   470  
   471  		caBytes, err := getBytes(ca)
   472  		if err != nil {
   473  			return nil, nil, fmt.Errorf("failed to decode ca cert: %w", err)
   474  		}
   475  
   476  		if caPool == nil {
   477  			caPool = x509.NewCertPool()
   478  		}
   479  		if !caPool.AppendCertsFromPEM(caBytes) {
   480  			return nil, nil, errors.New("unknown failure constructing cert pool for ca")
   481  		}
   482  
   483  		cert, err := parseCert(caBytes)
   484  		if err != nil {
   485  			return nil, nil, fmt.Errorf("failed to parse x509 certificate: %w", err)
   486  		}
   487  		certs = append(certs, cert)
   488  	}
   489  	return caPool, certs, nil
   490  }
   491  
   492  // logic borrowed from tls.X509KeyPair()
   493  func parseCert(bytes []byte) (*x509.Certificate, error) {
   494  
   495  	var certBytes [][]byte
   496  	for {
   497  		var certDERBlock *pem.Block
   498  		certDERBlock, bytes = pem.Decode(bytes)
   499  		if certDERBlock == nil {
   500  			break
   501  		}
   502  		if certDERBlock.Type == "CERTIFICATE" {
   503  			certBytes = append(certBytes, certDERBlock.Bytes)
   504  		}
   505  	}
   506  
   507  	if len(certBytes) == 0 || len(certBytes[0]) == 0 {
   508  		return nil, fmt.Errorf("failed to decode PEM certificate data")
   509  	}
   510  	return x509.ParseCertificate(certBytes[0])
   511  }
   512  
   513  func appendError(aggregatedErr error, err error) error {
   514  	if aggregatedErr == nil {
   515  		return err
   516  	}
   517  	if err == nil {
   518  		return aggregatedErr
   519  	}
   520  	return fmt.Errorf("%v, %w", aggregatedErr, err)
   521  }
   522  
   523  func (s *localStoreCertProvider) refreshCerts() {
   524  
   525  	for {
   526  		select {
   527  		case <-s.stop:
   528  			return
   529  		case <-s.ticker.C:
   530  		}
   531  
   532  		newCerts, err := s.loadCerts()
   533  		if err != nil {
   534  			s.logger.Error("failed to load certificates", tag.Error(err))
   535  			continue
   536  		}
   537  
   538  		s.RLock()
   539  		currentCerts := s.certs
   540  		s.RUnlock()
   541  		if currentCerts.isEqual(newCerts) {
   542  			continue
   543  		}
   544  
   545  		s.logger.Info("loaded new TLS certificates")
   546  		s.Lock()
   547  		s.certs = newCerts
   548  		s.Unlock()
   549  	}
   550  }
   551  
   552  func (s *localStoreCertProvider) isTLSEnabled() bool {
   553  	return s.tlsSettings != nil || s.workerTLSSettings != nil
   554  }
   555  
   556  func (c *certCache) isEqual(other *certCache) bool {
   557  
   558  	if c == other {
   559  		return true
   560  	}
   561  	if c == nil || other == nil {
   562  		return false
   563  	}
   564  
   565  	if !equalTLSCerts(c.serverCert, other.serverCert) ||
   566  		!equalTLSCerts(c.workerCert, other.workerCert) ||
   567  		!equalX509(c.clientCACerts, other.clientCACerts) ||
   568  		!equalX509(c.serverCACerts, other.serverCACerts) ||
   569  		!equalX509(c.serverCACertsWorker, other.serverCACertsWorker) {
   570  		return false
   571  	}
   572  	return true
   573  }
   574  
   575  func equal(a, b [][]byte) bool {
   576  	if len(a) != len(b) {
   577  		return false
   578  	}
   579  	for i := range a {
   580  		if !bytes.Equal(a[i], b[i]) {
   581  			return false
   582  		}
   583  	}
   584  	return true
   585  }
   586  
   587  func equalX509(a, b []*x509.Certificate) bool {
   588  	if len(a) != len(b) {
   589  		return false
   590  	}
   591  	for i := range a {
   592  		if !a[i].Equal(b[i]) {
   593  			return false
   594  		}
   595  	}
   596  	return true
   597  }
   598  
   599  func equalTLSCerts(a, b *tls.Certificate) bool {
   600  	if a != nil {
   601  		if b == nil || !equal(a.Certificate, b.Certificate) {
   602  			return false
   603  		}
   604  	} else {
   605  		if b != nil {
   606  			return false
   607  		}
   608  	}
   609  	return true
   610  }