gitlab.com/go-extension/tls@v0.0.0-20240304171319-e6745021905e/tls_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "context" 10 "crypto" 11 "crypto/x509" 12 "encoding/json" 13 "errors" 14 "fmt" 15 "io" 16 "math" 17 "net" 18 "os" 19 "reflect" 20 "sort" 21 "strings" 22 "testing" 23 "time" 24 25 "github.com/rogpeppe/go-internal/testenv" 26 ) 27 28 var rsaCertPEM = `-----BEGIN CERTIFICATE----- 29 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV 30 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX 31 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF 32 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 33 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ 34 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa 35 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv 36 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF 37 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW 38 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V 39 -----END CERTIFICATE----- 40 ` 41 42 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY----- 43 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo 44 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G 45 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 46 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW 47 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T 48 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi 49 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== 50 -----END RSA TESTING KEY----- 51 `) 52 53 // keyPEM is the same as rsaKeyPEM, but declares itself as just 54 // "PRIVATE KEY", not "RSA PRIVATE KEY". https://golang.org/issue/4477 55 var keyPEM = testingKey(`-----BEGIN TESTING KEY----- 56 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo 57 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G 58 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 59 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW 60 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T 61 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi 62 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== 63 -----END TESTING KEY----- 64 `) 65 66 var ecdsaCertPEM = `-----BEGIN CERTIFICATE----- 67 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw 68 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 69 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG 70 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk 71 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR 72 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl 73 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8 74 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo 75 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb 76 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1 77 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA== 78 -----END CERTIFICATE----- 79 ` 80 81 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS----- 82 BgUrgQQAIw== 83 -----END EC PARAMETERS----- 84 -----BEGIN EC TESTING KEY----- 85 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0 86 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL 87 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz 88 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q 89 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ== 90 -----END EC TESTING KEY----- 91 `) 92 93 var keyPairTests = []struct { 94 algo string 95 cert string 96 key string 97 }{ 98 {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM}, 99 {"RSA", rsaCertPEM, rsaKeyPEM}, 100 {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477 101 } 102 103 func TestX509KeyPair(t *testing.T) { 104 t.Parallel() 105 var pem []byte 106 for _, test := range keyPairTests { 107 pem = []byte(test.cert + test.key) 108 if _, err := X509KeyPair(pem, pem); err != nil { 109 t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err) 110 } 111 pem = []byte(test.key + test.cert) 112 if _, err := X509KeyPair(pem, pem); err != nil { 113 t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err) 114 } 115 } 116 } 117 118 func TestX509KeyPairErrors(t *testing.T) { 119 _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM)) 120 if err == nil { 121 t.Fatalf("X509KeyPair didn't return an error when arguments were switched") 122 } 123 if subStr := "been switched"; !strings.Contains(err.Error(), subStr) { 124 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err) 125 } 126 127 _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM)) 128 if err == nil { 129 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates") 130 } 131 if subStr := "certificate"; !strings.Contains(err.Error(), subStr) { 132 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err) 133 } 134 135 const nonsensePEM = ` 136 -----BEGIN NONSENSE----- 137 Zm9vZm9vZm9v 138 -----END NONSENSE----- 139 ` 140 141 _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM)) 142 if err == nil { 143 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense") 144 } 145 if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) { 146 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err) 147 } 148 } 149 150 func TestX509MixedKeyPair(t *testing.T) { 151 if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil { 152 t.Error("Load of RSA certificate succeeded with ECDSA private key") 153 } 154 if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil { 155 t.Error("Load of ECDSA certificate succeeded with RSA private key") 156 } 157 } 158 159 func newLocalListener(t testing.TB) net.Listener { 160 ln, err := net.Listen("tcp", "127.0.0.1:0") 161 if err != nil { 162 ln, err = net.Listen("tcp6", "[::1]:0") 163 } 164 if err != nil { 165 t.Fatal(err) 166 } 167 return ln 168 } 169 170 func TestDialTimeout(t *testing.T) { 171 if testing.Short() { 172 t.Skip("skipping in short mode") 173 } 174 175 timeout := 100 * time.Microsecond 176 for !t.Failed() { 177 acceptc := make(chan net.Conn) 178 listener := newLocalListener(t) 179 go func() { 180 for { 181 conn, err := listener.Accept() 182 if err != nil { 183 close(acceptc) 184 return 185 } 186 acceptc <- conn 187 } 188 }() 189 190 addr := listener.Addr().String() 191 dialer := &net.Dialer{ 192 Timeout: timeout, 193 } 194 if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil { 195 conn.Close() 196 t.Errorf("DialWithTimeout unexpectedly completed successfully") 197 } else if !isTimeoutError(err) { 198 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) 199 } 200 201 listener.Close() 202 203 // We're looking for a timeout during the handshake, so check that the 204 // Listener actually accepted the connection to initiate it. (If the server 205 // takes too long to accept the connection, we might cancel before the 206 // underlying net.Conn is ever dialed — without ever attempting a 207 // handshake.) 208 lconn, ok := <-acceptc 209 if ok { 210 // The Listener accepted a connection, so assume that it was from our 211 // Dial: we triggered the timeout at the point where we wanted it! 212 t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr()) 213 lconn.Close() 214 } 215 // Close any spurious extra connecitions from the listener. (This is 216 // possible if there are, for example, stray Dial calls from other tests.) 217 for extraConn := range acceptc { 218 t.Logf("spurious extra connection from %s", extraConn.RemoteAddr()) 219 extraConn.Close() 220 } 221 if ok { 222 break 223 } 224 225 t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout) 226 timeout *= 2 227 } 228 } 229 230 func TestDeadlineOnWrite(t *testing.T) { 231 if testing.Short() { 232 t.Skip("skipping in short mode") 233 } 234 235 ln := newLocalListener(t) 236 defer ln.Close() 237 238 srvCh := make(chan *Conn, 1) 239 240 go func() { 241 sconn, err := ln.Accept() 242 if err != nil { 243 srvCh <- nil 244 return 245 } 246 srv := Server(sconn, testConfig.Clone()) 247 if err := srv.Handshake(); err != nil { 248 srvCh <- nil 249 return 250 } 251 srvCh <- srv 252 }() 253 254 clientConfig := testConfig.Clone() 255 clientConfig.MaxVersion = VersionTLS12 256 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 257 if err != nil { 258 t.Fatal(err) 259 } 260 defer conn.Close() 261 262 srv := <-srvCh 263 if srv == nil { 264 t.Error(err) 265 } 266 267 // Make sure the client/server is setup correctly and is able to do a typical Write/Read 268 buf := make([]byte, 6) 269 if _, err := srv.Write([]byte("foobar")); err != nil { 270 t.Errorf("Write err: %v", err) 271 } 272 if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" { 273 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) 274 } 275 276 // Set a deadline which should cause Write to timeout 277 if err = srv.SetDeadline(time.Now()); err != nil { 278 t.Fatalf("SetDeadline(time.Now()) err: %v", err) 279 } 280 if _, err = srv.Write([]byte("should fail")); err == nil { 281 t.Fatal("Write should have timed out") 282 } 283 284 // Clear deadline and make sure it still times out 285 if err = srv.SetDeadline(time.Time{}); err != nil { 286 t.Fatalf("SetDeadline(time.Time{}) err: %v", err) 287 } 288 if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil { 289 t.Fatal("Write which previously failed should still time out") 290 } 291 292 // Verify the error 293 if ne := err.(net.Error); ne.Temporary() != false { 294 t.Error("Write timed out but incorrectly classified the error as Temporary") 295 } 296 if !isTimeoutError(err) { 297 t.Error("Write timed out but did not classify the error as a Timeout") 298 } 299 } 300 301 type readerFunc func([]byte) (int, error) 302 303 func (f readerFunc) Read(b []byte) (int, error) { return f(b) } 304 305 // TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake. 306 // (The other cases are all handled by the existing dial tests in this package, which 307 // all also flow through the same code shared code paths) 308 func TestDialer(t *testing.T) { 309 ln := newLocalListener(t) 310 defer ln.Close() 311 312 unblockServer := make(chan struct{}) // close-only 313 defer close(unblockServer) 314 go func() { 315 conn, err := ln.Accept() 316 if err != nil { 317 return 318 } 319 defer conn.Close() 320 <-unblockServer 321 }() 322 323 ctx, cancel := context.WithCancel(context.Background()) 324 d := Dialer{Config: &Config{ 325 Rand: readerFunc(func(b []byte) (n int, err error) { 326 // By the time crypto/tls wants randomness, that means it has a TCP 327 // connection, so we're past the Dialer's dial and now blocked 328 // in a handshake. Cancel our context and see if we get unstuck. 329 // (Our TCP listener above never reads or writes, so the Handshake 330 // would otherwise be stuck forever) 331 cancel() 332 return len(b), nil 333 }), 334 ServerName: "foo", 335 }} 336 _, err := d.DialContext(ctx, "tcp", ln.Addr().String()) 337 if err != context.Canceled { 338 t.Errorf("err = %v; want context.Canceled", err) 339 } 340 } 341 342 func isTimeoutError(err error) bool { 343 if ne, ok := err.(net.Error); ok { 344 return ne.Timeout() 345 } 346 return false 347 } 348 349 // tests that Conn.Read returns (non-zero, io.EOF) instead of 350 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right 351 // behind the application data in the buffer. 352 func TestConnReadNonzeroAndEOF(t *testing.T) { 353 // This test is racy: it assumes that after a write to a 354 // localhost TCP connection, the peer TCP connection can 355 // immediately read it. Because it's racy, we skip this test 356 // in short mode, and then retry it several times with an 357 // increasing sleep in between our final write (via srv.Close 358 // below) and the following read. 359 if testing.Short() { 360 t.Skip("skipping in short mode") 361 } 362 var err error 363 for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 { 364 if err = testConnReadNonzeroAndEOF(t, delay); err == nil { 365 return 366 } 367 } 368 t.Error(err) 369 } 370 371 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { 372 ln := newLocalListener(t) 373 defer ln.Close() 374 375 srvCh := make(chan *Conn, 1) 376 var serr error 377 go func() { 378 sconn, err := ln.Accept() 379 if err != nil { 380 serr = err 381 srvCh <- nil 382 return 383 } 384 serverConfig := testConfig.Clone() 385 srv := Server(sconn, serverConfig) 386 if err := srv.Handshake(); err != nil { 387 serr = fmt.Errorf("handshake: %v", err) 388 srvCh <- nil 389 return 390 } 391 srvCh <- srv 392 }() 393 394 clientConfig := testConfig.Clone() 395 // In TLS 1.3, alerts are encrypted and disguised as application data, so 396 // the opportunistic peek won't work. 397 clientConfig.MaxVersion = VersionTLS12 398 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 399 if err != nil { 400 t.Fatal(err) 401 } 402 defer conn.Close() 403 404 srv := <-srvCh 405 if srv == nil { 406 return serr 407 } 408 409 buf := make([]byte, 6) 410 411 srv.Write([]byte("foobar")) 412 n, err := conn.Read(buf) 413 if n != 6 || err != nil || string(buf) != "foobar" { 414 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) 415 } 416 417 srv.Write([]byte("abcdef")) 418 srv.Close() 419 time.Sleep(delay) 420 n, err = conn.Read(buf) 421 if n != 6 || string(buf) != "abcdef" { 422 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf) 423 } 424 if err != io.EOF { 425 return fmt.Errorf("Second Read error = %v; want io.EOF", err) 426 } 427 return nil 428 } 429 430 func TestTLSUniqueMatches(t *testing.T) { 431 ln := newLocalListener(t) 432 defer ln.Close() 433 434 serverTLSUniques := make(chan []byte) 435 parentDone := make(chan struct{}) 436 childDone := make(chan struct{}) 437 defer close(parentDone) 438 go func() { 439 defer close(childDone) 440 for i := 0; i < 2; i++ { 441 sconn, err := ln.Accept() 442 if err != nil { 443 t.Error(err) 444 return 445 } 446 serverConfig := testConfig.Clone() 447 serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3 448 srv := Server(sconn, serverConfig) 449 if err := srv.Handshake(); err != nil { 450 t.Error(err) 451 return 452 } 453 select { 454 case <-parentDone: 455 return 456 case serverTLSUniques <- srv.ConnectionState().TLSUnique: 457 } 458 } 459 }() 460 461 clientConfig := testConfig.Clone() 462 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) 463 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 464 if err != nil { 465 t.Fatal(err) 466 } 467 468 var serverTLSUniquesValue []byte 469 select { 470 case <-childDone: 471 return 472 case serverTLSUniquesValue = <-serverTLSUniques: 473 } 474 475 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { 476 t.Error("client and server channel bindings differ") 477 } 478 if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) { 479 t.Error("tls-unique is empty or zero") 480 } 481 conn.Close() 482 483 conn, err = Dial("tcp", ln.Addr().String(), clientConfig) 484 if err != nil { 485 t.Fatal(err) 486 } 487 defer conn.Close() 488 if !conn.ConnectionState().DidResume { 489 t.Error("second session did not use resumption") 490 } 491 492 select { 493 case <-childDone: 494 return 495 case serverTLSUniquesValue = <-serverTLSUniques: 496 } 497 498 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { 499 t.Error("client and server channel bindings differ when session resumption is used") 500 } 501 if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) { 502 t.Error("resumption tls-unique is empty or zero") 503 } 504 } 505 506 func TestVerifyHostname(t *testing.T) { 507 testenv.MustHaveExternalNetwork(t) 508 509 c, err := Dial("tcp", "www.google.com:https", nil) 510 if err != nil { 511 t.Fatal(err) 512 } 513 if err := c.VerifyHostname("www.google.com"); err != nil { 514 t.Fatalf("verify www.google.com: %v", err) 515 } 516 if err := c.VerifyHostname("www.yahoo.com"); err == nil { 517 t.Fatalf("verify www.yahoo.com succeeded") 518 } 519 520 c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true}) 521 if err != nil { 522 t.Fatal(err) 523 } 524 if err := c.VerifyHostname("www.google.com"); err == nil { 525 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true") 526 } 527 } 528 529 func TestConnCloseBreakingWrite(t *testing.T) { 530 ln := newLocalListener(t) 531 defer ln.Close() 532 533 srvCh := make(chan *Conn, 1) 534 var serr error 535 var sconn net.Conn 536 go func() { 537 var err error 538 sconn, err = ln.Accept() 539 if err != nil { 540 serr = err 541 srvCh <- nil 542 return 543 } 544 serverConfig := testConfig.Clone() 545 srv := Server(sconn, serverConfig) 546 if err := srv.Handshake(); err != nil { 547 serr = fmt.Errorf("handshake: %v", err) 548 srvCh <- nil 549 return 550 } 551 srvCh <- srv 552 }() 553 554 cconn, err := net.Dial("tcp", ln.Addr().String()) 555 if err != nil { 556 t.Fatal(err) 557 } 558 defer cconn.Close() 559 560 conn := &changeImplConn{ 561 Conn: cconn, 562 } 563 564 clientConfig := testConfig.Clone() 565 tconn := Client(conn, clientConfig) 566 if err := tconn.Handshake(); err != nil { 567 t.Fatal(err) 568 } 569 570 srv := <-srvCh 571 if srv == nil { 572 t.Fatal(serr) 573 } 574 defer sconn.Close() 575 576 connClosed := make(chan struct{}) 577 conn.closeFunc = func() error { 578 close(connClosed) 579 return nil 580 } 581 582 inWrite := make(chan bool, 1) 583 var errConnClosed = errors.New("conn closed for test") 584 conn.writeFunc = func(p []byte) (n int, err error) { 585 inWrite <- true 586 <-connClosed 587 return 0, errConnClosed 588 } 589 590 closeReturned := make(chan bool, 1) 591 go func() { 592 <-inWrite 593 tconn.Close() // test that this doesn't block forever. 594 closeReturned <- true 595 }() 596 597 _, err = tconn.Write([]byte("foo")) 598 if err != errConnClosed { 599 t.Errorf("Write error = %v; want errConnClosed", err) 600 } 601 602 <-closeReturned 603 if err := tconn.Close(); err != net.ErrClosed { 604 t.Errorf("Close error = %v; want net.ErrClosed", err) 605 } 606 } 607 608 func TestConnCloseWrite(t *testing.T) { 609 ln := newLocalListener(t) 610 defer ln.Close() 611 612 clientDoneChan := make(chan struct{}) 613 614 serverCloseWrite := func() error { 615 sconn, err := ln.Accept() 616 if err != nil { 617 return fmt.Errorf("accept: %v", err) 618 } 619 defer sconn.Close() 620 621 serverConfig := testConfig.Clone() 622 srv := Server(sconn, serverConfig) 623 if err := srv.Handshake(); err != nil { 624 return fmt.Errorf("handshake: %v", err) 625 } 626 defer srv.Close() 627 628 data, err := io.ReadAll(srv) 629 if err != nil { 630 return err 631 } 632 if len(data) > 0 { 633 return fmt.Errorf("Read data = %q; want nothing", data) 634 } 635 636 if err := srv.CloseWrite(); err != nil { 637 return fmt.Errorf("server CloseWrite: %v", err) 638 } 639 640 // Wait for clientCloseWrite to finish, so we know we 641 // tested the CloseWrite before we defer the 642 // sconn.Close above, which would also cause the 643 // client to unblock like CloseWrite. 644 <-clientDoneChan 645 return nil 646 } 647 648 clientCloseWrite := func() error { 649 defer close(clientDoneChan) 650 651 clientConfig := testConfig.Clone() 652 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 653 if err != nil { 654 return err 655 } 656 if err := conn.Handshake(); err != nil { 657 return err 658 } 659 defer conn.Close() 660 661 if err := conn.CloseWrite(); err != nil { 662 return fmt.Errorf("client CloseWrite: %v", err) 663 } 664 665 if _, err := conn.Write([]byte{0}); err != errShutdown { 666 return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) 667 } 668 669 data, err := io.ReadAll(conn) 670 if err != nil { 671 return err 672 } 673 if len(data) > 0 { 674 return fmt.Errorf("Read data = %q; want nothing", data) 675 } 676 return nil 677 } 678 679 errChan := make(chan error, 2) 680 681 go func() { errChan <- serverCloseWrite() }() 682 go func() { errChan <- clientCloseWrite() }() 683 684 for i := 0; i < 2; i++ { 685 select { 686 case err := <-errChan: 687 if err != nil { 688 t.Fatal(err) 689 } 690 case <-time.After(10 * time.Second): 691 t.Fatal("deadlock") 692 } 693 } 694 695 // Also test CloseWrite being called before the handshake is 696 // finished: 697 { 698 ln2 := newLocalListener(t) 699 defer ln2.Close() 700 701 netConn, err := net.Dial("tcp", ln2.Addr().String()) 702 if err != nil { 703 t.Fatal(err) 704 } 705 defer netConn.Close() 706 conn := Client(netConn, testConfig.Clone()) 707 708 if err := conn.CloseWrite(); err != errEarlyCloseWrite { 709 t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) 710 } 711 } 712 } 713 714 func TestWarningAlertFlood(t *testing.T) { 715 ln := newLocalListener(t) 716 defer ln.Close() 717 718 server := func() error { 719 sconn, err := ln.Accept() 720 if err != nil { 721 return fmt.Errorf("accept: %v", err) 722 } 723 defer sconn.Close() 724 725 serverConfig := testConfig.Clone() 726 srv := Server(sconn, serverConfig) 727 if err := srv.Handshake(); err != nil { 728 return fmt.Errorf("handshake: %v", err) 729 } 730 defer srv.Close() 731 732 _, err = io.ReadAll(srv) 733 if err == nil { 734 return errors.New("unexpected lack of error from server") 735 } 736 const expected = "too many ignored" 737 if str := err.Error(); !strings.Contains(str, expected) { 738 return fmt.Errorf("expected error containing %q, but saw: %s", expected, str) 739 } 740 741 return nil 742 } 743 744 errChan := make(chan error, 1) 745 go func() { errChan <- server() }() 746 747 clientConfig := testConfig.Clone() 748 clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3 749 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 750 if err != nil { 751 t.Fatal(err) 752 } 753 defer conn.Close() 754 if err := conn.Handshake(); err != nil { 755 t.Fatal(err) 756 } 757 758 for i := 0; i < maxUselessRecords+1; i++ { 759 conn.sendAlert(alertNoRenegotiation) 760 } 761 762 if err := <-errChan; err != nil { 763 t.Fatal(err) 764 } 765 } 766 767 func TestCloneFuncFields(t *testing.T) { 768 const expectedCount = 8 769 called := 0 770 771 c1 := Config{ 772 Time: func() time.Time { 773 called |= 1 << 0 774 return time.Time{} 775 }, 776 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { 777 called |= 1 << 1 778 return nil, nil 779 }, 780 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) { 781 called |= 1 << 2 782 return nil, nil 783 }, 784 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { 785 called |= 1 << 3 786 return nil, nil 787 }, 788 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 789 called |= 1 << 4 790 return nil 791 }, 792 VerifyConnection: func(ConnectionState) error { 793 called |= 1 << 5 794 return nil 795 }, 796 UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) { 797 called |= 1 << 6 798 return nil, nil 799 }, 800 WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) { 801 called |= 1 << 7 802 return nil, nil 803 }, 804 } 805 806 c2 := c1.Clone() 807 808 c2.Time() 809 c2.GetCertificate(nil) 810 c2.GetClientCertificate(nil) 811 c2.GetConfigForClient(nil) 812 c2.VerifyPeerCertificate(nil, nil) 813 c2.VerifyConnection(ConnectionState{}) 814 c2.UnwrapSession(nil, ConnectionState{}) 815 c2.WrapSession(ConnectionState{}, nil) 816 817 if called != (1<<expectedCount)-1 { 818 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) 819 } 820 } 821 822 func TestCloneNonFuncFields(t *testing.T) { 823 var c1 Config 824 v := reflect.ValueOf(&c1).Elem() 825 826 typ := v.Type() 827 for i := 0; i < typ.NumField(); i++ { 828 f := v.Field(i) 829 // testing/quick can't handle functions or interfaces and so 830 // isn't used here. 831 switch fn := typ.Field(i).Name; fn { 832 case "Rand": 833 f.Set(reflect.ValueOf(io.Reader(os.Stdin))) 834 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "GetProtoSetting", "WrapSession", "UnwrapSession": 835 // DeepEqual can't compare functions. If you add a 836 // function field to this list, you must also change 837 // TestCloneFuncFields to ensure that the func field is 838 // cloned. 839 case "Certificates": 840 f.Set(reflect.ValueOf([]Certificate{ 841 {Certificate: [][]byte{{'b'}}}, 842 })) 843 case "NameToCertificate": 844 f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil})) 845 case "CertCompressionDisabled": 846 f.Set(reflect.ValueOf(true)) 847 case "CertCompressionPreferences": 848 f.Set(reflect.ValueOf([]CertCompressionAlgorithm{Zlib})) 849 case "SignatureAlgorithmsPreferences": 850 f.Set(reflect.ValueOf([]SignatureScheme{PKCS1WithSHA1})) 851 case "RootCAs", "ClientCAs": 852 f.Set(reflect.ValueOf(x509.NewCertPool())) 853 case "ClientSessionCache": 854 f.Set(reflect.ValueOf(NewLRUClientSessionCache(10))) 855 case "KeyLogWriter": 856 f.Set(reflect.ValueOf(io.Writer(os.Stdout))) 857 case "NextProtos": 858 f.Set(reflect.ValueOf([]string{"a", "b"})) 859 case "NextProtoSettings": 860 f.Set(reflect.ValueOf(map[string][]byte{"a": nil})) 861 case "ServerName": 862 f.Set(reflect.ValueOf("b")) 863 case "ClientAuth": 864 f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) 865 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": 866 f.Set(reflect.ValueOf(true)) 867 case "MinVersion", "MaxVersion": 868 f.Set(reflect.ValueOf(uint16(VersionTLS12))) 869 case "SupportedVersions": 870 f.Set(reflect.ValueOf([]uint16{VersionTLS12})) 871 case "SessionTicketKey": 872 f.Set(reflect.ValueOf([32]byte{})) 873 case "PreferCipherSuites": 874 f.Set(reflect.ValueOf(true)) 875 case "CipherSuites": 876 f.Set(reflect.ValueOf([]uint16{1, 2})) 877 case "CurvePreferences": 878 f.Set(reflect.ValueOf([]CurveID{CurveP256})) 879 case "Renegotiation": 880 f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) 881 case "Allow0RTT": 882 f.Set(reflect.ValueOf(true)) 883 case "Max0RTTData": 884 f.Set(reflect.ValueOf(uint32(2048))) 885 case "MinClientHelloSize": 886 f.Set(reflect.ValueOf(uint16(512))) 887 case "ShuffleExtensions": 888 f.Set(reflect.ValueOf(true)) 889 case "Extensions": 890 f.Set(reflect.ValueOf([]Extension{&ServerNameIndicationExtension{}})) 891 case "KernelTX", "KernelRX": 892 f.Set(reflect.ValueOf(true)) 893 case "ECHEnabled": 894 f.Set(reflect.ValueOf(true)) 895 case "ClientECHDummyConfig": 896 f.Set(reflect.ValueOf(ECHDummyConfig{ 897 ConfigIds: []uint8{1}, 898 })) 899 case "ClientECHConfigs": 900 f.Set(reflect.ValueOf([]ECHConfig{ 901 {raw: []byte{'b'}}, 902 })) 903 case "ServerECHProvider": 904 f.Set(reflect.ValueOf(ECHProvider(&EXP_ECHKeySet{}))) 905 case "mutex", "autoSessionTicketKeys", "sessionTicketKeys": 906 continue // these are unexported fields that are handled separately 907 default: 908 t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) 909 } 910 } 911 // Set the unexported fields related to session ticket keys, which are copied with Clone(). 912 c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} 913 c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} 914 915 c2 := c1.Clone() 916 if !reflect.DeepEqual(&c1, c2) { 917 t.Errorf("clone failed to copy a field") 918 } 919 } 920 921 func TestCloneNilConfig(t *testing.T) { 922 var config *Config 923 if cc := config.Clone(); cc != nil { 924 t.Fatalf("Clone with nil should return nil, got: %+v", cc) 925 } 926 } 927 928 // changeImplConn is a net.Conn which can change its Write and Close 929 // methods. 930 type changeImplConn struct { 931 net.Conn 932 writeFunc func([]byte) (int, error) 933 closeFunc func() error 934 } 935 936 func (w *changeImplConn) Write(p []byte) (n int, err error) { 937 if w.writeFunc != nil { 938 return w.writeFunc(p) 939 } 940 return w.Conn.Write(p) 941 } 942 943 func (w *changeImplConn) Close() error { 944 if w.closeFunc != nil { 945 return w.closeFunc() 946 } 947 return w.Conn.Close() 948 } 949 950 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) { 951 ln := newLocalListener(b) 952 defer ln.Close() 953 954 N := b.N 955 956 // Less than 64KB because Windows appears to use a TCP rwin < 64KB. 957 // See Issue #15899. 958 const bufsize = 32 << 10 959 960 go func() { 961 buf := make([]byte, bufsize) 962 for i := 0; i < N; i++ { 963 sconn, err := ln.Accept() 964 if err != nil { 965 // panic rather than synchronize to avoid benchmark overhead 966 // (cannot call b.Fatal in goroutine) 967 panic(fmt.Errorf("accept: %v", err)) 968 } 969 serverConfig := testConfig.Clone() 970 serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers 971 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 972 srv := Server(sconn, serverConfig) 973 if err := srv.Handshake(); err != nil { 974 panic(fmt.Errorf("handshake: %v", err)) 975 } 976 if _, err := io.CopyBuffer(srv, srv, buf); err != nil { 977 panic(fmt.Errorf("copy buffer: %v", err)) 978 } 979 } 980 }() 981 982 b.SetBytes(totalBytes) 983 clientConfig := testConfig.Clone() 984 clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers 985 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 986 clientConfig.MaxVersion = version 987 988 buf := make([]byte, bufsize) 989 chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf)))) 990 for i := 0; i < N; i++ { 991 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 992 if err != nil { 993 b.Fatal(err) 994 } 995 for j := 0; j < chunks; j++ { 996 _, err := conn.Write(buf) 997 if err != nil { 998 b.Fatal(err) 999 } 1000 _, err = io.ReadFull(conn, buf) 1001 if err != nil { 1002 b.Fatal(err) 1003 } 1004 } 1005 conn.Close() 1006 } 1007 } 1008 1009 func BenchmarkThroughput(b *testing.B) { 1010 for _, mode := range []string{"Max", "Dynamic"} { 1011 for size := 1; size <= 64; size <<= 1 { 1012 name := fmt.Sprintf("%sPacket/%dMB", mode, size) 1013 b.Run(name, func(b *testing.B) { 1014 b.Run("TLSv12", func(b *testing.B) { 1015 throughput(b, VersionTLS12, int64(size<<20), mode == "Max") 1016 }) 1017 b.Run("TLSv13", func(b *testing.B) { 1018 throughput(b, VersionTLS13, int64(size<<20), mode == "Max") 1019 }) 1020 }) 1021 } 1022 } 1023 } 1024 1025 type slowConn struct { 1026 net.Conn 1027 bps int 1028 } 1029 1030 func (c *slowConn) Write(p []byte) (int, error) { 1031 if c.bps == 0 { 1032 panic("too slow") 1033 } 1034 t0 := time.Now() 1035 wrote := 0 1036 for wrote < len(p) { 1037 time.Sleep(100 * time.Microsecond) 1038 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8 1039 if allowed > len(p) { 1040 allowed = len(p) 1041 } 1042 if wrote < allowed { 1043 n, err := c.Conn.Write(p[wrote:allowed]) 1044 wrote += n 1045 if err != nil { 1046 return wrote, err 1047 } 1048 } 1049 } 1050 return len(p), nil 1051 } 1052 1053 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) { 1054 ln := newLocalListener(b) 1055 defer ln.Close() 1056 1057 N := b.N 1058 1059 go func() { 1060 for i := 0; i < N; i++ { 1061 sconn, err := ln.Accept() 1062 if err != nil { 1063 // panic rather than synchronize to avoid benchmark overhead 1064 // (cannot call b.Fatal in goroutine) 1065 panic(fmt.Errorf("accept: %v", err)) 1066 } 1067 serverConfig := testConfig.Clone() 1068 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1069 srv := Server(&slowConn{sconn, bps}, serverConfig) 1070 if err := srv.Handshake(); err != nil { 1071 panic(fmt.Errorf("handshake: %v", err)) 1072 } 1073 io.Copy(srv, srv) 1074 } 1075 }() 1076 1077 clientConfig := testConfig.Clone() 1078 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1079 clientConfig.MaxVersion = version 1080 1081 buf := make([]byte, 16384) 1082 peek := make([]byte, 1) 1083 1084 for i := 0; i < N; i++ { 1085 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 1086 if err != nil { 1087 b.Fatal(err) 1088 } 1089 // make sure we're connected and previous connection has stopped 1090 if _, err := conn.Write(buf[:1]); err != nil { 1091 b.Fatal(err) 1092 } 1093 if _, err := io.ReadFull(conn, peek); err != nil { 1094 b.Fatal(err) 1095 } 1096 if _, err := conn.Write(buf); err != nil { 1097 b.Fatal(err) 1098 } 1099 if _, err = io.ReadFull(conn, peek); err != nil { 1100 b.Fatal(err) 1101 } 1102 conn.Close() 1103 } 1104 } 1105 1106 func BenchmarkLatency(b *testing.B) { 1107 for _, mode := range []string{"Max", "Dynamic"} { 1108 for _, kbps := range []int{200, 500, 1000, 2000, 5000} { 1109 name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps) 1110 b.Run(name, func(b *testing.B) { 1111 b.Run("TLSv12", func(b *testing.B) { 1112 latency(b, VersionTLS12, kbps*1000, mode == "Max") 1113 }) 1114 b.Run("TLSv13", func(b *testing.B) { 1115 latency(b, VersionTLS13, kbps*1000, mode == "Max") 1116 }) 1117 }) 1118 } 1119 } 1120 } 1121 1122 func TestConnectionStateMarshal(t *testing.T) { 1123 cs := &ConnectionState{} 1124 _, err := json.Marshal(cs) 1125 if err != nil { 1126 t.Errorf("json.Marshal failed on ConnectionState: %v", err) 1127 } 1128 } 1129 1130 func TestConnectionState(t *testing.T) { 1131 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1132 if err != nil { 1133 panic(err) 1134 } 1135 rootCAs := x509.NewCertPool() 1136 rootCAs.AddCert(issuer) 1137 1138 now := func() time.Time { return time.Unix(1476984729, 0) } 1139 1140 const alpnProtocol = "golang" 1141 const serverName = "example.golang" 1142 var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} 1143 var ocsp = []byte("dummy ocsp") 1144 1145 for _, v := range []uint16{VersionTLS12, VersionTLS13} { 1146 var name string 1147 switch v { 1148 case VersionTLS12: 1149 name = "TLSv12" 1150 case VersionTLS13: 1151 name = "TLSv13" 1152 } 1153 t.Run(name, func(t *testing.T) { 1154 config := &Config{ 1155 Time: now, 1156 Rand: zeroSource{}, 1157 Certificates: make([]Certificate, 1), 1158 MaxVersion: v, 1159 RootCAs: rootCAs, 1160 ClientCAs: rootCAs, 1161 ClientAuth: RequireAndVerifyClientCert, 1162 NextProtos: []string{alpnProtocol}, 1163 ServerName: serverName, 1164 } 1165 config.Certificates[0].Certificate = [][]byte{testRSACertificate} 1166 config.Certificates[0].PrivateKey = testRSAPrivateKey 1167 config.Certificates[0].SignedCertificateTimestamps = scts 1168 config.Certificates[0].OCSPStaple = ocsp 1169 1170 ss, cs, err := testHandshake(t, config, config) 1171 if err != nil { 1172 t.Fatalf("Handshake failed: %v", err) 1173 } 1174 1175 if ss.Version != v || cs.Version != v { 1176 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v) 1177 } 1178 1179 if !ss.HandshakeComplete || !cs.HandshakeComplete { 1180 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete) 1181 } 1182 1183 if ss.DidResume || cs.DidResume { 1184 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume) 1185 } 1186 1187 if ss.CipherSuite == 0 || cs.CipherSuite == 0 { 1188 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite) 1189 } 1190 1191 if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol { 1192 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol) 1193 } 1194 1195 if !cs.NegotiatedProtocolIsMutual { 1196 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side") 1197 } 1198 // NegotiatedProtocolIsMutual on the server side is unspecified. 1199 1200 if ss.ServerName != serverName { 1201 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName) 1202 } 1203 if cs.ServerName != serverName { 1204 t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName) 1205 } 1206 1207 if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 { 1208 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1) 1209 } 1210 1211 if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 { 1212 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1) 1213 } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 { 1214 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2) 1215 } 1216 1217 if len(cs.SignedCertificateTimestamps) != 2 { 1218 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2) 1219 } 1220 if !bytes.Equal(cs.OCSPResponse, ocsp) { 1221 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp) 1222 } 1223 // Only TLS 1.3 supports OCSP and SCTs on client certs. 1224 if v == VersionTLS13 { 1225 if len(ss.SignedCertificateTimestamps) != 2 { 1226 t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2) 1227 } 1228 if !bytes.Equal(ss.OCSPResponse, ocsp) { 1229 t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp) 1230 } 1231 } 1232 1233 if v == VersionTLS13 { 1234 if ss.TLSUnique != nil || cs.TLSUnique != nil { 1235 t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique) 1236 } 1237 } else { 1238 if ss.TLSUnique == nil || cs.TLSUnique == nil { 1239 t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique) 1240 } 1241 } 1242 }) 1243 } 1244 } 1245 1246 // Issue 28744: Ensure that we don't modify memory 1247 // that Config doesn't own such as Certificates. 1248 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) { 1249 c0 := Certificate{ 1250 Certificate: [][]byte{testRSACertificate}, 1251 PrivateKey: testRSAPrivateKey, 1252 } 1253 c1 := Certificate{ 1254 Certificate: [][]byte{testSNICertificate}, 1255 PrivateKey: testRSAPrivateKey, 1256 } 1257 config := testConfig.Clone() 1258 config.Certificates = []Certificate{c0, c1} 1259 1260 config.BuildNameToCertificate() 1261 got := config.Certificates 1262 want := []Certificate{c0, c1} 1263 if !reflect.DeepEqual(got, want) { 1264 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want) 1265 } 1266 } 1267 1268 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } 1269 1270 func TestClientHelloInfo_SupportsCertificate(t *testing.T) { 1271 rsaCert := &Certificate{ 1272 Certificate: [][]byte{testRSACertificate}, 1273 PrivateKey: testRSAPrivateKey, 1274 } 1275 pkcs1Cert := &Certificate{ 1276 Certificate: [][]byte{testRSACertificate}, 1277 PrivateKey: testRSAPrivateKey, 1278 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, 1279 } 1280 ecdsaCert := &Certificate{ 1281 // ECDSA P-256 certificate 1282 Certificate: [][]byte{testP256Certificate}, 1283 PrivateKey: testP256PrivateKey, 1284 } 1285 ed25519Cert := &Certificate{ 1286 Certificate: [][]byte{testEd25519Certificate}, 1287 PrivateKey: testEd25519PrivateKey, 1288 } 1289 1290 tests := []struct { 1291 c *Certificate 1292 chi *ClientHelloInfo 1293 wantErr string 1294 }{ 1295 {rsaCert, &ClientHelloInfo{ 1296 ServerName: "example.golang", 1297 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1298 SupportedVersions: []uint16{VersionTLS13}, 1299 }, ""}, 1300 {ecdsaCert, &ClientHelloInfo{ 1301 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, 1302 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1303 }, ""}, 1304 {rsaCert, &ClientHelloInfo{ 1305 ServerName: "example.com", 1306 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1307 SupportedVersions: []uint16{VersionTLS13}, 1308 }, "not valid for requested server name"}, 1309 {ecdsaCert, &ClientHelloInfo{ 1310 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, 1311 SupportedVersions: []uint16{VersionTLS13}, 1312 }, "signature algorithms"}, 1313 {pkcs1Cert, &ClientHelloInfo{ 1314 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, 1315 SupportedVersions: []uint16{VersionTLS13}, 1316 }, "signature algorithms"}, 1317 1318 {rsaCert, &ClientHelloInfo{ 1319 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1320 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, 1321 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1322 }, "signature algorithms"}, 1323 {rsaCert, &ClientHelloInfo{ 1324 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1325 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, 1326 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1327 config: &Config{ 1328 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1329 MaxVersion: VersionTLS12, 1330 }, 1331 }, ""}, // Check that mutual version selection works. 1332 1333 {ecdsaCert, &ClientHelloInfo{ 1334 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1335 SupportedCurves: []CurveID{CurveP256}, 1336 SupportedPoints: []uint8{pointFormatUncompressed}, 1337 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1338 SupportedVersions: []uint16{VersionTLS12}, 1339 }, ""}, 1340 {ecdsaCert, &ClientHelloInfo{ 1341 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1342 SupportedCurves: []CurveID{CurveP256}, 1343 SupportedPoints: []uint8{pointFormatUncompressed}, 1344 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, 1345 SupportedVersions: []uint16{VersionTLS12}, 1346 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme. 1347 {ecdsaCert, &ClientHelloInfo{ 1348 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1349 SupportedCurves: []CurveID{CurveP256}, 1350 SupportedPoints: []uint8{pointFormatUncompressed}, 1351 SignatureSchemes: nil, 1352 SupportedVersions: []uint16{VersionTLS12}, 1353 }, ""}, // TLS 1.2 comes with default signature schemes. 1354 {ecdsaCert, &ClientHelloInfo{ 1355 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1356 SupportedCurves: []CurveID{CurveP256}, 1357 SupportedPoints: []uint8{pointFormatUncompressed}, 1358 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1359 SupportedVersions: []uint16{VersionTLS12}, 1360 }, "cipher suite"}, 1361 {ecdsaCert, &ClientHelloInfo{ 1362 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1363 SupportedCurves: []CurveID{CurveP256}, 1364 SupportedPoints: []uint8{pointFormatUncompressed}, 1365 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1366 SupportedVersions: []uint16{VersionTLS12}, 1367 config: &Config{ 1368 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1369 }, 1370 }, "cipher suite"}, 1371 {ecdsaCert, &ClientHelloInfo{ 1372 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1373 SupportedCurves: []CurveID{CurveP384}, 1374 SupportedPoints: []uint8{pointFormatUncompressed}, 1375 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1376 SupportedVersions: []uint16{VersionTLS12}, 1377 }, "certificate curve"}, 1378 {ecdsaCert, &ClientHelloInfo{ 1379 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1380 SupportedCurves: []CurveID{CurveP256}, 1381 SupportedPoints: []uint8{1}, 1382 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1383 SupportedVersions: []uint16{VersionTLS12}, 1384 }, "doesn't support ECDHE"}, 1385 {ecdsaCert, &ClientHelloInfo{ 1386 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1387 SupportedCurves: []CurveID{CurveP256}, 1388 SupportedPoints: []uint8{pointFormatUncompressed}, 1389 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1390 SupportedVersions: []uint16{VersionTLS12}, 1391 }, "signature algorithms"}, 1392 1393 {ed25519Cert, &ClientHelloInfo{ 1394 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1395 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1396 SupportedPoints: []uint8{pointFormatUncompressed}, 1397 SignatureSchemes: []SignatureScheme{Ed25519}, 1398 SupportedVersions: []uint16{VersionTLS12}, 1399 }, ""}, 1400 {ed25519Cert, &ClientHelloInfo{ 1401 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1402 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1403 SupportedPoints: []uint8{pointFormatUncompressed}, 1404 SignatureSchemes: []SignatureScheme{Ed25519}, 1405 SupportedVersions: []uint16{VersionTLS10}, 1406 config: &Config{MinVersion: VersionTLS10}, 1407 }, "doesn't support Ed25519"}, 1408 {ed25519Cert, &ClientHelloInfo{ 1409 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1410 SupportedCurves: []CurveID{}, 1411 SupportedPoints: []uint8{pointFormatUncompressed}, 1412 SignatureSchemes: []SignatureScheme{Ed25519}, 1413 SupportedVersions: []uint16{VersionTLS12}, 1414 }, "doesn't support ECDHE"}, 1415 1416 {rsaCert, &ClientHelloInfo{ 1417 CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, 1418 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1419 SupportedPoints: []uint8{pointFormatUncompressed}, 1420 SupportedVersions: []uint16{VersionTLS10}, 1421 config: &Config{MinVersion: VersionTLS10}, 1422 }, ""}, 1423 {rsaCert, &ClientHelloInfo{ 1424 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1425 SupportedVersions: []uint16{VersionTLS12}, 1426 config: &Config{ 1427 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1428 }, 1429 }, ""}, // static RSA fallback 1430 } 1431 for i, tt := range tests { 1432 err := tt.chi.SupportsCertificate(tt.c) 1433 switch { 1434 case tt.wantErr == "" && err != nil: 1435 t.Errorf("%d: unexpected error: %v", i, err) 1436 case tt.wantErr != "" && err == nil: 1437 t.Errorf("%d: unexpected success", i) 1438 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr): 1439 t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr) 1440 } 1441 } 1442 } 1443 1444 func TestCipherSuites(t *testing.T) { 1445 var lastID uint16 1446 for _, c := range CipherSuites() { 1447 if lastID > c.ID { 1448 t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) 1449 } else { 1450 lastID = c.ID 1451 } 1452 1453 if c.Insecure { 1454 t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID) 1455 } 1456 } 1457 lastID = 0 1458 for _, c := range InsecureCipherSuites() { 1459 if lastID > c.ID { 1460 t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) 1461 } else { 1462 lastID = c.ID 1463 } 1464 1465 if !c.Insecure { 1466 t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID) 1467 } 1468 } 1469 1470 CipherSuiteByID := func(id uint16) *CipherSuite { 1471 for _, c := range CipherSuites() { 1472 if c.ID == id { 1473 return c 1474 } 1475 } 1476 for _, c := range InsecureCipherSuites() { 1477 if c.ID == id { 1478 return c 1479 } 1480 } 1481 return nil 1482 } 1483 1484 for _, c := range cipherSuites { 1485 cc := CipherSuiteByID(c.id) 1486 if cc == nil { 1487 t.Errorf("%#04x: no CipherSuite entry", c.id) 1488 continue 1489 } 1490 1491 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 { 1492 t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1493 } else if !tls12Only && len(cc.SupportedVersions) != 3 { 1494 t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1495 } 1496 1497 if got := CipherSuiteName(c.id); got != cc.Name { 1498 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) 1499 } 1500 } 1501 for _, c := range cipherSuitesTLS13 { 1502 cc := CipherSuiteByID(c.id) 1503 if cc == nil { 1504 t.Errorf("%#04x: no CipherSuite entry", c.id) 1505 continue 1506 } 1507 1508 if cc.Insecure { 1509 t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure) 1510 } 1511 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 { 1512 t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1513 } 1514 1515 if got := CipherSuiteName(c.id); got != cc.Name { 1516 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) 1517 } 1518 } 1519 1520 if got := CipherSuiteName(0xabc); got != "0x0ABC" { 1521 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got) 1522 } 1523 1524 if len(cipherSuitesPreferenceOrder) != len(cipherSuites)+len(cipherSuitesTLS13) { 1525 t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites + cipherSuitesTLS13") 1526 } 1527 if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) { 1528 t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder") 1529 } 1530 if len(defaultCipherSuites) >= len(defaultCipherSuitesWithRSAKex) { 1531 t.Errorf("defaultCipherSuitesWithRSAKex should be longer than defaultCipherSuites") 1532 } 1533 1534 // Check that disabled suites are marked insecure. 1535 for _, badSuites := range []map[uint16]bool{disabledCipherSuites, rsaKexCiphers} { 1536 for id := range badSuites { 1537 c := CipherSuiteByID(id) 1538 if c == nil { 1539 t.Errorf("%#04x: no CipherSuite entry", id) 1540 continue 1541 } 1542 if !c.Insecure { 1543 t.Errorf("%#04x: disabled by default but not marked insecure", id) 1544 } 1545 } 1546 } 1547 1548 for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} { 1549 // Check that insecure and HTTP/2 bad cipher suites are at the end of 1550 // the preference lists. 1551 var sawInsecure, sawBad bool 1552 for _, id := range prefOrder { 1553 c := CipherSuiteByID(id) 1554 if c == nil { 1555 t.Errorf("%#04x: no CipherSuite entry", id) 1556 continue 1557 } 1558 1559 if c.Insecure { 1560 sawInsecure = true 1561 } else if sawInsecure { 1562 t.Errorf("%#04x: secure suite after insecure one(s)", id) 1563 } 1564 1565 if http2isBadCipher(id) { 1566 sawBad = true 1567 } else if sawBad { 1568 t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id) 1569 } 1570 } 1571 1572 // Check that the list is sorted according to the documented criteria. 1573 isBetter := func(a, b int) bool { 1574 aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b]) 1575 1576 aSuite13, bSuite13 := cipherSuiteTLS13ByID(prefOrder[a]), cipherSuiteTLS13ByID(prefOrder[b]) 1577 switch { 1578 case aSuite13 != nil && bSuite13 != nil: // Both TLS 1.3 cipher 1579 // AES < ChaCha20 1580 if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") { 1581 return i == 0 // true for cipherSuitesPreferenceOrder 1582 } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") { 1583 return i != 0 // true for cipherSuitesPreferenceOrderNoAES 1584 } 1585 // AES-GCM < AES-CCM 1586 if strings.Contains(aName, "GCM") && strings.Contains(bName, "CCM") { 1587 return true 1588 } else if strings.Contains(aName, "CCM") && strings.Contains(bName, "GCM") { 1589 return false 1590 } 1591 // AES-128 < AES-256 1592 if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") { 1593 return true 1594 } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") { 1595 return false 1596 } 1597 // AES-CCM < AES-CCM-8 1598 if !strings.Contains(aName, "CCM_8") && strings.Contains(bName, "CCM_8") { 1599 return true 1600 } else if strings.Contains(aName, "CCM_8") && !strings.Contains(bName, "CCM_8") { 1601 return false 1602 } 1603 t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName) 1604 panic("unreachable") 1605 case aSuite13 != nil: // aSuite13 is TLS 1.3 cipher, bSuite13 is TLS 1.2 cipher 1606 // TLS 1.3 < TLS 1.2 1607 return true 1608 case bSuite13 != nil: // bSuite13 is TLS 1.3 cipher, aSuite13 is TLS 1.2 cipher 1609 // TLS 1.3 < TLS 1.2 1610 return false 1611 default: // Both TLS 1.2 cipher 1612 } 1613 1614 aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b]) 1615 // * < RC4 1616 if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") { 1617 return true 1618 } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") { 1619 return false 1620 } 1621 // * < CBC_SHA256 1622 if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") { 1623 return true 1624 } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") { 1625 return false 1626 } 1627 // * < 3DES 1628 if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") { 1629 return true 1630 } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") { 1631 return false 1632 } 1633 // ECDHE < * 1634 if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 { 1635 return true 1636 } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 { 1637 return false 1638 } 1639 // AEAD < CBC 1640 if aSuite.aead != nil && bSuite.aead == nil { 1641 return true 1642 } else if aSuite.aead == nil && bSuite.aead != nil { 1643 return false 1644 } 1645 // AES < ChaCha20 1646 if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") { 1647 return i == 0 // true for cipherSuitesPreferenceOrder 1648 } else if strings.Contains(aName, "AES") && strings.Contains(bName, "CAMELLIA") { 1649 return i != 0 // true for cipherSuitesPreferenceOrderNoAES 1650 } 1651 // AES < ARIA 1652 if strings.Contains(aName, "AES") && strings.Contains(bName, "ARIA") { 1653 return true 1654 } else if strings.Contains(aName, "ARIA") && strings.Contains(bName, "AES") { 1655 return false 1656 } 1657 // ARIA < CAMELLIA 1658 if strings.Contains(aName, "ARIA") && strings.Contains(bName, "CAMELLIA") { 1659 return true 1660 } else if strings.Contains(aName, "CAMELLIA") && strings.Contains(bName, "ARIA") { 1661 return false 1662 } 1663 // CAMELLIA < ChaCha20 1664 if strings.Contains(aName, "CAMELLIA") && strings.Contains(bName, "CHACHA20") { 1665 return i == 0 // true for cipherSuitesPreferenceOrder 1666 } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "CAMELLIA") { 1667 return i != 0 // true for cipherSuitesPreferenceOrderNoAES 1668 } 1669 // AES-GCM < AES-CCM 1670 if strings.Contains(aName, "GCM") && strings.Contains(bName, "CCM") { 1671 return true 1672 } else if strings.Contains(aName, "CCM") && strings.Contains(bName, "GCM") { 1673 return false 1674 } 1675 // ARIA-128 < ARIA-256 1676 if strings.Contains(aName, "ARIA_128") && strings.Contains(bName, "ARIA_256") { 1677 return true 1678 } else if strings.Contains(aName, "ARIA_256") && strings.Contains(bName, "ARIA_128") { 1679 return false 1680 } 1681 // CAMELLIA-128 < CAMELLIA-256 1682 if strings.Contains(aName, "CAMELLIA_128") && strings.Contains(bName, "CAMELLIA_256") { 1683 return true 1684 } else if strings.Contains(aName, "CAMELLIA_256") && strings.Contains(bName, "CAMELLIA_128") { 1685 return false 1686 } 1687 // AES-128 < AES-256 1688 if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") { 1689 return true 1690 } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") { 1691 return false 1692 } 1693 // AES-CCM < AES-CCM-8 1694 if !strings.Contains(aName, "CCM_8") && strings.Contains(bName, "CCM_8") { 1695 return true 1696 } else if strings.Contains(aName, "CCM_8") && !strings.Contains(bName, "CCM_8") { 1697 return false 1698 } 1699 // ECDSA < RSA 1700 if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 { 1701 return true 1702 } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 { 1703 return false 1704 } 1705 // * < SHA384 1706 if aSuite.flags&suiteSHA384 == 0 && bSuite.flags&suiteSHA384 != 0 { 1707 return true 1708 } else if aSuite.flags&suiteSHA384 != 0 && bSuite.flags&suiteSHA384 == 0 { 1709 return false 1710 } 1711 t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName) 1712 panic("unreachable") 1713 } 1714 if !sort.SliceIsSorted(prefOrder, isBetter) { 1715 t.Error("preference order is not sorted according to the rules") 1716 } 1717 } 1718 } 1719 1720 func TestVersionName(t *testing.T) { 1721 if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp { 1722 t.Errorf("unexpected VersionName: got %q, expected %q", got, exp) 1723 } 1724 if got, exp := VersionName(0x12a), "0x012A"; got != exp { 1725 t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp) 1726 } 1727 } 1728 1729 // http2isBadCipher is copied from net/http. 1730 // TODO: if it ends up exposed somewhere, use that instead. 1731 func http2isBadCipher(cipher uint16) bool { 1732 switch cipher { 1733 case TLS_RSA_WITH_RC4_128_SHA, 1734 TLS_RSA_WITH_3DES_EDE_CBC_SHA, 1735 TLS_RSA_WITH_AES_128_CBC_SHA, 1736 TLS_RSA_WITH_AES_256_CBC_SHA, 1737 TLS_RSA_WITH_AES_128_CBC_SHA256, 1738 TLS_RSA_WITH_ARIA_128_CBC_SHA256, 1739 TLS_RSA_WITH_ARIA_256_CBC_SHA384, 1740 TLS_RSA_WITH_AES_128_GCM_SHA256, 1741 TLS_RSA_WITH_AES_256_GCM_SHA384, 1742 TLS_RSA_WITH_AES_128_CCM_SHA256, 1743 TLS_RSA_WITH_AES_256_CCM_SHA256, 1744 TLS_RSA_WITH_AES_128_CCM_8_SHA256, 1745 TLS_RSA_WITH_AES_256_CCM_8_SHA256, 1746 TLS_RSA_WITH_ARIA_128_GCM_SHA256, 1747 TLS_RSA_WITH_ARIA_256_GCM_SHA384, 1748 TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, 1749 TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, 1750 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 1751 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 1752 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 1753 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, 1754 TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, 1755 TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, 1756 TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, 1757 TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, 1758 TLS_ECDHE_RSA_WITH_RC4_128_SHA, 1759 TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, 1760 TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 1761 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 1762 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 1763 TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, 1764 TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, 1765 TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, 1766 TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, 1767 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 1768 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: 1769 return true 1770 default: 1771 return false 1772 } 1773 } 1774 1775 type brokenSigner struct{ crypto.Signer } 1776 1777 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { 1778 // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded. 1779 return s.Signer.Sign(rand, digest, opts.HashFunc()) 1780 } 1781 1782 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that 1783 // always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS. 1784 func TestPKCS1OnlyCert(t *testing.T) { 1785 clientConfig := testConfig.Clone() 1786 clientConfig.Certificates = []Certificate{{ 1787 Certificate: [][]byte{testRSACertificate}, 1788 PrivateKey: brokenSigner{testRSAPrivateKey}, 1789 }} 1790 serverConfig := testConfig.Clone() 1791 serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5 1792 serverConfig.ClientAuth = RequireAnyClientCert 1793 1794 // If RSA-PSS is selected, the handshake should fail. 1795 if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil { 1796 t.Fatal("expected broken certificate to cause connection to fail") 1797 } 1798 1799 clientConfig.Certificates[0].SupportedSignatureAlgorithms = 1800 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256} 1801 1802 // But if the certificate restricts supported algorithms, RSA-PSS should not 1803 // be selected, and the handshake should succeed. 1804 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil { 1805 t.Error(err) 1806 } 1807 } 1808 1809 func TestVerifyCertificates(t *testing.T) { 1810 // See https://go.dev/issue/31641. 1811 t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) }) 1812 t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) }) 1813 } 1814 1815 func testVerifyCertificates(t *testing.T, version uint16) { 1816 tests := []struct { 1817 name string 1818 1819 InsecureSkipVerify bool 1820 ClientAuth ClientAuthType 1821 ClientCertificates bool 1822 }{ 1823 { 1824 name: "defaults", 1825 }, 1826 { 1827 name: "InsecureSkipVerify", 1828 InsecureSkipVerify: true, 1829 }, 1830 { 1831 name: "RequestClientCert with no certs", 1832 ClientAuth: RequestClientCert, 1833 }, 1834 { 1835 name: "RequestClientCert with certs", 1836 ClientAuth: RequestClientCert, 1837 ClientCertificates: true, 1838 }, 1839 { 1840 name: "RequireAnyClientCert", 1841 ClientAuth: RequireAnyClientCert, 1842 ClientCertificates: true, 1843 }, 1844 { 1845 name: "VerifyClientCertIfGiven with no certs", 1846 ClientAuth: VerifyClientCertIfGiven, 1847 }, 1848 { 1849 name: "VerifyClientCertIfGiven with certs", 1850 ClientAuth: VerifyClientCertIfGiven, 1851 ClientCertificates: true, 1852 }, 1853 { 1854 name: "RequireAndVerifyClientCert", 1855 ClientAuth: RequireAndVerifyClientCert, 1856 ClientCertificates: true, 1857 }, 1858 } 1859 1860 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1861 if err != nil { 1862 t.Fatal(err) 1863 } 1864 rootCAs := x509.NewCertPool() 1865 rootCAs.AddCert(issuer) 1866 1867 for _, test := range tests { 1868 test := test 1869 t.Run(test.name, func(t *testing.T) { 1870 t.Parallel() 1871 1872 var serverVerifyConnection, clientVerifyConnection bool 1873 var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool 1874 1875 clientConfig := testConfig.Clone() 1876 clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } 1877 clientConfig.MaxVersion = version 1878 clientConfig.MinVersion = version 1879 clientConfig.RootCAs = rootCAs 1880 clientConfig.ServerName = "example.golang" 1881 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) 1882 serverConfig := clientConfig.Clone() 1883 serverConfig.ClientCAs = rootCAs 1884 1885 clientConfig.VerifyConnection = func(cs ConnectionState) error { 1886 clientVerifyConnection = true 1887 return nil 1888 } 1889 clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 1890 clientVerifyPeerCertificates = true 1891 return nil 1892 } 1893 serverConfig.VerifyConnection = func(cs ConnectionState) error { 1894 serverVerifyConnection = true 1895 return nil 1896 } 1897 serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 1898 serverVerifyPeerCertificates = true 1899 return nil 1900 } 1901 1902 clientConfig.InsecureSkipVerify = test.InsecureSkipVerify 1903 serverConfig.ClientAuth = test.ClientAuth 1904 if !test.ClientCertificates { 1905 clientConfig.Certificates = nil 1906 } 1907 1908 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil { 1909 t.Fatal(err) 1910 } 1911 1912 want := serverConfig.ClientAuth != NoClientCert 1913 if serverVerifyPeerCertificates != want { 1914 t.Errorf("VerifyPeerCertificates on the server: got %v, want %v", 1915 serverVerifyPeerCertificates, want) 1916 } 1917 if !clientVerifyPeerCertificates { 1918 t.Errorf("VerifyPeerCertificates not called on the client") 1919 } 1920 if !serverVerifyConnection { 1921 t.Error("VerifyConnection did not get called on the server") 1922 } 1923 if !clientVerifyConnection { 1924 t.Error("VerifyConnection did not get called on the client") 1925 } 1926 1927 serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false 1928 serverVerifyConnection, clientVerifyConnection = false, false 1929 cs, _, err := testHandshake(t, clientConfig, serverConfig) 1930 if err != nil { 1931 t.Fatal(err) 1932 } 1933 if !cs.DidResume { 1934 t.Error("expected resumption") 1935 } 1936 1937 if serverVerifyPeerCertificates { 1938 t.Error("VerifyPeerCertificates got called on the server on resumption") 1939 } 1940 if clientVerifyPeerCertificates { 1941 t.Error("VerifyPeerCertificates got called on the client on resumption") 1942 } 1943 if !serverVerifyConnection { 1944 t.Error("VerifyConnection did not get called on the server on resumption") 1945 } 1946 if !clientVerifyConnection { 1947 t.Error("VerifyConnection did not get called on the client on resumption") 1948 } 1949 }) 1950 } 1951 }