gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/handshake_client_test.go (about) 1 // Copyright (c) 2022 zhaochun 2 // core-gm is licensed under Mulan PSL v2. 3 // You can use this software according to the terms and conditions of the Mulan PSL v2. 4 // You may obtain a copy of Mulan PSL v2 at: 5 // http://license.coscl.org.cn/MulanPSL2 6 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 7 // See the Mulan PSL v2 for more details. 8 9 /* 10 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 11 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 12 */ 13 14 package gmtls 15 16 import ( 17 "bytes" 18 "context" 19 "crypto/rsa" 20 "encoding/base64" 21 "encoding/binary" 22 "encoding/pem" 23 "errors" 24 "fmt" 25 "io" 26 "math/big" 27 "net" 28 "os" 29 "os/exec" 30 "path/filepath" 31 "reflect" 32 "runtime" 33 "strconv" 34 "strings" 35 "testing" 36 "time" 37 38 "gitee.com/ks-custle/core-gm/x509" 39 ) 40 41 // Note: see comment in handshake_test.go for details of how the reference 42 // tests work. 43 44 // opensslInputEvent enumerates possible inputs that can be sent to an `openssl 45 // s_client` process. 46 type opensslInputEvent int 47 48 const ( 49 // opensslRenegotiate causes OpenSSL to request a renegotiation of the 50 // connection. 51 opensslRenegotiate opensslInputEvent = iota 52 53 // opensslSendBanner causes OpenSSL to send the contents of 54 // opensslSentinel on the connection. 55 opensslSendSentinel 56 57 // opensslKeyUpdate causes OpenSSL to send a key update message to the 58 // client and request one back. 59 opensslKeyUpdate 60 ) 61 62 const opensslSentinel = "SENTINEL\n" 63 64 type opensslInput chan opensslInputEvent 65 66 func (i opensslInput) Read(buf []byte) (n int, err error) { 67 for event := range i { 68 switch event { 69 case opensslRenegotiate: 70 return copy(buf, "R\n"), nil 71 case opensslKeyUpdate: 72 return copy(buf, "K\n"), nil 73 case opensslSendSentinel: 74 return copy(buf, opensslSentinel), nil 75 default: 76 panic("unknown event") 77 } 78 } 79 80 return 0, io.EOF 81 } 82 83 // opensslOutputSink is an io.Writer that receives the stdout and stderr from an 84 // `openssl` process and sends a value to handshakeComplete or readKeyUpdate 85 // when certain messages are seen. 86 type opensslOutputSink struct { 87 handshakeComplete chan struct{} 88 readKeyUpdate chan struct{} 89 all []byte 90 line []byte 91 } 92 93 func newOpensslOutputSink() *opensslOutputSink { 94 return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil} 95 } 96 97 // opensslEndOfHandshake is a message that the “openssl s_server” tool will 98 // print when a handshake completes if run with “-state”. 99 const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished" 100 101 // opensslReadKeyUpdate is a message that the “openssl s_server” tool will 102 // print when a KeyUpdate message is received if run with “-state”. 103 const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update" 104 105 func (o *opensslOutputSink) Write(data []byte) (n int, err error) { 106 o.line = append(o.line, data...) 107 o.all = append(o.all, data...) 108 109 for { 110 i := bytes.IndexByte(o.line, '\n') 111 if i < 0 { 112 break 113 } 114 115 if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) { 116 o.handshakeComplete <- struct{}{} 117 } 118 if bytes.Equal([]byte(opensslReadKeyUpdate), o.line[:i]) { 119 o.readKeyUpdate <- struct{}{} 120 } 121 o.line = o.line[i+1:] 122 } 123 124 return len(data), nil 125 } 126 127 func (o *opensslOutputSink) String() string { 128 return string(o.all) 129 } 130 131 // clientTest represents a test of the TLS client handshake against a reference 132 // implementation. 133 type clientTest struct { 134 // name is a freeform string identifying the test and the file in which 135 // the expected results will be stored. 136 name string 137 // args, if not empty, contains a series of arguments for the 138 // command to run for the reference server. 139 args []string 140 // config, if not nil, contains a custom Config to use for this test. 141 config *Config 142 // cert, if not empty, contains a DER-encoded certificate for the 143 // reference server. 144 cert []byte 145 // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or 146 // *ecdsa.PrivateKey which is the private key for the reference server. 147 key interface{} 148 // extensions, if not nil, contains a list of extension data to be returned 149 // from the ServerHello. The data should be in standard TLS format with 150 // a 2-byte uint16 type, 2-byte data length, followed by the extension data. 151 extensions [][]byte 152 // validate, if not nil, is a function that will be called with the 153 // ConnectionState of the resulting connection. It returns a non-nil 154 // error if the ConnectionState is unacceptable. 155 validate func(ConnectionState) error 156 // numRenegotiations is the number of times that the connection will be 157 // renegotiated. 158 numRenegotiations int 159 // renegotiationExpectedToFail, if not zero, is the number of the 160 // renegotiation attempt that is expected to fail. 161 renegotiationExpectedToFail int 162 // checkRenegotiationError, if not nil, is called with any error 163 // arising from renegotiation. It can map expected errors to nil to 164 // ignore them. 165 checkRenegotiationError func(renegotiationNum int, err error) error 166 // sendKeyUpdate will cause the server to send a KeyUpdate message. 167 sendKeyUpdate bool 168 } 169 170 var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"} 171 172 // connFromCommand starts the reference server process, connects to it and 173 // returns a recordingConn for the connection. The stdin return value is an 174 // opensslInput for the stdin of the child process. It must be closed before 175 // Waiting for child. 176 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) { 177 cert := testRSACertificate 178 if len(test.cert) > 0 { 179 cert = test.cert 180 } 181 certPath := tempFile(string(cert)) 182 defer func(name string) { 183 err := os.Remove(name) 184 if err != nil { 185 panic(err) 186 } 187 }(certPath) 188 189 var key interface{} = testRSAPrivateKey 190 if test.key != nil { 191 key = test.key 192 } 193 derBytes, err := x509.MarshalPKCS8PrivateKey(key) 194 if err != nil { 195 panic(err) 196 } 197 198 var pemOut bytes.Buffer 199 err = pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes}) 200 if err != nil { 201 panic(err) 202 } 203 204 keyPath := tempFile(pemOut.String()) 205 defer func(name string) { 206 err := os.Remove(name) 207 if err != nil { 208 panic(err) 209 } 210 }(keyPath) 211 212 var command []string 213 command = append(command, serverCommand...) 214 command = append(command, test.args...) 215 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) 216 // serverPort contains the port that OpenSSL will listen on. OpenSSL 217 // can't take "0" as an argument here so we have to pick a number and 218 // hope that it's not in use on the machine. Since this only occurs 219 // when -update is given and thus when there's a human watching the 220 // test, this isn't too bad. 221 const serverPort = 24323 222 command = append(command, "-accept", strconv.Itoa(serverPort)) 223 224 if len(test.extensions) > 0 { 225 var serverInfo bytes.Buffer 226 for _, ext := range test.extensions { 227 err := pem.Encode(&serverInfo, &pem.Block{ 228 Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), 229 Bytes: ext, 230 }) 231 if err != nil { 232 panic(err) 233 } 234 } 235 serverInfoPath := tempFile(serverInfo.String()) 236 defer func(name string) { 237 err := os.Remove(name) 238 if err != nil { 239 panic(err) 240 } 241 }(serverInfoPath) 242 command = append(command, "-serverinfo", serverInfoPath) 243 } 244 245 if test.numRenegotiations > 0 || test.sendKeyUpdate { 246 found := false 247 for _, flag := range command[1:] { 248 if flag == "-state" { 249 found = true 250 break 251 } 252 } 253 254 if !found { 255 panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate") 256 } 257 } 258 259 cmd := exec.Command(command[0], command[1:]...) 260 stdin = make(chan opensslInputEvent) 261 cmd.Stdin = stdin 262 out := newOpensslOutputSink() 263 cmd.Stdout = out 264 cmd.Stderr = out 265 if err := cmd.Start(); err != nil { 266 return nil, nil, nil, nil, err 267 } 268 269 // OpenSSL does print an "ACCEPT" banner, but it does so *before* 270 // opening the listening socket, so we can't use that to wait until it 271 // has started listening. Thus we are forced to poll until we get a 272 // connection. 273 var tcpConn net.Conn 274 for i := uint(0); i < 5; i++ { 275 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ 276 IP: net.IPv4(127, 0, 0, 1), 277 Port: serverPort, 278 }) 279 if err == nil { 280 break 281 } 282 time.Sleep((1 << i) * 5 * time.Millisecond) 283 } 284 if err != nil { 285 close(stdin) 286 err1 := cmd.Process.Kill() 287 if err1 != nil { 288 panic(err1) 289 } 290 err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out) 291 return nil, nil, nil, nil, err 292 } 293 294 record := &recordingConn{ 295 Conn: tcpConn, 296 } 297 298 return record, cmd, stdin, out, nil 299 } 300 301 func (test *clientTest) dataPath() string { 302 return filepath.Join("testdata", "Client-"+test.name) 303 } 304 305 func (test *clientTest) loadData() (flows [][]byte, err error) { 306 in, err := os.Open(test.dataPath()) 307 if err != nil { 308 return nil, err 309 } 310 defer func(in *os.File) { 311 err := in.Close() 312 if err != nil { 313 panic(err) 314 } 315 }(in) 316 return parseTestData(in) 317 } 318 319 func (test *clientTest) run(t *testing.T, write bool) { 320 var clientConn, serverConn net.Conn 321 var recordingConn *recordingConn 322 var childProcess *exec.Cmd 323 var stdin opensslInput 324 var stdout *opensslOutputSink 325 326 if write { 327 var err error 328 recordingConn, childProcess, stdin, stdout, err = test.connFromCommand() 329 if err != nil { 330 t.Fatalf("Failed to start subcommand: %s", err) 331 } 332 clientConn = recordingConn 333 defer func() { 334 if t.Failed() { 335 t.Logf("OpenSSL output:\n\n%s", stdout.all) 336 } 337 }() 338 } else { 339 clientConn, serverConn = localPipe(t) 340 } 341 342 doneChan := make(chan bool) 343 defer func() { 344 err := clientConn.Close() 345 if err != nil { 346 panic(err) 347 } 348 <-doneChan 349 }() 350 go func() { 351 defer close(doneChan) 352 353 config := test.config 354 if config == nil { 355 config = testConfig 356 } 357 client := Client(clientConn, config) 358 defer func(client *Conn) { 359 err := client.Close() 360 if err != nil { 361 panic(err) 362 } 363 }(client) 364 365 if _, err := client.Write([]byte("hello\n")); err != nil { 366 t.Errorf("Client.Write failed: %s", err) 367 return 368 } 369 370 for i := 1; i <= test.numRenegotiations; i++ { 371 // The initial handshake will generate a 372 // handshakeComplete signal which needs to be quashed. 373 if i == 1 && write { 374 <-stdout.handshakeComplete 375 } 376 377 // OpenSSL will try to interleave application data and 378 // a renegotiation if we send both concurrently. 379 // Therefore: ask OpensSSL to start a renegotiation, run 380 // a goroutine to call client.Read and thus process the 381 // renegotiation request, watch for OpenSSL's stdout to 382 // indicate that the handshake is complete and, 383 // finally, have OpenSSL write something to cause 384 // client.Read to complete. 385 if write { 386 stdin <- opensslRenegotiate 387 } 388 389 signalChan := make(chan struct{}) 390 391 go func() { 392 defer close(signalChan) 393 394 buf := make([]byte, 256) 395 n, err := client.Read(buf) 396 397 if test.checkRenegotiationError != nil { 398 newErr := test.checkRenegotiationError(i, err) 399 if err != nil && newErr == nil { 400 return 401 } 402 err = newErr 403 } 404 405 if err != nil { 406 t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) 407 return 408 } 409 410 buf = buf[:n] 411 if !bytes.Equal([]byte(opensslSentinel), buf) { 412 t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) 413 } 414 415 if expected := i + 1; client.handshakes != expected { 416 t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes) 417 } 418 }() 419 420 if write && test.renegotiationExpectedToFail != i { 421 <-stdout.handshakeComplete 422 stdin <- opensslSendSentinel 423 } 424 <-signalChan 425 } 426 427 if test.sendKeyUpdate { 428 if write { 429 <-stdout.handshakeComplete 430 stdin <- opensslKeyUpdate 431 } 432 433 doneRead := make(chan struct{}) 434 435 go func() { 436 defer close(doneRead) 437 438 buf := make([]byte, 256) 439 n, err := client.Read(buf) 440 441 if err != nil { 442 t.Errorf("Client.Read failed after KeyUpdate: %s", err) 443 return 444 } 445 446 buf = buf[:n] 447 if !bytes.Equal([]byte(opensslSentinel), buf) { 448 t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) 449 } 450 }() 451 452 if write { 453 // There's no real reason to wait for the client KeyUpdate to 454 // send data with the new server keys, except that s_server 455 // drops writes if they are sent at the wrong time. 456 <-stdout.readKeyUpdate 457 stdin <- opensslSendSentinel 458 } 459 <-doneRead 460 461 if _, err := client.Write([]byte("hello again\n")); err != nil { 462 t.Errorf("Client.Write failed: %s", err) 463 return 464 } 465 } 466 467 if test.validate != nil { 468 if err := test.validate(client.ConnectionState()); err != nil { 469 t.Errorf("validate callback returned error: %s", err) 470 } 471 } 472 473 // If the server sent us an alert after our last flight, give it a 474 // chance to arrive. 475 if write && test.renegotiationExpectedToFail == 0 { 476 if err := peekError(client); err != nil { 477 t.Errorf("final Read returned an error: %s", err) 478 } 479 } 480 }() 481 482 if !write { 483 flows, err := test.loadData() 484 if err != nil { 485 t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) 486 } 487 for i, b := range flows { 488 if i%2 == 1 { 489 if *fast { 490 err := serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) 491 if err != nil { 492 panic(err) 493 } 494 } else { 495 err := serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) 496 if err != nil { 497 panic(err) 498 } 499 } 500 _, err := serverConn.Write(b) 501 if err != nil { 502 panic(err) 503 } 504 continue 505 } 506 bb := make([]byte, len(b)) 507 if *fast { 508 err := serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) 509 if err != nil { 510 panic(err) 511 } 512 } else { 513 err := serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) 514 if err != nil { 515 panic(err) 516 } 517 } 518 _, err := io.ReadFull(serverConn, bb) 519 if err != nil { 520 t.Fatalf("%s, flow %d: %s", test.name, i+1, err) 521 } 522 if !bytes.Equal(b, bb) { 523 t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) 524 } 525 } 526 } 527 528 <-doneChan 529 if !write { 530 err := serverConn.Close() 531 if err != nil { 532 panic(err) 533 } 534 } 535 536 if write { 537 path := test.dataPath() 538 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 539 if err != nil { 540 t.Fatalf("Failed to create output file: %s", err) 541 } 542 defer func(out *os.File) { 543 err := out.Close() 544 if err != nil { 545 panic(err) 546 } 547 }(out) 548 err = recordingConn.Close() 549 if err != nil { 550 panic(err) 551 } 552 close(stdin) 553 err = childProcess.Process.Kill() 554 if err != nil { 555 panic(err) 556 } 557 err = childProcess.Wait() 558 if err != nil { 559 panic(err) 560 } 561 if len(recordingConn.flows) < 3 { 562 t.Fatalf("Client connection didn't work") 563 } 564 _, err = recordingConn.WriteTo(out) 565 if err != nil { 566 panic(err) 567 } 568 t.Logf("Wrote %s\n", path) 569 } 570 } 571 572 // peekError does a read with a short timeout to check if the next read would 573 // cause an error, for example if there is an alert waiting on the wire. 574 func peekError(conn net.Conn) error { 575 err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 576 if err != nil { 577 return err 578 } 579 if n, err := conn.Read(make([]byte, 1)); n != 0 { 580 return errors.New("unexpectedly read data") 581 } else if err != nil { 582 if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { 583 return err 584 } 585 } 586 return nil 587 } 588 589 func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) { 590 // Make a deep copy of the template before going parallel. 591 test := *template 592 if template.config != nil { 593 test.config = template.config.Clone() 594 } 595 test.name = version + "-" + test.name 596 test.args = append([]string{option}, test.args...) 597 598 runTestAndUpdateIfNeeded(t, version, test.run, false) 599 } 600 601 func runClientTestTLS10(t *testing.T, template *clientTest) { 602 runClientTestForVersion(t, template, "TLSv10", "-tls1") 603 } 604 605 func runClientTestTLS11(t *testing.T, template *clientTest) { 606 runClientTestForVersion(t, template, "TLSv11", "-tls1_1") 607 } 608 609 func runClientTestTLS12(t *testing.T, template *clientTest) { 610 runClientTestForVersion(t, template, "TLSv12", "-tls1_2") 611 } 612 613 func runClientTestTLS13(t *testing.T, template *clientTest) { 614 runClientTestForVersion(t, template, "TLSv13", "-tls1_3") 615 } 616 617 func TestHandshakeClientRSARC4(t *testing.T) { 618 test := &clientTest{ 619 name: "RSA-RC4", 620 args: []string{"-cipher", "RC4-SHA"}, 621 } 622 runClientTestTLS10(t, test) 623 runClientTestTLS11(t, test) 624 runClientTestTLS12(t, test) 625 } 626 627 func TestHandshakeClientRSAAES128GCM(t *testing.T) { 628 test := &clientTest{ 629 name: "AES128-GCM-SHA256", 630 args: []string{"-cipher", "AES128-GCM-SHA256"}, 631 } 632 runClientTestTLS12(t, test) 633 } 634 635 func TestHandshakeClientRSAAES256GCM(t *testing.T) { 636 test := &clientTest{ 637 name: "AES256-GCM-SHA384", 638 args: []string{"-cipher", "AES256-GCM-SHA384"}, 639 } 640 runClientTestTLS12(t, test) 641 } 642 643 func TestHandshakeClientECDHERSAAES(t *testing.T) { 644 test := &clientTest{ 645 name: "ECDHE-RSA-AES", 646 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"}, 647 } 648 runClientTestTLS10(t, test) 649 runClientTestTLS11(t, test) 650 runClientTestTLS12(t, test) 651 } 652 653 func TestHandshakeClientECDHEECDSAAES(t *testing.T) { 654 test := &clientTest{ 655 name: "ECDHE-ECDSA-AES", 656 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"}, 657 cert: testECDSACertificate, 658 key: testECDSAPrivateKey, 659 } 660 runClientTestTLS10(t, test) 661 runClientTestTLS11(t, test) 662 runClientTestTLS12(t, test) 663 } 664 665 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { 666 test := &clientTest{ 667 name: "ECDHE-ECDSA-AES-GCM", 668 args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, 669 cert: testECDSACertificate, 670 key: testECDSAPrivateKey, 671 } 672 runClientTestTLS12(t, test) 673 } 674 675 func TestHandshakeClientAES256GCMSHA384(t *testing.T) { 676 test := &clientTest{ 677 name: "ECDHE-ECDSA-AES256-GCM-SHA384", 678 args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, 679 cert: testECDSACertificate, 680 key: testECDSAPrivateKey, 681 } 682 runClientTestTLS12(t, test) 683 } 684 685 func TestHandshakeClientAES128CBCSHA256(t *testing.T) { 686 test := &clientTest{ 687 name: "AES128-SHA256", 688 args: []string{"-cipher", "AES128-SHA256"}, 689 } 690 runClientTestTLS12(t, test) 691 } 692 693 func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) { 694 test := &clientTest{ 695 name: "ECDHE-RSA-AES128-SHA256", 696 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"}, 697 } 698 runClientTestTLS12(t, test) 699 } 700 701 func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) { 702 test := &clientTest{ 703 name: "ECDHE-ECDSA-AES128-SHA256", 704 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"}, 705 cert: testECDSACertificate, 706 key: testECDSAPrivateKey, 707 } 708 runClientTestTLS12(t, test) 709 } 710 711 func TestHandshakeClientX25519(t *testing.T) { 712 config := testConfig.Clone() 713 config.CurvePreferences = []CurveID{X25519} 714 715 test := &clientTest{ 716 name: "X25519-ECDHE", 717 args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"}, 718 config: config, 719 } 720 721 runClientTestTLS12(t, test) 722 runClientTestTLS13(t, test) 723 } 724 725 func TestHandshakeClientP256(t *testing.T) { 726 config := testConfig.Clone() 727 config.CurvePreferences = []CurveID{CurveP256} 728 729 test := &clientTest{ 730 name: "P256-ECDHE", 731 args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, 732 config: config, 733 } 734 735 runClientTestTLS12(t, test) 736 runClientTestTLS13(t, test) 737 } 738 739 func TestHandshakeClientHelloRetryRequest(t *testing.T) { 740 config := testConfig.Clone() 741 config.CurvePreferences = []CurveID{X25519, CurveP256} 742 743 test := &clientTest{ 744 name: "HelloRetryRequest", 745 args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, 746 config: config, 747 } 748 749 runClientTestTLS13(t, test) 750 } 751 752 func TestHandshakeClientECDHERSAChaCha20(t *testing.T) { 753 config := testConfig.Clone() 754 config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305} 755 756 test := &clientTest{ 757 name: "ECDHE-RSA-CHACHA20-POLY1305", 758 args: []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"}, 759 config: config, 760 } 761 762 runClientTestTLS12(t, test) 763 } 764 765 func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) { 766 config := testConfig.Clone() 767 config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305} 768 769 test := &clientTest{ 770 name: "ECDHE-ECDSA-CHACHA20-POLY1305", 771 args: []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"}, 772 config: config, 773 cert: testECDSACertificate, 774 key: testECDSAPrivateKey, 775 } 776 777 runClientTestTLS12(t, test) 778 } 779 780 func TestHandshakeClientAES128SHA256(t *testing.T) { 781 test := &clientTest{ 782 name: "AES128-SHA256", 783 args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"}, 784 } 785 runClientTestTLS13(t, test) 786 } 787 func TestHandshakeClientAES256SHA384(t *testing.T) { 788 test := &clientTest{ 789 name: "AES256-SHA384", 790 args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"}, 791 } 792 runClientTestTLS13(t, test) 793 } 794 func TestHandshakeClientCHACHA20SHA256(t *testing.T) { 795 test := &clientTest{ 796 name: "CHACHA20-SHA256", 797 args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, 798 } 799 runClientTestTLS13(t, test) 800 } 801 802 func TestHandshakeClientECDSATLS13(t *testing.T) { 803 test := &clientTest{ 804 name: "ECDSA", 805 cert: testECDSACertificate, 806 key: testECDSAPrivateKey, 807 } 808 runClientTestTLS13(t, test) 809 } 810 811 func TestHandshakeClientEd25519(t *testing.T) { 812 test := &clientTest{ 813 name: "Ed25519", 814 cert: testEd25519Certificate, 815 key: testEd25519PrivateKey, 816 } 817 runClientTestTLS12(t, test) 818 runClientTestTLS13(t, test) 819 820 config := testConfig.Clone() 821 cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM)) 822 config.Certificates = []Certificate{cert} 823 824 test = &clientTest{ 825 name: "ClientCert-Ed25519", 826 args: []string{"-Verify", "1"}, 827 config: config, 828 } 829 830 runClientTestTLS12(t, test) 831 runClientTestTLS13(t, test) 832 } 833 834 func TestHandshakeClientCertRSA(t *testing.T) { 835 config := testConfig.Clone() 836 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 837 config.Certificates = []Certificate{cert} 838 839 test := &clientTest{ 840 name: "ClientCert-RSA-RSA", 841 args: []string{"-cipher", "AES128", "-Verify", "1"}, 842 config: config, 843 } 844 845 runClientTestTLS10(t, test) 846 runClientTestTLS12(t, test) 847 848 test = &clientTest{ 849 name: "ClientCert-RSA-ECDSA", 850 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, 851 config: config, 852 cert: testECDSACertificate, 853 key: testECDSAPrivateKey, 854 } 855 856 runClientTestTLS10(t, test) 857 runClientTestTLS12(t, test) 858 runClientTestTLS13(t, test) 859 860 test = &clientTest{ 861 name: "ClientCert-RSA-AES256-GCM-SHA384", 862 args: []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"}, 863 config: config, 864 cert: testRSACertificate, 865 key: testRSAPrivateKey, 866 } 867 868 runClientTestTLS12(t, test) 869 } 870 871 func TestHandshakeClientCertECDSA(t *testing.T) { 872 config := testConfig.Clone() 873 cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) 874 config.Certificates = []Certificate{cert} 875 876 test := &clientTest{ 877 name: "ClientCert-ECDSA-RSA", 878 args: []string{"-cipher", "AES128", "-Verify", "1"}, 879 config: config, 880 } 881 882 runClientTestTLS10(t, test) 883 runClientTestTLS12(t, test) 884 runClientTestTLS13(t, test) 885 886 test = &clientTest{ 887 name: "ClientCert-ECDSA-ECDSA", 888 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, 889 config: config, 890 cert: testECDSACertificate, 891 key: testECDSAPrivateKey, 892 } 893 894 runClientTestTLS10(t, test) 895 runClientTestTLS12(t, test) 896 } 897 898 // TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both 899 // client and server certificates. It also serves from both sides a certificate 900 // signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation 901 // works. 902 func TestHandshakeClientCertRSAPSS(t *testing.T) { 903 cert, err := x509.ParseCertificate(testRSAPSSCertificate) 904 if err != nil { 905 panic(err) 906 } 907 rootCAs := x509.NewCertPool() 908 rootCAs.AddCert(cert) 909 910 config := testConfig.Clone() 911 // Use GetClientCertificate to bypass the client certificate selection logic. 912 config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) { 913 return &Certificate{ 914 Certificate: [][]byte{testRSAPSSCertificate}, 915 PrivateKey: testRSAPrivateKey, 916 }, nil 917 } 918 config.RootCAs = rootCAs 919 920 test := &clientTest{ 921 name: "ClientCert-RSA-RSAPSS", 922 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", 923 "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"}, 924 config: config, 925 cert: testRSAPSSCertificate, 926 key: testRSAPrivateKey, 927 } 928 runClientTestTLS12(t, test) 929 runClientTestTLS13(t, test) 930 } 931 932 func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) { 933 config := testConfig.Clone() 934 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 935 config.Certificates = []Certificate{cert} 936 937 test := &clientTest{ 938 name: "ClientCert-RSA-RSAPKCS1v15", 939 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", 940 "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"}, 941 config: config, 942 } 943 944 runClientTestTLS12(t, test) 945 } 946 947 func TestClientKeyUpdate(t *testing.T) { 948 test := &clientTest{ 949 name: "KeyUpdate", 950 args: []string{"-state"}, 951 sendKeyUpdate: true, 952 } 953 runClientTestTLS13(t, test) 954 } 955 956 func TestResumption(t *testing.T) { 957 // t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) }) 958 t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) }) 959 } 960 961 func testResumption(t *testing.T, version uint16) { 962 if testing.Short() { 963 t.Skip("skipping in -short mode") 964 } 965 serverConfig := &Config{ 966 MaxVersion: version, 967 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 968 Certificates: testConfig.Certificates, 969 } 970 971 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 972 if err != nil { 973 panic(err) 974 } 975 976 rootCAs := x509.NewCertPool() 977 rootCAs.AddCert(issuer) 978 979 clientConfig := &Config{ 980 MaxVersion: version, 981 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, 982 ClientSessionCache: NewLRUClientSessionCache(32), 983 RootCAs: rootCAs, 984 ServerName: "example.golang", 985 } 986 987 testResumeState := func(test string, didResume bool) { 988 _, hs, err := testHandshake(t, clientConfig, serverConfig) 989 if err != nil { 990 t.Fatalf("%s: handshake failed: %s", test, err) 991 } 992 if hs.DidResume != didResume { 993 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) 994 } 995 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { 996 t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) 997 } 998 if got, want := hs.ServerName, clientConfig.ServerName; got != want { 999 t.Errorf("%s: server name %s, want %s", test, got, want) 1000 } 1001 } 1002 1003 getTicket := func() []byte { 1004 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket 1005 } 1006 deleteTicket := func() { 1007 ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey 1008 clientConfig.ClientSessionCache.Put(ticketKey, nil) 1009 } 1010 corruptTicket := func() { 1011 clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff 1012 } 1013 randomKey := func() [32]byte { 1014 var k [32]byte 1015 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { 1016 t.Fatalf("Failed to read new SessionTicketKey: %s", err) 1017 } 1018 return k 1019 } 1020 1021 testResumeState("Handshake", false) 1022 ticket := getTicket() 1023 testResumeState("Resume", true) 1024 if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 { 1025 t.Fatal("first ticket doesn't match ticket after resumption") 1026 } 1027 if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 { 1028 t.Fatal("ticket didn't change after resumption") 1029 } 1030 1031 // An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key. 1032 serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } 1033 testResumeState("ResumeWithOldTicket", true) 1034 if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) { 1035 t.Fatal("old first ticket matches the fresh one") 1036 } 1037 1038 // Now the session tickey key is expired, so a full handshake should occur. 1039 serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } 1040 testResumeState("ResumeWithExpiredTicket", false) 1041 if bytes.Equal(ticket, getTicket()) { 1042 t.Fatal("expired first ticket matches the fresh one") 1043 } 1044 1045 serverConfig.Time = func() time.Time { return time.Now() } // reset the time back 1046 key1 := randomKey() 1047 serverConfig.SetSessionTicketKeys([][32]byte{key1}) 1048 1049 testResumeState("InvalidSessionTicketKey", false) 1050 testResumeState("ResumeAfterInvalidSessionTicketKey", true) 1051 1052 key2 := randomKey() 1053 serverConfig.SetSessionTicketKeys([][32]byte{key2, key1}) 1054 ticket = getTicket() 1055 testResumeState("KeyChange", true) 1056 if bytes.Equal(ticket, getTicket()) { 1057 t.Fatal("new ticket wasn't included while resuming") 1058 } 1059 testResumeState("KeyChangeFinish", true) 1060 1061 // Age the session ticket a bit, but not yet expired. 1062 serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } 1063 testResumeState("OldSessionTicket", true) 1064 ticket = getTicket() 1065 // Expire the session ticket, which would force a full handshake. 1066 serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } 1067 testResumeState("ExpiredSessionTicket", false) 1068 if bytes.Equal(ticket, getTicket()) { 1069 t.Fatal("new ticket wasn't provided after old ticket expired") 1070 } 1071 1072 // Age the session ticket a bit at a time, but don't expire it. 1073 d := 0 * time.Hour 1074 for i := 0; i < 13; i++ { 1075 d += 12 * time.Hour 1076 serverConfig.Time = func() time.Time { return time.Now().Add(d) } 1077 testResumeState("OldSessionTicket", true) 1078 } 1079 // Expire it (now a little more than 7 days) and make sure a full 1080 // handshake occurs for TLS 1.2. Resumption should still occur for 1081 // TLS 1.3 since the client should be using a fresh ticket sent over 1082 // by the server. 1083 d += 12 * time.Hour 1084 serverConfig.Time = func() time.Time { return time.Now().Add(d) } 1085 if version == VersionTLS13 { 1086 testResumeState("ExpiredSessionTicket", true) 1087 } else { 1088 testResumeState("ExpiredSessionTicket", false) 1089 } 1090 if bytes.Equal(ticket, getTicket()) { 1091 t.Fatal("new ticket wasn't provided after old ticket expired") 1092 } 1093 1094 // Reset serverConfig to ensure that calling SetSessionTicketKeys 1095 // before the serverConfig is used works. 1096 serverConfig = &Config{ 1097 MaxVersion: version, 1098 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 1099 Certificates: testConfig.Certificates, 1100 } 1101 serverConfig.SetSessionTicketKeys([][32]byte{key2}) 1102 1103 testResumeState("FreshConfig", true) 1104 1105 // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF 1106 // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3. 1107 if version != VersionTLS13 { 1108 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} 1109 testResumeState("DifferentCipherSuite", false) 1110 testResumeState("DifferentCipherSuiteRecovers", true) 1111 } 1112 1113 deleteTicket() 1114 testResumeState("WithoutSessionTicket", false) 1115 1116 // Session resumption should work when using client certificates 1117 deleteTicket() 1118 serverConfig.ClientCAs = rootCAs 1119 serverConfig.ClientAuth = RequireAndVerifyClientCert 1120 clientConfig.Certificates = serverConfig.Certificates 1121 testResumeState("InitialHandshake", false) 1122 testResumeState("WithClientCertificates", true) 1123 serverConfig.ClientAuth = NoClientCert 1124 1125 // Tickets should be removed from the session cache on TLS handshake 1126 // failure, and the client should recover from a corrupted PSK 1127 testResumeState("FetchTicketToCorrupt", false) 1128 corruptTicket() 1129 _, _, err = testHandshake(t, clientConfig, serverConfig) 1130 if err == nil { 1131 t.Fatalf("handshake did not fail with a corrupted client secret") 1132 } 1133 testResumeState("AfterHandshakeFailure", false) 1134 1135 clientConfig.ClientSessionCache = nil 1136 testResumeState("WithoutSessionCache", false) 1137 } 1138 1139 func TestLRUClientSessionCache(t *testing.T) { 1140 // Initialize cache of capacity 4. 1141 cache := NewLRUClientSessionCache(4) 1142 cs := make([]ClientSessionState, 6) 1143 keys := []string{"0", "1", "2", "3", "4", "5", "6"} 1144 1145 // Add 4 entries to the cache and look them up. 1146 for i := 0; i < 4; i++ { 1147 cache.Put(keys[i], &cs[i]) 1148 } 1149 for i := 0; i < 4; i++ { 1150 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 1151 t.Fatalf("session cache failed lookup for added key: %s", keys[i]) 1152 } 1153 } 1154 1155 // Add 2 more entries to the cache. First 2 should be evicted. 1156 for i := 4; i < 6; i++ { 1157 cache.Put(keys[i], &cs[i]) 1158 } 1159 for i := 0; i < 2; i++ { 1160 if s, ok := cache.Get(keys[i]); ok || s != nil { 1161 t.Fatalf("session cache should have evicted key: %s", keys[i]) 1162 } 1163 } 1164 1165 // Touch entry 2. LRU should evict 3 next. 1166 cache.Get(keys[2]) 1167 cache.Put(keys[0], &cs[0]) 1168 if s, ok := cache.Get(keys[3]); ok || s != nil { 1169 t.Fatalf("session cache should have evicted key 3") 1170 } 1171 1172 // Update entry 0 in place. 1173 cache.Put(keys[0], &cs[3]) 1174 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { 1175 t.Fatalf("session cache failed update for key 0") 1176 } 1177 1178 // Calling Put with a nil entry deletes the key. 1179 cache.Put(keys[0], nil) 1180 if _, ok := cache.Get(keys[0]); ok { 1181 t.Fatalf("session cache failed to delete key 0") 1182 } 1183 1184 // Delete entry 2. LRU should keep 4 and 5 1185 cache.Put(keys[2], nil) 1186 if _, ok := cache.Get(keys[2]); ok { 1187 t.Fatalf("session cache failed to delete key 4") 1188 } 1189 for i := 4; i < 6; i++ { 1190 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 1191 t.Fatalf("session cache should not have deleted key: %s", keys[i]) 1192 } 1193 } 1194 } 1195 1196 func TestKeyLogTLS12(t *testing.T) { 1197 var serverBuf, clientBuf bytes.Buffer 1198 1199 clientConfig := testConfig.Clone() 1200 clientConfig.KeyLogWriter = &clientBuf 1201 clientConfig.MaxVersion = VersionTLS12 1202 1203 serverConfig := testConfig.Clone() 1204 serverConfig.KeyLogWriter = &serverBuf 1205 serverConfig.MaxVersion = VersionTLS12 1206 1207 c, s := localPipe(t) 1208 done := make(chan bool) 1209 1210 go func() { 1211 defer close(done) 1212 1213 if err := Server(s, serverConfig).Handshake(); err != nil { 1214 t.Errorf("server: %s", err) 1215 return 1216 } 1217 err := s.Close() 1218 if err != nil { 1219 panic(err) 1220 } 1221 }() 1222 1223 if err := Client(c, clientConfig).Handshake(); err != nil { 1224 t.Fatalf("client: %s", err) 1225 } 1226 1227 err := c.Close() 1228 if err != nil { 1229 panic(err) 1230 } 1231 <-done 1232 1233 checkKeylogLine := func(side, loggedLine string) { 1234 if len(loggedLine) == 0 { 1235 t.Fatalf("%s: no keylog line was produced", side) 1236 } 1237 const expectedLen = 13 /* "CLIENT_RANDOM" */ + 1238 1 /* space */ + 1239 32*2 /* hex client nonce */ + 1240 1 /* space */ + 1241 48*2 /* hex master secret */ + 1242 1 /* new line */ 1243 if len(loggedLine) != expectedLen { 1244 t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine) 1245 } 1246 if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") { 1247 t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine) 1248 } 1249 } 1250 1251 checkKeylogLine("client", clientBuf.String()) 1252 checkKeylogLine("server", serverBuf.String()) 1253 } 1254 1255 func TestKeyLogTLS13(t *testing.T) { 1256 var serverBuf, clientBuf bytes.Buffer 1257 1258 clientConfig := testConfig.Clone() 1259 clientConfig.KeyLogWriter = &clientBuf 1260 1261 serverConfig := testConfig.Clone() 1262 serverConfig.KeyLogWriter = &serverBuf 1263 1264 c, s := localPipe(t) 1265 done := make(chan bool) 1266 1267 go func() { 1268 defer close(done) 1269 1270 if err := Server(s, serverConfig).Handshake(); err != nil { 1271 t.Errorf("server: %s", err) 1272 return 1273 } 1274 err := s.Close() 1275 if err != nil { 1276 panic(err) 1277 } 1278 }() 1279 1280 if err := Client(c, clientConfig).Handshake(); err != nil { 1281 t.Fatalf("client: %s", err) 1282 } 1283 1284 err := c.Close() 1285 if err != nil { 1286 panic(err) 1287 } 1288 <-done 1289 1290 checkKeylogLines := func(side, loggedLines string) { 1291 loggedLines = strings.TrimSpace(loggedLines) 1292 lines := strings.Split(loggedLines, "\n") 1293 if len(lines) != 4 { 1294 t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines)) 1295 } 1296 } 1297 1298 checkKeylogLines("client", clientBuf.String()) 1299 checkKeylogLines("server", serverBuf.String()) 1300 } 1301 1302 func TestHandshakeClientALPNMatch(t *testing.T) { 1303 config := testConfig.Clone() 1304 config.NextProtos = []string{"proto2", "proto1"} 1305 1306 test := &clientTest{ 1307 name: "ALPN", 1308 // Note that this needs OpenSSL 1.0.2 because that is the first 1309 // version that supports the -alpn flag. 1310 args: []string{"-alpn", "proto1,proto2"}, 1311 config: config, 1312 validate: func(state ConnectionState) error { 1313 // The server's preferences should override the client. 1314 if state.NegotiatedProtocol != "proto1" { 1315 //goland:noinspection GoErrorStringFormat 1316 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) 1317 } 1318 return nil 1319 }, 1320 } 1321 runClientTestTLS12(t, test) 1322 runClientTestTLS13(t, test) 1323 } 1324 1325 func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) { 1326 // This checks that the server can't select an application protocol that the 1327 // client didn't offer. 1328 1329 c, s := localPipe(t) 1330 errChan := make(chan error, 1) 1331 1332 go func() { 1333 client := Client(c, &Config{ 1334 ServerName: "foo", 1335 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1336 NextProtos: []string{"http", "something-else"}, 1337 }) 1338 errChan <- client.Handshake() 1339 }() 1340 1341 var header [5]byte 1342 if _, err := io.ReadFull(s, header[:]); err != nil { 1343 t.Fatal(err) 1344 } 1345 recordLen := int(header[3])<<8 | int(header[4]) 1346 1347 record := make([]byte, recordLen) 1348 if _, err := io.ReadFull(s, record); err != nil { 1349 t.Fatal(err) 1350 } 1351 1352 serverHello := &serverHelloMsg{ 1353 vers: VersionTLS12, 1354 random: make([]byte, 32), 1355 cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256, 1356 alpnProtocol: "how-about-this", 1357 } 1358 serverHelloBytes := serverHello.marshal() 1359 1360 _, err := s.Write([]byte{ 1361 byte(recordTypeHandshake), 1362 byte(VersionTLS12 >> 8), 1363 byte(VersionTLS12 & 0xff), 1364 byte(len(serverHelloBytes) >> 8), 1365 byte(len(serverHelloBytes)), 1366 }) 1367 if err != nil { 1368 panic(err) 1369 } 1370 _, err = s.Write(serverHelloBytes) 1371 if err != nil { 1372 panic(err) 1373 } 1374 err = s.Close() 1375 if err != nil { 1376 panic(err) 1377 } 1378 1379 if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") { 1380 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) 1381 } 1382 } 1383 1384 // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` 1385 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" 1386 1387 func TestHandshakClientSCTs(t *testing.T) { 1388 config := testConfig.Clone() 1389 1390 scts, err := base64.StdEncoding.DecodeString(sctsBase64) 1391 if err != nil { 1392 t.Fatal(err) 1393 } 1394 1395 // Note that this needs OpenSSL 1.0.2 because that is the first 1396 // version that supports the -serverinfo flag. 1397 test := &clientTest{ 1398 name: "SCT", 1399 config: config, 1400 extensions: [][]byte{scts}, 1401 validate: func(state ConnectionState) error { 1402 expectedSCTs := [][]byte{ 1403 scts[8:125], 1404 scts[127:245], 1405 scts[247:], 1406 } 1407 if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { 1408 //goland:noinspection GoErrorStringFormat 1409 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) 1410 } 1411 for i, expected := range expectedSCTs { 1412 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { 1413 return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) 1414 } 1415 } 1416 return nil 1417 }, 1418 } 1419 runClientTestTLS12(t, test) 1420 1421 // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only 1422 // supports ServerHello extensions. 1423 } 1424 1425 func TestRenegotiationRejected(t *testing.T) { 1426 config := testConfig.Clone() 1427 test := &clientTest{ 1428 name: "RenegotiationRejected", 1429 args: []string{"-state"}, 1430 config: config, 1431 numRenegotiations: 1, 1432 renegotiationExpectedToFail: 1, 1433 checkRenegotiationError: func(renegotiationNum int, err error) error { 1434 if err == nil { 1435 return errors.New("expected error from renegotiation but got nil") 1436 } 1437 if !strings.Contains(err.Error(), "no renegotiation") { 1438 return fmt.Errorf("expected renegotiation to be rejected but got %q", err) 1439 } 1440 return nil 1441 }, 1442 } 1443 runClientTestTLS12(t, test) 1444 } 1445 1446 func TestRenegotiateOnce(t *testing.T) { 1447 config := testConfig.Clone() 1448 config.Renegotiation = RenegotiateOnceAsClient 1449 1450 test := &clientTest{ 1451 name: "RenegotiateOnce", 1452 args: []string{"-state"}, 1453 config: config, 1454 numRenegotiations: 1, 1455 } 1456 1457 runClientTestTLS12(t, test) 1458 } 1459 1460 func TestRenegotiateTwice(t *testing.T) { 1461 config := testConfig.Clone() 1462 config.Renegotiation = RenegotiateFreelyAsClient 1463 1464 test := &clientTest{ 1465 name: "RenegotiateTwice", 1466 args: []string{"-state"}, 1467 config: config, 1468 numRenegotiations: 2, 1469 } 1470 1471 runClientTestTLS12(t, test) 1472 } 1473 1474 func TestRenegotiateTwiceRejected(t *testing.T) { 1475 config := testConfig.Clone() 1476 config.Renegotiation = RenegotiateOnceAsClient 1477 1478 test := &clientTest{ 1479 name: "RenegotiateTwiceRejected", 1480 args: []string{"-state"}, 1481 config: config, 1482 numRenegotiations: 2, 1483 renegotiationExpectedToFail: 2, 1484 checkRenegotiationError: func(renegotiationNum int, err error) error { 1485 if renegotiationNum == 1 { 1486 return err 1487 } 1488 1489 if err == nil { 1490 return errors.New("expected error from renegotiation but got nil") 1491 } 1492 if !strings.Contains(err.Error(), "no renegotiation") { 1493 return fmt.Errorf("expected renegotiation to be rejected but got %q", err) 1494 } 1495 return nil 1496 }, 1497 } 1498 1499 runClientTestTLS12(t, test) 1500 } 1501 1502 func TestHandshakeClientExportKeyingMaterial(t *testing.T) { 1503 test := &clientTest{ 1504 name: "ExportKeyingMaterial", 1505 config: testConfig.Clone(), 1506 validate: func(state ConnectionState) error { 1507 if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil { 1508 return fmt.Errorf("ExportKeyingMaterial failed: %v", err) 1509 } else if len(km) != 42 { 1510 //goland:noinspection GoErrorStringFormat 1511 return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42) 1512 } 1513 return nil 1514 }, 1515 } 1516 runClientTestTLS10(t, test) 1517 runClientTestTLS12(t, test) 1518 runClientTestTLS13(t, test) 1519 } 1520 1521 var hostnameInSNITests = []struct { 1522 in, out string 1523 }{ 1524 // Opaque string 1525 {"", ""}, 1526 {"localhost", "localhost"}, 1527 {"foo, bar, baz and qux", "foo, bar, baz and qux"}, 1528 1529 // DNS hostname 1530 {"golang.org", "golang.org"}, 1531 {"golang.org.", "golang.org"}, 1532 1533 // Literal IPv4 address 1534 {"1.2.3.4", ""}, 1535 1536 // Literal IPv6 address 1537 {"::1", ""}, 1538 {"::1%lo0", ""}, // with zone identifier 1539 {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal 1540 {"[::1%lo0]", ""}, 1541 } 1542 1543 func TestHostnameInSNI(t *testing.T) { 1544 for _, tt := range hostnameInSNITests { 1545 c, s := localPipe(t) 1546 1547 go func(host string) { 1548 err := Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() 1549 if err != nil { 1550 panic(err) 1551 } 1552 }(tt.in) 1553 1554 var header [5]byte 1555 if _, err := io.ReadFull(s, header[:]); err != nil { 1556 t.Fatal(err) 1557 } 1558 recordLen := int(header[3])<<8 | int(header[4]) 1559 1560 record := make([]byte, recordLen) 1561 if _, err := io.ReadFull(s, record[:]); err != nil { 1562 t.Fatal(err) 1563 } 1564 1565 err := c.Close() 1566 if err != nil { 1567 panic(err) 1568 } 1569 err = s.Close() 1570 if err != nil { 1571 panic(err) 1572 } 1573 1574 var m clientHelloMsg 1575 if !m.unmarshal(record) { 1576 t.Errorf("unmarshaling ClientHello for %q failed", tt.in) 1577 continue 1578 } 1579 if tt.in != tt.out && m.serverName == tt.in { 1580 t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record) 1581 } 1582 if m.serverName != tt.out { 1583 t.Errorf("expected %q not found in ClientHello: %x", tt.out, record) 1584 } 1585 } 1586 } 1587 1588 func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { 1589 // This checks that the server can't select a cipher suite that the 1590 // client didn't offer. See #13174. 1591 1592 c, s := localPipe(t) 1593 errChan := make(chan error, 1) 1594 1595 go func() { 1596 client := Client(c, &Config{ 1597 ServerName: "foo", 1598 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 1599 }) 1600 errChan <- client.Handshake() 1601 }() 1602 1603 var header [5]byte 1604 if _, err := io.ReadFull(s, header[:]); err != nil { 1605 t.Fatal(err) 1606 } 1607 recordLen := int(header[3])<<8 | int(header[4]) 1608 1609 record := make([]byte, recordLen) 1610 if _, err := io.ReadFull(s, record); err != nil { 1611 t.Fatal(err) 1612 } 1613 1614 // Create a ServerHello that selects a different cipher suite than the 1615 // sole one that the client offered. 1616 serverHello := &serverHelloMsg{ 1617 vers: VersionTLS12, 1618 random: make([]byte, 32), 1619 cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, 1620 } 1621 serverHelloBytes := serverHello.marshal() 1622 1623 _, err := s.Write([]byte{ 1624 byte(recordTypeHandshake), 1625 byte(VersionTLS12 >> 8), 1626 byte(VersionTLS12 & 0xff), 1627 byte(len(serverHelloBytes) >> 8), 1628 byte(len(serverHelloBytes)), 1629 }) 1630 if err != nil { 1631 panic(err) 1632 } 1633 _, err = s.Write(serverHelloBytes) 1634 if err != nil { 1635 panic(err) 1636 } 1637 err = s.Close() 1638 if err != nil { 1639 panic(err) 1640 } 1641 1642 if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") { 1643 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) 1644 } 1645 } 1646 1647 func TestVerifyConnection(t *testing.T) { 1648 t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) }) 1649 t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) }) 1650 } 1651 1652 func testVerifyConnection(t *testing.T, version uint16) { 1653 checkFields := func(c ConnectionState, called *int, errorType string) error { 1654 if c.Version != version { 1655 return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version) 1656 } 1657 if c.HandshakeComplete { 1658 return fmt.Errorf("%s: got HandshakeComplete, want false", errorType) 1659 } 1660 if c.ServerName != "example.golang" { 1661 return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang") 1662 } 1663 if c.NegotiatedProtocol != "protocol1" { 1664 return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1") 1665 } 1666 if c.CipherSuite == 0 { 1667 return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType) 1668 } 1669 wantDidResume := false 1670 if *called == 2 { // if this is the second time, then it should be a resumption 1671 wantDidResume = true 1672 } 1673 if c.DidResume != wantDidResume { 1674 return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume) 1675 } 1676 return nil 1677 } 1678 1679 tests := []struct { 1680 name string 1681 configureServer func(*Config, *int) 1682 configureClient func(*Config, *int) 1683 }{ 1684 { 1685 name: "RequireAndVerifyClientCert", 1686 configureServer: func(config *Config, called *int) { 1687 config.ClientAuth = RequireAndVerifyClientCert 1688 config.VerifyConnection = func(c ConnectionState) error { 1689 *called++ 1690 if l := len(c.PeerCertificates); l != 1 { 1691 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) 1692 } 1693 if len(c.VerifiedChains) == 0 { 1694 return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") 1695 } 1696 return checkFields(c, called, "server") 1697 } 1698 }, 1699 configureClient: func(config *Config, called *int) { 1700 config.VerifyConnection = func(c ConnectionState) error { 1701 *called++ 1702 if l := len(c.PeerCertificates); l != 1 { 1703 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) 1704 } 1705 if len(c.VerifiedChains) == 0 { 1706 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") 1707 } 1708 if c.DidResume { 1709 return nil 1710 // The SCTs and OCSP Response are dropped on resumption. 1711 // See http://golang.org/issue/39075. 1712 } 1713 if len(c.OCSPResponse) == 0 { 1714 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") 1715 } 1716 if len(c.SignedCertificateTimestamps) == 0 { 1717 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") 1718 } 1719 return checkFields(c, called, "client") 1720 } 1721 }, 1722 }, 1723 { 1724 name: "InsecureSkipVerify", 1725 configureServer: func(config *Config, called *int) { 1726 config.ClientAuth = RequireAnyClientCert 1727 config.InsecureSkipVerify = true 1728 config.VerifyConnection = func(c ConnectionState) error { 1729 *called++ 1730 if l := len(c.PeerCertificates); l != 1 { 1731 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) 1732 } 1733 if c.VerifiedChains != nil { 1734 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) 1735 } 1736 return checkFields(c, called, "server") 1737 } 1738 }, 1739 configureClient: func(config *Config, called *int) { 1740 config.InsecureSkipVerify = true 1741 config.VerifyConnection = func(c ConnectionState) error { 1742 *called++ 1743 if l := len(c.PeerCertificates); l != 1 { 1744 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) 1745 } 1746 if c.VerifiedChains != nil { 1747 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) 1748 } 1749 if c.DidResume { 1750 return nil 1751 // The SCTs and OCSP Response are dropped on resumption. 1752 // See http://golang.org/issue/39075. 1753 } 1754 if len(c.OCSPResponse) == 0 { 1755 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") 1756 } 1757 if len(c.SignedCertificateTimestamps) == 0 { 1758 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") 1759 } 1760 return checkFields(c, called, "client") 1761 } 1762 }, 1763 }, 1764 { 1765 name: "NoClientCert", 1766 configureServer: func(config *Config, called *int) { 1767 config.ClientAuth = NoClientCert 1768 config.VerifyConnection = func(c ConnectionState) error { 1769 *called++ 1770 return checkFields(c, called, "server") 1771 } 1772 }, 1773 configureClient: func(config *Config, called *int) { 1774 config.VerifyConnection = func(c ConnectionState) error { 1775 *called++ 1776 return checkFields(c, called, "client") 1777 } 1778 }, 1779 }, 1780 { 1781 name: "RequestClientCert", 1782 configureServer: func(config *Config, called *int) { 1783 config.ClientAuth = RequestClientCert 1784 config.VerifyConnection = func(c ConnectionState) error { 1785 *called++ 1786 return checkFields(c, called, "server") 1787 } 1788 }, 1789 configureClient: func(config *Config, called *int) { 1790 config.Certificates = nil // clear the client cert 1791 config.VerifyConnection = func(c ConnectionState) error { 1792 *called++ 1793 if l := len(c.PeerCertificates); l != 1 { 1794 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) 1795 } 1796 if len(c.VerifiedChains) == 0 { 1797 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") 1798 } 1799 if c.DidResume { 1800 return nil 1801 // The SCTs and OCSP Response are dropped on resumption. 1802 // See http://golang.org/issue/39075. 1803 } 1804 if len(c.OCSPResponse) == 0 { 1805 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") 1806 } 1807 if len(c.SignedCertificateTimestamps) == 0 { 1808 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") 1809 } 1810 return checkFields(c, called, "client") 1811 } 1812 }, 1813 }, 1814 } 1815 for _, test := range tests { 1816 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1817 if err != nil { 1818 panic(err) 1819 } 1820 rootCAs := x509.NewCertPool() 1821 rootCAs.AddCert(issuer) 1822 1823 var serverCalled, clientCalled int 1824 1825 serverConfig := &Config{ 1826 MaxVersion: version, 1827 Certificates: []Certificate{testConfig.Certificates[0]}, 1828 ClientCAs: rootCAs, 1829 NextProtos: []string{"protocol1"}, 1830 } 1831 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} 1832 serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp") 1833 test.configureServer(serverConfig, &serverCalled) 1834 1835 clientConfig := &Config{ 1836 MaxVersion: version, 1837 ClientSessionCache: NewLRUClientSessionCache(32), 1838 RootCAs: rootCAs, 1839 ServerName: "example.golang", 1840 Certificates: []Certificate{testConfig.Certificates[0]}, 1841 NextProtos: []string{"protocol1"}, 1842 } 1843 test.configureClient(clientConfig, &clientCalled) 1844 1845 testHandshakeState := func(name string, didResume bool) { 1846 _, hs, err := testHandshake(t, clientConfig, serverConfig) 1847 if err != nil { 1848 t.Fatalf("%s: handshake failed: %s", name, err) 1849 } 1850 if hs.DidResume != didResume { 1851 t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume) 1852 } 1853 wantCalled := 1 1854 if didResume { 1855 wantCalled = 2 // resumption would mean this is the second time it was called in this test 1856 } 1857 if clientCalled != wantCalled { 1858 t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled) 1859 } 1860 if serverCalled != wantCalled { 1861 t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled) 1862 } 1863 } 1864 testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false) 1865 testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true) 1866 } 1867 } 1868 1869 func TestVerifyPeerCertificate(t *testing.T) { 1870 t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) }) 1871 t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) }) 1872 } 1873 1874 func testVerifyPeerCertificate(t *testing.T, version uint16) { 1875 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 1876 if err != nil { 1877 panic(err) 1878 } 1879 1880 rootCAs := x509.NewCertPool() 1881 rootCAs.AddCert(issuer) 1882 1883 now := func() time.Time { return time.Unix(1476984729, 0) } 1884 1885 sentinelErr := errors.New("TestVerifyPeerCertificate") 1886 1887 verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 1888 if l := len(rawCerts); l != 1 { 1889 return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) 1890 } 1891 if len(validatedChains) == 0 { 1892 return errors.New("got len(validatedChains) = 0, wanted non-zero") 1893 } 1894 *called = true 1895 return nil 1896 } 1897 verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error { 1898 if l := len(c.PeerCertificates); l != 1 { 1899 return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l) 1900 } 1901 if len(c.VerifiedChains) == 0 { 1902 return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero") 1903 } 1904 if isClient && len(c.OCSPResponse) == 0 { 1905 return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero") 1906 } 1907 *called = true 1908 return nil 1909 } 1910 1911 tests := []struct { 1912 configureServer func(*Config, *bool) 1913 configureClient func(*Config, *bool) 1914 validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) 1915 }{ 1916 { 1917 configureServer: func(config *Config, called *bool) { 1918 config.InsecureSkipVerify = false 1919 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 1920 return verifyPeerCertificateCallback(called, rawCerts, validatedChains) 1921 } 1922 }, 1923 configureClient: func(config *Config, called *bool) { 1924 config.InsecureSkipVerify = false 1925 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 1926 return verifyPeerCertificateCallback(called, rawCerts, validatedChains) 1927 } 1928 }, 1929 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 1930 if clientErr != nil { 1931 t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) 1932 } 1933 if serverErr != nil { 1934 t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) 1935 } 1936 if !clientCalled { 1937 t.Errorf("test[%d]: client did not call callback", testNo) 1938 } 1939 if !serverCalled { 1940 t.Errorf("test[%d]: server did not call callback", testNo) 1941 } 1942 }, 1943 }, 1944 { 1945 configureServer: func(config *Config, called *bool) { 1946 config.InsecureSkipVerify = false 1947 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 1948 return sentinelErr 1949 } 1950 }, 1951 configureClient: func(config *Config, called *bool) { 1952 config.VerifyPeerCertificate = nil 1953 }, 1954 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 1955 if serverErr != sentinelErr { 1956 t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) 1957 } 1958 }, 1959 }, 1960 { 1961 configureServer: func(config *Config, called *bool) { 1962 config.InsecureSkipVerify = false 1963 }, 1964 configureClient: func(config *Config, called *bool) { 1965 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 1966 return sentinelErr 1967 } 1968 }, 1969 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 1970 if clientErr != sentinelErr { 1971 t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) 1972 } 1973 }, 1974 }, 1975 { 1976 configureServer: func(config *Config, called *bool) { 1977 config.InsecureSkipVerify = false 1978 }, 1979 configureClient: func(config *Config, called *bool) { 1980 config.InsecureSkipVerify = true 1981 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 1982 if l := len(rawCerts); l != 1 { 1983 return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) 1984 } 1985 // With InsecureSkipVerify set, this 1986 // callback should still be called but 1987 // validatedChains must be empty. 1988 if l := len(validatedChains); l != 0 { 1989 return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l) 1990 } 1991 *called = true 1992 return nil 1993 } 1994 }, 1995 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 1996 if clientErr != nil { 1997 t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) 1998 } 1999 if serverErr != nil { 2000 t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) 2001 } 2002 if !clientCalled { 2003 t.Errorf("test[%d]: client did not call callback", testNo) 2004 } 2005 }, 2006 }, 2007 { 2008 configureServer: func(config *Config, called *bool) { 2009 config.InsecureSkipVerify = false 2010 config.VerifyConnection = func(c ConnectionState) error { 2011 return verifyConnectionCallback(called, false, c) 2012 } 2013 }, 2014 configureClient: func(config *Config, called *bool) { 2015 config.InsecureSkipVerify = false 2016 config.VerifyConnection = func(c ConnectionState) error { 2017 return verifyConnectionCallback(called, true, c) 2018 } 2019 }, 2020 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 2021 if clientErr != nil { 2022 t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) 2023 } 2024 if serverErr != nil { 2025 t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) 2026 } 2027 if !clientCalled { 2028 t.Errorf("test[%d]: client did not call callback", testNo) 2029 } 2030 if !serverCalled { 2031 t.Errorf("test[%d]: server did not call callback", testNo) 2032 } 2033 }, 2034 }, 2035 { 2036 configureServer: func(config *Config, called *bool) { 2037 config.InsecureSkipVerify = false 2038 config.VerifyConnection = func(c ConnectionState) error { 2039 return sentinelErr 2040 } 2041 }, 2042 configureClient: func(config *Config, called *bool) { 2043 config.InsecureSkipVerify = false 2044 config.VerifyConnection = nil 2045 }, 2046 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 2047 if serverErr != sentinelErr { 2048 t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) 2049 } 2050 }, 2051 }, 2052 { 2053 configureServer: func(config *Config, called *bool) { 2054 config.InsecureSkipVerify = false 2055 config.VerifyConnection = nil 2056 }, 2057 configureClient: func(config *Config, called *bool) { 2058 config.InsecureSkipVerify = false 2059 config.VerifyConnection = func(c ConnectionState) error { 2060 return sentinelErr 2061 } 2062 }, 2063 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 2064 if clientErr != sentinelErr { 2065 t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) 2066 } 2067 }, 2068 }, 2069 { 2070 configureServer: func(config *Config, called *bool) { 2071 config.InsecureSkipVerify = false 2072 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 2073 return verifyPeerCertificateCallback(called, rawCerts, validatedChains) 2074 } 2075 config.VerifyConnection = func(c ConnectionState) error { 2076 return sentinelErr 2077 } 2078 }, 2079 configureClient: func(config *Config, called *bool) { 2080 config.InsecureSkipVerify = false 2081 config.VerifyPeerCertificate = nil 2082 config.VerifyConnection = nil 2083 }, 2084 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 2085 if serverErr != sentinelErr { 2086 t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) 2087 } 2088 if !serverCalled { 2089 t.Errorf("test[%d]: server did not call callback", testNo) 2090 } 2091 }, 2092 }, 2093 { 2094 configureServer: func(config *Config, called *bool) { 2095 config.InsecureSkipVerify = false 2096 config.VerifyPeerCertificate = nil 2097 config.VerifyConnection = nil 2098 }, 2099 configureClient: func(config *Config, called *bool) { 2100 config.InsecureSkipVerify = false 2101 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { 2102 return verifyPeerCertificateCallback(called, rawCerts, validatedChains) 2103 } 2104 config.VerifyConnection = func(c ConnectionState) error { 2105 return sentinelErr 2106 } 2107 }, 2108 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { 2109 if clientErr != sentinelErr { 2110 t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) 2111 } 2112 if !clientCalled { 2113 t.Errorf("test[%d]: client did not call callback", testNo) 2114 } 2115 }, 2116 }, 2117 } 2118 2119 for i, test := range tests { 2120 c, s := localPipe(t) 2121 done := make(chan error) 2122 2123 var clientCalled, serverCalled bool 2124 2125 go func() { 2126 config := testConfig.Clone() 2127 config.ServerName = "example.golang" 2128 config.ClientAuth = RequireAndVerifyClientCert 2129 config.ClientCAs = rootCAs 2130 config.Time = now 2131 config.MaxVersion = version 2132 config.Certificates = make([]Certificate, 1) 2133 config.Certificates[0].Certificate = [][]byte{testRSACertificate} 2134 config.Certificates[0].PrivateKey = testRSAPrivateKey 2135 config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} 2136 config.Certificates[0].OCSPStaple = []byte("dummy ocsp") 2137 test.configureServer(config, &serverCalled) 2138 2139 err = Server(s, config).Handshake() 2140 err := s.Close() 2141 if err != nil { 2142 panic(err) 2143 } 2144 done <- err 2145 }() 2146 2147 config := testConfig.Clone() 2148 config.ServerName = "example.golang" 2149 config.RootCAs = rootCAs 2150 config.Time = now 2151 config.MaxVersion = version 2152 test.configureClient(config, &clientCalled) 2153 clientErr := Client(c, config).Handshake() 2154 err := c.Close() 2155 if err != nil { 2156 panic(err) 2157 } 2158 serverErr := <-done 2159 2160 test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr) 2161 } 2162 } 2163 2164 // brokenConn wraps a net.Conn and causes all Writes after a certain number to 2165 // fail with brokenConnErr. 2166 type brokenConn struct { 2167 net.Conn 2168 2169 // breakAfter is the number of successful writes that will be allowed 2170 // before all subsequent writes fail. 2171 breakAfter int 2172 2173 // numWrites is the number of writes that have been done. 2174 numWrites int 2175 } 2176 2177 // errbrokenConn is the error that brokenConn returns once exhausted. 2178 var errbrokenConn = errors.New("too many writes to brokenConn") 2179 2180 func (b *brokenConn) Write(data []byte) (int, error) { 2181 if b.numWrites >= b.breakAfter { 2182 return 0, errbrokenConn 2183 } 2184 2185 b.numWrites++ 2186 return b.Conn.Write(data) 2187 } 2188 2189 func TestFailedWrite(t *testing.T) { 2190 // Test that a write error during the handshake is returned. 2191 for _, breakAfter := range []int{0, 1} { 2192 c, s := localPipe(t) 2193 done := make(chan bool) 2194 2195 go func() { 2196 err := Server(s, testConfig).Handshake() 2197 if err != nil { 2198 panic(err) 2199 } 2200 err = s.Close() 2201 if err != nil { 2202 panic(err) 2203 } 2204 done <- true 2205 }() 2206 2207 brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} 2208 err := Client(brokenC, testConfig).Handshake() 2209 if err != errbrokenConn { 2210 t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) 2211 } 2212 err = brokenC.Close() 2213 if err != nil { 2214 panic(err) 2215 } 2216 2217 <-done 2218 } 2219 } 2220 2221 // writeCountingConn wraps a net.Conn and counts the number of Write calls. 2222 type writeCountingConn struct { 2223 net.Conn 2224 2225 // numWrites is the number of writes that have been done. 2226 numWrites int 2227 } 2228 2229 func (wcc *writeCountingConn) Write(data []byte) (int, error) { 2230 wcc.numWrites++ 2231 return wcc.Conn.Write(data) 2232 } 2233 2234 func TestBuffering(t *testing.T) { 2235 t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) }) 2236 t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) }) 2237 } 2238 2239 func testBuffering(t *testing.T, version uint16) { 2240 c, s := localPipe(t) 2241 done := make(chan bool) 2242 2243 clientWCC := &writeCountingConn{Conn: c} 2244 serverWCC := &writeCountingConn{Conn: s} 2245 2246 go func() { 2247 config := testConfig.Clone() 2248 config.MaxVersion = version 2249 err := Server(serverWCC, config).Handshake() 2250 if err != nil { 2251 panic(err) 2252 } 2253 err = serverWCC.Close() 2254 if err != nil { 2255 panic(err) 2256 } 2257 done <- true 2258 }() 2259 2260 err := Client(clientWCC, testConfig).Handshake() 2261 if err != nil { 2262 t.Fatal(err) 2263 } 2264 err = clientWCC.Close() 2265 if err != nil { 2266 panic(err) 2267 } 2268 <-done 2269 2270 var expectedClient, expectedServer int 2271 if version == VersionTLS13 { 2272 expectedClient = 2 2273 expectedServer = 1 2274 } else { 2275 expectedClient = 2 2276 expectedServer = 2 2277 } 2278 2279 if n := clientWCC.numWrites; n != expectedClient { 2280 t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n) 2281 } 2282 2283 if n := serverWCC.numWrites; n != expectedServer { 2284 t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n) 2285 } 2286 } 2287 2288 func TestAlertFlushing(t *testing.T) { 2289 c, s := localPipe(t) 2290 done := make(chan bool) 2291 2292 clientWCC := &writeCountingConn{Conn: c} 2293 serverWCC := &writeCountingConn{Conn: s} 2294 2295 serverConfig := testConfig.Clone() 2296 2297 // Cause a signature-time error 2298 brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey} 2299 brokenKey.D = big.NewInt(42) 2300 serverConfig.Certificates = []Certificate{{ 2301 Certificate: [][]byte{testRSACertificate}, 2302 PrivateKey: &brokenKey, 2303 }} 2304 2305 go func() { 2306 err := Server(serverWCC, serverConfig).Handshake() 2307 if err != nil { 2308 panic(err) 2309 } 2310 err = serverWCC.Close() 2311 if err != nil { 2312 panic(err) 2313 } 2314 done <- true 2315 }() 2316 2317 err := Client(clientWCC, testConfig).Handshake() 2318 if err == nil { 2319 t.Fatal("client unexpectedly returned no error") 2320 } 2321 2322 const expectedError = "remote error: tls: internal error" 2323 if e := err.Error(); !strings.Contains(e, expectedError) { 2324 t.Fatalf("expected to find %q in error but error was %q", expectedError, e) 2325 } 2326 err = clientWCC.Close() 2327 if err != nil { 2328 panic(err) 2329 } 2330 <-done 2331 2332 if n := serverWCC.numWrites; n != 1 { 2333 t.Errorf("expected server handshake to complete with one write, but saw %d", n) 2334 } 2335 } 2336 2337 func TestHandshakeRace(t *testing.T) { 2338 if testing.Short() { 2339 t.Skip("skipping in -short mode") 2340 } 2341 t.Parallel() 2342 // This test races a Read and Write to try and complete a handshake in 2343 // order to provide some evidence that there are no races or deadlocks 2344 // in the handshake locking. 2345 for i := 0; i < 32; i++ { 2346 c, s := localPipe(t) 2347 2348 go func() { 2349 server := Server(s, testConfig) 2350 if err := server.Handshake(); err != nil { 2351 panic(err) 2352 } 2353 2354 var request [1]byte 2355 if n, err := server.Read(request[:]); err != nil || n != 1 { 2356 panic(err) 2357 } 2358 2359 _, err := server.Write(request[:]) 2360 if err != nil { 2361 panic(err) 2362 } 2363 err = server.Close() 2364 if err != nil { 2365 panic(err) 2366 } 2367 }() 2368 2369 startWrite := make(chan struct{}) 2370 startRead := make(chan struct{}) 2371 readDone := make(chan struct{}, 1) 2372 2373 client := Client(c, testConfig) 2374 go func() { 2375 <-startWrite 2376 var request [1]byte 2377 _, err := client.Write(request[:]) 2378 if err != nil { 2379 panic(err) 2380 } 2381 }() 2382 2383 go func() { 2384 <-startRead 2385 var reply [1]byte 2386 if _, err := io.ReadFull(client, reply[:]); err != nil { 2387 panic(err) 2388 } 2389 err := c.Close() 2390 if err != nil { 2391 panic(err) 2392 } 2393 readDone <- struct{}{} 2394 }() 2395 2396 if i&1 == 1 { 2397 startWrite <- struct{}{} 2398 startRead <- struct{}{} 2399 } else { 2400 startRead <- struct{}{} 2401 startWrite <- struct{}{} 2402 } 2403 <-readDone 2404 } 2405 } 2406 2407 var getClientCertificateTests = []struct { 2408 setup func(*Config, *Config) 2409 expectedClientError string 2410 verify func(*testing.T, int, *ConnectionState) 2411 }{ 2412 { 2413 func(clientConfig, serverConfig *Config) { 2414 // Returning a Certificate with no certificate data 2415 // should result in an empty message being sent to the 2416 // server. 2417 serverConfig.ClientCAs = nil 2418 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { 2419 if len(cri.SignatureSchemes) == 0 { 2420 panic("empty SignatureSchemes") 2421 } 2422 if len(cri.AcceptableCAs) != 0 { 2423 panic("AcceptableCAs should have been empty") 2424 } 2425 return new(Certificate), nil 2426 } 2427 }, 2428 "", 2429 func(t *testing.T, testNum int, cs *ConnectionState) { 2430 if l := len(cs.PeerCertificates); l != 0 { 2431 t.Errorf("#%d: expected no certificates but got %d", testNum, l) 2432 } 2433 }, 2434 }, 2435 { 2436 func(clientConfig, serverConfig *Config) { 2437 // With TLS 1.1, the SignatureSchemes should be 2438 // synthesised from the supported certificate types. 2439 clientConfig.MaxVersion = VersionTLS11 2440 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { 2441 if len(cri.SignatureSchemes) == 0 { 2442 panic("empty SignatureSchemes") 2443 } 2444 return new(Certificate), nil 2445 } 2446 }, 2447 "", 2448 func(t *testing.T, testNum int, cs *ConnectionState) { 2449 if l := len(cs.PeerCertificates); l != 0 { 2450 t.Errorf("#%d: expected no certificates but got %d", testNum, l) 2451 } 2452 }, 2453 }, 2454 { 2455 func(clientConfig, serverConfig *Config) { 2456 // Returning an error should abort the handshake with 2457 // that error. 2458 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { 2459 return nil, errors.New("GetClientCertificate") 2460 } 2461 }, 2462 "GetClientCertificate", 2463 func(t *testing.T, testNum int, cs *ConnectionState) { 2464 }, 2465 }, 2466 { 2467 func(clientConfig, serverConfig *Config) { 2468 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { 2469 if len(cri.AcceptableCAs) == 0 { 2470 panic("empty AcceptableCAs") 2471 } 2472 cert := &Certificate{ 2473 Certificate: [][]byte{testRSACertificate}, 2474 PrivateKey: testRSAPrivateKey, 2475 } 2476 return cert, nil 2477 } 2478 }, 2479 "", 2480 func(t *testing.T, testNum int, cs *ConnectionState) { 2481 if len(cs.VerifiedChains) == 0 { 2482 t.Errorf("#%d: expected some verified chains, but found none", testNum) 2483 } 2484 }, 2485 }, 2486 } 2487 2488 func TestGetClientCertificate(t *testing.T) { 2489 t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) }) 2490 t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) }) 2491 } 2492 2493 func testGetClientCertificate(t *testing.T, version uint16) { 2494 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 2495 if err != nil { 2496 panic(err) 2497 } 2498 2499 for i, test := range getClientCertificateTests { 2500 serverConfig := testConfig.Clone() 2501 serverConfig.ClientAuth = VerifyClientCertIfGiven 2502 serverConfig.RootCAs = x509.NewCertPool() 2503 serverConfig.RootCAs.AddCert(issuer) 2504 serverConfig.ClientCAs = serverConfig.RootCAs 2505 serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } 2506 serverConfig.MaxVersion = version 2507 2508 clientConfig := testConfig.Clone() 2509 clientConfig.MaxVersion = version 2510 2511 test.setup(clientConfig, serverConfig) 2512 2513 type serverResult struct { 2514 cs ConnectionState 2515 err error 2516 } 2517 2518 c, s := localPipe(t) 2519 done := make(chan serverResult) 2520 2521 go func() { 2522 defer func(s net.Conn) { 2523 err := s.Close() 2524 if err != nil { 2525 panic(err) 2526 } 2527 }(s) 2528 server := Server(s, serverConfig) 2529 err := server.Handshake() 2530 2531 var cs ConnectionState 2532 if err == nil { 2533 cs = server.ConnectionState() 2534 } 2535 done <- serverResult{cs, err} 2536 }() 2537 2538 clientErr := Client(c, clientConfig).Handshake() 2539 err := c.Close() 2540 if err != nil { 2541 panic(err) 2542 } 2543 2544 result := <-done 2545 2546 if clientErr != nil { 2547 if len(test.expectedClientError) == 0 { 2548 t.Errorf("#%d: client error: %v", i, clientErr) 2549 } else if got := clientErr.Error(); got != test.expectedClientError { 2550 t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got) 2551 } else { 2552 test.verify(t, i, &result.cs) 2553 } 2554 } else if len(test.expectedClientError) > 0 { 2555 t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError) 2556 } else if err := result.err; err != nil { 2557 t.Errorf("#%d: server error: %v", i, err) 2558 } else { 2559 test.verify(t, i, &result.cs) 2560 } 2561 } 2562 } 2563 2564 func TestRSAPSSKeyError(t *testing.T) { 2565 // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for 2566 // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with 2567 // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't 2568 // parse, or that they don't carry *rsa.PublicKey keys. 2569 b, _ := pem.Decode([]byte(` 2570 -----BEGIN CERTIFICATE----- 2571 MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK 2572 MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC 2573 AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3 2574 MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP 2575 ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z 2576 /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5 2577 b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL 2578 QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou 2579 czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT 2580 JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz 2581 AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn 2582 OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME 2583 AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab 2584 sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z 2585 H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1 2586 KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ 2587 bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD 2588 HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi 2589 RwBA9Xk1KBNF 2590 -----END CERTIFICATE-----`)) 2591 if b == nil { 2592 t.Fatal("Failed to decode certificate") 2593 } 2594 cert, err := x509.ParseCertificate(b.Bytes) 2595 if err != nil { 2596 return 2597 } 2598 if _, ok := cert.PublicKey.(*rsa.PublicKey); ok { 2599 t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms") 2600 } 2601 } 2602 2603 func TestCloseClientConnectionOnIdleServer(t *testing.T) { 2604 clientConn, serverConn := localPipe(t) 2605 client := Client(clientConn, testConfig.Clone()) 2606 go func() { 2607 var b [1]byte 2608 _, err := serverConn.Read(b[:]) 2609 if err != nil { 2610 panic(err) 2611 } 2612 err = client.Close() 2613 if err != nil { 2614 panic(err) 2615 } 2616 }() 2617 err := client.SetWriteDeadline(time.Now().Add(time.Minute)) 2618 if err != nil { 2619 panic(err) 2620 } 2621 err = client.Handshake() 2622 if err != nil { 2623 if err, ok := err.(net.Error); ok && err.Timeout() { 2624 t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) 2625 } 2626 } else { 2627 t.Errorf("Error expected, but no error returned") 2628 } 2629 } 2630 2631 func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error { 2632 defer func() { testingOnlyForceDowngradeCanary = false }() 2633 testingOnlyForceDowngradeCanary = true 2634 2635 clientConfig := testConfig.Clone() 2636 clientConfig.MaxVersion = clientVersion 2637 serverConfig := testConfig.Clone() 2638 serverConfig.MaxVersion = serverVersion 2639 _, _, err := testHandshake(t, clientConfig, serverConfig) 2640 return err 2641 } 2642 2643 func TestDowngradeCanary(t *testing.T) { 2644 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil { 2645 t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected") 2646 } 2647 if testing.Short() { 2648 t.Skip("skipping the rest of the checks in short mode") 2649 } 2650 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil { 2651 t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected") 2652 } 2653 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil { 2654 t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected") 2655 } 2656 if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil { 2657 t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected") 2658 } 2659 if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil { 2660 t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected") 2661 } 2662 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil { 2663 t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3") 2664 } 2665 if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil { 2666 t.Errorf("client didn't ignore expected TLS 1.2 canary") 2667 } 2668 if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil { 2669 t.Errorf("client unexpectedly reacted to a canary in TLS 1.1") 2670 } 2671 if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil { 2672 t.Errorf("client unexpectedly reacted to a canary in TLS 1.0") 2673 } 2674 } 2675 2676 func TestResumptionKeepsOCSPAndSCT(t *testing.T) { 2677 t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) }) 2678 t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) }) 2679 } 2680 2681 func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { 2682 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 2683 if err != nil { 2684 t.Fatalf("failed to parse test issuer") 2685 } 2686 roots := x509.NewCertPool() 2687 roots.AddCert(issuer) 2688 clientConfig := &Config{ 2689 MaxVersion: ver, 2690 ClientSessionCache: NewLRUClientSessionCache(32), 2691 ServerName: "example.golang", 2692 RootCAs: roots, 2693 } 2694 serverConfig := testConfig.Clone() 2695 serverConfig.MaxVersion = ver 2696 serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3} 2697 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}} 2698 2699 _, ccs, err := testHandshake(t, clientConfig, serverConfig) 2700 if err != nil { 2701 t.Fatalf("handshake failed: %s", err) 2702 } 2703 // after a new session we expect to see OCSPResponse and 2704 // SignedCertificateTimestamps populated as usual 2705 if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { 2706 t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v", 2707 serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) 2708 } 2709 if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { 2710 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v", 2711 serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) 2712 } 2713 2714 // if the server doesn't send any SCTs, repopulate the old SCTs 2715 oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps 2716 serverConfig.Certificates[0].SignedCertificateTimestamps = nil 2717 _, ccs, err = testHandshake(t, clientConfig, serverConfig) 2718 if err != nil { 2719 t.Fatalf("handshake failed: %s", err) 2720 } 2721 if !ccs.DidResume { 2722 t.Fatalf("expected session to be resumed") 2723 } 2724 // after a resumed session we also expect to see OCSPResponse 2725 // and SignedCertificateTimestamps populated 2726 if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { 2727 t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v", 2728 serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) 2729 } 2730 if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) { 2731 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", 2732 oldSCTs, ccs.SignedCertificateTimestamps) 2733 } 2734 2735 // Only test overriding the SCTs for TLS 1.2, since in 1.3 2736 // the server won't send the message containing them 2737 if ver == VersionTLS13 { 2738 return 2739 } 2740 2741 // if the server changes the SCTs it sends, they should override the saved SCTs 2742 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}} 2743 _, ccs, err = testHandshake(t, clientConfig, serverConfig) 2744 if err != nil { 2745 t.Fatalf("handshake failed: %s", err) 2746 } 2747 if !ccs.DidResume { 2748 t.Fatalf("expected session to be resumed") 2749 } 2750 if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { 2751 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", 2752 serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) 2753 } 2754 } 2755 2756 // TestClientHandshakeContextCancellation tests that cancelling 2757 // the context given to the client side conn.HandshakeContext 2758 // interrupts the in-progress handshake. 2759 func TestClientHandshakeContextCancellation(t *testing.T) { 2760 c, s := localPipe(t) 2761 ctx, cancel := context.WithCancel(context.Background()) 2762 unblockServer := make(chan struct{}) 2763 defer close(unblockServer) 2764 go func() { 2765 cancel() 2766 <-unblockServer 2767 _ = s.Close() 2768 }() 2769 cli := Client(c, testConfig) 2770 // Initiates client side handshake, which will block until the client hello is read 2771 // by the server, unless the cancellation works. 2772 err := cli.HandshakeContext(ctx) 2773 if err == nil { 2774 t.Fatal("Client handshake did not error when the context was canceled") 2775 } 2776 if err != context.Canceled { 2777 t.Errorf("Unexpected client handshake error: %v", err) 2778 } 2779 if runtime.GOARCH == "wasm" { 2780 t.Skip("conn.Close does not error as expected when called multiple times on WASM") 2781 } 2782 err = cli.Close() 2783 if err == nil { 2784 t.Error("Client connection was not closed when the context was canceled") 2785 } 2786 }