github.com/3andne/restls-client-go@v0.1.6/tls_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "context" 10 "crypto" 11 "crypto/x509" 12 "encoding/json" 13 "errors" 14 "fmt" 15 "io" 16 "math" 17 "net" 18 "os" 19 "reflect" 20 "sort" 21 "strings" 22 "sync/atomic" 23 "testing" 24 "time" 25 26 "github.com/3andne/restls-client-go/testenv" 27 ) 28 29 var rsaCertPEM = `-----BEGIN CERTIFICATE----- 30 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV 31 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX 32 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF 33 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 34 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ 35 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa 36 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv 37 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF 38 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW 39 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V 40 -----END CERTIFICATE----- 41 ` 42 43 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY----- 44 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo 45 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G 46 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 47 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW 48 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T 49 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi 50 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== 51 -----END RSA TESTING KEY----- 52 `) 53 54 // keyPEM is the same as rsaKeyPEM, but declares itself as just 55 // "PRIVATE KEY", not "RSA PRIVATE KEY". https://golang.org/issue/4477 56 var keyPEM = testingKey(`-----BEGIN TESTING KEY----- 57 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo 58 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G 59 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 60 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW 61 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T 62 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi 63 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== 64 -----END TESTING KEY----- 65 `) 66 67 var ecdsaCertPEM = `-----BEGIN CERTIFICATE----- 68 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw 69 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 70 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG 71 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk 72 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR 73 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl 74 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8 75 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo 76 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb 77 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1 78 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA== 79 -----END CERTIFICATE----- 80 ` 81 82 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS----- 83 BgUrgQQAIw== 84 -----END EC PARAMETERS----- 85 -----BEGIN EC TESTING KEY----- 86 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0 87 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL 88 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz 89 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q 90 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ== 91 -----END EC TESTING KEY----- 92 `) 93 94 var keyPairTests = []struct { 95 algo string 96 cert string 97 key string 98 }{ 99 {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM}, 100 {"RSA", rsaCertPEM, rsaKeyPEM}, 101 {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477 102 } 103 104 func TestX509KeyPair(t *testing.T) { 105 t.Parallel() 106 var pem []byte 107 for _, test := range keyPairTests { 108 pem = []byte(test.cert + test.key) 109 if _, err := X509KeyPair(pem, pem); err != nil { 110 t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err) 111 } 112 pem = []byte(test.key + test.cert) 113 if _, err := X509KeyPair(pem, pem); err != nil { 114 t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err) 115 } 116 } 117 } 118 119 func TestX509KeyPairErrors(t *testing.T) { 120 _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM)) 121 if err == nil { 122 t.Fatalf("X509KeyPair didn't return an error when arguments were switched") 123 } 124 if subStr := "been switched"; !strings.Contains(err.Error(), subStr) { 125 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err) 126 } 127 128 _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM)) 129 if err == nil { 130 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates") 131 } 132 if subStr := "certificate"; !strings.Contains(err.Error(), subStr) { 133 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err) 134 } 135 136 const nonsensePEM = ` 137 -----BEGIN NONSENSE----- 138 Zm9vZm9vZm9v 139 -----END NONSENSE----- 140 ` 141 142 _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM)) 143 if err == nil { 144 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense") 145 } 146 if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) { 147 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err) 148 } 149 } 150 151 func TestX509MixedKeyPair(t *testing.T) { 152 if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil { 153 t.Error("Load of RSA certificate succeeded with ECDSA private key") 154 } 155 if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil { 156 t.Error("Load of ECDSA certificate succeeded with RSA private key") 157 } 158 } 159 160 func newLocalListener(t testing.TB) net.Listener { 161 ln, err := net.Listen("tcp", "127.0.0.1:0") 162 if err != nil { 163 ln, err = net.Listen("tcp6", "[::1]:0") 164 } 165 if err != nil { 166 t.Fatal(err) 167 } 168 return ln 169 } 170 171 func TestDialTimeout(t *testing.T) { 172 if testing.Short() { 173 t.Skip("skipping in short mode") 174 } 175 176 timeout := 100 * time.Microsecond 177 for !t.Failed() { 178 acceptc := make(chan net.Conn) 179 listener := newLocalListener(t) 180 go func() { 181 for { 182 conn, err := listener.Accept() 183 if err != nil { 184 close(acceptc) 185 return 186 } 187 acceptc <- conn 188 } 189 }() 190 191 addr := listener.Addr().String() 192 dialer := &net.Dialer{ 193 Timeout: timeout, 194 } 195 if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil { 196 conn.Close() 197 t.Errorf("DialWithTimeout unexpectedly completed successfully") 198 } else if !isTimeoutError(err) { 199 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) 200 } 201 202 listener.Close() 203 204 // We're looking for a timeout during the handshake, so check that the 205 // Listener actually accepted the connection to initiate it. (If the server 206 // takes too long to accept the connection, we might cancel before the 207 // underlying net.Conn is ever dialed — without ever attempting a 208 // handshake.) 209 lconn, ok := <-acceptc 210 if ok { 211 // The Listener accepted a connection, so assume that it was from our 212 // Dial: we triggered the timeout at the point where we wanted it! 213 t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr()) 214 lconn.Close() 215 } 216 // Close any spurious extra connecitions from the listener. (This is 217 // possible if there are, for example, stray Dial calls from other tests.) 218 for extraConn := range acceptc { 219 t.Logf("spurious extra connection from %s", extraConn.RemoteAddr()) 220 extraConn.Close() 221 } 222 if ok { 223 break 224 } 225 226 t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout) 227 timeout *= 2 228 } 229 } 230 231 func TestDeadlineOnWrite(t *testing.T) { 232 if testing.Short() { 233 t.Skip("skipping in short mode") 234 } 235 236 ln := newLocalListener(t) 237 defer ln.Close() 238 239 srvCh := make(chan *Conn, 1) 240 241 go func() { 242 sconn, err := ln.Accept() 243 if err != nil { 244 srvCh <- nil 245 return 246 } 247 srv := Server(sconn, testConfig.Clone()) 248 if err := srv.Handshake(); err != nil { 249 srvCh <- nil 250 return 251 } 252 srvCh <- srv 253 }() 254 255 clientConfig := testConfig.Clone() 256 clientConfig.MaxVersion = VersionTLS12 257 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 258 if err != nil { 259 t.Fatal(err) 260 } 261 defer conn.Close() 262 263 srv := <-srvCh 264 if srv == nil { 265 t.Error(err) 266 } 267 268 // Make sure the client/server is setup correctly and is able to do a typical Write/Read 269 buf := make([]byte, 6) 270 if _, err := srv.Write([]byte("foobar")); err != nil { 271 t.Errorf("Write err: %v", err) 272 } 273 if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" { 274 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) 275 } 276 277 // Set a deadline which should cause Write to timeout 278 if err = srv.SetDeadline(time.Now()); err != nil { 279 t.Fatalf("SetDeadline(time.Now()) err: %v", err) 280 } 281 if _, err = srv.Write([]byte("should fail")); err == nil { 282 t.Fatal("Write should have timed out") 283 } 284 285 // Clear deadline and make sure it still times out 286 if err = srv.SetDeadline(time.Time{}); err != nil { 287 t.Fatalf("SetDeadline(time.Time{}) err: %v", err) 288 } 289 if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil { 290 t.Fatal("Write which previously failed should still time out") 291 } 292 293 // Verify the error 294 if ne := err.(net.Error); ne.Temporary() != false { 295 t.Error("Write timed out but incorrectly classified the error as Temporary") 296 } 297 if !isTimeoutError(err) { 298 t.Error("Write timed out but did not classify the error as a Timeout") 299 } 300 } 301 302 type readerFunc func([]byte) (int, error) 303 304 func (f readerFunc) Read(b []byte) (int, error) { return f(b) } 305 306 // TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake. 307 // (The other cases are all handled by the existing dial tests in this package, which 308 // all also flow through the same code shared code paths) 309 func TestDialer(t *testing.T) { 310 ln := newLocalListener(t) 311 defer ln.Close() 312 313 unblockServer := make(chan struct{}) // close-only 314 defer close(unblockServer) 315 go func() { 316 conn, err := ln.Accept() 317 if err != nil { 318 return 319 } 320 defer conn.Close() 321 <-unblockServer 322 }() 323 324 ctx, cancel := context.WithCancel(context.Background()) 325 d := Dialer{Config: &Config{ 326 Rand: readerFunc(func(b []byte) (n int, err error) { 327 // By the time crypto/tls wants randomness, that means it has a TCP 328 // connection, so we're past the Dialer's dial and now blocked 329 // in a handshake. Cancel our context and see if we get unstuck. 330 // (Our TCP listener above never reads or writes, so the Handshake 331 // would otherwise be stuck forever) 332 cancel() 333 return len(b), nil 334 }), 335 ServerName: "foo", 336 }} 337 _, err := d.DialContext(ctx, "tcp", ln.Addr().String()) 338 if err != context.Canceled { 339 t.Errorf("err = %v; want context.Canceled", err) 340 } 341 } 342 343 func isTimeoutError(err error) bool { 344 if ne, ok := err.(net.Error); ok { 345 return ne.Timeout() 346 } 347 return false 348 } 349 350 // tests that Conn.Read returns (non-zero, io.EOF) instead of 351 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right 352 // behind the application data in the buffer. 353 func TestConnReadNonzeroAndEOF(t *testing.T) { 354 // This test is racy: it assumes that after a write to a 355 // localhost TCP connection, the peer TCP connection can 356 // immediately read it. Because it's racy, we skip this test 357 // in short mode, and then retry it several times with an 358 // increasing sleep in between our final write (via srv.Close 359 // below) and the following read. 360 if testing.Short() { 361 t.Skip("skipping in short mode") 362 } 363 var err error 364 for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 { 365 if err = testConnReadNonzeroAndEOF(t, delay); err == nil { 366 return 367 } 368 } 369 t.Error(err) 370 } 371 372 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { 373 ln := newLocalListener(t) 374 defer ln.Close() 375 376 srvCh := make(chan *Conn, 1) 377 var serr error 378 go func() { 379 sconn, err := ln.Accept() 380 if err != nil { 381 serr = err 382 srvCh <- nil 383 return 384 } 385 serverConfig := testConfig.Clone() 386 srv := Server(sconn, serverConfig) 387 if err := srv.Handshake(); err != nil { 388 serr = fmt.Errorf("handshake: %v", err) 389 srvCh <- nil 390 return 391 } 392 srvCh <- srv 393 }() 394 395 clientConfig := testConfig.Clone() 396 // In TLS 1.3, alerts are encrypted and disguised as application data, so 397 // the opportunistic peek won't work. 398 clientConfig.MaxVersion = VersionTLS12 399 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 400 if err != nil { 401 t.Fatal(err) 402 } 403 defer conn.Close() 404 405 srv := <-srvCh 406 if srv == nil { 407 return serr 408 } 409 410 buf := make([]byte, 6) 411 412 srv.Write([]byte("foobar")) 413 n, err := conn.Read(buf) 414 if n != 6 || err != nil || string(buf) != "foobar" { 415 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) 416 } 417 418 srv.Write([]byte("abcdef")) 419 srv.Close() 420 time.Sleep(delay) 421 n, err = conn.Read(buf) 422 if n != 6 || string(buf) != "abcdef" { 423 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf) 424 } 425 if err != io.EOF { 426 return fmt.Errorf("Second Read error = %v; want io.EOF", err) 427 } 428 return nil 429 } 430 431 func TestTLSUniqueMatches(t *testing.T) { 432 ln := newLocalListener(t) 433 defer ln.Close() 434 435 serverTLSUniques := make(chan []byte) 436 parentDone := make(chan struct{}) 437 childDone := make(chan struct{}) 438 defer close(parentDone) 439 go func() { 440 defer close(childDone) 441 for i := 0; i < 2; i++ { 442 sconn, err := ln.Accept() 443 if err != nil { 444 t.Error(err) 445 return 446 } 447 serverConfig := testConfig.Clone() 448 serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3 449 srv := Server(sconn, serverConfig) 450 if err := srv.Handshake(); err != nil { 451 t.Error(err) 452 return 453 } 454 select { 455 case <-parentDone: 456 return 457 case serverTLSUniques <- srv.ConnectionState().TLSUnique: 458 } 459 } 460 }() 461 462 clientConfig := testConfig.Clone() 463 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) 464 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 465 if err != nil { 466 t.Fatal(err) 467 } 468 469 var serverTLSUniquesValue []byte 470 select { 471 case <-childDone: 472 return 473 case serverTLSUniquesValue = <-serverTLSUniques: 474 } 475 476 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { 477 t.Error("client and server channel bindings differ") 478 } 479 if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) { 480 t.Error("tls-unique is empty or zero") 481 } 482 conn.Close() 483 484 conn, err = Dial("tcp", ln.Addr().String(), clientConfig) 485 if err != nil { 486 t.Fatal(err) 487 } 488 defer conn.Close() 489 if !conn.ConnectionState().DidResume { 490 t.Error("second session did not use resumption") 491 } 492 493 select { 494 case <-childDone: 495 return 496 case serverTLSUniquesValue = <-serverTLSUniques: 497 } 498 499 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { 500 t.Error("client and server channel bindings differ when session resumption is used") 501 } 502 if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) { 503 t.Error("resumption tls-unique is empty or zero") 504 } 505 } 506 507 func TestVerifyHostname(t *testing.T) { 508 testenv.MustHaveExternalNetwork(t) 509 510 c, err := Dial("tcp", "www.google.com:https", nil) 511 if err != nil { 512 t.Fatal(err) 513 } 514 if err := c.VerifyHostname("www.google.com"); err != nil { 515 t.Fatalf("verify www.google.com: %v", err) 516 } 517 if err := c.VerifyHostname("www.yahoo.com"); err == nil { 518 t.Fatalf("verify www.yahoo.com succeeded") 519 } 520 521 c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true}) 522 if err != nil { 523 t.Fatal(err) 524 } 525 if err := c.VerifyHostname("www.google.com"); err == nil { 526 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true") 527 } 528 } 529 530 func TestConnCloseBreakingWrite(t *testing.T) { 531 ln := newLocalListener(t) 532 defer ln.Close() 533 534 srvCh := make(chan *Conn, 1) 535 var serr error 536 var sconn net.Conn 537 go func() { 538 var err error 539 sconn, err = ln.Accept() 540 if err != nil { 541 serr = err 542 srvCh <- nil 543 return 544 } 545 serverConfig := testConfig.Clone() 546 srv := Server(sconn, serverConfig) 547 if err := srv.Handshake(); err != nil { 548 serr = fmt.Errorf("handshake: %v", err) 549 srvCh <- nil 550 return 551 } 552 srvCh <- srv 553 }() 554 555 cconn, err := net.Dial("tcp", ln.Addr().String()) 556 if err != nil { 557 t.Fatal(err) 558 } 559 defer cconn.Close() 560 561 conn := &changeImplConn{ 562 Conn: cconn, 563 } 564 565 clientConfig := testConfig.Clone() 566 tconn := Client(conn, clientConfig) 567 if err := tconn.Handshake(); err != nil { 568 t.Fatal(err) 569 } 570 571 srv := <-srvCh 572 if srv == nil { 573 t.Fatal(serr) 574 } 575 defer sconn.Close() 576 577 connClosed := make(chan struct{}) 578 conn.closeFunc = func() error { 579 close(connClosed) 580 return nil 581 } 582 583 inWrite := make(chan bool, 1) 584 var errConnClosed = errors.New("conn closed for test") 585 conn.writeFunc = func(p []byte) (n int, err error) { 586 inWrite <- true 587 <-connClosed 588 return 0, errConnClosed 589 } 590 591 closeReturned := make(chan bool, 1) 592 go func() { 593 <-inWrite 594 tconn.Close() // test that this doesn't block forever. 595 closeReturned <- true 596 }() 597 598 _, err = tconn.Write([]byte("foo")) 599 if err != errConnClosed { 600 t.Errorf("Write error = %v; want errConnClosed", err) 601 } 602 603 <-closeReturned 604 if err := tconn.Close(); err != net.ErrClosed { 605 t.Errorf("Close error = %v; want net.ErrClosed", err) 606 } 607 } 608 609 func TestConnCloseWrite(t *testing.T) { 610 ln := newLocalListener(t) 611 defer ln.Close() 612 613 clientDoneChan := make(chan struct{}) 614 615 serverCloseWrite := func() error { 616 sconn, err := ln.Accept() 617 if err != nil { 618 return fmt.Errorf("accept: %v", err) 619 } 620 defer sconn.Close() 621 622 serverConfig := testConfig.Clone() 623 srv := Server(sconn, serverConfig) 624 if err := srv.Handshake(); err != nil { 625 return fmt.Errorf("handshake: %v", err) 626 } 627 defer srv.Close() 628 629 data, err := io.ReadAll(srv) 630 if err != nil { 631 return err 632 } 633 if len(data) > 0 { 634 return fmt.Errorf("Read data = %q; want nothing", data) 635 } 636 637 if err := srv.CloseWrite(); err != nil { 638 return fmt.Errorf("server CloseWrite: %v", err) 639 } 640 641 // Wait for clientCloseWrite to finish, so we know we 642 // tested the CloseWrite before we defer the 643 // sconn.Close above, which would also cause the 644 // client to unblock like CloseWrite. 645 <-clientDoneChan 646 return nil 647 } 648 649 clientCloseWrite := func() error { 650 defer close(clientDoneChan) 651 652 clientConfig := testConfig.Clone() 653 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 654 if err != nil { 655 return err 656 } 657 if err := conn.Handshake(); err != nil { 658 return err 659 } 660 defer conn.Close() 661 662 if err := conn.CloseWrite(); err != nil { 663 return fmt.Errorf("client CloseWrite: %v", err) 664 } 665 666 if _, err := conn.Write([]byte{0}); err != errShutdown { 667 return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) 668 } 669 670 data, err := io.ReadAll(conn) 671 if err != nil { 672 return err 673 } 674 if len(data) > 0 { 675 return fmt.Errorf("Read data = %q; want nothing", data) 676 } 677 return nil 678 } 679 680 errChan := make(chan error, 2) 681 682 go func() { errChan <- serverCloseWrite() }() 683 go func() { errChan <- clientCloseWrite() }() 684 685 for i := 0; i < 2; i++ { 686 select { 687 case err := <-errChan: 688 if err != nil { 689 t.Fatal(err) 690 } 691 case <-time.After(10 * time.Second): 692 t.Fatal("deadlock") 693 } 694 } 695 696 // Also test CloseWrite being called before the handshake is 697 // finished: 698 { 699 ln2 := newLocalListener(t) 700 defer ln2.Close() 701 702 netConn, err := net.Dial("tcp", ln2.Addr().String()) 703 if err != nil { 704 t.Fatal(err) 705 } 706 defer netConn.Close() 707 conn := Client(netConn, testConfig.Clone()) 708 709 if err := conn.CloseWrite(); err != errEarlyCloseWrite { 710 t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) 711 } 712 } 713 } 714 715 func TestWarningAlertFlood(t *testing.T) { 716 ln := newLocalListener(t) 717 defer ln.Close() 718 719 server := func() error { 720 sconn, err := ln.Accept() 721 if err != nil { 722 return fmt.Errorf("accept: %v", err) 723 } 724 defer sconn.Close() 725 726 serverConfig := testConfig.Clone() 727 srv := Server(sconn, serverConfig) 728 if err := srv.Handshake(); err != nil { 729 return fmt.Errorf("handshake: %v", err) 730 } 731 defer srv.Close() 732 733 _, err = io.ReadAll(srv) 734 if err == nil { 735 return errors.New("unexpected lack of error from server") 736 } 737 const expected = "too many ignored" 738 if str := err.Error(); !strings.Contains(str, expected) { 739 return fmt.Errorf("expected error containing %q, but saw: %s", expected, str) 740 } 741 742 return nil 743 } 744 745 errChan := make(chan error, 1) 746 go func() { errChan <- server() }() 747 748 clientConfig := testConfig.Clone() 749 clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3 750 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 751 if err != nil { 752 t.Fatal(err) 753 } 754 defer conn.Close() 755 if err := conn.Handshake(); err != nil { 756 t.Fatal(err) 757 } 758 759 for i := 0; i < maxUselessRecords+1; i++ { 760 conn.sendAlert(alertNoRenegotiation) 761 } 762 763 if err := <-errChan; err != nil { 764 t.Fatal(err) 765 } 766 } 767 768 func TestCloneFuncFields(t *testing.T) { 769 const expectedCount = 8 770 called := 0 771 772 c1 := Config{ 773 Time: func() time.Time { 774 called |= 1 << 0 775 return time.Time{} 776 }, 777 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { 778 called |= 1 << 1 779 return nil, nil 780 }, 781 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) { 782 called |= 1 << 2 783 return nil, nil 784 }, 785 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { 786 called |= 1 << 3 787 return nil, nil 788 }, 789 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 790 called |= 1 << 4 791 return nil 792 }, 793 VerifyConnection: func(ConnectionState) error { 794 called |= 1 << 5 795 return nil 796 }, 797 UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) { 798 called |= 1 << 6 799 return nil, nil 800 }, 801 WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) { 802 called |= 1 << 7 803 return nil, nil 804 }, 805 } 806 807 c2 := c1.Clone() 808 809 c2.Time() 810 c2.GetCertificate(nil) 811 c2.GetClientCertificate(nil) 812 c2.GetConfigForClient(nil) 813 c2.VerifyPeerCertificate(nil, nil) 814 c2.VerifyConnection(ConnectionState{}) 815 c2.UnwrapSession(nil, ConnectionState{}) 816 c2.WrapSession(ConnectionState{}, nil) 817 818 if called != (1<<expectedCount)-1 { 819 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) 820 } 821 } 822 823 func TestCloneNonFuncFields(t *testing.T) { 824 var c1 Config 825 v := reflect.ValueOf(&c1).Elem() 826 827 typ := v.Type() 828 for i := 0; i < typ.NumField(); i++ { 829 f := v.Field(i) 830 // testing/quick can't handle functions or interfaces and so 831 // isn't used here. 832 switch fn := typ.Field(i).Name; fn { 833 case "Rand": 834 f.Set(reflect.ValueOf(io.Reader(os.Stdin))) 835 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession": 836 // DeepEqual can't compare functions. If you add a 837 // function field to this list, you must also change 838 // TestCloneFuncFields to ensure that the func field is 839 // cloned. 840 case "Certificates": 841 f.Set(reflect.ValueOf([]Certificate{ 842 {Certificate: [][]byte{{'b'}}}, 843 })) 844 case "NameToCertificate": 845 f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil})) 846 case "RootCAs", "ClientCAs": 847 f.Set(reflect.ValueOf(x509.NewCertPool())) 848 case "ClientSessionCache": 849 f.Set(reflect.ValueOf(NewLRUClientSessionCache(10))) 850 case "KeyLogWriter": 851 f.Set(reflect.ValueOf(io.Writer(os.Stdout))) 852 case "NextProtos": 853 f.Set(reflect.ValueOf([]string{"a", "b"})) 854 case "ServerName": 855 f.Set(reflect.ValueOf("b")) 856 case "ClientAuth": 857 f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) 858 case "InsecureSkipVerify", "InsecureSkipTimeVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": 859 f.Set(reflect.ValueOf(true)) 860 case "InsecureServerNameToVerify": 861 f.Set(reflect.ValueOf("c")) 862 case "MinVersion", "MaxVersion": 863 f.Set(reflect.ValueOf(uint16(VersionTLS12))) 864 case "SessionTicketKey": 865 f.Set(reflect.ValueOf([32]byte{})) 866 case "CipherSuites": 867 f.Set(reflect.ValueOf([]uint16{1, 2})) 868 case "CurvePreferences": 869 f.Set(reflect.ValueOf([]CurveID{CurveP256})) 870 case "Renegotiation": 871 f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) 872 case "mutex", "autoSessionTicketKeys", "sessionTicketKeys": 873 continue // these are unexported fields that are handled separately 874 case "ApplicationSettings": // [UTLS] ALPS (Application Settings) 875 f.Set(reflect.ValueOf(map[string][]byte{"a": {1}})) 876 // #Restls# Begin 877 case "RestlsSecret": 878 f.Set(reflect.ValueOf([]byte{'1', '2', '3', '4', '5'})) 879 case "VersionHint": 880 f.Set(reflect.ValueOf(TLS12Hint)) 881 case "CurveIDHint": 882 hint := atomic.Uint32{} 883 hint.Store(uint32(CurveP256)) 884 f.Set(reflect.ValueOf(hint)) 885 case "RestlsScript": 886 f.Set(reflect.ValueOf([]Line{{TargetLength{100, 0}, ActNoop{}}})) 887 case "ClientID": 888 f.Set(reflect.ValueOf(&HelloChrome_Auto)) 889 // #Restls# End 890 default: 891 t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) 892 } 893 } 894 // Set the unexported fields related to session ticket keys, which are copied with Clone(). 895 c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} 896 c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} 897 898 c2 := c1.Clone() 899 if !reflect.DeepEqual(&c1, c2) { 900 t.Errorf("clone failed to copy a field") 901 } 902 } 903 904 func TestCloneNilConfig(t *testing.T) { 905 var config *Config 906 if cc := config.Clone(); cc != nil { 907 t.Fatalf("Clone with nil should return nil, got: %+v", cc) 908 } 909 } 910 911 // changeImplConn is a net.Conn which can change its Write and Close 912 // methods. 913 type changeImplConn struct { 914 net.Conn 915 writeFunc func([]byte) (int, error) 916 closeFunc func() error 917 } 918 919 func (w *changeImplConn) Write(p []byte) (n int, err error) { 920 if w.writeFunc != nil { 921 return w.writeFunc(p) 922 } 923 return w.Conn.Write(p) 924 } 925 926 func (w *changeImplConn) Close() error { 927 if w.closeFunc != nil { 928 return w.closeFunc() 929 } 930 return w.Conn.Close() 931 } 932 933 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) { 934 ln := newLocalListener(b) 935 defer ln.Close() 936 937 N := b.N 938 939 // Less than 64KB because Windows appears to use a TCP rwin < 64KB. 940 // See Issue #15899. 941 const bufsize = 32 << 10 942 943 go func() { 944 buf := make([]byte, bufsize) 945 for i := 0; i < N; i++ { 946 sconn, err := ln.Accept() 947 if err != nil { 948 // panic rather than synchronize to avoid benchmark overhead 949 // (cannot call b.Fatal in goroutine) 950 panic(fmt.Errorf("accept: %v", err)) 951 } 952 serverConfig := testConfig.Clone() 953 serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers 954 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 955 srv := Server(sconn, serverConfig) 956 if err := srv.Handshake(); err != nil { 957 panic(fmt.Errorf("handshake: %v", err)) 958 } 959 if _, err := io.CopyBuffer(srv, srv, buf); err != nil { 960 panic(fmt.Errorf("copy buffer: %v", err)) 961 } 962 } 963 }() 964 965 b.SetBytes(totalBytes) 966 clientConfig := testConfig.Clone() 967 clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers 968 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 969 clientConfig.MaxVersion = version 970 971 buf := make([]byte, bufsize) 972 chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf)))) 973 for i := 0; i < N; i++ { 974 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 975 if err != nil { 976 b.Fatal(err) 977 } 978 for j := 0; j < chunks; j++ { 979 _, err := conn.Write(buf) 980 if err != nil { 981 b.Fatal(err) 982 } 983 _, err = io.ReadFull(conn, buf) 984 if err != nil { 985 b.Fatal(err) 986 } 987 } 988 conn.Close() 989 } 990 } 991 992 func BenchmarkThroughput(b *testing.B) { 993 for _, mode := range []string{"Max", "Dynamic"} { 994 for size := 1; size <= 64; size <<= 1 { 995 name := fmt.Sprintf("%sPacket/%dMB", mode, size) 996 b.Run(name, func(b *testing.B) { 997 b.Run("TLSv12", func(b *testing.B) { 998 throughput(b, VersionTLS12, int64(size<<20), mode == "Max") 999 }) 1000 b.Run("TLSv13", func(b *testing.B) { 1001 throughput(b, VersionTLS13, int64(size<<20), mode == "Max") 1002 }) 1003 }) 1004 } 1005 } 1006 } 1007 1008 type slowConn struct { 1009 net.Conn 1010 bps int 1011 } 1012 1013 func (c *slowConn) Write(p []byte) (int, error) { 1014 if c.bps == 0 { 1015 panic("too slow") 1016 } 1017 t0 := time.Now() 1018 wrote := 0 1019 for wrote < len(p) { 1020 time.Sleep(100 * time.Microsecond) 1021 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8 1022 if allowed > len(p) { 1023 allowed = len(p) 1024 } 1025 if wrote < allowed { 1026 n, err := c.Conn.Write(p[wrote:allowed]) 1027 wrote += n 1028 if err != nil { 1029 return wrote, err 1030 } 1031 } 1032 } 1033 return len(p), nil 1034 } 1035 1036 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) { 1037 ln := newLocalListener(b) 1038 defer ln.Close() 1039 1040 N := b.N 1041 1042 go func() { 1043 for i := 0; i < N; i++ { 1044 sconn, err := ln.Accept() 1045 if err != nil { 1046 // panic rather than synchronize to avoid benchmark overhead 1047 // (cannot call b.Fatal in goroutine) 1048 panic(fmt.Errorf("accept: %v", err)) 1049 } 1050 serverConfig := testConfig.Clone() 1051 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1052 srv := Server(&slowConn{sconn, bps}, serverConfig) 1053 if err := srv.Handshake(); err != nil { 1054 panic(fmt.Errorf("handshake: %v", err)) 1055 } 1056 io.Copy(srv, srv) 1057 } 1058 }() 1059 1060 clientConfig := testConfig.Clone() 1061 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1062 clientConfig.MaxVersion = version 1063 1064 buf := make([]byte, 16384) 1065 peek := make([]byte, 1) 1066 1067 for i := 0; i < N; i++ { 1068 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 1069 if err != nil { 1070 b.Fatal(err) 1071 } 1072 // make sure we're connected and previous connection has stopped 1073 if _, err := conn.Write(buf[:1]); err != nil { 1074 b.Fatal(err) 1075 } 1076 if _, err := io.ReadFull(conn, peek); err != nil { 1077 b.Fatal(err) 1078 } 1079 if _, err := conn.Write(buf); err != nil { 1080 b.Fatal(err) 1081 } 1082 if _, err = io.ReadFull(conn, peek); err != nil { 1083 b.Fatal(err) 1084 } 1085 conn.Close() 1086 } 1087 } 1088 1089 func BenchmarkLatency(b *testing.B) { 1090 for _, mode := range []string{"Max", "Dynamic"} { 1091 for _, kbps := range []int{200, 500, 1000, 2000, 5000} { 1092 name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps) 1093 b.Run(name, func(b *testing.B) { 1094 b.Run("TLSv12", func(b *testing.B) { 1095 latency(b, VersionTLS12, kbps*1000, mode == "Max") 1096 }) 1097 b.Run("TLSv13", func(b *testing.B) { 1098 latency(b, VersionTLS13, kbps*1000, mode == "Max") 1099 }) 1100 }) 1101 } 1102 } 1103 } 1104 1105 func TestConnectionStateMarshal(t *testing.T) { 1106 cs := &ConnectionState{} 1107 _, err := json.Marshal(cs) 1108 if err != nil { 1109 t.Errorf("json.Marshal failed on ConnectionState: %v", err) 1110 } 1111 } 1112 1113 func TestConnectionState(t *testing.T) { 1114 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1115 if err != nil { 1116 panic(err) 1117 } 1118 rootCAs := x509.NewCertPool() 1119 rootCAs.AddCert(issuer) 1120 1121 now := func() time.Time { return time.Unix(1476984729, 0) } 1122 1123 const alpnProtocol = "golang" 1124 const serverName = "example.golang" 1125 var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} 1126 var ocsp = []byte("dummy ocsp") 1127 1128 for _, v := range []uint16{VersionTLS12, VersionTLS13} { 1129 var name string 1130 switch v { 1131 case VersionTLS12: 1132 name = "TLSv12" 1133 case VersionTLS13: 1134 name = "TLSv13" 1135 } 1136 t.Run(name, func(t *testing.T) { 1137 config := &Config{ 1138 Time: now, 1139 Rand: zeroSource{}, 1140 Certificates: make([]Certificate, 1), 1141 MaxVersion: v, 1142 RootCAs: rootCAs, 1143 ClientCAs: rootCAs, 1144 ClientAuth: RequireAndVerifyClientCert, 1145 NextProtos: []string{alpnProtocol}, 1146 ServerName: serverName, 1147 } 1148 config.Certificates[0].Certificate = [][]byte{testRSACertificate} 1149 config.Certificates[0].PrivateKey = testRSAPrivateKey 1150 config.Certificates[0].SignedCertificateTimestamps = scts 1151 config.Certificates[0].OCSPStaple = ocsp 1152 1153 ss, cs, err := testHandshake(t, config, config) 1154 if err != nil { 1155 t.Fatalf("Handshake failed: %v", err) 1156 } 1157 1158 if ss.Version != v || cs.Version != v { 1159 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v) 1160 } 1161 1162 if !ss.HandshakeComplete || !cs.HandshakeComplete { 1163 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete) 1164 } 1165 1166 if ss.DidResume || cs.DidResume { 1167 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume) 1168 } 1169 1170 if ss.CipherSuite == 0 || cs.CipherSuite == 0 { 1171 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite) 1172 } 1173 1174 if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol { 1175 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol) 1176 } 1177 1178 if !cs.NegotiatedProtocolIsMutual { 1179 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side") 1180 } 1181 // NegotiatedProtocolIsMutual on the server side is unspecified. 1182 1183 if ss.ServerName != serverName { 1184 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName) 1185 } 1186 if cs.ServerName != serverName { 1187 t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName) 1188 } 1189 1190 if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 { 1191 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1) 1192 } 1193 1194 if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 { 1195 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1) 1196 } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 { 1197 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2) 1198 } 1199 1200 if len(cs.SignedCertificateTimestamps) != 2 { 1201 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2) 1202 } 1203 if !bytes.Equal(cs.OCSPResponse, ocsp) { 1204 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp) 1205 } 1206 // Only TLS 1.3 supports OCSP and SCTs on client certs. 1207 if v == VersionTLS13 { 1208 if len(ss.SignedCertificateTimestamps) != 2 { 1209 t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2) 1210 } 1211 if !bytes.Equal(ss.OCSPResponse, ocsp) { 1212 t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp) 1213 } 1214 } 1215 1216 if v == VersionTLS13 { 1217 if ss.TLSUnique != nil || cs.TLSUnique != nil { 1218 t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique) 1219 } 1220 } else { 1221 if ss.TLSUnique == nil || cs.TLSUnique == nil { 1222 t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique) 1223 } 1224 } 1225 }) 1226 } 1227 } 1228 1229 // Issue 28744: Ensure that we don't modify memory 1230 // that Config doesn't own such as Certificates. 1231 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) { 1232 c0 := Certificate{ 1233 Certificate: [][]byte{testRSACertificate}, 1234 PrivateKey: testRSAPrivateKey, 1235 } 1236 c1 := Certificate{ 1237 Certificate: [][]byte{testSNICertificate}, 1238 PrivateKey: testRSAPrivateKey, 1239 } 1240 config := testConfig.Clone() 1241 config.Certificates = []Certificate{c0, c1} 1242 1243 config.BuildNameToCertificate() 1244 got := config.Certificates 1245 want := []Certificate{c0, c1} 1246 if !reflect.DeepEqual(got, want) { 1247 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want) 1248 } 1249 } 1250 1251 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } 1252 1253 func TestClientHelloInfo_SupportsCertificate(t *testing.T) { 1254 rsaCert := &Certificate{ 1255 Certificate: [][]byte{testRSACertificate}, 1256 PrivateKey: testRSAPrivateKey, 1257 } 1258 pkcs1Cert := &Certificate{ 1259 Certificate: [][]byte{testRSACertificate}, 1260 PrivateKey: testRSAPrivateKey, 1261 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, 1262 } 1263 ecdsaCert := &Certificate{ 1264 // ECDSA P-256 certificate 1265 Certificate: [][]byte{testP256Certificate}, 1266 PrivateKey: testP256PrivateKey, 1267 } 1268 ed25519Cert := &Certificate{ 1269 Certificate: [][]byte{testEd25519Certificate}, 1270 PrivateKey: testEd25519PrivateKey, 1271 } 1272 1273 tests := []struct { 1274 c *Certificate 1275 chi *ClientHelloInfo 1276 wantErr string 1277 }{ 1278 {rsaCert, &ClientHelloInfo{ 1279 ServerName: "example.golang", 1280 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1281 SupportedVersions: []uint16{VersionTLS13}, 1282 }, ""}, 1283 {ecdsaCert, &ClientHelloInfo{ 1284 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, 1285 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1286 }, ""}, 1287 {rsaCert, &ClientHelloInfo{ 1288 ServerName: "example.com", 1289 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1290 SupportedVersions: []uint16{VersionTLS13}, 1291 }, "not valid for requested server name"}, 1292 {ecdsaCert, &ClientHelloInfo{ 1293 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, 1294 SupportedVersions: []uint16{VersionTLS13}, 1295 }, "signature algorithms"}, 1296 {pkcs1Cert, &ClientHelloInfo{ 1297 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, 1298 SupportedVersions: []uint16{VersionTLS13}, 1299 }, "signature algorithms"}, 1300 1301 {rsaCert, &ClientHelloInfo{ 1302 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1303 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, 1304 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1305 }, "signature algorithms"}, 1306 {rsaCert, &ClientHelloInfo{ 1307 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1308 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, 1309 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1310 config: &Config{ 1311 MaxVersion: VersionTLS12, 1312 }, 1313 }, ""}, // Check that mutual version selection works. 1314 1315 {ecdsaCert, &ClientHelloInfo{ 1316 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1317 SupportedCurves: []CurveID{CurveP256}, 1318 SupportedPoints: []uint8{pointFormatUncompressed}, 1319 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1320 SupportedVersions: []uint16{VersionTLS12}, 1321 }, ""}, 1322 {ecdsaCert, &ClientHelloInfo{ 1323 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1324 SupportedCurves: []CurveID{CurveP256}, 1325 SupportedPoints: []uint8{pointFormatUncompressed}, 1326 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, 1327 SupportedVersions: []uint16{VersionTLS12}, 1328 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme. 1329 {ecdsaCert, &ClientHelloInfo{ 1330 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1331 SupportedCurves: []CurveID{CurveP256}, 1332 SupportedPoints: []uint8{pointFormatUncompressed}, 1333 SignatureSchemes: nil, 1334 SupportedVersions: []uint16{VersionTLS12}, 1335 }, ""}, // TLS 1.2 comes with default signature schemes. 1336 {ecdsaCert, &ClientHelloInfo{ 1337 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1338 SupportedCurves: []CurveID{CurveP256}, 1339 SupportedPoints: []uint8{pointFormatUncompressed}, 1340 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1341 SupportedVersions: []uint16{VersionTLS12}, 1342 }, "cipher suite"}, 1343 {ecdsaCert, &ClientHelloInfo{ 1344 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1345 SupportedCurves: []CurveID{CurveP256}, 1346 SupportedPoints: []uint8{pointFormatUncompressed}, 1347 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1348 SupportedVersions: []uint16{VersionTLS12}, 1349 config: &Config{ 1350 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1351 }, 1352 }, "cipher suite"}, 1353 {ecdsaCert, &ClientHelloInfo{ 1354 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1355 SupportedCurves: []CurveID{CurveP384}, 1356 SupportedPoints: []uint8{pointFormatUncompressed}, 1357 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1358 SupportedVersions: []uint16{VersionTLS12}, 1359 }, "certificate curve"}, 1360 {ecdsaCert, &ClientHelloInfo{ 1361 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1362 SupportedCurves: []CurveID{CurveP256}, 1363 SupportedPoints: []uint8{1}, 1364 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1365 SupportedVersions: []uint16{VersionTLS12}, 1366 }, "doesn't support ECDHE"}, 1367 {ecdsaCert, &ClientHelloInfo{ 1368 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1369 SupportedCurves: []CurveID{CurveP256}, 1370 SupportedPoints: []uint8{pointFormatUncompressed}, 1371 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1372 SupportedVersions: []uint16{VersionTLS12}, 1373 }, "signature algorithms"}, 1374 1375 {ed25519Cert, &ClientHelloInfo{ 1376 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1377 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1378 SupportedPoints: []uint8{pointFormatUncompressed}, 1379 SignatureSchemes: []SignatureScheme{Ed25519}, 1380 SupportedVersions: []uint16{VersionTLS12}, 1381 }, ""}, 1382 {ed25519Cert, &ClientHelloInfo{ 1383 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1384 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1385 SupportedPoints: []uint8{pointFormatUncompressed}, 1386 SignatureSchemes: []SignatureScheme{Ed25519}, 1387 SupportedVersions: []uint16{VersionTLS10}, 1388 }, "doesn't support Ed25519"}, 1389 {ed25519Cert, &ClientHelloInfo{ 1390 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1391 SupportedCurves: []CurveID{}, 1392 SupportedPoints: []uint8{pointFormatUncompressed}, 1393 SignatureSchemes: []SignatureScheme{Ed25519}, 1394 SupportedVersions: []uint16{VersionTLS12}, 1395 }, "doesn't support ECDHE"}, 1396 1397 {rsaCert, &ClientHelloInfo{ 1398 CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, 1399 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1400 SupportedPoints: []uint8{pointFormatUncompressed}, 1401 SupportedVersions: []uint16{VersionTLS10}, 1402 }, ""}, 1403 {rsaCert, &ClientHelloInfo{ 1404 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1405 SupportedVersions: []uint16{VersionTLS12}, 1406 }, ""}, // static RSA fallback 1407 } 1408 for i, tt := range tests { 1409 err := tt.chi.SupportsCertificate(tt.c) 1410 switch { 1411 case tt.wantErr == "" && err != nil: 1412 t.Errorf("%d: unexpected error: %v", i, err) 1413 case tt.wantErr != "" && err == nil: 1414 t.Errorf("%d: unexpected success", i) 1415 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr): 1416 t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr) 1417 } 1418 } 1419 } 1420 1421 func TestCipherSuites(t *testing.T) { 1422 var lastID uint16 1423 for _, c := range CipherSuites() { 1424 if lastID > c.ID { 1425 t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) 1426 } else { 1427 lastID = c.ID 1428 } 1429 1430 if c.Insecure { 1431 t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID) 1432 } 1433 } 1434 lastID = 0 1435 for _, c := range InsecureCipherSuites() { 1436 if lastID > c.ID { 1437 t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) 1438 } else { 1439 lastID = c.ID 1440 } 1441 1442 if !c.Insecure { 1443 t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID) 1444 } 1445 } 1446 1447 CipherSuiteByID := func(id uint16) *CipherSuite { 1448 for _, c := range CipherSuites() { 1449 if c.ID == id { 1450 return c 1451 } 1452 } 1453 for _, c := range InsecureCipherSuites() { 1454 if c.ID == id { 1455 return c 1456 } 1457 } 1458 return nil 1459 } 1460 1461 for _, c := range cipherSuites { 1462 cc := CipherSuiteByID(c.id) 1463 if cc == nil { 1464 t.Errorf("%#04x: no CipherSuite entry", c.id) 1465 continue 1466 } 1467 1468 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 { 1469 t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1470 } else if !tls12Only && len(cc.SupportedVersions) != 3 { 1471 t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1472 } 1473 1474 if got := CipherSuiteName(c.id); got != cc.Name { 1475 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) 1476 } 1477 } 1478 for _, c := range cipherSuitesTLS13 { 1479 cc := CipherSuiteByID(c.id) 1480 if cc == nil { 1481 t.Errorf("%#04x: no CipherSuite entry", c.id) 1482 continue 1483 } 1484 1485 if cc.Insecure { 1486 t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure) 1487 } 1488 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 { 1489 t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1490 } 1491 1492 if got := CipherSuiteName(c.id); got != cc.Name { 1493 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) 1494 } 1495 } 1496 1497 if got := CipherSuiteName(0xabc); got != "0x0ABC" { 1498 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got) 1499 } 1500 1501 if len(cipherSuitesPreferenceOrder) != len(cipherSuites) { 1502 t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites") 1503 } 1504 if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) { 1505 t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder") 1506 } 1507 1508 // Check that disabled suites are at the end of the preference lists, and 1509 // that they are marked insecure. 1510 for i, id := range disabledCipherSuites { 1511 offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) 1512 if cipherSuitesPreferenceOrder[offset+i] != id { 1513 t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i) 1514 } 1515 if cipherSuitesPreferenceOrderNoAES[offset+i] != id { 1516 t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i) 1517 } 1518 c := CipherSuiteByID(id) 1519 if c == nil { 1520 t.Errorf("%#04x: no CipherSuite entry", id) 1521 continue 1522 } 1523 if !c.Insecure { 1524 t.Errorf("%#04x: disabled by default but not marked insecure", id) 1525 } 1526 } 1527 1528 for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} { 1529 // Check that insecure and HTTP/2 bad cipher suites are at the end of 1530 // the preference lists. 1531 var sawInsecure, sawBad bool 1532 for _, id := range prefOrder { 1533 c := CipherSuiteByID(id) 1534 if c == nil { 1535 t.Errorf("%#04x: no CipherSuite entry", id) 1536 continue 1537 } 1538 1539 if c.Insecure { 1540 sawInsecure = true 1541 } else if sawInsecure { 1542 t.Errorf("%#04x: secure suite after insecure one(s)", id) 1543 } 1544 1545 if http2isBadCipher(id) { 1546 sawBad = true 1547 } else if sawBad { 1548 t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id) 1549 } 1550 } 1551 1552 // Check that the list is sorted according to the documented criteria. 1553 isBetter := func(a, b int) bool { 1554 aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b]) 1555 aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b]) 1556 // * < RC4 1557 if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") { 1558 return true 1559 } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") { 1560 return false 1561 } 1562 // * < CBC_SHA256 1563 if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") { 1564 return true 1565 } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") { 1566 return false 1567 } 1568 // * < 3DES 1569 if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") { 1570 return true 1571 } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") { 1572 return false 1573 } 1574 // ECDHE < * 1575 if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 { 1576 return true 1577 } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 { 1578 return false 1579 } 1580 // AEAD < CBC 1581 if aSuite.aead != nil && bSuite.aead == nil { 1582 return true 1583 } else if aSuite.aead == nil && bSuite.aead != nil { 1584 return false 1585 } 1586 // AES < ChaCha20 1587 if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") { 1588 return i == 0 // true for cipherSuitesPreferenceOrder 1589 } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") { 1590 return i != 0 // true for cipherSuitesPreferenceOrderNoAES 1591 } 1592 // AES-128 < AES-256 1593 if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") { 1594 return true 1595 } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") { 1596 return false 1597 } 1598 // ECDSA < RSA 1599 if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 { 1600 return true 1601 } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 { 1602 return false 1603 } 1604 t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName) 1605 panic("unreachable") 1606 } 1607 if !sort.SliceIsSorted(prefOrder, isBetter) { 1608 t.Error("preference order is not sorted according to the rules") 1609 } 1610 } 1611 } 1612 1613 func TestVersionName(t *testing.T) { 1614 if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp { 1615 t.Errorf("unexpected VersionName: got %q, expected %q", got, exp) 1616 } 1617 if got, exp := VersionName(0x12a), "0x012A"; got != exp { 1618 t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp) 1619 } 1620 } 1621 1622 // http2isBadCipher is copied from net/http. 1623 // TODO: if it ends up exposed somewhere, use that instead. 1624 func http2isBadCipher(cipher uint16) bool { 1625 switch cipher { 1626 case TLS_RSA_WITH_RC4_128_SHA, 1627 TLS_RSA_WITH_3DES_EDE_CBC_SHA, 1628 TLS_RSA_WITH_AES_128_CBC_SHA, 1629 TLS_RSA_WITH_AES_256_CBC_SHA, 1630 TLS_RSA_WITH_AES_128_CBC_SHA256, 1631 TLS_RSA_WITH_AES_128_GCM_SHA256, 1632 TLS_RSA_WITH_AES_256_GCM_SHA384, 1633 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 1634 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 1635 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 1636 TLS_ECDHE_RSA_WITH_RC4_128_SHA, 1637 TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 1638 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 1639 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 1640 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 1641 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: 1642 return true 1643 default: 1644 return false 1645 } 1646 } 1647 1648 type brokenSigner struct{ crypto.Signer } 1649 1650 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { 1651 // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded. 1652 return s.Signer.Sign(rand, digest, opts.HashFunc()) 1653 } 1654 1655 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that 1656 // always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS. 1657 func TestPKCS1OnlyCert(t *testing.T) { 1658 clientConfig := testConfig.Clone() 1659 clientConfig.Certificates = []Certificate{{ 1660 Certificate: [][]byte{testRSACertificate}, 1661 PrivateKey: brokenSigner{testRSAPrivateKey}, 1662 }} 1663 serverConfig := testConfig.Clone() 1664 serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5 1665 serverConfig.ClientAuth = RequireAnyClientCert 1666 1667 // If RSA-PSS is selected, the handshake should fail. 1668 if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil { 1669 t.Fatal("expected broken certificate to cause connection to fail") 1670 } 1671 1672 clientConfig.Certificates[0].SupportedSignatureAlgorithms = 1673 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256} 1674 1675 // But if the certificate restricts supported algorithms, RSA-PSS should not 1676 // be selected, and the handshake should succeed. 1677 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil { 1678 t.Error(err) 1679 } 1680 } 1681 1682 func TestVerifyCertificates(t *testing.T) { 1683 // See https://go.dev/issue/31641. 1684 t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) }) 1685 t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) }) 1686 } 1687 1688 func testVerifyCertificates(t *testing.T, version uint16) { 1689 tests := []struct { 1690 name string 1691 1692 InsecureSkipVerify bool 1693 ClientAuth ClientAuthType 1694 ClientCertificates bool 1695 }{ 1696 { 1697 name: "defaults", 1698 }, 1699 { 1700 name: "InsecureSkipVerify", 1701 InsecureSkipVerify: true, 1702 }, 1703 { 1704 name: "RequestClientCert with no certs", 1705 ClientAuth: RequestClientCert, 1706 }, 1707 { 1708 name: "RequestClientCert with certs", 1709 ClientAuth: RequestClientCert, 1710 ClientCertificates: true, 1711 }, 1712 { 1713 name: "RequireAnyClientCert", 1714 ClientAuth: RequireAnyClientCert, 1715 ClientCertificates: true, 1716 }, 1717 { 1718 name: "VerifyClientCertIfGiven with no certs", 1719 ClientAuth: VerifyClientCertIfGiven, 1720 }, 1721 { 1722 name: "VerifyClientCertIfGiven with certs", 1723 ClientAuth: VerifyClientCertIfGiven, 1724 ClientCertificates: true, 1725 }, 1726 { 1727 name: "RequireAndVerifyClientCert", 1728 ClientAuth: RequireAndVerifyClientCert, 1729 ClientCertificates: true, 1730 }, 1731 } 1732 1733 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1734 if err != nil { 1735 t.Fatal(err) 1736 } 1737 rootCAs := x509.NewCertPool() 1738 rootCAs.AddCert(issuer) 1739 1740 for _, test := range tests { 1741 test := test 1742 t.Run(test.name, func(t *testing.T) { 1743 t.Parallel() 1744 1745 var serverVerifyConnection, clientVerifyConnection bool 1746 var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool 1747 1748 clientConfig := testConfig.Clone() 1749 clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } 1750 clientConfig.MaxVersion = version 1751 clientConfig.MinVersion = version 1752 clientConfig.RootCAs = rootCAs 1753 clientConfig.ServerName = "example.golang" 1754 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) 1755 serverConfig := clientConfig.Clone() 1756 serverConfig.ClientCAs = rootCAs 1757 1758 clientConfig.VerifyConnection = func(cs ConnectionState) error { 1759 clientVerifyConnection = true 1760 return nil 1761 } 1762 clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 1763 clientVerifyPeerCertificates = true 1764 return nil 1765 } 1766 serverConfig.VerifyConnection = func(cs ConnectionState) error { 1767 serverVerifyConnection = true 1768 return nil 1769 } 1770 serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 1771 serverVerifyPeerCertificates = true 1772 return nil 1773 } 1774 1775 clientConfig.InsecureSkipVerify = test.InsecureSkipVerify 1776 serverConfig.ClientAuth = test.ClientAuth 1777 if !test.ClientCertificates { 1778 clientConfig.Certificates = nil 1779 } 1780 1781 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil { 1782 t.Fatal(err) 1783 } 1784 1785 want := serverConfig.ClientAuth != NoClientCert 1786 if serverVerifyPeerCertificates != want { 1787 t.Errorf("VerifyPeerCertificates on the server: got %v, want %v", 1788 serverVerifyPeerCertificates, want) 1789 } 1790 if !clientVerifyPeerCertificates { 1791 t.Errorf("VerifyPeerCertificates not called on the client") 1792 } 1793 if !serverVerifyConnection { 1794 t.Error("VerifyConnection did not get called on the server") 1795 } 1796 if !clientVerifyConnection { 1797 t.Error("VerifyConnection did not get called on the client") 1798 } 1799 1800 serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false 1801 serverVerifyConnection, clientVerifyConnection = false, false 1802 cs, _, err := testHandshake(t, clientConfig, serverConfig) 1803 if err != nil { 1804 t.Fatal(err) 1805 } 1806 if !cs.DidResume { 1807 t.Error("expected resumption") 1808 } 1809 1810 if serverVerifyPeerCertificates { 1811 t.Error("VerifyPeerCertificates got called on the server on resumption") 1812 } 1813 if clientVerifyPeerCertificates { 1814 t.Error("VerifyPeerCertificates got called on the client on resumption") 1815 } 1816 if !serverVerifyConnection { 1817 t.Error("VerifyConnection did not get called on the server on resumption") 1818 } 1819 if !clientVerifyConnection { 1820 t.Error("VerifyConnection did not get called on the client on resumption") 1821 } 1822 }) 1823 } 1824 }