go.temporal.io/server@v1.23.0/common/rpc/encryption/local_store_tls_provider.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package encryption 26 27 import ( 28 "crypto/tls" 29 "crypto/x509" 30 "fmt" 31 "sync" 32 "time" 33 34 "go.temporal.io/server/common/log/tag" 35 "go.temporal.io/server/common/metrics" 36 37 "go.temporal.io/server/common/auth" 38 "go.temporal.io/server/common/config" 39 "go.temporal.io/server/common/log" 40 ) 41 42 type CertProviderFactory func( 43 tlsSettings *config.GroupTLS, 44 workerTlsSettings *config.WorkerTLS, 45 legacyWorkerSettings *config.ClientTLS, 46 refreshInterval time.Duration, 47 logger log.Logger) CertProvider 48 49 type localStoreTlsProvider struct { 50 sync.RWMutex 51 52 settings *config.RootTLS 53 54 internodeCertProvider CertProvider 55 internodeClientCertProvider CertProvider 56 frontendCertProvider CertProvider 57 workerCertProvider CertProvider 58 remoteClusterClientCertProvider map[string]CertProvider 59 frontendPerHostCertProviderMap *localStorePerHostCertProviderMap 60 61 cachedInternodeServerConfig *tls.Config 62 cachedInternodeClientConfig *tls.Config 63 cachedFrontendServerConfig *tls.Config 64 cachedFrontendClientConfig *tls.Config 65 cachedRemoteClusterClientConfig map[string]*tls.Config 66 67 ticker *time.Ticker 68 logger log.Logger 69 stop chan bool 70 metricsHandler metrics.Handler 71 } 72 73 var _ TLSConfigProvider = (*localStoreTlsProvider)(nil) 74 var _ CertExpirationChecker = (*localStoreTlsProvider)(nil) 75 76 func NewLocalStoreTlsProvider(tlsConfig *config.RootTLS, metricsHandler metrics.Handler, logger log.Logger, certProviderFactory CertProviderFactory, 77 ) (TLSConfigProvider, error) { 78 79 internodeProvider := certProviderFactory(&tlsConfig.Internode, nil, nil, tlsConfig.RefreshInterval, logger) 80 var workerProvider CertProvider 81 if isSystemWorker(tlsConfig) { // explicit system worker config 82 workerProvider = certProviderFactory(nil, &tlsConfig.SystemWorker, nil, tlsConfig.RefreshInterval, logger) 83 } else { // legacy implicit system worker config case 84 internodeWorkerProvider := certProviderFactory(&tlsConfig.Internode, nil, &tlsConfig.Frontend.Client, tlsConfig.RefreshInterval, logger) 85 workerProvider = internodeWorkerProvider 86 } 87 88 remoteClusterClientCertProvider := make(map[string]CertProvider) 89 for hostname, groupTLS := range tlsConfig.RemoteClusters { 90 remoteClusterClientCertProvider[hostname] = certProviderFactory(&groupTLS, nil, nil, tlsConfig.RefreshInterval, logger) 91 } 92 93 provider := &localStoreTlsProvider{ 94 internodeCertProvider: internodeProvider, 95 internodeClientCertProvider: internodeProvider, 96 frontendCertProvider: certProviderFactory(&tlsConfig.Frontend, nil, nil, tlsConfig.RefreshInterval, logger), 97 workerCertProvider: workerProvider, 98 frontendPerHostCertProviderMap: newLocalStorePerHostCertProviderMap( 99 tlsConfig.Frontend.PerHostOverrides, certProviderFactory, tlsConfig.RefreshInterval, logger), 100 remoteClusterClientCertProvider: remoteClusterClientCertProvider, 101 RWMutex: sync.RWMutex{}, 102 settings: tlsConfig, 103 metricsHandler: metricsHandler, 104 logger: logger, 105 cachedRemoteClusterClientConfig: make(map[string]*tls.Config), 106 } 107 provider.initialize() 108 return provider, nil 109 } 110 111 func (s *localStoreTlsProvider) initialize() { 112 period := s.settings.ExpirationChecks.CheckInterval 113 if period != 0 { 114 s.stop = make(chan bool) 115 s.ticker = time.NewTicker(period) 116 s.checkCertExpiration() // perform initial check to emit metrics and logs right away 117 go s.timerCallback() 118 } 119 } 120 121 func (s *localStoreTlsProvider) Close() { 122 123 if s.ticker != nil { 124 s.ticker.Stop() 125 } 126 if s.stop != nil { 127 s.stop <- true 128 close(s.stop) 129 } 130 } 131 132 func (s *localStoreTlsProvider) GetInternodeClientConfig() (*tls.Config, error) { 133 134 client := &s.settings.Internode.Client 135 return s.getOrCreateConfig( 136 &s.cachedInternodeClientConfig, 137 func() (*tls.Config, error) { 138 return newClientTLSConfig(s.internodeClientCertProvider, client.ServerName, 139 s.settings.Internode.Server.RequireClientAuth, false, !client.DisableHostVerification) 140 }, 141 s.settings.Internode.IsClientEnabled(), 142 ) 143 } 144 145 func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) { 146 147 var client *config.ClientTLS 148 var useTLS bool 149 if isSystemWorker(s.settings) { 150 client = &s.settings.SystemWorker.Client 151 useTLS = true 152 } else { 153 client = &s.settings.Frontend.Client 154 useTLS = s.settings.Frontend.IsClientEnabled() 155 } 156 return s.getOrCreateConfig( 157 &s.cachedFrontendClientConfig, 158 func() (*tls.Config, error) { 159 return newClientTLSConfig(s.workerCertProvider, client.ServerName, 160 useTLS, true, !client.DisableHostVerification) 161 }, 162 useTLS, 163 ) 164 } 165 166 func (s *localStoreTlsProvider) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) { 167 groupTLS, ok := s.settings.RemoteClusters[hostname] 168 if !ok { 169 return nil, nil 170 } 171 172 return s.getOrCreateRemoteClusterClientConfig( 173 hostname, 174 func() (*tls.Config, error) { 175 return newClientTLSConfig( 176 s.remoteClusterClientCertProvider[hostname], 177 groupTLS.Client.ServerName, 178 groupTLS.Server.RequireClientAuth, 179 false, 180 !groupTLS.Client.DisableHostVerification) 181 }, 182 groupTLS.IsClientEnabled(), 183 ) 184 } 185 186 func (s *localStoreTlsProvider) GetFrontendServerConfig() (*tls.Config, error) { 187 return s.getOrCreateConfig( 188 &s.cachedFrontendServerConfig, 189 func() (*tls.Config, error) { 190 return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap, &s.settings.Frontend, s.logger) 191 }, 192 s.settings.Frontend.IsServerEnabled()) 193 } 194 195 func (s *localStoreTlsProvider) GetInternodeServerConfig() (*tls.Config, error) { 196 return s.getOrCreateConfig( 197 &s.cachedInternodeServerConfig, 198 func() (*tls.Config, error) { 199 return newServerTLSConfig(s.internodeCertProvider, nil, &s.settings.Internode, s.logger) 200 }, 201 s.settings.Internode.IsServerEnabled()) 202 } 203 204 func (s *localStoreTlsProvider) GetExpiringCerts(timeWindow time.Duration, 205 ) (expiring CertExpirationMap, expired CertExpirationMap, err error) { 206 207 expiring = make(CertExpirationMap, 0) 208 expired = make(CertExpirationMap, 0) 209 210 checkError := checkExpiration(s.internodeCertProvider, timeWindow, expiring, expired) 211 err = appendError(err, checkError) 212 checkError = checkExpiration(s.frontendCertProvider, timeWindow, expiring, expired) 213 err = appendError(err, checkError) 214 checkError = checkExpiration(s.workerCertProvider, timeWindow, expiring, expired) 215 err = appendError(err, checkError) 216 checkError = checkExpiration(s.frontendPerHostCertProviderMap, timeWindow, expiring, expired) 217 err = appendError(err, checkError) 218 219 return expiring, expired, err 220 } 221 222 func checkExpiration( 223 provider CertExpirationChecker, 224 timeWindow time.Duration, 225 expiring CertExpirationMap, 226 expired CertExpirationMap, 227 ) error { 228 229 providerExpiring, providerExpired, err := provider.GetExpiringCerts(timeWindow) 230 mergeMaps(expiring, providerExpiring) 231 mergeMaps(expired, providerExpired) 232 return err 233 } 234 235 func (s *localStoreTlsProvider) getOrCreateConfig( 236 cachedConfig **tls.Config, 237 configConstructor tlsConfigConstructor, 238 isEnabled bool, 239 ) (*tls.Config, error) { 240 if !isEnabled { 241 return nil, nil 242 } 243 244 // Check if exists under a read lock first 245 s.RLock() 246 if *cachedConfig != nil { 247 defer s.RUnlock() 248 return *cachedConfig, nil 249 } 250 // Not found, promote to write lock to initialize 251 s.RUnlock() 252 s.Lock() 253 defer s.Unlock() 254 // Check if someone got here first while waiting for write lock 255 if *cachedConfig != nil { 256 return *cachedConfig, nil 257 } 258 259 // Load configuration 260 localConfig, err := configConstructor() 261 262 if err != nil { 263 return nil, err 264 } 265 266 *cachedConfig = localConfig 267 return *cachedConfig, nil 268 } 269 270 func (s *localStoreTlsProvider) getOrCreateRemoteClusterClientConfig( 271 hostname string, 272 configConstructor tlsConfigConstructor, 273 isEnabled bool, 274 ) (*tls.Config, error) { 275 if !isEnabled { 276 return nil, nil 277 } 278 279 // Check if exists under a read lock first 280 s.RLock() 281 if clientConfig, ok := s.cachedRemoteClusterClientConfig[hostname]; ok { 282 defer s.RUnlock() 283 return clientConfig, nil 284 } 285 // Not found, promote to write lock to initialize 286 s.RUnlock() 287 s.Lock() 288 defer s.Unlock() 289 // Check if someone got here first while waiting for write lock 290 if clientConfig, ok := s.cachedRemoteClusterClientConfig[hostname]; ok { 291 return clientConfig, nil 292 } 293 294 // Load configuration 295 localConfig, err := configConstructor() 296 297 if err != nil { 298 return nil, err 299 } 300 301 s.cachedRemoteClusterClientConfig[hostname] = localConfig 302 return localConfig, nil 303 } 304 305 func newServerTLSConfig( 306 certProvider CertProvider, 307 perHostCertProviderMap PerHostCertProviderMap, 308 config *config.GroupTLS, 309 logger log.Logger, 310 ) (*tls.Config, error) { 311 312 clientAuthRequired := config.Server.RequireClientAuth 313 tlsConfig, err := getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired, "", "", logger) 314 if err != nil { 315 return nil, err 316 } 317 318 tlsConfig.GetConfigForClient = func(c *tls.ClientHelloInfo) (*tls.Config, error) { 319 320 remoteAddress := c.Conn.RemoteAddr().String() 321 logger.Debug("attempted incoming TLS connection", tag.Address(remoteAddress), tag.ServerName(c.ServerName)) 322 323 if perHostCertProviderMap != nil && perHostCertProviderMap.NumberOfHosts() > 0 { 324 perHostCertProvider, hostClientAuthRequired, err := perHostCertProviderMap.GetCertProvider(c.ServerName) 325 if err != nil { 326 logger.Error("error while looking up per-host provider for attempted incoming TLS connection", 327 tag.ServerName(c.ServerName), tag.Address(remoteAddress), tag.Error(err)) 328 return nil, err 329 } 330 331 if perHostCertProvider != nil { 332 return getServerTLSConfigFromCertProvider(perHostCertProvider, hostClientAuthRequired, remoteAddress, c.ServerName, logger) 333 } 334 logger.Warn("cannot find a per-host provider for attempted incoming TLS connection. returning default TLS configuration", 335 tag.ServerName(c.ServerName), tag.Address(remoteAddress)) 336 return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired, remoteAddress, c.ServerName, logger) 337 } 338 return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired, remoteAddress, c.ServerName, logger) 339 } 340 341 return tlsConfig, nil 342 } 343 344 func getServerTLSConfigFromCertProvider( 345 certProvider CertProvider, 346 requireClientAuth bool, 347 remoteAddress string, 348 serverName string, 349 logger log.Logger) (*tls.Config, error) { 350 351 // Get serverCert from disk 352 serverCert, err := certProvider.FetchServerCertificate() 353 if err != nil { 354 return nil, fmt.Errorf("loading server tls certificate failed: %v", err) 355 } 356 357 // tls disabled, responsibility of cert provider above to error otherwise 358 if serverCert == nil { 359 return nil, nil 360 } 361 362 // Default to NoClientAuth 363 clientAuthType := tls.NoClientCert 364 var clientCaPool *x509.CertPool 365 366 // If mTLS enabled 367 if requireClientAuth { 368 clientAuthType = tls.RequireAndVerifyClientCert 369 370 ca, err := certProvider.FetchClientCAs() 371 if err != nil { 372 return nil, fmt.Errorf("failed to fetch client CAs: %v", err) 373 } 374 375 clientCaPool = ca 376 } 377 if remoteAddress != "" { // remoteAddress=="" when we return initial tls.Config object when configuring server 378 logger.Debug("returning TLS config for connection", tag.Address(remoteAddress), tag.ServerName(serverName)) 379 } 380 return auth.NewTLSConfigWithCertsAndCAs( 381 clientAuthType, 382 []tls.Certificate{*serverCert}, 383 clientCaPool, 384 logger), nil 385 } 386 387 func newClientTLSConfig( 388 clientProvider CertProvider, 389 serverName string, 390 isAuthRequired bool, 391 isWorker bool, 392 enableHostVerification bool, 393 ) (*tls.Config, error) { 394 // Optional ServerCA for client if not already trusted by host 395 serverCa, err := clientProvider.FetchServerRootCAsForClient(isWorker) 396 if err != nil { 397 return nil, fmt.Errorf("failed to load client ca: %v", err) 398 } 399 400 var getCert tlsCertFetcher 401 402 // mTLS enabled, present certificate 403 if isAuthRequired { 404 getCert = func() (*tls.Certificate, error) { 405 cert, err := clientProvider.FetchClientCertificate(isWorker) 406 if err != nil { 407 return nil, err 408 } 409 410 if cert == nil { 411 return nil, fmt.Errorf("client auth required, but no certificate provided") 412 } 413 return cert, nil 414 } 415 } 416 417 return auth.NewDynamicTLSClientConfig( 418 getCert, 419 serverCa, 420 serverName, 421 enableHostVerification, 422 ), nil 423 } 424 425 func (s *localStoreTlsProvider) timerCallback() { 426 for { 427 select { 428 case <-s.stop: 429 return 430 case <-s.ticker.C: 431 } 432 433 s.checkCertExpiration() 434 } 435 } 436 437 func (s *localStoreTlsProvider) checkCertExpiration() { 438 var retError error 439 defer log.CapturePanic(s.logger, &retError) 440 441 var errorTime time.Time 442 if s.settings.ExpirationChecks.ErrorWindow != 0 { 443 errorTime = time.Now().UTC().Add(s.settings.ExpirationChecks.ErrorWindow) 444 } else { 445 errorTime = time.Now().UTC().AddDate(10, 0, 0) 446 } 447 448 window := s.settings.ExpirationChecks.WarningWindow 449 // if only ErrorWindow is set, we set WarningWindow to the same value, so that the checks do happen 450 if window == 0 && s.settings.ExpirationChecks.ErrorWindow != 0 { 451 window = s.settings.ExpirationChecks.ErrorWindow 452 } 453 if window != 0 { 454 expiring, expired, err := s.GetExpiringCerts(window) 455 if err != nil { 456 s.logger.Error(fmt.Sprintf("error while checking for certificate expiration: %v", err)) 457 return 458 } 459 if s.metricsHandler != nil { 460 s.metricsHandler.Gauge(metrics.TlsCertsExpired.Name()).Record(float64(len(expired))) 461 s.metricsHandler.Gauge(metrics.TlsCertsExpiring.Name()).Record(float64(len(expiring))) 462 } 463 s.logCerts(expired, true, errorTime) 464 s.logCerts(expiring, false, errorTime) 465 } 466 } 467 468 func (s *localStoreTlsProvider) logCerts(certs CertExpirationMap, expired bool, errorTime time.Time) { 469 470 for _, cert := range certs { 471 str := createExpirationLogMessage(cert, expired) 472 if expired || cert.Expiration.Before(errorTime) { 473 s.logger.Error(str) 474 } else { 475 s.logger.Warn(str) 476 } 477 } 478 } 479 480 func createExpirationLogMessage(cert CertExpirationData, expired bool) string { 481 482 var verb string 483 if expired { 484 verb = "has expired" 485 } else { 486 verb = "will expire" 487 } 488 return fmt.Sprintf("certificate with thumbprint=%x %s on %v, IsCA=%t, DNS=%v", 489 cert.Thumbprint, verb, cert.Expiration, cert.IsCA, cert.DNSNames) 490 } 491 492 func mergeMaps(to CertExpirationMap, from CertExpirationMap) { 493 for k, v := range from { 494 to[k] = v 495 } 496 } 497 498 func isSystemWorker(tls *config.RootTLS) bool { 499 return tls.SystemWorker.CertData != "" || tls.SystemWorker.CertFile != "" || 500 len(tls.SystemWorker.Client.RootCAData) > 0 || len(tls.SystemWorker.Client.RootCAFiles) > 0 || 501 tls.SystemWorker.Client.ForceTLS 502 }