go.temporal.io/server@v1.23.0/common/auth/tls_config_helper_test.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package auth 26 27 import ( 28 "crypto/tls" 29 "crypto/x509" 30 "encoding/base64" 31 "fmt" 32 "io" 33 "net/http" 34 "net/http/httptest" 35 "os" 36 "testing" 37 38 "github.com/golang/mock/gomock" 39 "github.com/stretchr/testify/assert" 40 ) 41 42 var validBase64CaData, invalidBase64CaData, validBase64Certificate, invalidBase64Certificate, validBase64Key, invalidBase64Key string 43 44 func readFile(path string) string { 45 file, err := os.Open("testdata/" + path) 46 if err != nil { 47 panic(err) 48 } 49 defer func() { 50 if err := file.Close(); err != nil { 51 panic(err) 52 } 53 }() 54 data, err := io.ReadAll(file) 55 if err != nil { 56 panic(err) 57 } 58 return base64.StdEncoding.EncodeToString(data) 59 } 60 61 func init() { 62 validBase64CaData = readFile("ca.crt") 63 invalidBase64CaData = readFile("invalid_ca.crt") 64 validBase64Certificate = readFile("localhost.crt") 65 invalidBase64Certificate = readFile("invalid_localhost.crt") 66 validBase64Key = readFile("localhost.key") 67 invalidBase64Key = readFile("invalid_localhost.key") 68 } 69 70 // test if the input is valid 71 func Test_NewTLSConfig(t *testing.T) { 72 tests := map[string]struct { 73 cfg *TLS 74 cfgErr string 75 }{ 76 "emptyConfig": { 77 cfg: &TLS{}, 78 }, 79 "caData_good": { 80 cfg: &TLS{ 81 Enabled: true, 82 CaData: validBase64CaData, 83 }, 84 }, 85 "caData_badBase64": { 86 cfg: &TLS{Enabled: true, CaData: "this isn't base64"}, 87 cfgErr: "illegal base64 data at input byte", 88 }, 89 "caData_badPEM": { 90 cfg: &TLS{Enabled: true, CaData: "dGhpcyBpc24ndCBhIFBFTSBjZXJ0"}, 91 cfgErr: "unable to parse certs as PEM", 92 }, 93 "clientCert_badbase64cert": { 94 cfg: &TLS{ 95 Enabled: true, 96 CertData: "this ain't base64", 97 KeyData: validBase64Key, 98 }, 99 cfgErr: "illegal base64 data at input byte", 100 }, 101 "clientCert_badbase64key": { 102 cfg: &TLS{ 103 Enabled: true, 104 CertData: validBase64Certificate, 105 KeyData: "this ain't base64", 106 }, 107 cfgErr: "illegal base64 data at input byte", 108 }, 109 "clientCert_missingprivatekey": { 110 cfg: &TLS{ 111 Enabled: true, 112 CertData: validBase64Certificate, 113 KeyData: "", 114 }, 115 cfgErr: "unable to config TLS: cert or key is missing", 116 }, 117 "clientCert_duplicate_cert": { 118 cfg: &TLS{ 119 Enabled: true, 120 CertData: validBase64Certificate, 121 CertFile: "/a/b/c", 122 }, 123 cfgErr: "only one of certData or certFile properties should be specified", 124 }, 125 "clientCert_duplicate_key": { 126 cfg: &TLS{ 127 Enabled: true, 128 KeyData: validBase64Key, 129 KeyFile: "/a/b/c", 130 }, 131 cfgErr: "only one of keyData or keyFile properties should be specified", 132 }, 133 "clientCert_duplicate_ca": { 134 cfg: &TLS{ 135 Enabled: true, 136 CaData: validBase64CaData, 137 CaFile: "/a/b/c", 138 }, 139 cfgErr: "only one of caData or caFile properties should be specified", 140 }, 141 } 142 143 for name, tc := range tests { 144 t.Run(name, func(t *testing.T) { 145 ctrl := gomock.NewController(t) 146 _, err := NewTLSConfig(tc.cfg) 147 if tc.cfgErr != "" { 148 assert.ErrorContains(t, err, tc.cfgErr) 149 } else { 150 assert.NoError(t, err) 151 } 152 153 ctrl.Finish() 154 }) 155 } 156 } 157 158 func Test_ConnectToTLSServerWithCA(t *testing.T) { 159 // setup server 160 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 161 fmt.Fprintln(w, "Hello World") 162 }) 163 ts := httptest.NewUnstartedServer(h) 164 certBytes, err := os.ReadFile("./testdata/localhost.crt") 165 if err != nil { 166 panic(fmt.Errorf("unable to decode certificate %w", err)) 167 } 168 keyBytes, err := os.ReadFile("./testdata/localhost.key") 169 if err != nil { 170 panic(fmt.Errorf("unable to decode key %w", err)) 171 } 172 cert, err := tls.X509KeyPair(certBytes, keyBytes) 173 if err != nil { 174 panic(fmt.Errorf("unable to load certificate %w", err)) 175 } 176 ts.TLS = &tls.Config{ 177 Certificates: []tls.Certificate{cert}, 178 } 179 ts.StartTLS() 180 181 tests := map[string]struct { 182 cfg *TLS 183 connectionErr string 184 }{ 185 "caData_good": { 186 cfg: &TLS{ 187 Enabled: true, 188 CaData: validBase64CaData, 189 }, 190 }, 191 "caData_signedByWrongCA": { 192 cfg: &TLS{ 193 Enabled: true, 194 EnableHostVerification: true, 195 CaData: invalidBase64CaData, 196 }, 197 connectionErr: "x509: certificate signed by unknown authority", 198 }, 199 "caData_signedByWrongCAButNotEnableHostVerification": { 200 cfg: &TLS{ 201 Enabled: true, 202 EnableHostVerification: false, 203 CaData: invalidBase64CaData, 204 }, 205 }, 206 "caFile_good": { 207 cfg: &TLS{ 208 Enabled: true, 209 EnableHostVerification: true, 210 CaFile: "testdata/ca.crt", 211 }, 212 }, 213 "caFile_signedByWrongCA": { 214 cfg: &TLS{ 215 Enabled: true, 216 EnableHostVerification: true, 217 CaFile: "testdata/invalid_ca.crt", 218 }, 219 connectionErr: "x509: certificate signed by unknown authority", 220 }, 221 "caFile_signedByWrongCANotEnableHostVerification": { 222 cfg: &TLS{ 223 Enabled: true, 224 EnableHostVerification: false, 225 CaFile: "testdata/invalid_ca.crt", 226 }, 227 }, 228 "certData_good": { 229 cfg: &TLS{ 230 Enabled: true, 231 EnableHostVerification: true, 232 CaData: validBase64Certificate, 233 }, 234 }, 235 } 236 237 for name, tc := range tests { 238 t.Run(name, func(t *testing.T) { 239 ctrl := gomock.NewController(t) 240 tlsConfig, err := NewTLSConfig(tc.cfg) 241 if err != nil { 242 panic(err) 243 } 244 cl := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} 245 resp, err := cl.Get(ts.URL) 246 if tc.connectionErr != "" { 247 assert.ErrorContains(t, err, tc.connectionErr) 248 } else { 249 assert.NoError(t, err) 250 assert.Equal(t, 200, resp.StatusCode) 251 } 252 253 ctrl.Finish() 254 }) 255 } 256 } 257 258 func Test_ConnectToTLSServerWithClientCertificate(t *testing.T) { 259 // setup server 260 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 261 fmt.Fprintln(w, "Hello World") 262 }) 263 ts := httptest.NewUnstartedServer(h) 264 certBytes, err := os.ReadFile("./testdata/localhost.crt") 265 if err != nil { 266 panic(fmt.Errorf("unable to decode certificate %w", err)) 267 } 268 keyBytes, err := os.ReadFile("./testdata/localhost.key") 269 if err != nil { 270 panic(fmt.Errorf("unable to decode key %w", err)) 271 } 272 cert, err := tls.X509KeyPair(certBytes, keyBytes) 273 if err != nil { 274 panic(fmt.Errorf("unable to load certificate %w", err)) 275 } 276 caBytes, _ := os.ReadFile("testdata/ca.crt") 277 caCertPool := x509.NewCertPool() 278 caCertPool.AppendCertsFromPEM(caBytes) 279 ts.TLS = &tls.Config{ 280 ClientCAs: caCertPool, 281 Certificates: []tls.Certificate{cert}, 282 ClientAuth: tls.RequireAndVerifyClientCert, 283 } 284 ts.StartTLS() 285 286 tests := map[string]struct { 287 cfg *TLS 288 connectionErr string 289 }{ 290 "clientData_good": { 291 cfg: &TLS{ 292 Enabled: true, 293 EnableHostVerification: true, 294 CaData: validBase64CaData, 295 CertData: validBase64Certificate, 296 KeyData: validBase64Key, 297 }, 298 }, 299 "clientData_certNotProvided": { 300 cfg: &TLS{ 301 Enabled: true, 302 EnableHostVerification: true, 303 CaData: validBase64CaData, 304 }, 305 connectionErr: "certificate required", 306 }, 307 "clientData_certInvalid": { 308 cfg: &TLS{ 309 Enabled: true, 310 EnableHostVerification: true, 311 CaData: validBase64CaData, 312 CertData: invalidBase64Certificate, 313 KeyData: invalidBase64Key, 314 }, 315 connectionErr: "certificate required", 316 }, 317 "certFile_good": { 318 cfg: &TLS{ 319 Enabled: true, 320 EnableHostVerification: true, 321 CaData: validBase64CaData, 322 CertFile: "testdata/localhost.crt", 323 KeyFile: "testdata/localhost.key", 324 }, 325 }, 326 "clientFile_certInvalid": { 327 cfg: &TLS{ 328 Enabled: true, 329 EnableHostVerification: true, 330 CaData: validBase64CaData, 331 CertFile: "testdata/invalid_localhost.crt", 332 KeyFile: "testdata/invalid_localhost.key", 333 }, 334 connectionErr: "certificate required", 335 }, 336 } 337 338 for name, tc := range tests { 339 t.Run(name, func(t *testing.T) { 340 ctrl := gomock.NewController(t) 341 tlsConfig, err := NewTLSConfig(tc.cfg) 342 if err != nil { 343 panic(err) 344 } 345 cl := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} 346 resp, err := cl.Get(ts.URL) 347 if tc.connectionErr != "" { 348 assert.ErrorContains(t, err, tc.connectionErr) 349 } else { 350 assert.NoError(t, err) 351 assert.Equal(t, 200, resp.StatusCode) 352 } 353 354 ctrl.Finish() 355 }) 356 } 357 }