go.temporal.io/server@v1.23.0/common/rpc/encryption/local_store_cert_provider.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package encryption 26 27 import ( 28 "bytes" 29 "crypto/md5" 30 "crypto/tls" 31 "crypto/x509" 32 "encoding/base64" 33 "encoding/pem" 34 "errors" 35 "fmt" 36 "os" 37 "sync" 38 "time" 39 40 "go.temporal.io/server/common/config" 41 "go.temporal.io/server/common/log" 42 "go.temporal.io/server/common/log/tag" 43 ) 44 45 var _ CertProvider = (*localStoreCertProvider)(nil) 46 var _ CertExpirationChecker = (*localStoreCertProvider)(nil) 47 48 type certCache struct { 49 serverCert *tls.Certificate 50 workerCert *tls.Certificate 51 clientCAPool *x509.CertPool 52 serverCAPool *x509.CertPool 53 serverCAsWorkerPool *x509.CertPool 54 clientCACerts []*x509.Certificate // copies of certs in the clientCAPool CertPool for expiration checks 55 serverCACerts []*x509.Certificate // copies of certs in the serverCAPool CertPool for expiration checks 56 serverCACertsWorker []*x509.Certificate // copies of certs in the serverCAsWorkerPool CertPool for expiration checks 57 } 58 59 type localStoreCertProvider struct { 60 sync.RWMutex 61 62 tlsSettings *config.GroupTLS 63 workerTLSSettings *config.WorkerTLS 64 isLegacyWorkerConfig bool 65 legacyWorkerSettings *config.ClientTLS 66 67 certs *certCache 68 refreshInterval time.Duration 69 70 ticker *time.Ticker 71 stop chan bool 72 logger log.Logger 73 } 74 75 type loadOrDecodeDataFunc func(item string) ([]byte, error) 76 77 type tlsCertFetcher func() (*tls.Certificate, error) 78 79 func (s *localStoreCertProvider) initialize() { 80 81 if s.refreshInterval != 0 { 82 s.stop = make(chan bool) 83 s.ticker = time.NewTicker(s.refreshInterval) 84 go s.refreshCerts() 85 } 86 } 87 88 func NewLocalStoreCertProvider( 89 tlsSettings *config.GroupTLS, 90 workerTlsSettings *config.WorkerTLS, 91 legacyWorkerSettings *config.ClientTLS, 92 refreshInterval time.Duration, 93 logger log.Logger) CertProvider { 94 95 provider := &localStoreCertProvider{ 96 tlsSettings: tlsSettings, 97 workerTLSSettings: workerTlsSettings, 98 legacyWorkerSettings: legacyWorkerSettings, 99 isLegacyWorkerConfig: legacyWorkerSettings != nil, 100 logger: logger, 101 refreshInterval: refreshInterval, 102 } 103 provider.initialize() 104 return provider 105 } 106 107 func (s *localStoreCertProvider) Close() { 108 109 if s.ticker != nil { 110 s.ticker.Stop() 111 } 112 if s.stop != nil { 113 s.stop <- true 114 close(s.stop) 115 } 116 } 117 118 func (s *localStoreCertProvider) FetchServerCertificate() (*tls.Certificate, error) { 119 120 if s.tlsSettings == nil { 121 return nil, nil 122 } 123 certs, err := s.getCerts() 124 if err != nil { 125 return nil, err 126 } 127 return certs.serverCert, nil 128 } 129 130 func (s *localStoreCertProvider) FetchClientCAs() (*x509.CertPool, error) { 131 132 if s.tlsSettings == nil { 133 return nil, nil 134 } 135 certs, err := s.getCerts() 136 if err != nil { 137 return nil, err 138 } 139 return certs.clientCAPool, nil 140 } 141 142 func (s *localStoreCertProvider) FetchServerRootCAsForClient(isWorker bool) (*x509.CertPool, error) { 143 144 clientSettings := s.getClientTLSSettings(isWorker) 145 if clientSettings == nil { 146 return nil, nil 147 } 148 certs, err := s.getCerts() 149 if err != nil { 150 return nil, err 151 } 152 153 if isWorker { 154 return certs.serverCAsWorkerPool, nil 155 } 156 157 return certs.serverCAPool, nil 158 } 159 160 func (s *localStoreCertProvider) FetchClientCertificate(isWorker bool) (*tls.Certificate, error) { 161 162 if !s.isTLSEnabled() { 163 return nil, nil 164 } 165 certs, err := s.getCerts() 166 if err != nil { 167 return nil, err 168 } 169 if isWorker { 170 return certs.workerCert, nil 171 } 172 return certs.serverCert, nil 173 } 174 175 func (s *localStoreCertProvider) GetExpiringCerts(timeWindow time.Duration, 176 ) (CertExpirationMap, CertExpirationMap, error) { 177 178 expiring := make(CertExpirationMap) 179 expired := make(CertExpirationMap) 180 when := time.Now().UTC().Add(timeWindow) 181 182 certs, err := s.getCerts() 183 if err != nil { 184 return nil, nil, err 185 } 186 187 checkError := checkTLSCertForExpiration(certs.serverCert, when, expiring, expired) 188 err = appendError(err, checkError) 189 checkError = checkTLSCertForExpiration(certs.workerCert, when, expiring, expired) 190 err = appendError(err, checkError) 191 192 checkCertsForExpiration(certs.clientCACerts, when, expiring, expired) 193 checkCertsForExpiration(certs.serverCACerts, when, expiring, expired) 194 checkCertsForExpiration(certs.serverCACertsWorker, when, expiring, expired) 195 196 return expiring, expired, err 197 } 198 199 func (s *localStoreCertProvider) getCerts() (*certCache, error) { 200 201 s.RLock() 202 if s.certs != nil { 203 defer s.RUnlock() 204 return s.certs, nil 205 } 206 s.RUnlock() 207 s.Lock() 208 defer s.Unlock() 209 210 if s.certs != nil { 211 return s.certs, nil 212 } 213 214 newCerts, err := s.loadCerts() 215 if err != nil { 216 return nil, err 217 } 218 219 if newCerts == nil { 220 s.certs = &certCache{} 221 } else { 222 s.certs = newCerts 223 } 224 return s.certs, nil 225 } 226 227 func (s *localStoreCertProvider) loadCerts() (*certCache, error) { 228 229 if !s.isTLSEnabled() { 230 return nil, nil 231 } 232 233 newCerts := certCache{} 234 var err error 235 236 if s.tlsSettings != nil { 237 newCerts.serverCert, err = s.fetchCertificate(s.tlsSettings.Server.CertFile, s.tlsSettings.Server.CertData, 238 s.tlsSettings.Server.KeyFile, s.tlsSettings.Server.KeyData) 239 if err != nil { 240 return nil, err 241 } 242 243 certPool, certs, err := s.fetchCAs(s.tlsSettings.Server.ClientCAFiles, s.tlsSettings.Server.ClientCAData, 244 "cannot specify both clientCAFiles and clientCAData properties") 245 if err != nil { 246 return nil, err 247 } 248 newCerts.clientCAPool = certPool 249 newCerts.clientCACerts = certs 250 } 251 252 if s.isLegacyWorkerConfig { 253 newCerts.workerCert = newCerts.serverCert 254 } else { 255 if s.workerTLSSettings != nil { 256 newCerts.workerCert, err = s.fetchCertificate(s.workerTLSSettings.CertFile, s.workerTLSSettings.CertData, 257 s.workerTLSSettings.KeyFile, s.workerTLSSettings.KeyData) 258 if err != nil { 259 return nil, err 260 } 261 } 262 } 263 264 nonWorkerPool, nonWorkerCerts, err := s.loadServerCACerts(false) 265 if err != nil { 266 return nil, err 267 } 268 newCerts.serverCAPool = nonWorkerPool 269 newCerts.serverCACerts = nonWorkerCerts 270 271 workerPool, workerCerts, err := s.loadServerCACerts(true) 272 if err != nil { 273 return nil, err 274 } 275 newCerts.serverCAsWorkerPool = workerPool 276 newCerts.serverCACertsWorker = workerCerts 277 278 return &newCerts, nil 279 } 280 281 func (s *localStoreCertProvider) fetchCertificate( 282 certFile string, certData string, 283 keyFile string, keyData string) (*tls.Certificate, error) { 284 if certFile == "" && certData == "" { 285 return nil, nil 286 } 287 288 if certFile != "" && certData != "" { 289 return nil, errors.New("only one of certFile or certData properties should be spcified") 290 } 291 292 var certBytes []byte 293 var keyBytes []byte 294 var err error 295 296 if certFile != "" { 297 s.logger.Info("loading certificate from file", tag.TLSCertFile(certFile)) 298 certBytes, err = os.ReadFile(certFile) 299 if err != nil { 300 return nil, err 301 } 302 } else if certData != "" { 303 certBytes, err = base64.StdEncoding.DecodeString(certData) 304 if err != nil { 305 return nil, fmt.Errorf("TLS public certificate could not be decoded: %w", err) 306 } 307 } 308 309 if keyFile != "" { 310 s.logger.Info("loading private key from file", tag.TLSKeyFile(keyFile)) 311 keyBytes, err = os.ReadFile(keyFile) 312 if err != nil { 313 return nil, err 314 } 315 } else if keyData != "" { 316 keyBytes, err = base64.StdEncoding.DecodeString(keyData) 317 if err != nil { 318 return nil, fmt.Errorf("TLS private key could not be decoded: %w", err) 319 } 320 } 321 322 cert, err := tls.X509KeyPair(certBytes, keyBytes) 323 if err != nil { 324 return nil, fmt.Errorf("loading tls certificate failed: %v", err) 325 } 326 327 return &cert, nil 328 } 329 330 func (s *localStoreCertProvider) getClientTLSSettings(isWorker bool) *config.ClientTLS { 331 if isWorker && s.workerTLSSettings != nil { 332 return &s.workerTLSSettings.Client // explicit system worker case 333 } else if isWorker { 334 return s.legacyWorkerSettings // legacy config case when we use Frontend.Client settings 335 } else { 336 if s.tlsSettings == nil { 337 return nil 338 } 339 return &s.tlsSettings.Client // internode client case 340 } 341 } 342 343 func (s *localStoreCertProvider) loadServerCACerts(isWorker bool) (*x509.CertPool, []*x509.Certificate, error) { 344 345 clientSettings := s.getClientTLSSettings(isWorker) 346 if clientSettings == nil { 347 return nil, nil, nil 348 } 349 350 return s.fetchCAs(clientSettings.RootCAFiles, clientSettings.RootCAData, 351 "cannot specify both rootCAFiles and rootCAData properties") 352 } 353 354 func (s *localStoreCertProvider) fetchCAs( 355 files []string, 356 data []string, 357 duplicateErrorMessage string) (*x509.CertPool, []*x509.Certificate, error) { 358 if len(files) == 0 && len(data) == 0 { 359 return nil, nil, nil 360 } 361 362 caPoolFromFiles, caCertsFromFiles, err := s.buildCAPoolFromFiles(files) 363 if err != nil { 364 return nil, nil, err 365 } 366 367 caPoolFromData, caCertsFromData, err := buildCAPoolFromData(data) 368 if err != nil { 369 return nil, nil, err 370 } 371 372 if caPoolFromFiles != nil && caPoolFromData != nil { 373 return nil, nil, errors.New(duplicateErrorMessage) 374 } 375 376 var certPool *x509.CertPool 377 var certs []*x509.Certificate 378 379 if caPoolFromData != nil { 380 certPool = caPoolFromData 381 certs = caCertsFromData 382 } else { 383 certPool = caPoolFromFiles 384 certs = caCertsFromFiles 385 } 386 387 return certPool, certs, nil 388 } 389 390 func checkTLSCertForExpiration( 391 cert *tls.Certificate, 392 when time.Time, 393 expiring CertExpirationMap, 394 expired CertExpirationMap, 395 ) error { 396 397 if cert == nil { 398 return nil 399 } 400 401 x509cert, err := x509.ParseCertificate(cert.Certificate[0]) 402 if err != nil { 403 return err 404 } 405 checkCertForExpiration(x509cert, when, expiring, expired) 406 return nil 407 } 408 409 func checkCertsForExpiration( 410 certs []*x509.Certificate, 411 time time.Time, 412 expiring CertExpirationMap, 413 expired CertExpirationMap, 414 ) { 415 416 for _, cert := range certs { 417 checkCertForExpiration(cert, time, expiring, expired) 418 } 419 } 420 421 func checkCertForExpiration( 422 cert *x509.Certificate, 423 pointInTime time.Time, 424 expiring CertExpirationMap, 425 expired CertExpirationMap, 426 ) { 427 428 if cert != nil && expiresBefore(cert, pointInTime) { 429 record := CertExpirationData{ 430 Thumbprint: md5.Sum(cert.Raw), 431 IsCA: cert.IsCA, 432 DNSNames: cert.DNSNames, 433 Expiration: cert.NotAfter, 434 } 435 if record.Expiration.Before(time.Now().UTC()) { 436 expired[record.Thumbprint] = record 437 } else { 438 expiring[record.Thumbprint] = record 439 } 440 } 441 } 442 443 func expiresBefore(cert *x509.Certificate, pointInTime time.Time) bool { 444 return cert.NotAfter.Before(pointInTime) 445 } 446 447 func buildCAPoolFromData(caData []string) (*x509.CertPool, []*x509.Certificate, error) { 448 449 return buildCAPool(caData, base64.StdEncoding.DecodeString) 450 } 451 452 func (s *localStoreCertProvider) buildCAPoolFromFiles(caFiles []string) (*x509.CertPool, []*x509.Certificate, error) { 453 if len(caFiles) == 0 { 454 return nil, nil, nil 455 } 456 457 s.logger.Info("loading CA certs from", tag.TLSCertFiles(caFiles)) 458 return buildCAPool(caFiles, os.ReadFile) 459 } 460 461 func buildCAPool(cas []string, getBytes loadOrDecodeDataFunc) (*x509.CertPool, []*x509.Certificate, error) { 462 463 var caPool *x509.CertPool 464 var certs []*x509.Certificate 465 466 for _, ca := range cas { 467 if ca == "" { 468 continue 469 } 470 471 caBytes, err := getBytes(ca) 472 if err != nil { 473 return nil, nil, fmt.Errorf("failed to decode ca cert: %w", err) 474 } 475 476 if caPool == nil { 477 caPool = x509.NewCertPool() 478 } 479 if !caPool.AppendCertsFromPEM(caBytes) { 480 return nil, nil, errors.New("unknown failure constructing cert pool for ca") 481 } 482 483 cert, err := parseCert(caBytes) 484 if err != nil { 485 return nil, nil, fmt.Errorf("failed to parse x509 certificate: %w", err) 486 } 487 certs = append(certs, cert) 488 } 489 return caPool, certs, nil 490 } 491 492 // logic borrowed from tls.X509KeyPair() 493 func parseCert(bytes []byte) (*x509.Certificate, error) { 494 495 var certBytes [][]byte 496 for { 497 var certDERBlock *pem.Block 498 certDERBlock, bytes = pem.Decode(bytes) 499 if certDERBlock == nil { 500 break 501 } 502 if certDERBlock.Type == "CERTIFICATE" { 503 certBytes = append(certBytes, certDERBlock.Bytes) 504 } 505 } 506 507 if len(certBytes) == 0 || len(certBytes[0]) == 0 { 508 return nil, fmt.Errorf("failed to decode PEM certificate data") 509 } 510 return x509.ParseCertificate(certBytes[0]) 511 } 512 513 func appendError(aggregatedErr error, err error) error { 514 if aggregatedErr == nil { 515 return err 516 } 517 if err == nil { 518 return aggregatedErr 519 } 520 return fmt.Errorf("%v, %w", aggregatedErr, err) 521 } 522 523 func (s *localStoreCertProvider) refreshCerts() { 524 525 for { 526 select { 527 case <-s.stop: 528 return 529 case <-s.ticker.C: 530 } 531 532 newCerts, err := s.loadCerts() 533 if err != nil { 534 s.logger.Error("failed to load certificates", tag.Error(err)) 535 continue 536 } 537 538 s.RLock() 539 currentCerts := s.certs 540 s.RUnlock() 541 if currentCerts.isEqual(newCerts) { 542 continue 543 } 544 545 s.logger.Info("loaded new TLS certificates") 546 s.Lock() 547 s.certs = newCerts 548 s.Unlock() 549 } 550 } 551 552 func (s *localStoreCertProvider) isTLSEnabled() bool { 553 return s.tlsSettings != nil || s.workerTLSSettings != nil 554 } 555 556 func (c *certCache) isEqual(other *certCache) bool { 557 558 if c == other { 559 return true 560 } 561 if c == nil || other == nil { 562 return false 563 } 564 565 if !equalTLSCerts(c.serverCert, other.serverCert) || 566 !equalTLSCerts(c.workerCert, other.workerCert) || 567 !equalX509(c.clientCACerts, other.clientCACerts) || 568 !equalX509(c.serverCACerts, other.serverCACerts) || 569 !equalX509(c.serverCACertsWorker, other.serverCACertsWorker) { 570 return false 571 } 572 return true 573 } 574 575 func equal(a, b [][]byte) bool { 576 if len(a) != len(b) { 577 return false 578 } 579 for i := range a { 580 if !bytes.Equal(a[i], b[i]) { 581 return false 582 } 583 } 584 return true 585 } 586 587 func equalX509(a, b []*x509.Certificate) bool { 588 if len(a) != len(b) { 589 return false 590 } 591 for i := range a { 592 if !a[i].Equal(b[i]) { 593 return false 594 } 595 } 596 return true 597 } 598 599 func equalTLSCerts(a, b *tls.Certificate) bool { 600 if a != nil { 601 if b == nil || !equal(a.Certificate, b.Certificate) { 602 return false 603 } 604 } else { 605 if b != nil { 606 return false 607 } 608 } 609 return true 610 }