go.temporal.io/server@v1.23.0/common/auth/tls_config_helper.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package auth 26 27 import ( 28 "crypto/tls" 29 "crypto/x509" 30 "encoding/base64" 31 "encoding/pem" 32 "errors" 33 "fmt" 34 "os" 35 36 "go.temporal.io/server/common/log" 37 "go.temporal.io/server/common/log/tag" 38 ) 39 40 var ErrTLSConfig = errors.New("unable to config TLS") 41 42 // Helper methods for creating tls.Config structs to ensure MinVersion is 1.3 43 44 func NewEmptyTLSConfig() *tls.Config { 45 return &tls.Config{ 46 MinVersion: tls.VersionTLS12, 47 NextProtos: []string{ 48 "h2", 49 }, 50 } 51 } 52 53 func NewTLSConfigForServer( 54 serverName string, 55 enableHostVerification bool, 56 ) *tls.Config { 57 c := NewEmptyTLSConfig() 58 c.ServerName = serverName 59 c.InsecureSkipVerify = !enableHostVerification 60 return c 61 } 62 63 func NewDynamicTLSClientConfig( 64 getCert func() (*tls.Certificate, error), 65 rootCAs *x509.CertPool, 66 serverName string, 67 enableHostVerification bool, 68 ) *tls.Config { 69 c := NewTLSConfigForServer(serverName, enableHostVerification) 70 71 if getCert != nil { 72 c.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { 73 return getCert() 74 } 75 } 76 c.RootCAs = rootCAs 77 78 return c 79 } 80 81 func NewTLSConfigWithCertsAndCAs( 82 clientAuth tls.ClientAuthType, 83 certificates []tls.Certificate, 84 clientCAs *x509.CertPool, 85 logger log.Logger, 86 ) *tls.Config { 87 c := NewEmptyTLSConfig() 88 c.ClientAuth = clientAuth 89 c.Certificates = certificates 90 c.ClientCAs = clientCAs 91 c.VerifyConnection = func(state tls.ConnectionState) error { 92 logger.Debug("successfully established incoming TLS connection", tag.ServerName(state.ServerName), tag.Name(tlsCN(state))) 93 return nil 94 } 95 return c 96 } 97 98 func tlsCN(state tls.ConnectionState) string { 99 100 if len(state.PeerCertificates) == 0 { 101 return "" 102 } 103 return state.PeerCertificates[0].Subject.CommonName 104 } 105 106 func NewTLSConfig(temporalTls *TLS) (*tls.Config, error) { 107 if temporalTls == nil || !temporalTls.Enabled { 108 return nil, nil 109 } 110 err := validateTemporalTls(temporalTls) 111 if err != nil { 112 return nil, err 113 } 114 115 tlsConfig := &tls.Config{ 116 InsecureSkipVerify: !temporalTls.EnableHostVerification, 117 } 118 if temporalTls.ServerName != "" { 119 tlsConfig.ServerName = temporalTls.ServerName 120 } 121 122 // Load CA cert 123 caCertPool, err := parseCAs(temporalTls) 124 if err != nil { 125 return nil, err 126 } 127 if caCertPool != nil { 128 tlsConfig.RootCAs = caCertPool 129 } 130 131 // Load client cert 132 clientCert, err := parseClientCert(temporalTls) 133 if err != nil { 134 return nil, err 135 } 136 if clientCert != nil { 137 tlsConfig.Certificates = []tls.Certificate{*clientCert} 138 } 139 140 return tlsConfig, nil 141 } 142 143 func validateTemporalTls(temporalTls *TLS) error { 144 if temporalTls.CertData != "" && temporalTls.CertFile != "" { 145 return fmt.Errorf("%w: %s", ErrTLSConfig, "only one of certData or certFile properties should be specified") 146 } 147 148 if temporalTls.KeyData != "" && temporalTls.KeyFile != "" { 149 return fmt.Errorf("%w: %s", ErrTLSConfig, "only one of keyData or keyFile properties should be specified") 150 } 151 152 certProvided := temporalTls.CertData != "" || temporalTls.CertFile != "" 153 keyProvided := temporalTls.KeyData != "" || temporalTls.KeyFile != "" 154 if certProvided != keyProvided { 155 return fmt.Errorf("%w: %s", ErrTLSConfig, "cert or key is missing") 156 } 157 158 if temporalTls.CaData != "" && temporalTls.CaFile != "" { 159 return fmt.Errorf("%w: %s", ErrTLSConfig, "only one of caData or caFile properties should be specified") 160 } 161 return nil 162 } 163 164 func parseCAs(temporalTls *TLS) (*x509.CertPool, error) { 165 var caBytes []byte 166 var err error 167 if temporalTls.CaFile != "" { 168 caBytes, err = os.ReadFile(temporalTls.CaFile) 169 if err != nil { 170 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to read client ca file", err) 171 } 172 } else if temporalTls.CaData != "" { 173 caBytes, err = base64.StdEncoding.DecodeString(temporalTls.CaData) 174 if err != nil { 175 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to decode client ca data", err) 176 } 177 } 178 if len(caBytes) > 0 { 179 caCertPool := x509.NewCertPool() 180 caCerts, err := parseCertsFromPEM(caBytes) 181 if len(caCerts) == 0 { 182 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to parse certs as PEM", err) 183 } 184 for _, cert := range caCerts { 185 caCertPool.AddCert(cert) 186 } 187 if err != nil { 188 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to load decoded CA Cert as PEM", err) 189 } 190 return caCertPool, nil 191 } 192 return nil, nil 193 } 194 195 func parseCertsFromPEM(pemCerts []byte) ([]*x509.Certificate, error) { 196 for len(pemCerts) > 0 { 197 var block *pem.Block 198 block, pemCerts = pem.Decode(pemCerts) 199 if block == nil { 200 break 201 } 202 if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { 203 continue 204 } 205 206 certBytes := block.Bytes 207 return x509.ParseCertificates(certBytes) 208 } 209 return nil, nil 210 } 211 212 func parseClientCert(temporalTls *TLS) (*tls.Certificate, error) { 213 var certBytes []byte 214 var keyBytes []byte 215 var err error 216 if temporalTls.CertFile != "" { 217 certBytes, err = os.ReadFile(temporalTls.CertFile) 218 if err != nil { 219 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to read client certificate file", err) 220 } 221 } else if temporalTls.CertData != "" { 222 certBytes, err = base64.StdEncoding.DecodeString(temporalTls.CertData) 223 if err != nil { 224 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to decode client certificate", err) 225 } 226 } 227 228 if temporalTls.KeyFile != "" { 229 keyBytes, err = os.ReadFile(temporalTls.KeyFile) 230 if err != nil { 231 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to read client certificate private key file", err) 232 } 233 } else if temporalTls.KeyData != "" { 234 keyBytes, err = base64.StdEncoding.DecodeString(temporalTls.KeyData) 235 if err != nil { 236 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to decode client certificate private key", err) 237 } 238 } 239 240 if len(certBytes) > 0 { 241 clientCert, err := tls.X509KeyPair(certBytes, keyBytes) 242 if err != nil { 243 return nil, fmt.Errorf("%w: %s (%w)", ErrTLSConfig, "unable to generate x509 key pair", err) 244 } 245 246 return &clientCert, nil 247 } 248 return nil, nil 249 }