github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmtls/tls_test.go (about) 1 // Copyright 2022 s1ren@github.com/hxx258456. 2 3 /* 4 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 5 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 6 */ 7 8 package gmtls 9 10 import ( 11 "bytes" 12 "context" 13 "crypto" 14 "encoding/json" 15 "errors" 16 "fmt" 17 "io" 18 "math" 19 "net" 20 "os" 21 "reflect" 22 "sort" 23 "strings" 24 "testing" 25 "time" 26 27 "github.com/hxx258456/ccgo/internal/testenv" 28 "github.com/hxx258456/ccgo/x509" 29 ) 30 31 var rsaCertPEM = `-----BEGIN CERTIFICATE----- 32 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV 33 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX 34 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF 35 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 36 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ 37 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa 38 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv 39 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF 40 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW 41 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V 42 -----END CERTIFICATE----- 43 ` 44 45 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY----- 46 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo 47 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G 48 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 49 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW 50 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T 51 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi 52 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== 53 -----END RSA TESTING KEY----- 54 `) 55 56 // keyPEM is the same as rsaKeyPEM, but declares itself as just 57 // "PRIVATE KEY", not "RSA PRIVATE KEY". https://golang.org/issue/4477 58 var keyPEM = testingKey(`-----BEGIN TESTING KEY----- 59 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo 60 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G 61 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 62 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW 63 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T 64 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi 65 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g== 66 -----END TESTING KEY----- 67 `) 68 69 var ecdsaCertPEM = `-----BEGIN CERTIFICATE----- 70 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw 71 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0 72 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG 73 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk 74 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR 75 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl 76 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8 77 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo 78 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb 79 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1 80 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA== 81 -----END CERTIFICATE----- 82 ` 83 84 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS----- 85 BgUrgQQAIw== 86 -----END EC PARAMETERS----- 87 -----BEGIN EC TESTING KEY----- 88 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0 89 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL 90 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz 91 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q 92 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ== 93 -----END EC TESTING KEY----- 94 `) 95 96 var keyPairTests = []struct { 97 algo string 98 cert string 99 key string 100 }{ 101 {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM}, 102 {"RSA", rsaCertPEM, rsaKeyPEM}, 103 {"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477 104 } 105 106 func TestX509KeyPair(t *testing.T) { 107 t.Parallel() 108 var pem []byte 109 for _, test := range keyPairTests { 110 pem = []byte(test.cert + test.key) 111 if _, err := X509KeyPair(pem, pem); err != nil { 112 t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err) 113 } 114 pem = []byte(test.key + test.cert) 115 if _, err := X509KeyPair(pem, pem); err != nil { 116 t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err) 117 } 118 } 119 } 120 121 func TestX509KeyPairErrors(t *testing.T) { 122 _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM)) 123 if err == nil { 124 t.Fatalf("X509KeyPair didn't return an error when arguments were switched") 125 } 126 if subStr := "been switched"; !strings.Contains(err.Error(), subStr) { 127 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err) 128 } 129 130 _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM)) 131 if err == nil { 132 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates") 133 } 134 if subStr := "certificate"; !strings.Contains(err.Error(), subStr) { 135 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err) 136 } 137 138 const nonsensePEM = ` 139 -----BEGIN NONSENSE----- 140 Zm9vZm9vZm9v 141 -----END NONSENSE----- 142 ` 143 144 _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM)) 145 if err == nil { 146 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense") 147 } 148 if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) { 149 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err) 150 } 151 } 152 153 func TestX509MixedKeyPair(t *testing.T) { 154 if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil { 155 t.Error("Load of RSA certificate succeeded with ECDSA private key") 156 } 157 if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil { 158 t.Error("Load of ECDSA certificate succeeded with RSA private key") 159 } 160 } 161 162 func newLocalListener(t testing.TB) net.Listener { 163 ln, err := net.Listen("tcp", "127.0.0.1:0") 164 if err != nil { 165 ln, err = net.Listen("tcp6", "[::1]:0") 166 } 167 if err != nil { 168 t.Fatal(err) 169 } 170 return ln 171 } 172 173 func TestDialTimeout(t *testing.T) { 174 if testing.Short() { 175 t.Skip("skipping in short mode") 176 } 177 listener := newLocalListener(t) 178 179 addr := listener.Addr().String() 180 defer func(listener net.Listener) { 181 err := listener.Close() 182 if err != nil { 183 panic(err) 184 } 185 }(listener) 186 187 complete := make(chan bool) 188 defer close(complete) 189 190 go func() { 191 conn, err := listener.Accept() 192 if err != nil { 193 t.Error(err) 194 return 195 } 196 <-complete 197 _ = conn.Close() 198 }() 199 200 dialer := &net.Dialer{ 201 Timeout: 10 * time.Millisecond, 202 } 203 204 var err error 205 if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil { 206 t.Fatal("DialWithTimeout completed successfully") 207 } 208 209 if !isTimeoutError(err) { 210 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) 211 } 212 } 213 214 func TestDeadlineOnWrite(t *testing.T) { 215 if testing.Short() { 216 t.Skip("skipping in short mode") 217 } 218 219 ln := newLocalListener(t) 220 defer func(ln net.Listener) { 221 err := ln.Close() 222 if err != nil { 223 panic(err) 224 } 225 }(ln) 226 227 srvCh := make(chan *Conn, 1) 228 229 go func() { 230 sconn, err := ln.Accept() 231 if err != nil { 232 srvCh <- nil 233 return 234 } 235 srv := Server(sconn, testConfig.Clone()) 236 if err := srv.Handshake(); err != nil { 237 srvCh <- nil 238 return 239 } 240 srvCh <- srv 241 }() 242 243 clientConfig := testConfig.Clone() 244 // clientConfig.MaxVersion = VersionTLS12 245 clientConfig.MaxVersion = VersionGMSSL 246 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 247 if err != nil { 248 t.Fatal(err) 249 } 250 defer func(conn *Conn) { 251 err := conn.Close() 252 if err != nil { 253 panic(err) 254 } 255 }(conn) 256 257 srv := <-srvCh 258 if srv == nil { 259 t.Error(err) 260 } 261 262 // Make sure the client/server is setup correctly and is able to do a typical Write/Read 263 buf := make([]byte, 6) 264 if _, err := srv.Write([]byte("foobar")); err != nil { 265 t.Errorf("Write err: %v", err) 266 } 267 if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" { 268 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) 269 } 270 271 // Set a deadline which should cause Write to timeout 272 if err = srv.SetDeadline(time.Now()); err != nil { 273 t.Fatalf("SetDeadline(time.Now()) err: %v", err) 274 } 275 if _, err = srv.Write([]byte("should fail")); err == nil { 276 t.Fatal("Write should have timed out") 277 } 278 279 // Clear deadline and make sure it still times out 280 if err = srv.SetDeadline(time.Time{}); err != nil { 281 t.Fatalf("SetDeadline(time.Time{}) err: %v", err) 282 } 283 if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil { 284 t.Fatal("Write which previously failed should still time out") 285 } 286 287 // Verify the error 288 if ne := err.(net.Error); ne.Temporary() != false { 289 t.Error("Write timed out but incorrectly classified the error as Temporary") 290 } 291 if !isTimeoutError(err) { 292 t.Error("Write timed out but did not classify the error as a Timeout") 293 } 294 } 295 296 type readerFunc func([]byte) (int, error) 297 298 func (f readerFunc) Read(b []byte) (int, error) { return f(b) } 299 300 // TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake. 301 // (The other cases are all handled by the existing dial tests in this package, which 302 // all also flow through the same code shared code paths) 303 func TestDialer(t *testing.T) { 304 ln := newLocalListener(t) 305 defer func(ln net.Listener) { 306 err := ln.Close() 307 if err != nil { 308 panic(err) 309 } 310 }(ln) 311 312 unblockServer := make(chan struct{}) // close-only 313 defer close(unblockServer) 314 go func() { 315 conn, err := ln.Accept() 316 if err != nil { 317 return 318 } 319 defer func(conn net.Conn) { 320 err := conn.Close() 321 if err != nil { 322 panic(err) 323 } 324 }(conn) 325 <-unblockServer 326 }() 327 328 ctx, cancel := context.WithCancel(context.Background()) 329 d := Dialer{Config: &Config{ 330 Rand: readerFunc(func(b []byte) (n int, err error) { 331 // By the time crypto/tls wants randomness, that means it has a TCP 332 // connection, so we're past the Dialer's dial and now blocked 333 // in a handshake. Cancel our context and see if we get unstuck. 334 // (Our TCP listener above never reads or writes, so the Handshake 335 // would otherwise be stuck forever) 336 cancel() 337 return len(b), nil 338 }), 339 ServerName: "foo", 340 }} 341 _, err := d.DialContext(ctx, "tcp", ln.Addr().String()) 342 if err != context.Canceled { 343 t.Errorf("err = %v; want context.Canceled", err) 344 } 345 } 346 347 func isTimeoutError(err error) bool { 348 if ne, ok := err.(net.Error); ok { 349 return ne.Timeout() 350 } 351 return false 352 } 353 354 // tests that Conn.Read returns (non-zero, io.EOF) instead of 355 // (non-zero, nil) when a Close (alertCloseNotify) is sitting right 356 // behind the application data in the buffer. 357 func TestConnReadNonzeroAndEOF(t *testing.T) { 358 // This test is racy: it assumes that after a write to a 359 // localhost TCP connection, the peer TCP connection can 360 // immediately read it. Because it's racy, we skip this test 361 // in short mode, and then retry it several times with an 362 // increasing sleep in between our final write (via srv.Close 363 // below) and the following read. 364 if testing.Short() { 365 t.Skip("skipping in short mode") 366 } 367 var err error 368 for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 { 369 if err = testConnReadNonzeroAndEOF(t, delay); err == nil { 370 return 371 } 372 } 373 t.Error(err) 374 } 375 376 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { 377 ln := newLocalListener(t) 378 defer func(ln net.Listener) { 379 err := ln.Close() 380 if err != nil { 381 panic(err) 382 } 383 }(ln) 384 385 srvCh := make(chan *Conn, 1) 386 var serr error 387 go func() { 388 sconn, err := ln.Accept() 389 if err != nil { 390 serr = err 391 srvCh <- nil 392 return 393 } 394 serverConfig := testConfig.Clone() 395 srv := Server(sconn, serverConfig) 396 if err := srv.Handshake(); err != nil { 397 serr = fmt.Errorf("handshake: %v", err) 398 srvCh <- nil 399 return 400 } 401 srvCh <- srv 402 }() 403 404 clientConfig := testConfig.Clone() 405 // In TLS 1.3, alerts are encrypted and disguised as application data, so 406 // the opportunistic peek won't work. 407 clientConfig.MaxVersion = VersionTLS12 408 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 409 if err != nil { 410 t.Fatal(err) 411 } 412 defer func(conn *Conn) { 413 err := conn.Close() 414 if err != nil { 415 panic(err) 416 } 417 }(conn) 418 419 srv := <-srvCh 420 if srv == nil { 421 return serr 422 } 423 424 buf := make([]byte, 6) 425 426 srv.Write([]byte("foobar")) 427 n, err := conn.Read(buf) 428 if n != 6 || err != nil || string(buf) != "foobar" { 429 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) 430 } 431 432 _, _ = srv.Write([]byte("abcdef")) 433 _ = srv.Close() 434 time.Sleep(delay) 435 n, err = conn.Read(buf) 436 if n != 6 || string(buf) != "abcdef" { 437 //goland:noinspection GoErrorStringFormat 438 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf) 439 } 440 if err != io.EOF { 441 //goland:noinspection GoErrorStringFormat 442 return fmt.Errorf("Second Read error = %v; want io.EOF", err) 443 } 444 return nil 445 } 446 447 func TestTLSUniqueMatches(t *testing.T) { 448 ln := newLocalListener(t) 449 defer func(ln net.Listener) { 450 err := ln.Close() 451 if err != nil { 452 panic(err) 453 } 454 }(ln) 455 456 serverTLSUniques := make(chan []byte) 457 parentDone := make(chan struct{}) 458 childDone := make(chan struct{}) 459 defer close(parentDone) 460 go func() { 461 defer close(childDone) 462 for i := 0; i < 2; i++ { 463 sconn, err := ln.Accept() 464 if err != nil { 465 t.Error(err) 466 return 467 } 468 serverConfig := testConfig.Clone() 469 serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3 470 srv := Server(sconn, serverConfig) 471 if err := srv.Handshake(); err != nil { 472 t.Error(err) 473 return 474 } 475 select { 476 case <-parentDone: 477 return 478 case serverTLSUniques <- srv.ConnectionState().TLSUnique: 479 } 480 } 481 }() 482 483 clientConfig := testConfig.Clone() 484 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) 485 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 486 if err != nil { 487 t.Fatal(err) 488 } 489 490 var serverTLSUniquesValue []byte 491 select { 492 case <-childDone: 493 return 494 case serverTLSUniquesValue = <-serverTLSUniques: 495 } 496 497 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { 498 t.Error("client and server channel bindings differ") 499 } 500 _ = conn.Close() 501 502 conn, err = Dial("tcp", ln.Addr().String(), clientConfig) 503 if err != nil { 504 t.Fatal(err) 505 } 506 defer func(conn *Conn) { 507 err := conn.Close() 508 if err != nil { 509 panic(err) 510 } 511 }(conn) 512 if !conn.ConnectionState().DidResume { 513 t.Error("second session did not use resumption") 514 } 515 516 select { 517 case <-childDone: 518 return 519 case serverTLSUniquesValue = <-serverTLSUniques: 520 } 521 522 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { 523 t.Error("client and server channel bindings differ when session resumption is used") 524 } 525 } 526 527 func TestVerifyHostname(t *testing.T) { 528 // 想找个国内的支持tls1.3的站点,但是大部分都还只支持tls1.2,好不容易找到一个思否是支持tls1.3的。 529 // github等国外网站倒是支持tls1.3,奈何访问不稳定。 530 uri := "segmentfault.com" 531 532 testenv.MustHaveExternalNetwork(t) 533 clientConfig := &Config{ 534 CurvePreferences: []CurveID{CurveP256}, 535 // MaxVersion: VersionTLS12, 536 } 537 fmt.Println("---------- 访问 https://" + uri + " ----------") 538 c, err := Dial("tcp", uri+":https", clientConfig) 539 if err != nil { 540 t.Fatal(err) 541 } 542 if err := c.VerifyHostname(uri); err != nil { 543 t.Fatalf("verify %s: %v", uri, err) 544 } 545 if err := c.VerifyHostname("test.com"); err == nil { 546 t.Fatalf("verify test.com succeeded") 547 } 548 549 fmt.Println("---------- 访问 https://" + uri + " 不做证书检查 ----------") 550 clientConfig.InsecureSkipVerify = true 551 c, err = Dial("tcp", uri+":https", clientConfig) 552 if err != nil { 553 t.Fatal(err) 554 } 555 if err := c.VerifyHostname(uri); err == nil { 556 t.Fatalf("verify %s succeeded with InsecureSkipVerify=true", uri) 557 } 558 } 559 560 func TestConnCloseBreakingWrite(t *testing.T) { 561 ln := newLocalListener(t) 562 defer func(ln net.Listener) { 563 err := ln.Close() 564 if err != nil { 565 panic(err) 566 } 567 }(ln) 568 569 srvCh := make(chan *Conn, 1) 570 var serr error 571 var sconn net.Conn 572 go func() { 573 var err error 574 sconn, err = ln.Accept() 575 if err != nil { 576 serr = err 577 srvCh <- nil 578 return 579 } 580 serverConfig := testConfig.Clone() 581 srv := Server(sconn, serverConfig) 582 if err := srv.Handshake(); err != nil { 583 serr = fmt.Errorf("handshake: %v", err) 584 srvCh <- nil 585 return 586 } 587 srvCh <- srv 588 }() 589 590 cconn, err := net.Dial("tcp", ln.Addr().String()) 591 if err != nil { 592 t.Fatal(err) 593 } 594 defer func(cconn net.Conn) { 595 err := cconn.Close() 596 if err != nil { 597 panic(err) 598 } 599 }(cconn) 600 601 conn := &changeImplConn{ 602 Conn: cconn, 603 } 604 605 clientConfig := testConfig.Clone() 606 tconn := Client(conn, clientConfig) 607 if err := tconn.Handshake(); err != nil { 608 t.Fatal(err) 609 } 610 611 srv := <-srvCh 612 if srv == nil { 613 t.Fatal(serr) 614 } 615 defer func(sconn net.Conn) { 616 err := sconn.Close() 617 if err != nil { 618 panic(err) 619 } 620 }(sconn) 621 622 connClosed := make(chan struct{}) 623 conn.closeFunc = func() error { 624 close(connClosed) 625 return nil 626 } 627 628 inWrite := make(chan bool, 1) 629 var errConnClosed = errors.New("conn closed for test") 630 conn.writeFunc = func(p []byte) (n int, err error) { 631 inWrite <- true 632 <-connClosed 633 return 0, errConnClosed 634 } 635 636 closeReturned := make(chan bool, 1) 637 go func() { 638 <-inWrite 639 _ = tconn.Close() // test that this doesn't block forever. 640 closeReturned <- true 641 }() 642 643 _, err = tconn.Write([]byte("foo")) 644 if err != errConnClosed { 645 t.Errorf("Write error = %v; want errConnClosed", err) 646 } 647 648 <-closeReturned 649 if err := tconn.Close(); err != net.ErrClosed { 650 t.Errorf("Close error = %v; want net.ErrClosed", err) 651 } 652 } 653 654 func TestConnCloseWrite(t *testing.T) { 655 ln := newLocalListener(t) 656 defer func(ln net.Listener) { 657 err := ln.Close() 658 if err != nil { 659 panic(err) 660 } 661 }(ln) 662 663 clientDoneChan := make(chan struct{}) 664 665 serverCloseWrite := func() error { 666 sconn, err := ln.Accept() 667 if err != nil { 668 return fmt.Errorf("accept: %v", err) 669 } 670 defer func(sconn net.Conn) { 671 err := sconn.Close() 672 if err != nil { 673 panic(err) 674 } 675 }(sconn) 676 677 serverConfig := testConfig.Clone() 678 srv := Server(sconn, serverConfig) 679 if err := srv.Handshake(); err != nil { 680 return fmt.Errorf("handshake: %v", err) 681 } 682 defer func(srv *Conn) { 683 err := srv.Close() 684 if err != nil { 685 panic(err) 686 } 687 }(srv) 688 689 data, err := io.ReadAll(srv) 690 if err != nil { 691 return err 692 } 693 if len(data) > 0 { 694 //goland:noinspection GoErrorStringFormat 695 return fmt.Errorf("Read data = %q; want nothing", data) 696 } 697 698 if err := srv.CloseWrite(); err != nil { 699 return fmt.Errorf("server CloseWrite: %v", err) 700 } 701 702 // Wait for clientCloseWrite to finish, so we know we 703 // tested the CloseWrite before we defer the 704 // sconn.Close above, which would also cause the 705 // client to unblock like CloseWrite. 706 <-clientDoneChan 707 return nil 708 } 709 710 clientCloseWrite := func() error { 711 defer close(clientDoneChan) 712 713 clientConfig := testConfig.Clone() 714 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 715 if err != nil { 716 return err 717 } 718 if err := conn.Handshake(); err != nil { 719 return err 720 } 721 defer func(conn *Conn) { 722 err := conn.Close() 723 if err != nil { 724 panic(err) 725 } 726 }(conn) 727 728 if err := conn.CloseWrite(); err != nil { 729 return fmt.Errorf("client CloseWrite: %v", err) 730 } 731 732 if _, err := conn.Write([]byte{0}); err != errShutdown { 733 return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) 734 } 735 736 data, err := io.ReadAll(conn) 737 if err != nil { 738 return err 739 } 740 if len(data) > 0 { 741 //goland:noinspection GoErrorStringFormat 742 return fmt.Errorf("Read data = %q; want nothing", data) 743 } 744 return nil 745 } 746 747 errChan := make(chan error, 2) 748 749 go func() { errChan <- serverCloseWrite() }() 750 go func() { errChan <- clientCloseWrite() }() 751 752 for i := 0; i < 2; i++ { 753 select { 754 case err := <-errChan: 755 if err != nil { 756 t.Fatal(err) 757 } 758 case <-time.After(10 * time.Second): 759 t.Fatal("deadlock") 760 } 761 } 762 763 // Also test CloseWrite being called before the handshake is 764 // finished: 765 { 766 ln2 := newLocalListener(t) 767 defer func(ln2 net.Listener) { 768 err := ln2.Close() 769 if err != nil { 770 panic(err) 771 } 772 }(ln2) 773 774 netConn, err := net.Dial("tcp", ln2.Addr().String()) 775 if err != nil { 776 t.Fatal(err) 777 } 778 defer func(netConn net.Conn) { 779 err := netConn.Close() 780 if err != nil { 781 panic(err) 782 } 783 }(netConn) 784 conn := Client(netConn, testConfig.Clone()) 785 786 if err := conn.CloseWrite(); err != errEarlyCloseWrite { 787 t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) 788 } 789 } 790 } 791 792 func TestWarningAlertFlood(t *testing.T) { 793 ln := newLocalListener(t) 794 defer func(ln net.Listener) { 795 err := ln.Close() 796 if err != nil { 797 panic(err) 798 } 799 }(ln) 800 801 server := func() error { 802 sconn, err := ln.Accept() 803 if err != nil { 804 return fmt.Errorf("accept: %v", err) 805 } 806 defer func(sconn net.Conn) { 807 err := sconn.Close() 808 if err != nil { 809 panic(err) 810 } 811 }(sconn) 812 813 serverConfig := testConfig.Clone() 814 srv := Server(sconn, serverConfig) 815 if err := srv.Handshake(); err != nil { 816 return fmt.Errorf("handshake: %v", err) 817 } 818 defer func(srv *Conn) { 819 err := srv.Close() 820 if err != nil { 821 panic(err) 822 } 823 }(srv) 824 825 _, err = io.ReadAll(srv) 826 if err == nil { 827 return errors.New("unexpected lack of error from server") 828 } 829 const expected = "too many ignored" 830 if str := err.Error(); !strings.Contains(str, expected) { 831 return fmt.Errorf("expected error containing %q, but saw: %s", expected, str) 832 } 833 834 return nil 835 } 836 837 errChan := make(chan error, 1) 838 go func() { errChan <- server() }() 839 840 clientConfig := testConfig.Clone() 841 clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3 842 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 843 if err != nil { 844 t.Fatal(err) 845 } 846 defer func(conn *Conn) { 847 err := conn.Close() 848 if err != nil { 849 panic(err) 850 } 851 }(conn) 852 if err := conn.Handshake(); err != nil { 853 t.Fatal(err) 854 } 855 856 for i := 0; i < maxUselessRecords+1; i++ { 857 _ = conn.sendAlert(alertNoRenegotiation) 858 } 859 860 if err := <-errChan; err != nil { 861 t.Fatal(err) 862 } 863 } 864 865 func TestCloneFuncFields(t *testing.T) { 866 const expectedCount = 6 867 called := 0 868 869 c1 := Config{ 870 Time: func() time.Time { 871 called |= 1 << 0 872 return time.Time{} 873 }, 874 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { 875 called |= 1 << 1 876 return nil, nil 877 }, 878 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) { 879 called |= 1 << 2 880 return nil, nil 881 }, 882 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { 883 called |= 1 << 3 884 return nil, nil 885 }, 886 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 887 called |= 1 << 4 888 return nil 889 }, 890 VerifyConnection: func(ConnectionState) error { 891 called |= 1 << 5 892 return nil 893 }, 894 } 895 896 c2 := c1.Clone() 897 898 c2.Time() 899 _, _ = c2.GetCertificate(nil) 900 _, _ = c2.GetClientCertificate(nil) 901 _, _ = c2.GetConfigForClient(nil) 902 _ = c2.VerifyPeerCertificate(nil, nil) 903 _ = c2.VerifyConnection(ConnectionState{}) 904 905 if called != (1<<expectedCount)-1 { 906 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) 907 } 908 } 909 910 func TestCloneNonFuncFields(t *testing.T) { 911 var c1 Config 912 v := reflect.ValueOf(&c1).Elem() 913 914 typ := v.Type() 915 for i := 0; i < typ.NumField(); i++ { 916 f := v.Field(i) 917 // testing/quick can't handle functions or interfaces and so 918 // isn't used here. 919 switch fn := typ.Field(i).Name; fn { 920 case "Rand": 921 f.Set(reflect.ValueOf(io.Reader(os.Stdin))) 922 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate": 923 // DeepEqual can't compare functions. If you add a 924 // function field to this list, you must also change 925 // TestCloneFuncFields to ensure that the func field is 926 // cloned. 927 case "Certificates": 928 f.Set(reflect.ValueOf([]Certificate{ 929 {Certificate: [][]byte{{'b'}}}, 930 })) 931 case "NameToCertificate": 932 f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil})) 933 case "RootCAs", "ClientCAs": 934 f.Set(reflect.ValueOf(x509.NewCertPool())) 935 case "ClientSessionCache": 936 f.Set(reflect.ValueOf(NewLRUClientSessionCache(10))) 937 case "KeyLogWriter": 938 f.Set(reflect.ValueOf(io.Writer(os.Stdout))) 939 case "NextProtos": 940 f.Set(reflect.ValueOf([]string{"a", "b"})) 941 case "ServerName": 942 f.Set(reflect.ValueOf("b")) 943 case "ClientAuth": 944 f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) 945 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": 946 f.Set(reflect.ValueOf(true)) 947 case "MinVersion", "MaxVersion": 948 f.Set(reflect.ValueOf(uint16(VersionTLS12))) 949 case "SessionTicketKey": 950 f.Set(reflect.ValueOf([32]byte{})) 951 case "CipherSuites": 952 f.Set(reflect.ValueOf([]uint16{1, 2})) 953 case "CurvePreferences": 954 f.Set(reflect.ValueOf([]CurveID{CurveP256})) 955 case "Renegotiation": 956 f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) 957 case "mutex", "autoSessionTicketKeys", "sessionTicketKeys": 958 continue // these are unexported fields that are handled separately 959 default: 960 t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) 961 } 962 } 963 // Set the unexported fields related to session ticket keys, which are copied with Clone(). 964 c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} 965 c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} 966 967 c2 := c1.Clone() 968 if !reflect.DeepEqual(&c1, c2) { 969 t.Errorf("clone failed to copy a field") 970 } 971 } 972 973 func TestCloneNilConfig(t *testing.T) { 974 var config *Config 975 if cc := config.Clone(); cc != nil { 976 t.Fatalf("Clone with nil should return nil, got: %+v", cc) 977 } 978 } 979 980 // changeImplConn is a net.Conn which can change its Write and Close 981 // methods. 982 type changeImplConn struct { 983 net.Conn 984 writeFunc func([]byte) (int, error) 985 closeFunc func() error 986 } 987 988 func (w *changeImplConn) Write(p []byte) (n int, err error) { 989 if w.writeFunc != nil { 990 return w.writeFunc(p) 991 } 992 return w.Conn.Write(p) 993 } 994 995 func (w *changeImplConn) Close() error { 996 if w.closeFunc != nil { 997 return w.closeFunc() 998 } 999 return w.Conn.Close() 1000 } 1001 1002 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) { 1003 ln := newLocalListener(b) 1004 defer func(ln net.Listener) { 1005 err := ln.Close() 1006 if err != nil { 1007 panic(err) 1008 } 1009 }(ln) 1010 1011 N := b.N 1012 1013 // Less than 64KB because Windows appears to use a TCP rwin < 64KB. 1014 // See Issue #15899. 1015 const bufsize = 32 << 10 1016 1017 go func() { 1018 buf := make([]byte, bufsize) 1019 for i := 0; i < N; i++ { 1020 sconn, err := ln.Accept() 1021 if err != nil { 1022 // panic rather than synchronize to avoid benchmark overhead 1023 // (cannot call b.Fatal in goroutine) 1024 panic(fmt.Errorf("accept: %v", err)) 1025 } 1026 serverConfig := testConfig.Clone() 1027 serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers 1028 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1029 srv := Server(sconn, serverConfig) 1030 if err := srv.Handshake(); err != nil { 1031 panic(fmt.Errorf("handshake: %v", err)) 1032 } 1033 if _, err := io.CopyBuffer(srv, srv, buf); err != nil { 1034 panic(fmt.Errorf("copy buffer: %v", err)) 1035 } 1036 } 1037 }() 1038 1039 b.SetBytes(totalBytes) 1040 clientConfig := testConfig.Clone() 1041 clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers 1042 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1043 clientConfig.MaxVersion = version 1044 1045 buf := make([]byte, bufsize) 1046 chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf)))) 1047 for i := 0; i < N; i++ { 1048 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 1049 if err != nil { 1050 b.Fatal(err) 1051 } 1052 for j := 0; j < chunks; j++ { 1053 _, err := conn.Write(buf) 1054 if err != nil { 1055 b.Fatal(err) 1056 } 1057 _, err = io.ReadFull(conn, buf) 1058 if err != nil { 1059 b.Fatal(err) 1060 } 1061 } 1062 _ = conn.Close() 1063 } 1064 } 1065 1066 func BenchmarkThroughput(b *testing.B) { 1067 for _, mode := range []string{"Max", "Dynamic"} { 1068 for size := 1; size <= 64; size <<= 1 { 1069 name := fmt.Sprintf("%sPacket/%dMB", mode, size) 1070 b.Run(name, func(b *testing.B) { 1071 b.Run("TLSv12", func(b *testing.B) { 1072 throughput(b, VersionTLS12, int64(size<<20), mode == "Max") 1073 }) 1074 b.Run("TLSv13", func(b *testing.B) { 1075 throughput(b, VersionTLS13, int64(size<<20), mode == "Max") 1076 }) 1077 }) 1078 } 1079 } 1080 } 1081 1082 type slowConn struct { 1083 net.Conn 1084 bps int 1085 } 1086 1087 func (c *slowConn) Write(p []byte) (int, error) { 1088 if c.bps == 0 { 1089 panic("too slow") 1090 } 1091 t0 := time.Now() 1092 wrote := 0 1093 for wrote < len(p) { 1094 time.Sleep(100 * time.Microsecond) 1095 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8 1096 if allowed > len(p) { 1097 allowed = len(p) 1098 } 1099 if wrote < allowed { 1100 n, err := c.Conn.Write(p[wrote:allowed]) 1101 wrote += n 1102 if err != nil { 1103 return wrote, err 1104 } 1105 } 1106 } 1107 return len(p), nil 1108 } 1109 1110 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) { 1111 ln := newLocalListener(b) 1112 defer func(ln net.Listener) { 1113 err := ln.Close() 1114 if err != nil { 1115 panic(err) 1116 } 1117 }(ln) 1118 1119 N := b.N 1120 1121 go func() { 1122 for i := 0; i < N; i++ { 1123 sconn, err := ln.Accept() 1124 if err != nil { 1125 // panic rather than synchronize to avoid benchmark overhead 1126 // (cannot call b.Fatal in goroutine) 1127 panic(fmt.Errorf("accept: %v", err)) 1128 } 1129 serverConfig := testConfig.Clone() 1130 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1131 srv := Server(&slowConn{sconn, bps}, serverConfig) 1132 if err := srv.Handshake(); err != nil { 1133 panic(fmt.Errorf("handshake: %v", err)) 1134 } 1135 _, _ = io.Copy(srv, srv) 1136 } 1137 }() 1138 1139 clientConfig := testConfig.Clone() 1140 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled 1141 clientConfig.MaxVersion = version 1142 1143 buf := make([]byte, 16384) 1144 peek := make([]byte, 1) 1145 1146 for i := 0; i < N; i++ { 1147 conn, err := Dial("tcp", ln.Addr().String(), clientConfig) 1148 if err != nil { 1149 b.Fatal(err) 1150 } 1151 // make sure we're connected and previous connection has stopped 1152 if _, err := conn.Write(buf[:1]); err != nil { 1153 b.Fatal(err) 1154 } 1155 if _, err := io.ReadFull(conn, peek); err != nil { 1156 b.Fatal(err) 1157 } 1158 if _, err := conn.Write(buf); err != nil { 1159 b.Fatal(err) 1160 } 1161 if _, err = io.ReadFull(conn, peek); err != nil { 1162 b.Fatal(err) 1163 } 1164 _ = conn.Close() 1165 } 1166 } 1167 1168 func BenchmarkLatency(b *testing.B) { 1169 for _, mode := range []string{"Max", "Dynamic"} { 1170 for _, kbps := range []int{200, 500, 1000, 2000, 5000} { 1171 name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps) 1172 b.Run(name, func(b *testing.B) { 1173 b.Run("TLSv12", func(b *testing.B) { 1174 latency(b, VersionTLS12, kbps*1000, mode == "Max") 1175 }) 1176 b.Run("TLSv13", func(b *testing.B) { 1177 latency(b, VersionTLS13, kbps*1000, mode == "Max") 1178 }) 1179 }) 1180 } 1181 } 1182 } 1183 1184 func TestConnectionStateMarshal(t *testing.T) { 1185 cs := &ConnectionState{} 1186 _, err := json.Marshal(cs) 1187 if err != nil { 1188 t.Errorf("json.Marshal failed on ConnectionState: %v", err) 1189 } 1190 } 1191 1192 func TestConnectionState(t *testing.T) { 1193 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1194 if err != nil { 1195 panic(err) 1196 } 1197 rootCAs := x509.NewCertPool() 1198 rootCAs.AddCert(issuer) 1199 1200 now := func() time.Time { return time.Unix(1476984729, 0) } 1201 1202 const alpnProtocol = "golang" 1203 const serverName = "example.golang" 1204 var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} 1205 var ocsp = []byte("dummy ocsp") 1206 1207 for _, v := range []uint16{VersionTLS12, VersionTLS13} { 1208 var name string 1209 switch v { 1210 case VersionTLS12: 1211 name = "TLSv12" 1212 case VersionTLS13: 1213 name = "TLSv13" 1214 } 1215 t.Run(name, func(t *testing.T) { 1216 config := &Config{ 1217 Time: now, 1218 Rand: zeroSource{}, 1219 Certificates: make([]Certificate, 1), 1220 MaxVersion: v, 1221 RootCAs: rootCAs, 1222 ClientCAs: rootCAs, 1223 ClientAuth: RequireAndVerifyClientCert, 1224 NextProtos: []string{alpnProtocol}, 1225 ServerName: serverName, 1226 } 1227 config.Certificates[0].Certificate = [][]byte{testRSACertificate} 1228 config.Certificates[0].PrivateKey = testRSAPrivateKey 1229 config.Certificates[0].SignedCertificateTimestamps = scts 1230 config.Certificates[0].OCSPStaple = ocsp 1231 1232 ss, cs, err := testHandshake(t, config, config) 1233 if err != nil { 1234 t.Fatalf("Handshake failed: %v", err) 1235 } 1236 1237 if ss.Version != v || cs.Version != v { 1238 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v) 1239 } 1240 1241 if !ss.HandshakeComplete || !cs.HandshakeComplete { 1242 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete) 1243 } 1244 1245 if ss.DidResume || cs.DidResume { 1246 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume) 1247 } 1248 1249 if ss.CipherSuite == 0 || cs.CipherSuite == 0 { 1250 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite) 1251 } 1252 1253 if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol { 1254 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol) 1255 } 1256 1257 if !cs.NegotiatedProtocolIsMutual { 1258 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side") 1259 } 1260 // NegotiatedProtocolIsMutual on the server side is unspecified. 1261 1262 if ss.ServerName != serverName { 1263 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName) 1264 } 1265 if cs.ServerName != serverName { 1266 t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName) 1267 } 1268 1269 if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 { 1270 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1) 1271 } 1272 1273 if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 { 1274 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1) 1275 } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 { 1276 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2) 1277 } 1278 1279 if len(cs.SignedCertificateTimestamps) != 2 { 1280 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2) 1281 } 1282 if !bytes.Equal(cs.OCSPResponse, ocsp) { 1283 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp) 1284 } 1285 // Only TLS 1.3 supports OCSP and SCTs on client certs. 1286 if v == VersionTLS13 { 1287 if len(ss.SignedCertificateTimestamps) != 2 { 1288 t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2) 1289 } 1290 if !bytes.Equal(ss.OCSPResponse, ocsp) { 1291 t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp) 1292 } 1293 } 1294 1295 if v == VersionTLS13 { 1296 if ss.TLSUnique != nil || cs.TLSUnique != nil { 1297 t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique) 1298 } 1299 } else { 1300 if ss.TLSUnique == nil || cs.TLSUnique == nil { 1301 t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique) 1302 } 1303 } 1304 }) 1305 } 1306 } 1307 1308 // Issue 28744: Ensure that we don't modify memory 1309 // that Config doesn't own such as Certificates. 1310 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) { 1311 c0 := Certificate{ 1312 Certificate: [][]byte{testRSACertificate}, 1313 PrivateKey: testRSAPrivateKey, 1314 } 1315 c1 := Certificate{ 1316 Certificate: [][]byte{testSNICertificate}, 1317 PrivateKey: testRSAPrivateKey, 1318 } 1319 config := testConfig.Clone() 1320 config.Certificates = []Certificate{c0, c1} 1321 1322 config.BuildNameToCertificate() 1323 got := config.Certificates 1324 want := []Certificate{c0, c1} 1325 if !reflect.DeepEqual(got, want) { 1326 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want) 1327 } 1328 } 1329 1330 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } 1331 1332 func TestClientHelloInfo_SupportsCertificate(t *testing.T) { 1333 rsaCert := &Certificate{ 1334 Certificate: [][]byte{testRSACertificate}, 1335 PrivateKey: testRSAPrivateKey, 1336 } 1337 pkcs1Cert := &Certificate{ 1338 Certificate: [][]byte{testRSACertificate}, 1339 PrivateKey: testRSAPrivateKey, 1340 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}, 1341 } 1342 ecdsaCert := &Certificate{ 1343 // ECDSA P-256 certificate 1344 Certificate: [][]byte{testP256Certificate}, 1345 PrivateKey: testP256PrivateKey, 1346 } 1347 ed25519Cert := &Certificate{ 1348 Certificate: [][]byte{testEd25519Certificate}, 1349 PrivateKey: testEd25519PrivateKey, 1350 } 1351 1352 tests := []struct { 1353 c *Certificate 1354 chi *ClientHelloInfo 1355 wantErr string 1356 }{ 1357 {rsaCert, &ClientHelloInfo{ 1358 ServerName: "example.golang", 1359 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1360 SupportedVersions: []uint16{VersionTLS13}, 1361 }, ""}, 1362 {ecdsaCert, &ClientHelloInfo{ 1363 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, 1364 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1365 }, ""}, 1366 {rsaCert, &ClientHelloInfo{ 1367 ServerName: "example.com", 1368 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1369 SupportedVersions: []uint16{VersionTLS13}, 1370 }, "not valid for requested server name"}, 1371 {ecdsaCert, &ClientHelloInfo{ 1372 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, 1373 SupportedVersions: []uint16{VersionTLS13}, 1374 }, "signature algorithms"}, 1375 {pkcs1Cert, &ClientHelloInfo{ 1376 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256}, 1377 SupportedVersions: []uint16{VersionTLS13}, 1378 }, "signature algorithms"}, 1379 1380 {rsaCert, &ClientHelloInfo{ 1381 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1382 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, 1383 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1384 }, "signature algorithms"}, 1385 {rsaCert, &ClientHelloInfo{ 1386 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1387 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1}, 1388 SupportedVersions: []uint16{VersionTLS13, VersionTLS12}, 1389 config: &Config{ 1390 MaxVersion: VersionTLS12, 1391 }, 1392 }, ""}, // Check that mutual version selection works. 1393 1394 {ecdsaCert, &ClientHelloInfo{ 1395 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1396 SupportedCurves: []CurveID{CurveP256}, 1397 SupportedPoints: []uint8{pointFormatUncompressed}, 1398 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1399 SupportedVersions: []uint16{VersionTLS12}, 1400 }, ""}, 1401 {ecdsaCert, &ClientHelloInfo{ 1402 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1403 SupportedCurves: []CurveID{CurveP256}, 1404 SupportedPoints: []uint8{pointFormatUncompressed}, 1405 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384}, 1406 SupportedVersions: []uint16{VersionTLS12}, 1407 }, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme. 1408 {ecdsaCert, &ClientHelloInfo{ 1409 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1410 SupportedCurves: []CurveID{CurveP256}, 1411 SupportedPoints: []uint8{pointFormatUncompressed}, 1412 SignatureSchemes: nil, 1413 SupportedVersions: []uint16{VersionTLS12}, 1414 }, ""}, // TLS 1.2 comes with default signature schemes. 1415 {ecdsaCert, &ClientHelloInfo{ 1416 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1417 SupportedCurves: []CurveID{CurveP256}, 1418 SupportedPoints: []uint8{pointFormatUncompressed}, 1419 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1420 SupportedVersions: []uint16{VersionTLS12}, 1421 }, "cipher suite"}, 1422 {ecdsaCert, &ClientHelloInfo{ 1423 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1424 SupportedCurves: []CurveID{CurveP256}, 1425 SupportedPoints: []uint8{pointFormatUncompressed}, 1426 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1427 SupportedVersions: []uint16{VersionTLS12}, 1428 config: &Config{ 1429 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1430 }, 1431 }, "cipher suite"}, 1432 {ecdsaCert, &ClientHelloInfo{ 1433 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1434 SupportedCurves: []CurveID{CurveP384}, 1435 SupportedPoints: []uint8{pointFormatUncompressed}, 1436 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1437 SupportedVersions: []uint16{VersionTLS12}, 1438 }, "certificate curve"}, 1439 {ecdsaCert, &ClientHelloInfo{ 1440 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1441 SupportedCurves: []CurveID{CurveP256}, 1442 SupportedPoints: []uint8{1}, 1443 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256}, 1444 SupportedVersions: []uint16{VersionTLS12}, 1445 }, "doesn't support ECDHE"}, 1446 {ecdsaCert, &ClientHelloInfo{ 1447 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1448 SupportedCurves: []CurveID{CurveP256}, 1449 SupportedPoints: []uint8{pointFormatUncompressed}, 1450 SignatureSchemes: []SignatureScheme{PSSWithSHA256}, 1451 SupportedVersions: []uint16{VersionTLS12}, 1452 }, "signature algorithms"}, 1453 1454 {ed25519Cert, &ClientHelloInfo{ 1455 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1456 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1457 SupportedPoints: []uint8{pointFormatUncompressed}, 1458 SignatureSchemes: []SignatureScheme{Ed25519}, 1459 SupportedVersions: []uint16{VersionTLS12}, 1460 }, ""}, 1461 {ed25519Cert, &ClientHelloInfo{ 1462 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1463 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1464 SupportedPoints: []uint8{pointFormatUncompressed}, 1465 SignatureSchemes: []SignatureScheme{Ed25519}, 1466 SupportedVersions: []uint16{VersionTLS10}, 1467 }, "doesn't support Ed25519"}, 1468 {ed25519Cert, &ClientHelloInfo{ 1469 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 1470 SupportedCurves: []CurveID{}, 1471 SupportedPoints: []uint8{pointFormatUncompressed}, 1472 SignatureSchemes: []SignatureScheme{Ed25519}, 1473 SupportedVersions: []uint16{VersionTLS12}, 1474 }, "doesn't support ECDHE"}, 1475 1476 {rsaCert, &ClientHelloInfo{ 1477 CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, 1478 SupportedCurves: []CurveID{CurveP256}, // only relevant for ECDHE support 1479 SupportedPoints: []uint8{pointFormatUncompressed}, 1480 SupportedVersions: []uint16{VersionTLS10}, 1481 }, ""}, 1482 {rsaCert, &ClientHelloInfo{ 1483 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1484 SupportedVersions: []uint16{VersionTLS12}, 1485 }, ""}, // static RSA fallback 1486 } 1487 for i, tt := range tests { 1488 err := tt.chi.SupportsCertificate(tt.c) 1489 switch { 1490 case tt.wantErr == "" && err != nil: 1491 t.Errorf("%d: unexpected error: %v", i, err) 1492 case tt.wantErr != "" && err == nil: 1493 t.Errorf("%d: unexpected success", i) 1494 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr): 1495 t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr) 1496 } 1497 } 1498 } 1499 1500 func TestCipherSuites(t *testing.T) { 1501 var lastID uint16 1502 for _, c := range CipherSuites() { 1503 if lastID > c.ID { 1504 t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) 1505 } else { 1506 lastID = c.ID 1507 } 1508 1509 if c.Insecure { 1510 t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID) 1511 } 1512 } 1513 lastID = 0 1514 for _, c := range InsecureCipherSuites() { 1515 if lastID > c.ID { 1516 t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID) 1517 } else { 1518 lastID = c.ID 1519 } 1520 1521 if !c.Insecure { 1522 t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID) 1523 } 1524 } 1525 1526 CipherSuiteByID := func(id uint16) *CipherSuite { 1527 for _, c := range CipherSuites() { 1528 if c.ID == id { 1529 return c 1530 } 1531 } 1532 for _, c := range InsecureCipherSuites() { 1533 if c.ID == id { 1534 return c 1535 } 1536 } 1537 return nil 1538 } 1539 1540 for _, c := range cipherSuites { 1541 cc := CipherSuiteByID(c.id) 1542 if cc == nil { 1543 t.Errorf("%#04x: no CipherSuite entry", c.id) 1544 continue 1545 } 1546 1547 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 { 1548 t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1549 } else if !tls12Only && len(cc.SupportedVersions) != 3 { 1550 t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1551 } 1552 1553 if got := CipherSuiteName(c.id); got != cc.Name { 1554 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) 1555 } 1556 } 1557 for _, c := range cipherSuitesTLS13 { 1558 cc := CipherSuiteByID(c.id) 1559 if cc == nil { 1560 t.Errorf("%#04x: no CipherSuite entry", c.id) 1561 continue 1562 } 1563 1564 if cc.Insecure { 1565 t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure) 1566 } 1567 if len(cc.SupportedVersions) <= 0 { 1568 t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1569 } 1570 for _, version := range cc.SupportedVersions { 1571 if version != VersionTLS13 && version != VersionGMSSL { 1572 t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1573 } 1574 } 1575 // if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 { 1576 // t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions) 1577 // } 1578 1579 if got := CipherSuiteName(c.id); got != cc.Name { 1580 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name) 1581 } 1582 } 1583 1584 if got := CipherSuiteName(0xabc); got != "0x0ABC" { 1585 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got) 1586 } 1587 1588 if len(cipherSuitesPreferenceOrder) != len(cipherSuites) { 1589 t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites") 1590 } 1591 if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) { 1592 t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder") 1593 } 1594 1595 // Check that disabled suites are at the end of the preference lists, and 1596 // that they are marked insecure. 1597 for i, id := range disabledCipherSuites { 1598 offset := len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) 1599 if cipherSuitesPreferenceOrder[offset+i] != id { 1600 t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrder", i) 1601 } 1602 if cipherSuitesPreferenceOrderNoAES[offset+i] != id { 1603 t.Errorf("disabledCipherSuites[%d]: not at the end of cipherSuitesPreferenceOrderNoAES", i) 1604 } 1605 c := CipherSuiteByID(id) 1606 if c == nil { 1607 t.Errorf("%#04x: no CipherSuite entry", id) 1608 continue 1609 } 1610 if !c.Insecure { 1611 t.Errorf("%#04x: disabled by default but not marked insecure", id) 1612 } 1613 } 1614 1615 for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} { 1616 // Check that insecure and HTTP/2 bad cipher suites are at the end of 1617 // the preference lists. 1618 var sawInsecure, sawBad bool 1619 for _, id := range prefOrder { 1620 c := CipherSuiteByID(id) 1621 if c == nil { 1622 t.Errorf("%#04x: no CipherSuite entry", id) 1623 continue 1624 } 1625 1626 if c.Insecure { 1627 sawInsecure = true 1628 } else if sawInsecure { 1629 t.Errorf("%#04x: secure suite after insecure one(s)", id) 1630 } 1631 1632 if http2isBadCipher(id) { 1633 sawBad = true 1634 } else if sawBad { 1635 t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id) 1636 } 1637 } 1638 1639 // Check that the list is sorted according to the documented criteria. 1640 isBetter := func(a, b int) bool { 1641 aSuite, bSuite := cipherSuiteByID(prefOrder[a]), cipherSuiteByID(prefOrder[b]) 1642 aName, bName := CipherSuiteName(prefOrder[a]), CipherSuiteName(prefOrder[b]) 1643 // * < RC4 1644 if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") { 1645 return true 1646 } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") { 1647 return false 1648 } 1649 // * < CBC_SHA256 1650 if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") { 1651 return true 1652 } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") { 1653 return false 1654 } 1655 // * < 3DES 1656 if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") { 1657 return true 1658 } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") { 1659 return false 1660 } 1661 // ECDHE < * 1662 if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 { 1663 return true 1664 } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 { 1665 return false 1666 } 1667 // AEAD < CBC 1668 if aSuite.aead != nil && bSuite.aead == nil { 1669 return true 1670 } else if aSuite.aead == nil && bSuite.aead != nil { 1671 return false 1672 } 1673 // AES < ChaCha20 1674 if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") { 1675 return i == 0 // true for cipherSuitesPreferenceOrder 1676 } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") { 1677 return i != 0 // true for cipherSuitesPreferenceOrderNoAES 1678 } 1679 // AES-128 < AES-256 1680 if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") { 1681 return true 1682 } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") { 1683 return false 1684 } 1685 // ECDSA < RSA 1686 if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 { 1687 return true 1688 } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 { 1689 return false 1690 } 1691 t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName) 1692 //goland:noinspection GoUnreachableCode 1693 panic("unreachable") 1694 } 1695 if !sort.SliceIsSorted(prefOrder, isBetter) { 1696 t.Error("preference order is not sorted according to the rules") 1697 } 1698 } 1699 } 1700 1701 // http2isBadCipher is copied from net/http. 1702 // TODO: if it ends up exposed somewhere, use that instead. 1703 func http2isBadCipher(cipher uint16) bool { 1704 switch cipher { 1705 case TLS_RSA_WITH_RC4_128_SHA, 1706 TLS_RSA_WITH_3DES_EDE_CBC_SHA, 1707 TLS_RSA_WITH_AES_128_CBC_SHA, 1708 TLS_RSA_WITH_AES_256_CBC_SHA, 1709 TLS_RSA_WITH_AES_128_CBC_SHA256, 1710 TLS_RSA_WITH_AES_128_GCM_SHA256, 1711 TLS_RSA_WITH_AES_256_GCM_SHA384, 1712 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 1713 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 1714 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 1715 TLS_ECDHE_RSA_WITH_RC4_128_SHA, 1716 TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 1717 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 1718 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 1719 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 1720 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: 1721 return true 1722 default: 1723 return false 1724 } 1725 } 1726 1727 type brokenSigner struct{ crypto.Signer } 1728 1729 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { 1730 // Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded. 1731 return s.Signer.Sign(rand, digest, opts.HashFunc()) 1732 } 1733 1734 // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that 1735 // always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS. 1736 func TestPKCS1OnlyCert(t *testing.T) { 1737 clientConfig := testConfig.Clone() 1738 clientConfig.Certificates = []Certificate{{ 1739 Certificate: [][]byte{testRSACertificate}, 1740 PrivateKey: brokenSigner{testRSAPrivateKey}, 1741 }} 1742 serverConfig := testConfig.Clone() 1743 serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5 1744 serverConfig.ClientAuth = RequireAnyClientCert 1745 1746 // If RSA-PSS is selected, the handshake should fail. 1747 if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil { 1748 t.Fatal("expected broken certificate to cause connection to fail") 1749 } 1750 1751 clientConfig.Certificates[0].SupportedSignatureAlgorithms = 1752 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256} 1753 1754 // But if the certificate restricts supported algorithms, RSA-PSS should not 1755 // be selected, and the handshake should succeed. 1756 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil { 1757 t.Error(err) 1758 } 1759 }