github.com/sbinet/go@v0.0.0-20160827155028-54d7de7dd62b/src/crypto/tls/handshake_client_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "crypto/x509" 12 "encoding/base64" 13 "encoding/binary" 14 "encoding/pem" 15 "errors" 16 "fmt" 17 "io" 18 "net" 19 "os" 20 "os/exec" 21 "path/filepath" 22 "strconv" 23 "strings" 24 "testing" 25 "time" 26 ) 27 28 // Note: see comment in handshake_test.go for details of how the reference 29 // tests work. 30 31 // opensslInputEvent enumerates possible inputs that can be sent to an `openssl 32 // s_client` process. 33 type opensslInputEvent int 34 35 const ( 36 // opensslRenegotiate causes OpenSSL to request a renegotiation of the 37 // connection. 38 opensslRenegotiate opensslInputEvent = iota 39 40 // opensslSendBanner causes OpenSSL to send the contents of 41 // opensslSentinel on the connection. 42 opensslSendSentinel 43 ) 44 45 const opensslSentinel = "SENTINEL\n" 46 47 type opensslInput chan opensslInputEvent 48 49 func (i opensslInput) Read(buf []byte) (n int, err error) { 50 for event := range i { 51 switch event { 52 case opensslRenegotiate: 53 return copy(buf, []byte("R\n")), nil 54 case opensslSendSentinel: 55 return copy(buf, []byte(opensslSentinel)), nil 56 default: 57 panic("unknown event") 58 } 59 } 60 61 return 0, io.EOF 62 } 63 64 // opensslOutputSink is an io.Writer that receives the stdout and stderr from 65 // an `openssl` process and sends a value to handshakeComplete when it sees a 66 // log message from a completed server handshake. 67 type opensslOutputSink struct { 68 handshakeComplete chan struct{} 69 all []byte 70 line []byte 71 } 72 73 func newOpensslOutputSink() *opensslOutputSink { 74 return &opensslOutputSink{make(chan struct{}), nil, nil} 75 } 76 77 // opensslEndOfHandshake is a message that the “openssl s_server” tool will 78 // print when a handshake completes if run with “-state”. 79 const opensslEndOfHandshake = "SSL_accept:SSLv3 write finished A" 80 81 func (o *opensslOutputSink) Write(data []byte) (n int, err error) { 82 o.line = append(o.line, data...) 83 o.all = append(o.all, data...) 84 85 for { 86 i := bytes.Index(o.line, []byte{'\n'}) 87 if i < 0 { 88 break 89 } 90 91 if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) { 92 o.handshakeComplete <- struct{}{} 93 } 94 o.line = o.line[i+1:] 95 } 96 97 return len(data), nil 98 } 99 100 func (o *opensslOutputSink) WriteTo(w io.Writer) (int64, error) { 101 n, err := w.Write(o.all) 102 return int64(n), err 103 } 104 105 // clientTest represents a test of the TLS client handshake against a reference 106 // implementation. 107 type clientTest struct { 108 // name is a freeform string identifying the test and the file in which 109 // the expected results will be stored. 110 name string 111 // command, if not empty, contains a series of arguments for the 112 // command to run for the reference server. 113 command []string 114 // config, if not nil, contains a custom Config to use for this test. 115 config *Config 116 // cert, if not empty, contains a DER-encoded certificate for the 117 // reference server. 118 cert []byte 119 // key, if not nil, contains either a *rsa.PrivateKey or 120 // *ecdsa.PrivateKey which is the private key for the reference server. 121 key interface{} 122 // extensions, if not nil, contains a list of extension data to be returned 123 // from the ServerHello. The data should be in standard TLS format with 124 // a 2-byte uint16 type, 2-byte data length, followed by the extension data. 125 extensions [][]byte 126 // validate, if not nil, is a function that will be called with the 127 // ConnectionState of the resulting connection. It returns a non-nil 128 // error if the ConnectionState is unacceptable. 129 validate func(ConnectionState) error 130 // numRenegotiations is the number of times that the connection will be 131 // renegotiated. 132 numRenegotiations int 133 // renegotiationExpectedToFail, if not zero, is the number of the 134 // renegotiation attempt that is expected to fail. 135 renegotiationExpectedToFail int 136 // checkRenegotiationError, if not nil, is called with any error 137 // arising from renegotiation. It can map expected errors to nil to 138 // ignore them. 139 checkRenegotiationError func(renegotiationNum int, err error) error 140 } 141 142 var defaultServerCommand = []string{"openssl", "s_server"} 143 144 // connFromCommand starts the reference server process, connects to it and 145 // returns a recordingConn for the connection. The stdin return value is an 146 // opensslInput for the stdin of the child process. It must be closed before 147 // Waiting for child. 148 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) { 149 cert := testRSACertificate 150 if len(test.cert) > 0 { 151 cert = test.cert 152 } 153 certPath := tempFile(string(cert)) 154 defer os.Remove(certPath) 155 156 var key interface{} = testRSAPrivateKey 157 if test.key != nil { 158 key = test.key 159 } 160 var pemType string 161 var derBytes []byte 162 switch key := key.(type) { 163 case *rsa.PrivateKey: 164 pemType = "RSA" 165 derBytes = x509.MarshalPKCS1PrivateKey(key) 166 case *ecdsa.PrivateKey: 167 pemType = "EC" 168 var err error 169 derBytes, err = x509.MarshalECPrivateKey(key) 170 if err != nil { 171 panic(err) 172 } 173 default: 174 panic("unknown key type") 175 } 176 177 var pemOut bytes.Buffer 178 pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) 179 180 keyPath := tempFile(string(pemOut.Bytes())) 181 defer os.Remove(keyPath) 182 183 var command []string 184 if len(test.command) > 0 { 185 command = append(command, test.command...) 186 } else { 187 command = append(command, defaultServerCommand...) 188 } 189 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) 190 // serverPort contains the port that OpenSSL will listen on. OpenSSL 191 // can't take "0" as an argument here so we have to pick a number and 192 // hope that it's not in use on the machine. Since this only occurs 193 // when -update is given and thus when there's a human watching the 194 // test, this isn't too bad. 195 const serverPort = 24323 196 command = append(command, "-accept", strconv.Itoa(serverPort)) 197 198 if len(test.extensions) > 0 { 199 var serverInfo bytes.Buffer 200 for _, ext := range test.extensions { 201 pem.Encode(&serverInfo, &pem.Block{ 202 Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), 203 Bytes: ext, 204 }) 205 } 206 serverInfoPath := tempFile(serverInfo.String()) 207 defer os.Remove(serverInfoPath) 208 command = append(command, "-serverinfo", serverInfoPath) 209 } 210 211 if test.numRenegotiations > 0 { 212 found := false 213 for _, flag := range command[1:] { 214 if flag == "-state" { 215 found = true 216 break 217 } 218 } 219 220 if !found { 221 panic("-state flag missing to OpenSSL. You need this if testing renegotiation") 222 } 223 } 224 225 cmd := exec.Command(command[0], command[1:]...) 226 stdin = opensslInput(make(chan opensslInputEvent)) 227 cmd.Stdin = stdin 228 out := newOpensslOutputSink() 229 cmd.Stdout = out 230 cmd.Stderr = out 231 if err := cmd.Start(); err != nil { 232 return nil, nil, nil, nil, err 233 } 234 235 // OpenSSL does print an "ACCEPT" banner, but it does so *before* 236 // opening the listening socket, so we can't use that to wait until it 237 // has started listening. Thus we are forced to poll until we get a 238 // connection. 239 var tcpConn net.Conn 240 for i := uint(0); i < 5; i++ { 241 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ 242 IP: net.IPv4(127, 0, 0, 1), 243 Port: serverPort, 244 }) 245 if err == nil { 246 break 247 } 248 time.Sleep((1 << i) * 5 * time.Millisecond) 249 } 250 if err != nil { 251 close(stdin) 252 out.WriteTo(os.Stdout) 253 cmd.Process.Kill() 254 return nil, nil, nil, nil, cmd.Wait() 255 } 256 257 record := &recordingConn{ 258 Conn: tcpConn, 259 } 260 261 return record, cmd, stdin, out, nil 262 } 263 264 func (test *clientTest) dataPath() string { 265 return filepath.Join("testdata", "Client-"+test.name) 266 } 267 268 func (test *clientTest) loadData() (flows [][]byte, err error) { 269 in, err := os.Open(test.dataPath()) 270 if err != nil { 271 return nil, err 272 } 273 defer in.Close() 274 return parseTestData(in) 275 } 276 277 func (test *clientTest) run(t *testing.T, write bool) { 278 var clientConn, serverConn net.Conn 279 var recordingConn *recordingConn 280 var childProcess *exec.Cmd 281 var stdin opensslInput 282 var stdout *opensslOutputSink 283 284 if write { 285 var err error 286 recordingConn, childProcess, stdin, stdout, err = test.connFromCommand() 287 if err != nil { 288 t.Fatalf("Failed to start subcommand: %s", err) 289 } 290 clientConn = recordingConn 291 } else { 292 clientConn, serverConn = net.Pipe() 293 } 294 295 config := test.config 296 if config == nil { 297 config = testConfig 298 } 299 client := Client(clientConn, config) 300 301 doneChan := make(chan bool) 302 go func() { 303 defer func() { doneChan <- true }() 304 defer clientConn.Close() 305 defer client.Close() 306 307 if _, err := client.Write([]byte("hello\n")); err != nil { 308 t.Errorf("Client.Write failed: %s", err) 309 return 310 } 311 312 for i := 1; i <= test.numRenegotiations; i++ { 313 // The initial handshake will generate a 314 // handshakeComplete signal which needs to be quashed. 315 if i == 1 && write { 316 <-stdout.handshakeComplete 317 } 318 319 // OpenSSL will try to interleave application data and 320 // a renegotiation if we send both concurrently. 321 // Therefore: ask OpensSSL to start a renegotiation, run 322 // a goroutine to call client.Read and thus process the 323 // renegotiation request, watch for OpenSSL's stdout to 324 // indicate that the handshake is complete and, 325 // finally, have OpenSSL write something to cause 326 // client.Read to complete. 327 if write { 328 stdin <- opensslRenegotiate 329 } 330 331 signalChan := make(chan struct{}) 332 333 go func() { 334 defer func() { signalChan <- struct{}{} }() 335 336 buf := make([]byte, 256) 337 n, err := client.Read(buf) 338 339 if test.checkRenegotiationError != nil { 340 newErr := test.checkRenegotiationError(i, err) 341 if err != nil && newErr == nil { 342 return 343 } 344 err = newErr 345 } 346 347 if err != nil { 348 t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) 349 return 350 } 351 352 buf = buf[:n] 353 if !bytes.Equal([]byte(opensslSentinel), buf) { 354 t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) 355 } 356 357 if expected := i + 1; client.handshakes != expected { 358 t.Errorf("client should have recorded %d handshakes, but believes that %d have occured", expected, client.handshakes) 359 } 360 }() 361 362 if write && test.renegotiationExpectedToFail != i { 363 <-stdout.handshakeComplete 364 stdin <- opensslSendSentinel 365 } 366 <-signalChan 367 } 368 369 if test.validate != nil { 370 if err := test.validate(client.ConnectionState()); err != nil { 371 t.Errorf("validate callback returned error: %s", err) 372 } 373 } 374 }() 375 376 if !write { 377 flows, err := test.loadData() 378 if err != nil { 379 t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) 380 } 381 for i, b := range flows { 382 if i%2 == 1 { 383 serverConn.Write(b) 384 continue 385 } 386 bb := make([]byte, len(b)) 387 _, err := io.ReadFull(serverConn, bb) 388 if err != nil { 389 t.Fatalf("%s #%d: %s", test.name, i, err) 390 } 391 if !bytes.Equal(b, bb) { 392 t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b) 393 } 394 } 395 serverConn.Close() 396 } 397 398 <-doneChan 399 400 if write { 401 path := test.dataPath() 402 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 403 if err != nil { 404 t.Fatalf("Failed to create output file: %s", err) 405 } 406 defer out.Close() 407 recordingConn.Close() 408 close(stdin) 409 childProcess.Process.Kill() 410 childProcess.Wait() 411 if len(recordingConn.flows) < 3 { 412 childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout) 413 t.Fatalf("Client connection didn't work") 414 } 415 recordingConn.WriteTo(out) 416 fmt.Printf("Wrote %s\n", path) 417 } 418 } 419 420 func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) { 421 test := *template 422 test.name = prefix + test.name 423 if len(test.command) == 0 { 424 test.command = defaultClientCommand 425 } 426 test.command = append([]string(nil), test.command...) 427 test.command = append(test.command, option) 428 test.run(t, *update) 429 } 430 431 func runClientTestTLS10(t *testing.T, template *clientTest) { 432 runClientTestForVersion(t, template, "TLSv10-", "-tls1") 433 } 434 435 func runClientTestTLS11(t *testing.T, template *clientTest) { 436 runClientTestForVersion(t, template, "TLSv11-", "-tls1_1") 437 } 438 439 func runClientTestTLS12(t *testing.T, template *clientTest) { 440 runClientTestForVersion(t, template, "TLSv12-", "-tls1_2") 441 } 442 443 func TestHandshakeClientRSARC4(t *testing.T) { 444 test := &clientTest{ 445 name: "RSA-RC4", 446 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"}, 447 } 448 runClientTestTLS10(t, test) 449 runClientTestTLS11(t, test) 450 runClientTestTLS12(t, test) 451 } 452 453 func TestHandshakeClientRSAAES128GCM(t *testing.T) { 454 test := &clientTest{ 455 name: "AES128-GCM-SHA256", 456 command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"}, 457 } 458 runClientTestTLS12(t, test) 459 } 460 461 func TestHandshakeClientRSAAES256GCM(t *testing.T) { 462 test := &clientTest{ 463 name: "AES256-GCM-SHA384", 464 command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"}, 465 } 466 runClientTestTLS12(t, test) 467 } 468 469 func TestHandshakeClientECDHERSAAES(t *testing.T) { 470 test := &clientTest{ 471 name: "ECDHE-RSA-AES", 472 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"}, 473 } 474 runClientTestTLS10(t, test) 475 runClientTestTLS11(t, test) 476 runClientTestTLS12(t, test) 477 } 478 479 func TestHandshakeClientECDHEECDSAAES(t *testing.T) { 480 test := &clientTest{ 481 name: "ECDHE-ECDSA-AES", 482 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"}, 483 cert: testECDSACertificate, 484 key: testECDSAPrivateKey, 485 } 486 runClientTestTLS10(t, test) 487 runClientTestTLS11(t, test) 488 runClientTestTLS12(t, test) 489 } 490 491 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { 492 test := &clientTest{ 493 name: "ECDHE-ECDSA-AES-GCM", 494 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, 495 cert: testECDSACertificate, 496 key: testECDSAPrivateKey, 497 } 498 runClientTestTLS12(t, test) 499 } 500 501 func TestHandshakeClientAES256GCMSHA384(t *testing.T) { 502 test := &clientTest{ 503 name: "ECDHE-ECDSA-AES256-GCM-SHA384", 504 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, 505 cert: testECDSACertificate, 506 key: testECDSAPrivateKey, 507 } 508 runClientTestTLS12(t, test) 509 } 510 511 func TestHandshakeClientAES128CBCSHA256(t *testing.T) { 512 test := &clientTest{ 513 name: "AES128-SHA256", 514 command: []string{"openssl", "s_server", "-cipher", "AES128-SHA256"}, 515 } 516 runClientTestTLS12(t, test) 517 } 518 519 func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) { 520 test := &clientTest{ 521 name: "ECDHE-RSA-AES128-SHA256", 522 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA256"}, 523 } 524 runClientTestTLS12(t, test) 525 } 526 527 func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) { 528 test := &clientTest{ 529 name: "ECDHE-ECDSA-AES128-SHA256", 530 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA256"}, 531 cert: testECDSACertificate, 532 key: testECDSAPrivateKey, 533 } 534 runClientTestTLS12(t, test) 535 } 536 537 func TestHandshakeClientCertRSA(t *testing.T) { 538 config := testConfig.clone() 539 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 540 config.Certificates = []Certificate{cert} 541 542 test := &clientTest{ 543 name: "ClientCert-RSA-RSA", 544 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 545 config: config, 546 } 547 548 runClientTestTLS10(t, test) 549 runClientTestTLS12(t, test) 550 551 test = &clientTest{ 552 name: "ClientCert-RSA-ECDSA", 553 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 554 config: config, 555 cert: testECDSACertificate, 556 key: testECDSAPrivateKey, 557 } 558 559 runClientTestTLS10(t, test) 560 runClientTestTLS12(t, test) 561 562 test = &clientTest{ 563 name: "ClientCert-RSA-AES256-GCM-SHA384", 564 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"}, 565 config: config, 566 cert: testRSACertificate, 567 key: testRSAPrivateKey, 568 } 569 570 runClientTestTLS12(t, test) 571 } 572 573 func TestHandshakeClientCertECDSA(t *testing.T) { 574 config := testConfig.clone() 575 cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) 576 config.Certificates = []Certificate{cert} 577 578 test := &clientTest{ 579 name: "ClientCert-ECDSA-RSA", 580 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 581 config: config, 582 } 583 584 runClientTestTLS10(t, test) 585 runClientTestTLS12(t, test) 586 587 test = &clientTest{ 588 name: "ClientCert-ECDSA-ECDSA", 589 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 590 config: config, 591 cert: testECDSACertificate, 592 key: testECDSAPrivateKey, 593 } 594 595 runClientTestTLS10(t, test) 596 runClientTestTLS12(t, test) 597 } 598 599 func TestClientResumption(t *testing.T) { 600 serverConfig := &Config{ 601 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 602 Certificates: testConfig.Certificates, 603 } 604 605 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 606 if err != nil { 607 panic(err) 608 } 609 610 rootCAs := x509.NewCertPool() 611 rootCAs.AddCert(issuer) 612 613 clientConfig := &Config{ 614 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, 615 ClientSessionCache: NewLRUClientSessionCache(32), 616 RootCAs: rootCAs, 617 ServerName: "example.golang", 618 } 619 620 testResumeState := func(test string, didResume bool) { 621 _, hs, err := testHandshake(clientConfig, serverConfig) 622 if err != nil { 623 t.Fatalf("%s: handshake failed: %s", test, err) 624 } 625 if hs.DidResume != didResume { 626 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) 627 } 628 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { 629 t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) 630 } 631 } 632 633 getTicket := func() []byte { 634 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket 635 } 636 randomKey := func() [32]byte { 637 var k [32]byte 638 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { 639 t.Fatalf("Failed to read new SessionTicketKey: %s", err) 640 } 641 return k 642 } 643 644 testResumeState("Handshake", false) 645 ticket := getTicket() 646 testResumeState("Resume", true) 647 if !bytes.Equal(ticket, getTicket()) { 648 t.Fatal("first ticket doesn't match ticket after resumption") 649 } 650 651 key1 := randomKey() 652 serverConfig.SetSessionTicketKeys([][32]byte{key1}) 653 654 testResumeState("InvalidSessionTicketKey", false) 655 testResumeState("ResumeAfterInvalidSessionTicketKey", true) 656 657 key2 := randomKey() 658 serverConfig.SetSessionTicketKeys([][32]byte{key2, key1}) 659 ticket = getTicket() 660 testResumeState("KeyChange", true) 661 if bytes.Equal(ticket, getTicket()) { 662 t.Fatal("new ticket wasn't included while resuming") 663 } 664 testResumeState("KeyChangeFinish", true) 665 666 // Reset serverConfig to ensure that calling SetSessionTicketKeys 667 // before the serverConfig is used works. 668 serverConfig = &Config{ 669 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 670 Certificates: testConfig.Certificates, 671 } 672 serverConfig.SetSessionTicketKeys([][32]byte{key2}) 673 674 testResumeState("FreshConfig", true) 675 676 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} 677 testResumeState("DifferentCipherSuite", false) 678 testResumeState("DifferentCipherSuiteRecovers", true) 679 680 clientConfig.ClientSessionCache = nil 681 testResumeState("WithoutSessionCache", false) 682 } 683 684 func TestLRUClientSessionCache(t *testing.T) { 685 // Initialize cache of capacity 4. 686 cache := NewLRUClientSessionCache(4) 687 cs := make([]ClientSessionState, 6) 688 keys := []string{"0", "1", "2", "3", "4", "5", "6"} 689 690 // Add 4 entries to the cache and look them up. 691 for i := 0; i < 4; i++ { 692 cache.Put(keys[i], &cs[i]) 693 } 694 for i := 0; i < 4; i++ { 695 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 696 t.Fatalf("session cache failed lookup for added key: %s", keys[i]) 697 } 698 } 699 700 // Add 2 more entries to the cache. First 2 should be evicted. 701 for i := 4; i < 6; i++ { 702 cache.Put(keys[i], &cs[i]) 703 } 704 for i := 0; i < 2; i++ { 705 if s, ok := cache.Get(keys[i]); ok || s != nil { 706 t.Fatalf("session cache should have evicted key: %s", keys[i]) 707 } 708 } 709 710 // Touch entry 2. LRU should evict 3 next. 711 cache.Get(keys[2]) 712 cache.Put(keys[0], &cs[0]) 713 if s, ok := cache.Get(keys[3]); ok || s != nil { 714 t.Fatalf("session cache should have evicted key 3") 715 } 716 717 // Update entry 0 in place. 718 cache.Put(keys[0], &cs[3]) 719 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { 720 t.Fatalf("session cache failed update for key 0") 721 } 722 723 // Adding a nil entry is valid. 724 cache.Put(keys[0], nil) 725 if s, ok := cache.Get(keys[0]); !ok || s != nil { 726 t.Fatalf("failed to add nil entry to cache") 727 } 728 } 729 730 func TestHandshakeClientALPNMatch(t *testing.T) { 731 config := testConfig.clone() 732 config.NextProtos = []string{"proto2", "proto1"} 733 734 test := &clientTest{ 735 name: "ALPN", 736 // Note that this needs OpenSSL 1.0.2 because that is the first 737 // version that supports the -alpn flag. 738 command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, 739 config: config, 740 validate: func(state ConnectionState) error { 741 // The server's preferences should override the client. 742 if state.NegotiatedProtocol != "proto1" { 743 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) 744 } 745 return nil 746 }, 747 } 748 runClientTestTLS12(t, test) 749 } 750 751 func TestHandshakeClientALPNNoMatch(t *testing.T) { 752 config := testConfig.clone() 753 config.NextProtos = []string{"proto3"} 754 755 test := &clientTest{ 756 name: "ALPN-NoMatch", 757 // Note that this needs OpenSSL 1.0.2 because that is the first 758 // version that supports the -alpn flag. 759 command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, 760 config: config, 761 validate: func(state ConnectionState) error { 762 // There's no overlap so OpenSSL will not select a protocol. 763 if state.NegotiatedProtocol != "" { 764 return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol) 765 } 766 return nil 767 }, 768 } 769 runClientTestTLS12(t, test) 770 } 771 772 // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` 773 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" 774 775 func TestHandshakClientSCTs(t *testing.T) { 776 config := testConfig.clone() 777 778 scts, err := base64.StdEncoding.DecodeString(sctsBase64) 779 if err != nil { 780 t.Fatal(err) 781 } 782 783 test := &clientTest{ 784 name: "SCT", 785 // Note that this needs OpenSSL 1.0.2 because that is the first 786 // version that supports the -serverinfo flag. 787 command: []string{"openssl", "s_server"}, 788 config: config, 789 extensions: [][]byte{scts}, 790 validate: func(state ConnectionState) error { 791 expectedSCTs := [][]byte{ 792 scts[8:125], 793 scts[127:245], 794 scts[247:], 795 } 796 if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { 797 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) 798 } 799 for i, expected := range expectedSCTs { 800 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { 801 return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) 802 } 803 } 804 return nil 805 }, 806 } 807 runClientTestTLS12(t, test) 808 } 809 810 func TestRenegotiationRejected(t *testing.T) { 811 config := testConfig.clone() 812 test := &clientTest{ 813 name: "RenegotiationRejected", 814 command: []string{"openssl", "s_server", "-state"}, 815 config: config, 816 numRenegotiations: 1, 817 renegotiationExpectedToFail: 1, 818 checkRenegotiationError: func(renegotiationNum int, err error) error { 819 if err == nil { 820 return errors.New("expected error from renegotiation but got nil") 821 } 822 if !strings.Contains(err.Error(), "no renegotiation") { 823 return fmt.Errorf("expected renegotiation to be rejected but got %q", err) 824 } 825 return nil 826 }, 827 } 828 829 runClientTestTLS12(t, test) 830 } 831 832 func TestRenegotiateOnce(t *testing.T) { 833 config := testConfig.clone() 834 config.Renegotiation = RenegotiateOnceAsClient 835 836 test := &clientTest{ 837 name: "RenegotiateOnce", 838 command: []string{"openssl", "s_server", "-state"}, 839 config: config, 840 numRenegotiations: 1, 841 } 842 843 runClientTestTLS12(t, test) 844 } 845 846 func TestRenegotiateTwice(t *testing.T) { 847 config := testConfig.clone() 848 config.Renegotiation = RenegotiateFreelyAsClient 849 850 test := &clientTest{ 851 name: "RenegotiateTwice", 852 command: []string{"openssl", "s_server", "-state"}, 853 config: config, 854 numRenegotiations: 2, 855 } 856 857 runClientTestTLS12(t, test) 858 } 859 860 func TestRenegotiateTwiceRejected(t *testing.T) { 861 config := testConfig.clone() 862 config.Renegotiation = RenegotiateOnceAsClient 863 864 test := &clientTest{ 865 name: "RenegotiateTwiceRejected", 866 command: []string{"openssl", "s_server", "-state"}, 867 config: config, 868 numRenegotiations: 2, 869 renegotiationExpectedToFail: 2, 870 checkRenegotiationError: func(renegotiationNum int, err error) error { 871 if renegotiationNum == 1 { 872 return err 873 } 874 875 if err == nil { 876 return errors.New("expected error from renegotiation but got nil") 877 } 878 if !strings.Contains(err.Error(), "no renegotiation") { 879 return fmt.Errorf("expected renegotiation to be rejected but got %q", err) 880 } 881 return nil 882 }, 883 } 884 885 runClientTestTLS12(t, test) 886 } 887 888 var hostnameInSNITests = []struct { 889 in, out string 890 }{ 891 // Opaque string 892 {"", ""}, 893 {"localhost", "localhost"}, 894 {"foo, bar, baz and qux", "foo, bar, baz and qux"}, 895 896 // DNS hostname 897 {"golang.org", "golang.org"}, 898 {"golang.org.", "golang.org"}, 899 900 // Literal IPv4 address 901 {"1.2.3.4", ""}, 902 903 // Literal IPv6 address 904 {"::1", ""}, 905 {"::1%lo0", ""}, // with zone identifier 906 {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal 907 {"[::1%lo0]", ""}, 908 } 909 910 func TestHostnameInSNI(t *testing.T) { 911 for _, tt := range hostnameInSNITests { 912 c, s := net.Pipe() 913 914 go func(host string) { 915 Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() 916 }(tt.in) 917 918 var header [5]byte 919 if _, err := io.ReadFull(s, header[:]); err != nil { 920 t.Fatal(err) 921 } 922 recordLen := int(header[3])<<8 | int(header[4]) 923 924 record := make([]byte, recordLen) 925 if _, err := io.ReadFull(s, record[:]); err != nil { 926 t.Fatal(err) 927 } 928 929 c.Close() 930 s.Close() 931 932 var m clientHelloMsg 933 if !m.unmarshal(record) { 934 t.Errorf("unmarshaling ClientHello for %q failed", tt.in) 935 continue 936 } 937 if tt.in != tt.out && m.serverName == tt.in { 938 t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record) 939 } 940 if m.serverName != tt.out { 941 t.Errorf("expected %q not found in ClientHello: %x", tt.out, record) 942 } 943 } 944 } 945 946 func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { 947 // This checks that the server can't select a cipher suite that the 948 // client didn't offer. See #13174. 949 950 c, s := net.Pipe() 951 errChan := make(chan error, 1) 952 953 go func() { 954 client := Client(c, &Config{ 955 ServerName: "foo", 956 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 957 }) 958 errChan <- client.Handshake() 959 }() 960 961 var header [5]byte 962 if _, err := io.ReadFull(s, header[:]); err != nil { 963 t.Fatal(err) 964 } 965 recordLen := int(header[3])<<8 | int(header[4]) 966 967 record := make([]byte, recordLen) 968 if _, err := io.ReadFull(s, record); err != nil { 969 t.Fatal(err) 970 } 971 972 // Create a ServerHello that selects a different cipher suite than the 973 // sole one that the client offered. 974 serverHello := &serverHelloMsg{ 975 vers: VersionTLS12, 976 random: make([]byte, 32), 977 cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, 978 } 979 serverHelloBytes := serverHello.marshal() 980 981 s.Write([]byte{ 982 byte(recordTypeHandshake), 983 byte(VersionTLS12 >> 8), 984 byte(VersionTLS12 & 0xff), 985 byte(len(serverHelloBytes) >> 8), 986 byte(len(serverHelloBytes)), 987 }) 988 s.Write(serverHelloBytes) 989 s.Close() 990 991 if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") { 992 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) 993 } 994 } 995 996 // brokenConn wraps a net.Conn and causes all Writes after a certain number to 997 // fail with brokenConnErr. 998 type brokenConn struct { 999 net.Conn 1000 1001 // breakAfter is the number of successful writes that will be allowed 1002 // before all subsequent writes fail. 1003 breakAfter int 1004 1005 // numWrites is the number of writes that have been done. 1006 numWrites int 1007 } 1008 1009 // brokenConnErr is the error that brokenConn returns once exhausted. 1010 var brokenConnErr = errors.New("too many writes to brokenConn") 1011 1012 func (b *brokenConn) Write(data []byte) (int, error) { 1013 if b.numWrites >= b.breakAfter { 1014 return 0, brokenConnErr 1015 } 1016 1017 b.numWrites++ 1018 return b.Conn.Write(data) 1019 } 1020 1021 func TestFailedWrite(t *testing.T) { 1022 // Test that a write error during the handshake is returned. 1023 for _, breakAfter := range []int{0, 1} { 1024 c, s := net.Pipe() 1025 done := make(chan bool) 1026 1027 go func() { 1028 Server(s, testConfig).Handshake() 1029 s.Close() 1030 done <- true 1031 }() 1032 1033 brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} 1034 err := Client(brokenC, testConfig).Handshake() 1035 if err != brokenConnErr { 1036 t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) 1037 } 1038 brokenC.Close() 1039 1040 <-done 1041 } 1042 } 1043 1044 // writeCountingConn wraps a net.Conn and counts the number of Write calls. 1045 type writeCountingConn struct { 1046 net.Conn 1047 1048 // numWrites is the number of writes that have been done. 1049 numWrites int 1050 } 1051 1052 func (wcc *writeCountingConn) Write(data []byte) (int, error) { 1053 wcc.numWrites++ 1054 return wcc.Conn.Write(data) 1055 } 1056 1057 func TestBuffering(t *testing.T) { 1058 c, s := net.Pipe() 1059 done := make(chan bool) 1060 1061 clientWCC := &writeCountingConn{Conn: c} 1062 serverWCC := &writeCountingConn{Conn: s} 1063 1064 go func() { 1065 Server(serverWCC, testConfig).Handshake() 1066 serverWCC.Close() 1067 done <- true 1068 }() 1069 1070 err := Client(clientWCC, testConfig).Handshake() 1071 if err != nil { 1072 t.Fatal(err) 1073 } 1074 clientWCC.Close() 1075 <-done 1076 1077 if n := clientWCC.numWrites; n != 2 { 1078 t.Errorf("expected client handshake to complete with only two writes, but saw %d", n) 1079 } 1080 1081 if n := serverWCC.numWrites; n != 2 { 1082 t.Errorf("expected server handshake to complete with only two writes, but saw %d", n) 1083 } 1084 }