go.temporal.io/server@v1.23.0/common/rpc/encryption/local_store_tls_provider.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package encryption
    26  
    27  import (
    28  	"crypto/tls"
    29  	"crypto/x509"
    30  	"fmt"
    31  	"sync"
    32  	"time"
    33  
    34  	"go.temporal.io/server/common/log/tag"
    35  	"go.temporal.io/server/common/metrics"
    36  
    37  	"go.temporal.io/server/common/auth"
    38  	"go.temporal.io/server/common/config"
    39  	"go.temporal.io/server/common/log"
    40  )
    41  
    42  type CertProviderFactory func(
    43  	tlsSettings *config.GroupTLS,
    44  	workerTlsSettings *config.WorkerTLS,
    45  	legacyWorkerSettings *config.ClientTLS,
    46  	refreshInterval time.Duration,
    47  	logger log.Logger) CertProvider
    48  
    49  type localStoreTlsProvider struct {
    50  	sync.RWMutex
    51  
    52  	settings *config.RootTLS
    53  
    54  	internodeCertProvider           CertProvider
    55  	internodeClientCertProvider     CertProvider
    56  	frontendCertProvider            CertProvider
    57  	workerCertProvider              CertProvider
    58  	remoteClusterClientCertProvider map[string]CertProvider
    59  	frontendPerHostCertProviderMap  *localStorePerHostCertProviderMap
    60  
    61  	cachedInternodeServerConfig     *tls.Config
    62  	cachedInternodeClientConfig     *tls.Config
    63  	cachedFrontendServerConfig      *tls.Config
    64  	cachedFrontendClientConfig      *tls.Config
    65  	cachedRemoteClusterClientConfig map[string]*tls.Config
    66  
    67  	ticker         *time.Ticker
    68  	logger         log.Logger
    69  	stop           chan bool
    70  	metricsHandler metrics.Handler
    71  }
    72  
    73  var _ TLSConfigProvider = (*localStoreTlsProvider)(nil)
    74  var _ CertExpirationChecker = (*localStoreTlsProvider)(nil)
    75  
    76  func NewLocalStoreTlsProvider(tlsConfig *config.RootTLS, metricsHandler metrics.Handler, logger log.Logger, certProviderFactory CertProviderFactory,
    77  ) (TLSConfigProvider, error) {
    78  
    79  	internodeProvider := certProviderFactory(&tlsConfig.Internode, nil, nil, tlsConfig.RefreshInterval, logger)
    80  	var workerProvider CertProvider
    81  	if isSystemWorker(tlsConfig) { // explicit system worker config
    82  		workerProvider = certProviderFactory(nil, &tlsConfig.SystemWorker, nil, tlsConfig.RefreshInterval, logger)
    83  	} else { // legacy implicit system worker config case
    84  		internodeWorkerProvider := certProviderFactory(&tlsConfig.Internode, nil, &tlsConfig.Frontend.Client, tlsConfig.RefreshInterval, logger)
    85  		workerProvider = internodeWorkerProvider
    86  	}
    87  
    88  	remoteClusterClientCertProvider := make(map[string]CertProvider)
    89  	for hostname, groupTLS := range tlsConfig.RemoteClusters {
    90  		remoteClusterClientCertProvider[hostname] = certProviderFactory(&groupTLS, nil, nil, tlsConfig.RefreshInterval, logger)
    91  	}
    92  
    93  	provider := &localStoreTlsProvider{
    94  		internodeCertProvider:       internodeProvider,
    95  		internodeClientCertProvider: internodeProvider,
    96  		frontendCertProvider:        certProviderFactory(&tlsConfig.Frontend, nil, nil, tlsConfig.RefreshInterval, logger),
    97  		workerCertProvider:          workerProvider,
    98  		frontendPerHostCertProviderMap: newLocalStorePerHostCertProviderMap(
    99  			tlsConfig.Frontend.PerHostOverrides, certProviderFactory, tlsConfig.RefreshInterval, logger),
   100  		remoteClusterClientCertProvider: remoteClusterClientCertProvider,
   101  		RWMutex:                         sync.RWMutex{},
   102  		settings:                        tlsConfig,
   103  		metricsHandler:                  metricsHandler,
   104  		logger:                          logger,
   105  		cachedRemoteClusterClientConfig: make(map[string]*tls.Config),
   106  	}
   107  	provider.initialize()
   108  	return provider, nil
   109  }
   110  
   111  func (s *localStoreTlsProvider) initialize() {
   112  	period := s.settings.ExpirationChecks.CheckInterval
   113  	if period != 0 {
   114  		s.stop = make(chan bool)
   115  		s.ticker = time.NewTicker(period)
   116  		s.checkCertExpiration() // perform initial check to emit metrics and logs right away
   117  		go s.timerCallback()
   118  	}
   119  }
   120  
   121  func (s *localStoreTlsProvider) Close() {
   122  
   123  	if s.ticker != nil {
   124  		s.ticker.Stop()
   125  	}
   126  	if s.stop != nil {
   127  		s.stop <- true
   128  		close(s.stop)
   129  	}
   130  }
   131  
   132  func (s *localStoreTlsProvider) GetInternodeClientConfig() (*tls.Config, error) {
   133  
   134  	client := &s.settings.Internode.Client
   135  	return s.getOrCreateConfig(
   136  		&s.cachedInternodeClientConfig,
   137  		func() (*tls.Config, error) {
   138  			return newClientTLSConfig(s.internodeClientCertProvider, client.ServerName,
   139  				s.settings.Internode.Server.RequireClientAuth, false, !client.DisableHostVerification)
   140  		},
   141  		s.settings.Internode.IsClientEnabled(),
   142  	)
   143  }
   144  
   145  func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) {
   146  
   147  	var client *config.ClientTLS
   148  	var useTLS bool
   149  	if isSystemWorker(s.settings) {
   150  		client = &s.settings.SystemWorker.Client
   151  		useTLS = true
   152  	} else {
   153  		client = &s.settings.Frontend.Client
   154  		useTLS = s.settings.Frontend.IsClientEnabled()
   155  	}
   156  	return s.getOrCreateConfig(
   157  		&s.cachedFrontendClientConfig,
   158  		func() (*tls.Config, error) {
   159  			return newClientTLSConfig(s.workerCertProvider, client.ServerName,
   160  				useTLS, true, !client.DisableHostVerification)
   161  		},
   162  		useTLS,
   163  	)
   164  }
   165  
   166  func (s *localStoreTlsProvider) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) {
   167  	groupTLS, ok := s.settings.RemoteClusters[hostname]
   168  	if !ok {
   169  		return nil, nil
   170  	}
   171  
   172  	return s.getOrCreateRemoteClusterClientConfig(
   173  		hostname,
   174  		func() (*tls.Config, error) {
   175  			return newClientTLSConfig(
   176  				s.remoteClusterClientCertProvider[hostname],
   177  				groupTLS.Client.ServerName,
   178  				groupTLS.Server.RequireClientAuth,
   179  				false,
   180  				!groupTLS.Client.DisableHostVerification)
   181  		},
   182  		groupTLS.IsClientEnabled(),
   183  	)
   184  }
   185  
   186  func (s *localStoreTlsProvider) GetFrontendServerConfig() (*tls.Config, error) {
   187  	return s.getOrCreateConfig(
   188  		&s.cachedFrontendServerConfig,
   189  		func() (*tls.Config, error) {
   190  			return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap, &s.settings.Frontend, s.logger)
   191  		},
   192  		s.settings.Frontend.IsServerEnabled())
   193  }
   194  
   195  func (s *localStoreTlsProvider) GetInternodeServerConfig() (*tls.Config, error) {
   196  	return s.getOrCreateConfig(
   197  		&s.cachedInternodeServerConfig,
   198  		func() (*tls.Config, error) {
   199  			return newServerTLSConfig(s.internodeCertProvider, nil, &s.settings.Internode, s.logger)
   200  		},
   201  		s.settings.Internode.IsServerEnabled())
   202  }
   203  
   204  func (s *localStoreTlsProvider) GetExpiringCerts(timeWindow time.Duration,
   205  ) (expiring CertExpirationMap, expired CertExpirationMap, err error) {
   206  
   207  	expiring = make(CertExpirationMap, 0)
   208  	expired = make(CertExpirationMap, 0)
   209  
   210  	checkError := checkExpiration(s.internodeCertProvider, timeWindow, expiring, expired)
   211  	err = appendError(err, checkError)
   212  	checkError = checkExpiration(s.frontendCertProvider, timeWindow, expiring, expired)
   213  	err = appendError(err, checkError)
   214  	checkError = checkExpiration(s.workerCertProvider, timeWindow, expiring, expired)
   215  	err = appendError(err, checkError)
   216  	checkError = checkExpiration(s.frontendPerHostCertProviderMap, timeWindow, expiring, expired)
   217  	err = appendError(err, checkError)
   218  
   219  	return expiring, expired, err
   220  }
   221  
   222  func checkExpiration(
   223  	provider CertExpirationChecker,
   224  	timeWindow time.Duration,
   225  	expiring CertExpirationMap,
   226  	expired CertExpirationMap,
   227  ) error {
   228  
   229  	providerExpiring, providerExpired, err := provider.GetExpiringCerts(timeWindow)
   230  	mergeMaps(expiring, providerExpiring)
   231  	mergeMaps(expired, providerExpired)
   232  	return err
   233  }
   234  
   235  func (s *localStoreTlsProvider) getOrCreateConfig(
   236  	cachedConfig **tls.Config,
   237  	configConstructor tlsConfigConstructor,
   238  	isEnabled bool,
   239  ) (*tls.Config, error) {
   240  	if !isEnabled {
   241  		return nil, nil
   242  	}
   243  
   244  	// Check if exists under a read lock first
   245  	s.RLock()
   246  	if *cachedConfig != nil {
   247  		defer s.RUnlock()
   248  		return *cachedConfig, nil
   249  	}
   250  	// Not found, promote to write lock to initialize
   251  	s.RUnlock()
   252  	s.Lock()
   253  	defer s.Unlock()
   254  	// Check if someone got here first while waiting for write lock
   255  	if *cachedConfig != nil {
   256  		return *cachedConfig, nil
   257  	}
   258  
   259  	// Load configuration
   260  	localConfig, err := configConstructor()
   261  
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	*cachedConfig = localConfig
   267  	return *cachedConfig, nil
   268  }
   269  
   270  func (s *localStoreTlsProvider) getOrCreateRemoteClusterClientConfig(
   271  	hostname string,
   272  	configConstructor tlsConfigConstructor,
   273  	isEnabled bool,
   274  ) (*tls.Config, error) {
   275  	if !isEnabled {
   276  		return nil, nil
   277  	}
   278  
   279  	// Check if exists under a read lock first
   280  	s.RLock()
   281  	if clientConfig, ok := s.cachedRemoteClusterClientConfig[hostname]; ok {
   282  		defer s.RUnlock()
   283  		return clientConfig, nil
   284  	}
   285  	// Not found, promote to write lock to initialize
   286  	s.RUnlock()
   287  	s.Lock()
   288  	defer s.Unlock()
   289  	// Check if someone got here first while waiting for write lock
   290  	if clientConfig, ok := s.cachedRemoteClusterClientConfig[hostname]; ok {
   291  		return clientConfig, nil
   292  	}
   293  
   294  	// Load configuration
   295  	localConfig, err := configConstructor()
   296  
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	s.cachedRemoteClusterClientConfig[hostname] = localConfig
   302  	return localConfig, nil
   303  }
   304  
   305  func newServerTLSConfig(
   306  	certProvider CertProvider,
   307  	perHostCertProviderMap PerHostCertProviderMap,
   308  	config *config.GroupTLS,
   309  	logger log.Logger,
   310  ) (*tls.Config, error) {
   311  
   312  	clientAuthRequired := config.Server.RequireClientAuth
   313  	tlsConfig, err := getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired, "", "", logger)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  
   318  	tlsConfig.GetConfigForClient = func(c *tls.ClientHelloInfo) (*tls.Config, error) {
   319  
   320  		remoteAddress := c.Conn.RemoteAddr().String()
   321  		logger.Debug("attempted incoming TLS connection", tag.Address(remoteAddress), tag.ServerName(c.ServerName))
   322  
   323  		if perHostCertProviderMap != nil && perHostCertProviderMap.NumberOfHosts() > 0 {
   324  			perHostCertProvider, hostClientAuthRequired, err := perHostCertProviderMap.GetCertProvider(c.ServerName)
   325  			if err != nil {
   326  				logger.Error("error while looking up per-host provider for attempted incoming TLS connection",
   327  					tag.ServerName(c.ServerName), tag.Address(remoteAddress), tag.Error(err))
   328  				return nil, err
   329  			}
   330  
   331  			if perHostCertProvider != nil {
   332  				return getServerTLSConfigFromCertProvider(perHostCertProvider, hostClientAuthRequired, remoteAddress, c.ServerName, logger)
   333  			}
   334  			logger.Warn("cannot find a per-host provider for attempted incoming TLS connection. returning default TLS configuration",
   335  				tag.ServerName(c.ServerName), tag.Address(remoteAddress))
   336  			return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired, remoteAddress, c.ServerName, logger)
   337  		}
   338  		return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired, remoteAddress, c.ServerName, logger)
   339  	}
   340  
   341  	return tlsConfig, nil
   342  }
   343  
   344  func getServerTLSConfigFromCertProvider(
   345  	certProvider CertProvider,
   346  	requireClientAuth bool,
   347  	remoteAddress string,
   348  	serverName string,
   349  	logger log.Logger) (*tls.Config, error) {
   350  
   351  	// Get serverCert from disk
   352  	serverCert, err := certProvider.FetchServerCertificate()
   353  	if err != nil {
   354  		return nil, fmt.Errorf("loading server tls certificate failed: %v", err)
   355  	}
   356  
   357  	// tls disabled, responsibility of cert provider above to error otherwise
   358  	if serverCert == nil {
   359  		return nil, nil
   360  	}
   361  
   362  	// Default to NoClientAuth
   363  	clientAuthType := tls.NoClientCert
   364  	var clientCaPool *x509.CertPool
   365  
   366  	// If mTLS enabled
   367  	if requireClientAuth {
   368  		clientAuthType = tls.RequireAndVerifyClientCert
   369  
   370  		ca, err := certProvider.FetchClientCAs()
   371  		if err != nil {
   372  			return nil, fmt.Errorf("failed to fetch client CAs: %v", err)
   373  		}
   374  
   375  		clientCaPool = ca
   376  	}
   377  	if remoteAddress != "" { // remoteAddress=="" when we return initial tls.Config object when configuring server
   378  		logger.Debug("returning TLS config for connection", tag.Address(remoteAddress), tag.ServerName(serverName))
   379  	}
   380  	return auth.NewTLSConfigWithCertsAndCAs(
   381  		clientAuthType,
   382  		[]tls.Certificate{*serverCert},
   383  		clientCaPool,
   384  		logger), nil
   385  }
   386  
   387  func newClientTLSConfig(
   388  	clientProvider CertProvider,
   389  	serverName string,
   390  	isAuthRequired bool,
   391  	isWorker bool,
   392  	enableHostVerification bool,
   393  ) (*tls.Config, error) {
   394  	// Optional ServerCA for client if not already trusted by host
   395  	serverCa, err := clientProvider.FetchServerRootCAsForClient(isWorker)
   396  	if err != nil {
   397  		return nil, fmt.Errorf("failed to load client ca: %v", err)
   398  	}
   399  
   400  	var getCert tlsCertFetcher
   401  
   402  	// mTLS enabled, present certificate
   403  	if isAuthRequired {
   404  		getCert = func() (*tls.Certificate, error) {
   405  			cert, err := clientProvider.FetchClientCertificate(isWorker)
   406  			if err != nil {
   407  				return nil, err
   408  			}
   409  
   410  			if cert == nil {
   411  				return nil, fmt.Errorf("client auth required, but no certificate provided")
   412  			}
   413  			return cert, nil
   414  		}
   415  	}
   416  
   417  	return auth.NewDynamicTLSClientConfig(
   418  		getCert,
   419  		serverCa,
   420  		serverName,
   421  		enableHostVerification,
   422  	), nil
   423  }
   424  
   425  func (s *localStoreTlsProvider) timerCallback() {
   426  	for {
   427  		select {
   428  		case <-s.stop:
   429  			return
   430  		case <-s.ticker.C:
   431  		}
   432  
   433  		s.checkCertExpiration()
   434  	}
   435  }
   436  
   437  func (s *localStoreTlsProvider) checkCertExpiration() {
   438  	var retError error
   439  	defer log.CapturePanic(s.logger, &retError)
   440  
   441  	var errorTime time.Time
   442  	if s.settings.ExpirationChecks.ErrorWindow != 0 {
   443  		errorTime = time.Now().UTC().Add(s.settings.ExpirationChecks.ErrorWindow)
   444  	} else {
   445  		errorTime = time.Now().UTC().AddDate(10, 0, 0)
   446  	}
   447  
   448  	window := s.settings.ExpirationChecks.WarningWindow
   449  	// if only ErrorWindow is set, we set WarningWindow to the same value, so that the checks do happen
   450  	if window == 0 && s.settings.ExpirationChecks.ErrorWindow != 0 {
   451  		window = s.settings.ExpirationChecks.ErrorWindow
   452  	}
   453  	if window != 0 {
   454  		expiring, expired, err := s.GetExpiringCerts(window)
   455  		if err != nil {
   456  			s.logger.Error(fmt.Sprintf("error while checking for certificate expiration: %v", err))
   457  			return
   458  		}
   459  		if s.metricsHandler != nil {
   460  			s.metricsHandler.Gauge(metrics.TlsCertsExpired.Name()).Record(float64(len(expired)))
   461  			s.metricsHandler.Gauge(metrics.TlsCertsExpiring.Name()).Record(float64(len(expiring)))
   462  		}
   463  		s.logCerts(expired, true, errorTime)
   464  		s.logCerts(expiring, false, errorTime)
   465  	}
   466  }
   467  
   468  func (s *localStoreTlsProvider) logCerts(certs CertExpirationMap, expired bool, errorTime time.Time) {
   469  
   470  	for _, cert := range certs {
   471  		str := createExpirationLogMessage(cert, expired)
   472  		if expired || cert.Expiration.Before(errorTime) {
   473  			s.logger.Error(str)
   474  		} else {
   475  			s.logger.Warn(str)
   476  		}
   477  	}
   478  }
   479  
   480  func createExpirationLogMessage(cert CertExpirationData, expired bool) string {
   481  
   482  	var verb string
   483  	if expired {
   484  		verb = "has expired"
   485  	} else {
   486  		verb = "will expire"
   487  	}
   488  	return fmt.Sprintf("certificate with thumbprint=%x %s on %v, IsCA=%t, DNS=%v",
   489  		cert.Thumbprint, verb, cert.Expiration, cert.IsCA, cert.DNSNames)
   490  }
   491  
   492  func mergeMaps(to CertExpirationMap, from CertExpirationMap) {
   493  	for k, v := range from {
   494  		to[k] = v
   495  	}
   496  }
   497  
   498  func isSystemWorker(tls *config.RootTLS) bool {
   499  	return tls.SystemWorker.CertData != "" || tls.SystemWorker.CertFile != "" ||
   500  		len(tls.SystemWorker.Client.RootCAData) > 0 || len(tls.SystemWorker.Client.RootCAFiles) > 0 ||
   501  		tls.SystemWorker.Client.ForceTLS
   502  }