vitess.io/vitess@v0.16.2/go/vt/tlstest/tlstest_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tlstest 18 19 import ( 20 "crypto/tls" 21 "crypto/x509" 22 "fmt" 23 "io" 24 "net" 25 "strings" 26 "sync" 27 "testing" 28 "time" 29 30 "github.com/stretchr/testify/assert" 31 32 "vitess.io/vitess/go/vt/vttls" 33 ) 34 35 func TestClientServerWithoutCombineCerts(t *testing.T) { 36 testClientServer(t, false) 37 } 38 39 func TestClientServerWithCombineCerts(t *testing.T) { 40 testClientServer(t, true) 41 } 42 43 // testClientServer generates: 44 // - a root CA 45 // - a server intermediate CA, with a server. 46 // - a client intermediate CA, with a client. 47 // And then performs a few tests on them. 48 func testClientServer(t *testing.T, combineCerts bool) { 49 // Our test root. 50 root := t.TempDir() 51 52 clientServerKeyPairs := CreateClientServerCertPairs(root) 53 serverCA := "" 54 55 if combineCerts { 56 serverCA = clientServerKeyPairs.ServerCA 57 } 58 59 serverConfig, err := vttls.ServerConfig( 60 clientServerKeyPairs.ServerCert, 61 clientServerKeyPairs.ServerKey, 62 clientServerKeyPairs.ClientCA, 63 clientServerKeyPairs.ClientCRL, 64 serverCA, 65 tls.VersionTLS12) 66 if err != nil { 67 t.Fatalf("TLSServerConfig failed: %v", err) 68 } 69 clientConfig, err := vttls.ClientConfig( 70 vttls.VerifyIdentity, 71 clientServerKeyPairs.ClientCert, 72 clientServerKeyPairs.ClientKey, 73 clientServerKeyPairs.ServerCA, 74 clientServerKeyPairs.ServerCRL, 75 clientServerKeyPairs.ServerName, 76 tls.VersionTLS12) 77 if err != nil { 78 t.Fatalf("TLSClientConfig failed: %v", err) 79 } 80 81 // Create a TLS server listener. 82 listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) 83 if err != nil { 84 t.Fatalf("Listen failed: %v", err) 85 } 86 addr := listener.Addr().String() 87 defer listener.Close() 88 // create a dialer with timeout 89 dialer := new(net.Dialer) 90 dialer.Timeout = 10 * time.Second 91 92 wg := sync.WaitGroup{} 93 94 // 95 // Positive case: accept on server side, connect a client, send data. 96 // 97 var clientErr error 98 wg.Add(1) 99 go func() { 100 defer wg.Done() 101 clientConn, clientErr := tls.DialWithDialer(dialer, "tcp", addr, clientConfig) 102 if clientErr == nil { 103 _, _ = clientConn.Write([]byte{42}) 104 clientConn.Close() 105 } 106 }() 107 108 serverConn, err := listener.Accept() 109 if err != nil { 110 t.Fatalf("Accept failed: %v", err) 111 } 112 113 result := make([]byte, 1) 114 if n, err := serverConn.Read(result); (err != nil && err != io.EOF) || n != 1 { 115 t.Fatalf("Read failed: %v %v", n, err) 116 } 117 if result[0] != 42 { 118 t.Fatalf("Read returned wrong result: %v", result) 119 } 120 serverConn.Close() 121 122 wg.Wait() 123 124 if clientErr != nil { 125 t.Fatalf("Dial failed: %v", clientErr) 126 } 127 128 // 129 // Negative case: connect a client with wrong cert (using the 130 // server cert on the client side). 131 // 132 133 badClientConfig, err := vttls.ClientConfig( 134 vttls.VerifyIdentity, 135 clientServerKeyPairs.ServerCert, 136 clientServerKeyPairs.ServerKey, 137 clientServerKeyPairs.ServerCA, 138 clientServerKeyPairs.ServerCRL, 139 clientServerKeyPairs.ServerName, 140 tls.VersionTLS12) 141 if err != nil { 142 t.Fatalf("TLSClientConfig failed: %v", err) 143 } 144 145 var serverErr error 146 wg.Add(1) 147 go func() { 148 // We expect the Accept to work, but the first read to fail. 149 defer wg.Done() 150 serverConn, serverErr := listener.Accept() 151 // This will fail. 152 if serverErr == nil { 153 result := make([]byte, 1) 154 if n, err := serverConn.Read(result); err == nil { 155 fmt.Printf("Was able to read from server: %v\n", n) 156 } 157 serverConn.Close() 158 } 159 }() 160 161 // When using TLS 1.2, the Dial will fail. 162 // With TLS 1.3, the Dial will succeed and the first Read will fail. 163 clientConn, err := tls.DialWithDialer(dialer, "tcp", addr, badClientConfig) 164 if err != nil { 165 if !strings.Contains(err.Error(), "bad certificate") { 166 t.Errorf("Wrong error returned: %v", err) 167 } 168 return 169 } 170 wg.Wait() 171 if serverErr != nil { 172 t.Fatalf("Connection failed: %v", serverErr) 173 } 174 175 data := make([]byte, 1) 176 _, err = clientConn.Read(data) 177 if err == nil { 178 t.Fatalf("Dial or first Read was expected to fail") 179 } 180 if !strings.Contains(err.Error(), "bad certificate") { 181 t.Errorf("Wrong error returned: %v", err) 182 } 183 } 184 185 func getServerConfigWithoutCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) { 186 return vttls.ServerConfig( 187 keypairs.ServerCert, 188 keypairs.ServerKey, 189 keypairs.ClientCA, 190 keypairs.ClientCRL, 191 "", 192 tls.VersionTLS12) 193 } 194 195 func getServerConfigWithCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) { 196 return vttls.ServerConfig( 197 keypairs.ServerCert, 198 keypairs.ServerKey, 199 keypairs.ClientCA, 200 keypairs.ClientCRL, 201 keypairs.ServerCA, 202 tls.VersionTLS12) 203 } 204 205 func getClientConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) { 206 return vttls.ClientConfig( 207 vttls.VerifyIdentity, 208 keypairs.ClientCert, 209 keypairs.ClientKey, 210 keypairs.ServerCA, 211 keypairs.ServerCRL, 212 keypairs.ServerName, 213 tls.VersionTLS12) 214 } 215 216 func testServerTLSConfigCaching(t *testing.T, getServerConfig func(ClientServerKeyPairs) (*tls.Config, error)) { 217 testConfigGeneration(t, "servertlstest", getServerConfig, func(config *tls.Config) *x509.CertPool { 218 return config.ClientCAs 219 }) 220 } 221 222 func TestServerTLSConfigCachingWithoutCombinedCerts(t *testing.T) { 223 testServerTLSConfigCaching(t, getServerConfigWithoutCombinedCerts) 224 } 225 226 func TestServerTLSConfigCachingWithCombinedCerts(t *testing.T) { 227 testServerTLSConfigCaching(t, getServerConfigWithCombinedCerts) 228 } 229 230 func TestClientTLSConfigCaching(t *testing.T) { 231 testConfigGeneration(t, "clienttlstest", getClientConfig, func(config *tls.Config) *x509.CertPool { 232 return config.RootCAs 233 }) 234 } 235 236 func testConfigGeneration(t *testing.T, rootPrefix string, generateConfig func(ClientServerKeyPairs) (*tls.Config, error), getCertPool func(tlsConfig *tls.Config) *x509.CertPool) { 237 // Our test root. 238 root := t.TempDir() 239 240 const configsToGenerate = 1 241 242 firstClientServerKeyPairs := CreateClientServerCertPairs(root) 243 secondClientServerKeyPairs := CreateClientServerCertPairs(root) 244 245 firstExpectedConfig, _ := generateConfig(firstClientServerKeyPairs) 246 secondExpectedConfig, _ := generateConfig(secondClientServerKeyPairs) 247 firstConfigChannel := make(chan *tls.Config, configsToGenerate) 248 secondConfigChannel := make(chan *tls.Config, configsToGenerate) 249 250 var configCounter = 0 251 252 for i := 1; i <= configsToGenerate; i++ { 253 go func() { 254 firstConfig, _ := generateConfig(firstClientServerKeyPairs) 255 firstConfigChannel <- firstConfig 256 secondConfig, _ := generateConfig(secondClientServerKeyPairs) 257 secondConfigChannel <- secondConfig 258 }() 259 } 260 261 for { 262 select { 263 case firstConfig := <-firstConfigChannel: 264 assert.Equal(t, &firstExpectedConfig.Certificates, &firstConfig.Certificates) 265 assert.Equal(t, getCertPool(firstExpectedConfig), getCertPool(firstConfig)) 266 case secondConfig := <-secondConfigChannel: 267 assert.Equal(t, &secondExpectedConfig.Certificates, &secondConfig.Certificates) 268 assert.Equal(t, getCertPool(secondExpectedConfig), getCertPool(secondConfig)) 269 } 270 configCounter = configCounter + 1 271 272 if configCounter >= 2*configsToGenerate { 273 break 274 } 275 } 276 277 } 278 279 func testNumberOfCertsWithOrWithoutCombining(t *testing.T, numCertsExpected int, combine bool) { 280 // Our test root. 281 root := t.TempDir() 282 283 clientServerKeyPairs := CreateClientServerCertPairs(root) 284 serverCA := "" 285 if combine { 286 serverCA = clientServerKeyPairs.ServerCA 287 } 288 289 serverConfig, err := vttls.ServerConfig( 290 clientServerKeyPairs.ServerCert, 291 clientServerKeyPairs.ServerKey, 292 clientServerKeyPairs.ClientCA, 293 clientServerKeyPairs.ClientCRL, 294 serverCA, 295 tls.VersionTLS12) 296 297 if err != nil { 298 t.Fatalf("TLSServerConfig failed: %v", err) 299 } 300 assert.Equal(t, numCertsExpected, len(serverConfig.Certificates[0].Certificate)) 301 } 302 303 func TestNumberOfCertsWithoutCombining(t *testing.T) { 304 testNumberOfCertsWithOrWithoutCombining(t, 1, false) 305 } 306 307 func TestNumberOfCertsWithCombining(t *testing.T) { 308 testNumberOfCertsWithOrWithoutCombining(t, 2, true) 309 } 310 311 func assertTLSHandshakeFails(t *testing.T, serverConfig, clientConfig *tls.Config) { 312 // Create a TLS server listener. 313 listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) 314 if err != nil { 315 t.Fatalf("Listen failed: %v", err) 316 } 317 addr := listener.Addr().String() 318 defer listener.Close() 319 // create a dialer with timeout 320 dialer := new(net.Dialer) 321 dialer.Timeout = 10 * time.Second 322 323 wg := sync.WaitGroup{} 324 325 var clientErr error 326 wg.Add(1) 327 go func() { 328 defer wg.Done() 329 var clientConn *tls.Conn 330 clientConn, clientErr = tls.DialWithDialer(dialer, "tcp", addr, clientConfig) 331 if clientErr == nil { 332 clientConn.Close() 333 } 334 }() 335 336 serverConn, err := listener.Accept() 337 if err != nil { 338 // We should always be able to accept on the socket 339 t.Fatalf("Accept failed: %v", err) 340 } 341 342 err = serverConn.(*tls.Conn).Handshake() 343 if err != nil { 344 if !(strings.Contains(err.Error(), "Certificate revoked: CommonName=") || 345 strings.Contains(err.Error(), "remote error: tls: bad certificate")) { 346 t.Fatalf("Wrong error returned: %v", err) 347 } 348 } else { 349 t.Fatal("Server should have failed the TLS handshake but it did not") 350 } 351 serverConn.Close() 352 wg.Wait() 353 } 354 355 func TestClientServerWithRevokedServerCert(t *testing.T) { 356 root := t.TempDir() 357 358 clientServerKeyPairs := CreateClientServerCertPairs(root) 359 360 serverConfig, err := vttls.ServerConfig( 361 clientServerKeyPairs.RevokedServerCert, 362 clientServerKeyPairs.RevokedServerKey, 363 clientServerKeyPairs.ClientCA, 364 clientServerKeyPairs.ClientCRL, 365 "", 366 tls.VersionTLS12) 367 if err != nil { 368 t.Fatalf("TLSServerConfig failed: %v", err) 369 } 370 371 clientConfig, err := vttls.ClientConfig( 372 vttls.VerifyIdentity, 373 clientServerKeyPairs.ClientCert, 374 clientServerKeyPairs.ClientKey, 375 clientServerKeyPairs.ServerCA, 376 clientServerKeyPairs.ServerCRL, 377 clientServerKeyPairs.RevokedServerName, 378 tls.VersionTLS12) 379 if err != nil { 380 t.Fatalf("TLSClientConfig failed: %v", err) 381 } 382 383 assertTLSHandshakeFails(t, serverConfig, clientConfig) 384 385 serverConfig, err = vttls.ServerConfig( 386 clientServerKeyPairs.RevokedServerCert, 387 clientServerKeyPairs.RevokedServerKey, 388 clientServerKeyPairs.ClientCA, 389 clientServerKeyPairs.CombinedCRL, 390 "", 391 tls.VersionTLS12) 392 if err != nil { 393 t.Fatalf("TLSServerConfig failed: %v", err) 394 } 395 396 clientConfig, err = vttls.ClientConfig( 397 vttls.VerifyIdentity, 398 clientServerKeyPairs.ClientCert, 399 clientServerKeyPairs.ClientKey, 400 clientServerKeyPairs.ServerCA, 401 clientServerKeyPairs.CombinedCRL, 402 clientServerKeyPairs.RevokedServerName, 403 tls.VersionTLS12) 404 if err != nil { 405 t.Fatalf("TLSClientConfig failed: %v", err) 406 } 407 408 assertTLSHandshakeFails(t, serverConfig, clientConfig) 409 } 410 411 func TestClientServerWithRevokedClientCert(t *testing.T) { 412 root := t.TempDir() 413 414 clientServerKeyPairs := CreateClientServerCertPairs(root) 415 416 // Single CRL 417 418 serverConfig, err := vttls.ServerConfig( 419 clientServerKeyPairs.ServerCert, 420 clientServerKeyPairs.ServerKey, 421 clientServerKeyPairs.ClientCA, 422 clientServerKeyPairs.ClientCRL, 423 "", 424 tls.VersionTLS12) 425 if err != nil { 426 t.Fatalf("TLSServerConfig failed: %v", err) 427 } 428 429 clientConfig, err := vttls.ClientConfig( 430 vttls.VerifyIdentity, 431 clientServerKeyPairs.RevokedClientCert, 432 clientServerKeyPairs.RevokedClientKey, 433 clientServerKeyPairs.ServerCA, 434 clientServerKeyPairs.ServerCRL, 435 clientServerKeyPairs.ServerName, 436 tls.VersionTLS12) 437 if err != nil { 438 t.Fatalf("TLSClientConfig failed: %v", err) 439 } 440 441 assertTLSHandshakeFails(t, serverConfig, clientConfig) 442 443 // CombinedCRL 444 445 serverConfig, err = vttls.ServerConfig( 446 clientServerKeyPairs.ServerCert, 447 clientServerKeyPairs.ServerKey, 448 clientServerKeyPairs.ClientCA, 449 clientServerKeyPairs.CombinedCRL, 450 "", 451 tls.VersionTLS12) 452 if err != nil { 453 t.Fatalf("TLSServerConfig failed: %v", err) 454 } 455 456 clientConfig, err = vttls.ClientConfig( 457 vttls.VerifyIdentity, 458 clientServerKeyPairs.RevokedClientCert, 459 clientServerKeyPairs.RevokedClientKey, 460 clientServerKeyPairs.ServerCA, 461 clientServerKeyPairs.CombinedCRL, 462 clientServerKeyPairs.ServerName, 463 tls.VersionTLS12) 464 if err != nil { 465 t.Fatalf("TLSClientConfig failed: %v", err) 466 } 467 468 assertTLSHandshakeFails(t, serverConfig, clientConfig) 469 }