go.temporal.io/server@v1.23.0/common/auth/tls_config_helper.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package auth
    26  
    27  import (
    28  	"crypto/tls"
    29  	"crypto/x509"
    30  	"encoding/base64"
    31  	"encoding/pem"
    32  	"errors"
    33  	"fmt"
    34  	"os"
    35  
    36  	"go.temporal.io/server/common/log"
    37  	"go.temporal.io/server/common/log/tag"
    38  )
    39  
    40  var ErrTLSConfig = errors.New("unable to config TLS")
    41  
    42  // Helper methods for creating tls.Config structs to ensure MinVersion is 1.3
    43  
    44  func NewEmptyTLSConfig() *tls.Config {
    45  	return &tls.Config{
    46  		MinVersion: tls.VersionTLS12,
    47  		NextProtos: []string{
    48  			"h2",
    49  		},
    50  	}
    51  }
    52  
    53  func NewTLSConfigForServer(
    54  	serverName string,
    55  	enableHostVerification bool,
    56  ) *tls.Config {
    57  	c := NewEmptyTLSConfig()
    58  	c.ServerName = serverName
    59  	c.InsecureSkipVerify = !enableHostVerification
    60  	return c
    61  }
    62  
    63  func NewDynamicTLSClientConfig(
    64  	getCert func() (*tls.Certificate, error),
    65  	rootCAs *x509.CertPool,
    66  	serverName string,
    67  	enableHostVerification bool,
    68  ) *tls.Config {
    69  	c := NewTLSConfigForServer(serverName, enableHostVerification)
    70  
    71  	if getCert != nil {
    72  		c.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
    73  			return getCert()
    74  		}
    75  	}
    76  	c.RootCAs = rootCAs
    77  
    78  	return c
    79  }
    80  
    81  func NewTLSConfigWithCertsAndCAs(
    82  	clientAuth tls.ClientAuthType,
    83  	certificates []tls.Certificate,
    84  	clientCAs *x509.CertPool,
    85  	logger log.Logger,
    86  ) *tls.Config {
    87  	c := NewEmptyTLSConfig()
    88  	c.ClientAuth = clientAuth
    89  	c.Certificates = certificates
    90  	c.ClientCAs = clientCAs
    91  	c.VerifyConnection = func(state tls.ConnectionState) error {
    92  		logger.Debug("successfully established incoming TLS connection", tag.ServerName(state.ServerName), tag.Name(tlsCN(state)))
    93  		return nil
    94  	}
    95  	return c
    96  }
    97  
    98  func tlsCN(state tls.ConnectionState) string {
    99  
   100  	if len(state.PeerCertificates) == 0 {
   101  		return ""
   102  	}
   103  	return state.PeerCertificates[0].Subject.CommonName
   104  }
   105  
   106  func NewTLSConfig(temporalTls *TLS) (*tls.Config, error) {
   107  	if temporalTls == nil || !temporalTls.Enabled {
   108  		return nil, nil
   109  	}
   110  	err := validateTemporalTls(temporalTls)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	tlsConfig := &tls.Config{
   116  		InsecureSkipVerify: !temporalTls.EnableHostVerification,
   117  	}
   118  	if temporalTls.ServerName != "" {
   119  		tlsConfig.ServerName = temporalTls.ServerName
   120  	}
   121  
   122  	// Load CA cert
   123  	caCertPool, err := parseCAs(temporalTls)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	if caCertPool != nil {
   128  		tlsConfig.RootCAs = caCertPool
   129  	}
   130  
   131  	// Load client cert
   132  	clientCert, err := parseClientCert(temporalTls)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	if clientCert != nil {
   137  		tlsConfig.Certificates = []tls.Certificate{*clientCert}
   138  	}
   139  
   140  	return tlsConfig, nil
   141  }
   142  
   143  func validateTemporalTls(temporalTls *TLS) error {
   144  	if temporalTls.CertData != "" && temporalTls.CertFile != "" {
   145  		return fmt.Errorf("%w: %s", ErrTLSConfig, "only one of certData or certFile properties should be specified")
   146  	}
   147  
   148  	if temporalTls.KeyData != "" && temporalTls.KeyFile != "" {
   149  		return fmt.Errorf("%w: %s", ErrTLSConfig, "only one of keyData or keyFile properties should be specified")
   150  	}
   151  
   152  	certProvided := temporalTls.CertData != "" || temporalTls.CertFile != ""
   153  	keyProvided := temporalTls.KeyData != "" || temporalTls.KeyFile != ""
   154  	if certProvided != keyProvided {
   155  		return fmt.Errorf("%w: %s", ErrTLSConfig, "cert or key is missing")
   156  	}
   157  
   158  	if temporalTls.CaData != "" && temporalTls.CaFile != "" {
   159  		return fmt.Errorf("%w: %s", ErrTLSConfig, "only one of caData or caFile properties should be specified")
   160  	}
   161  	return nil
   162  }
   163  
   164  func parseCAs(temporalTls *TLS) (*x509.CertPool, error) {
   165  	var caBytes []byte
   166  	var err error
   167  	if temporalTls.CaFile != "" {
   168  		caBytes, err = os.ReadFile(temporalTls.CaFile)
   169  		if err != nil {
   170  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to read client ca file", err)
   171  		}
   172  	} else if temporalTls.CaData != "" {
   173  		caBytes, err = base64.StdEncoding.DecodeString(temporalTls.CaData)
   174  		if err != nil {
   175  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to decode client ca data", err)
   176  		}
   177  	}
   178  	if len(caBytes) > 0 {
   179  		caCertPool := x509.NewCertPool()
   180  		caCerts, err := parseCertsFromPEM(caBytes)
   181  		if len(caCerts) == 0 {
   182  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to parse certs as PEM", err)
   183  		}
   184  		for _, cert := range caCerts {
   185  			caCertPool.AddCert(cert)
   186  		}
   187  		if err != nil {
   188  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to load decoded CA Cert as PEM", err)
   189  		}
   190  		return caCertPool, nil
   191  	}
   192  	return nil, nil
   193  }
   194  
   195  func parseCertsFromPEM(pemCerts []byte) ([]*x509.Certificate, error) {
   196  	for len(pemCerts) > 0 {
   197  		var block *pem.Block
   198  		block, pemCerts = pem.Decode(pemCerts)
   199  		if block == nil {
   200  			break
   201  		}
   202  		if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   203  			continue
   204  		}
   205  
   206  		certBytes := block.Bytes
   207  		return x509.ParseCertificates(certBytes)
   208  	}
   209  	return nil, nil
   210  }
   211  
   212  func parseClientCert(temporalTls *TLS) (*tls.Certificate, error) {
   213  	var certBytes []byte
   214  	var keyBytes []byte
   215  	var err error
   216  	if temporalTls.CertFile != "" {
   217  		certBytes, err = os.ReadFile(temporalTls.CertFile)
   218  		if err != nil {
   219  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to read client certificate file", err)
   220  		}
   221  	} else if temporalTls.CertData != "" {
   222  		certBytes, err = base64.StdEncoding.DecodeString(temporalTls.CertData)
   223  		if err != nil {
   224  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to decode client certificate", err)
   225  		}
   226  	}
   227  
   228  	if temporalTls.KeyFile != "" {
   229  		keyBytes, err = os.ReadFile(temporalTls.KeyFile)
   230  		if err != nil {
   231  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to read client certificate private key file", err)
   232  		}
   233  	} else if temporalTls.KeyData != "" {
   234  		keyBytes, err = base64.StdEncoding.DecodeString(temporalTls.KeyData)
   235  		if err != nil {
   236  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to decode client certificate private key", err)
   237  		}
   238  	}
   239  
   240  	if len(certBytes) > 0 {
   241  		clientCert, err := tls.X509KeyPair(certBytes, keyBytes)
   242  		if err != nil {
   243  			return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to generate x509 key pair", err)
   244  		}
   245  
   246  		return &clientCert, nil
   247  	}
   248  	return nil, nil
   249  }