github.com/miolini/go@v0.0.0-20160405192216-fca68c8cb408/src/crypto/tls/handshake_client_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "crypto/x509" 12 "encoding/base64" 13 "encoding/binary" 14 "encoding/pem" 15 "errors" 16 "fmt" 17 "io" 18 "net" 19 "os" 20 "os/exec" 21 "path/filepath" 22 "strconv" 23 "strings" 24 "testing" 25 "time" 26 ) 27 28 // Note: see comment in handshake_test.go for details of how the reference 29 // tests work. 30 31 // blockingSource is an io.Reader that blocks a Read call until it's closed. 32 type blockingSource chan bool 33 34 func (b blockingSource) Read([]byte) (n int, err error) { 35 <-b 36 return 0, io.EOF 37 } 38 39 // clientTest represents a test of the TLS client handshake against a reference 40 // implementation. 41 type clientTest struct { 42 // name is a freeform string identifying the test and the file in which 43 // the expected results will be stored. 44 name string 45 // command, if not empty, contains a series of arguments for the 46 // command to run for the reference server. 47 command []string 48 // config, if not nil, contains a custom Config to use for this test. 49 config *Config 50 // cert, if not empty, contains a DER-encoded certificate for the 51 // reference server. 52 cert []byte 53 // key, if not nil, contains either a *rsa.PrivateKey or 54 // *ecdsa.PrivateKey which is the private key for the reference server. 55 key interface{} 56 // extensions, if not nil, contains a list of extension data to be returned 57 // from the ServerHello. The data should be in standard TLS format with 58 // a 2-byte uint16 type, 2-byte data length, followed by the extension data. 59 extensions [][]byte 60 // validate, if not nil, is a function that will be called with the 61 // ConnectionState of the resulting connection. It returns a non-nil 62 // error if the ConnectionState is unacceptable. 63 validate func(ConnectionState) error 64 } 65 66 var defaultServerCommand = []string{"openssl", "s_server"} 67 68 // connFromCommand starts the reference server process, connects to it and 69 // returns a recordingConn for the connection. The stdin return value is a 70 // blockingSource for the stdin of the child process. It must be closed before 71 // Waiting for child. 72 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) { 73 cert := testRSACertificate 74 if len(test.cert) > 0 { 75 cert = test.cert 76 } 77 certPath := tempFile(string(cert)) 78 defer os.Remove(certPath) 79 80 var key interface{} = testRSAPrivateKey 81 if test.key != nil { 82 key = test.key 83 } 84 var pemType string 85 var derBytes []byte 86 switch key := key.(type) { 87 case *rsa.PrivateKey: 88 pemType = "RSA" 89 derBytes = x509.MarshalPKCS1PrivateKey(key) 90 case *ecdsa.PrivateKey: 91 pemType = "EC" 92 var err error 93 derBytes, err = x509.MarshalECPrivateKey(key) 94 if err != nil { 95 panic(err) 96 } 97 default: 98 panic("unknown key type") 99 } 100 101 var pemOut bytes.Buffer 102 pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) 103 104 keyPath := tempFile(string(pemOut.Bytes())) 105 defer os.Remove(keyPath) 106 107 var command []string 108 if len(test.command) > 0 { 109 command = append(command, test.command...) 110 } else { 111 command = append(command, defaultServerCommand...) 112 } 113 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) 114 // serverPort contains the port that OpenSSL will listen on. OpenSSL 115 // can't take "0" as an argument here so we have to pick a number and 116 // hope that it's not in use on the machine. Since this only occurs 117 // when -update is given and thus when there's a human watching the 118 // test, this isn't too bad. 119 const serverPort = 24323 120 command = append(command, "-accept", strconv.Itoa(serverPort)) 121 122 if len(test.extensions) > 0 { 123 var serverInfo bytes.Buffer 124 for _, ext := range test.extensions { 125 pem.Encode(&serverInfo, &pem.Block{ 126 Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), 127 Bytes: ext, 128 }) 129 } 130 serverInfoPath := tempFile(serverInfo.String()) 131 defer os.Remove(serverInfoPath) 132 command = append(command, "-serverinfo", serverInfoPath) 133 } 134 135 cmd := exec.Command(command[0], command[1:]...) 136 stdin = blockingSource(make(chan bool)) 137 cmd.Stdin = stdin 138 var out bytes.Buffer 139 cmd.Stdout = &out 140 cmd.Stderr = &out 141 if err := cmd.Start(); err != nil { 142 return nil, nil, nil, err 143 } 144 145 // OpenSSL does print an "ACCEPT" banner, but it does so *before* 146 // opening the listening socket, so we can't use that to wait until it 147 // has started listening. Thus we are forced to poll until we get a 148 // connection. 149 var tcpConn net.Conn 150 for i := uint(0); i < 5; i++ { 151 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ 152 IP: net.IPv4(127, 0, 0, 1), 153 Port: serverPort, 154 }) 155 if err == nil { 156 break 157 } 158 time.Sleep((1 << i) * 5 * time.Millisecond) 159 } 160 if err != nil { 161 close(stdin) 162 out.WriteTo(os.Stdout) 163 cmd.Process.Kill() 164 return nil, nil, nil, cmd.Wait() 165 } 166 167 record := &recordingConn{ 168 Conn: tcpConn, 169 } 170 171 return record, cmd, stdin, nil 172 } 173 174 func (test *clientTest) dataPath() string { 175 return filepath.Join("testdata", "Client-"+test.name) 176 } 177 178 func (test *clientTest) loadData() (flows [][]byte, err error) { 179 in, err := os.Open(test.dataPath()) 180 if err != nil { 181 return nil, err 182 } 183 defer in.Close() 184 return parseTestData(in) 185 } 186 187 func (test *clientTest) run(t *testing.T, write bool) { 188 var clientConn, serverConn net.Conn 189 var recordingConn *recordingConn 190 var childProcess *exec.Cmd 191 var stdin blockingSource 192 193 if write { 194 var err error 195 recordingConn, childProcess, stdin, err = test.connFromCommand() 196 if err != nil { 197 t.Fatalf("Failed to start subcommand: %s", err) 198 } 199 clientConn = recordingConn 200 } else { 201 clientConn, serverConn = net.Pipe() 202 } 203 204 config := test.config 205 if config == nil { 206 config = testConfig 207 } 208 client := Client(clientConn, config) 209 210 doneChan := make(chan bool) 211 go func() { 212 if _, err := client.Write([]byte("hello\n")); err != nil { 213 t.Errorf("Client.Write failed: %s", err) 214 } 215 if test.validate != nil { 216 if err := test.validate(client.ConnectionState()); err != nil { 217 t.Errorf("validate callback returned error: %s", err) 218 } 219 } 220 client.Close() 221 clientConn.Close() 222 doneChan <- true 223 }() 224 225 if !write { 226 flows, err := test.loadData() 227 if err != nil { 228 t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) 229 } 230 for i, b := range flows { 231 if i%2 == 1 { 232 serverConn.Write(b) 233 continue 234 } 235 bb := make([]byte, len(b)) 236 _, err := io.ReadFull(serverConn, bb) 237 if err != nil { 238 t.Fatalf("%s #%d: %s", test.name, i, err) 239 } 240 if !bytes.Equal(b, bb) { 241 t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b) 242 } 243 } 244 serverConn.Close() 245 } 246 247 <-doneChan 248 249 if write { 250 path := test.dataPath() 251 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 252 if err != nil { 253 t.Fatalf("Failed to create output file: %s", err) 254 } 255 defer out.Close() 256 recordingConn.Close() 257 close(stdin) 258 childProcess.Process.Kill() 259 childProcess.Wait() 260 if len(recordingConn.flows) < 3 { 261 childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout) 262 t.Fatalf("Client connection didn't work") 263 } 264 recordingConn.WriteTo(out) 265 fmt.Printf("Wrote %s\n", path) 266 } 267 } 268 269 func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) { 270 test := *template 271 test.name = prefix + test.name 272 if len(test.command) == 0 { 273 test.command = defaultClientCommand 274 } 275 test.command = append([]string(nil), test.command...) 276 test.command = append(test.command, option) 277 test.run(t, *update) 278 } 279 280 func runClientTestTLS10(t *testing.T, template *clientTest) { 281 runClientTestForVersion(t, template, "TLSv10-", "-tls1") 282 } 283 284 func runClientTestTLS11(t *testing.T, template *clientTest) { 285 runClientTestForVersion(t, template, "TLSv11-", "-tls1_1") 286 } 287 288 func runClientTestTLS12(t *testing.T, template *clientTest) { 289 runClientTestForVersion(t, template, "TLSv12-", "-tls1_2") 290 } 291 292 func TestHandshakeClientRSARC4(t *testing.T) { 293 test := &clientTest{ 294 name: "RSA-RC4", 295 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"}, 296 } 297 runClientTestTLS10(t, test) 298 runClientTestTLS11(t, test) 299 runClientTestTLS12(t, test) 300 } 301 302 func TestHandshakeClientRSAAES128GCM(t *testing.T) { 303 test := &clientTest{ 304 name: "AES128-GCM-SHA256", 305 command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"}, 306 } 307 runClientTestTLS12(t, test) 308 } 309 310 func TestHandshakeClientRSAAES256GCM(t *testing.T) { 311 test := &clientTest{ 312 name: "AES256-GCM-SHA384", 313 command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"}, 314 } 315 runClientTestTLS12(t, test) 316 } 317 318 func TestHandshakeClientECDHERSAAES(t *testing.T) { 319 test := &clientTest{ 320 name: "ECDHE-RSA-AES", 321 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"}, 322 } 323 runClientTestTLS10(t, test) 324 runClientTestTLS11(t, test) 325 runClientTestTLS12(t, test) 326 } 327 328 func TestHandshakeClientECDHEECDSAAES(t *testing.T) { 329 test := &clientTest{ 330 name: "ECDHE-ECDSA-AES", 331 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"}, 332 cert: testECDSACertificate, 333 key: testECDSAPrivateKey, 334 } 335 runClientTestTLS10(t, test) 336 runClientTestTLS11(t, test) 337 runClientTestTLS12(t, test) 338 } 339 340 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { 341 test := &clientTest{ 342 name: "ECDHE-ECDSA-AES-GCM", 343 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, 344 cert: testECDSACertificate, 345 key: testECDSAPrivateKey, 346 } 347 runClientTestTLS12(t, test) 348 } 349 350 func TestHandshakeClientAES256GCMSHA384(t *testing.T) { 351 test := &clientTest{ 352 name: "ECDHE-ECDSA-AES256-GCM-SHA384", 353 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, 354 cert: testECDSACertificate, 355 key: testECDSAPrivateKey, 356 } 357 runClientTestTLS12(t, test) 358 } 359 360 func TestHandshakeClientCertRSA(t *testing.T) { 361 config := *testConfig 362 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 363 config.Certificates = []Certificate{cert} 364 365 test := &clientTest{ 366 name: "ClientCert-RSA-RSA", 367 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 368 config: &config, 369 } 370 371 runClientTestTLS10(t, test) 372 runClientTestTLS12(t, test) 373 374 test = &clientTest{ 375 name: "ClientCert-RSA-ECDSA", 376 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 377 config: &config, 378 cert: testECDSACertificate, 379 key: testECDSAPrivateKey, 380 } 381 382 runClientTestTLS10(t, test) 383 runClientTestTLS12(t, test) 384 385 test = &clientTest{ 386 name: "ClientCert-RSA-AES256-GCM-SHA384", 387 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"}, 388 config: &config, 389 cert: testRSACertificate, 390 key: testRSAPrivateKey, 391 } 392 393 runClientTestTLS12(t, test) 394 } 395 396 func TestHandshakeClientCertECDSA(t *testing.T) { 397 config := *testConfig 398 cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) 399 config.Certificates = []Certificate{cert} 400 401 test := &clientTest{ 402 name: "ClientCert-ECDSA-RSA", 403 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 404 config: &config, 405 } 406 407 runClientTestTLS10(t, test) 408 runClientTestTLS12(t, test) 409 410 test = &clientTest{ 411 name: "ClientCert-ECDSA-ECDSA", 412 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 413 config: &config, 414 cert: testECDSACertificate, 415 key: testECDSAPrivateKey, 416 } 417 418 runClientTestTLS10(t, test) 419 runClientTestTLS12(t, test) 420 } 421 422 func TestClientResumption(t *testing.T) { 423 serverConfig := &Config{ 424 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 425 Certificates: testConfig.Certificates, 426 } 427 428 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 429 if err != nil { 430 panic(err) 431 } 432 433 rootCAs := x509.NewCertPool() 434 rootCAs.AddCert(issuer) 435 436 clientConfig := &Config{ 437 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, 438 ClientSessionCache: NewLRUClientSessionCache(32), 439 RootCAs: rootCAs, 440 ServerName: "example.golang", 441 } 442 443 testResumeState := func(test string, didResume bool) { 444 _, hs, err := testHandshake(clientConfig, serverConfig) 445 if err != nil { 446 t.Fatalf("%s: handshake failed: %s", test, err) 447 } 448 if hs.DidResume != didResume { 449 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) 450 } 451 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { 452 t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) 453 } 454 } 455 456 getTicket := func() []byte { 457 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket 458 } 459 randomKey := func() [32]byte { 460 var k [32]byte 461 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { 462 t.Fatalf("Failed to read new SessionTicketKey: %s", err) 463 } 464 return k 465 } 466 467 testResumeState("Handshake", false) 468 ticket := getTicket() 469 testResumeState("Resume", true) 470 if !bytes.Equal(ticket, getTicket()) { 471 t.Fatal("first ticket doesn't match ticket after resumption") 472 } 473 474 key2 := randomKey() 475 serverConfig.SetSessionTicketKeys([][32]byte{key2}) 476 477 testResumeState("InvalidSessionTicketKey", false) 478 testResumeState("ResumeAfterInvalidSessionTicketKey", true) 479 480 serverConfig.SetSessionTicketKeys([][32]byte{randomKey(), key2}) 481 ticket = getTicket() 482 testResumeState("KeyChange", true) 483 if bytes.Equal(ticket, getTicket()) { 484 t.Fatal("new ticket wasn't included while resuming") 485 } 486 testResumeState("KeyChangeFinish", true) 487 488 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} 489 testResumeState("DifferentCipherSuite", false) 490 testResumeState("DifferentCipherSuiteRecovers", true) 491 492 clientConfig.ClientSessionCache = nil 493 testResumeState("WithoutSessionCache", false) 494 } 495 496 func TestLRUClientSessionCache(t *testing.T) { 497 // Initialize cache of capacity 4. 498 cache := NewLRUClientSessionCache(4) 499 cs := make([]ClientSessionState, 6) 500 keys := []string{"0", "1", "2", "3", "4", "5", "6"} 501 502 // Add 4 entries to the cache and look them up. 503 for i := 0; i < 4; i++ { 504 cache.Put(keys[i], &cs[i]) 505 } 506 for i := 0; i < 4; i++ { 507 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 508 t.Fatalf("session cache failed lookup for added key: %s", keys[i]) 509 } 510 } 511 512 // Add 2 more entries to the cache. First 2 should be evicted. 513 for i := 4; i < 6; i++ { 514 cache.Put(keys[i], &cs[i]) 515 } 516 for i := 0; i < 2; i++ { 517 if s, ok := cache.Get(keys[i]); ok || s != nil { 518 t.Fatalf("session cache should have evicted key: %s", keys[i]) 519 } 520 } 521 522 // Touch entry 2. LRU should evict 3 next. 523 cache.Get(keys[2]) 524 cache.Put(keys[0], &cs[0]) 525 if s, ok := cache.Get(keys[3]); ok || s != nil { 526 t.Fatalf("session cache should have evicted key 3") 527 } 528 529 // Update entry 0 in place. 530 cache.Put(keys[0], &cs[3]) 531 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { 532 t.Fatalf("session cache failed update for key 0") 533 } 534 535 // Adding a nil entry is valid. 536 cache.Put(keys[0], nil) 537 if s, ok := cache.Get(keys[0]); !ok || s != nil { 538 t.Fatalf("failed to add nil entry to cache") 539 } 540 } 541 542 func TestHandshakeClientALPNMatch(t *testing.T) { 543 config := *testConfig 544 config.NextProtos = []string{"proto2", "proto1"} 545 546 test := &clientTest{ 547 name: "ALPN", 548 // Note that this needs OpenSSL 1.0.2 because that is the first 549 // version that supports the -alpn flag. 550 command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, 551 config: &config, 552 validate: func(state ConnectionState) error { 553 // The server's preferences should override the client. 554 if state.NegotiatedProtocol != "proto1" { 555 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) 556 } 557 return nil 558 }, 559 } 560 runClientTestTLS12(t, test) 561 } 562 563 func TestHandshakeClientALPNNoMatch(t *testing.T) { 564 config := *testConfig 565 config.NextProtos = []string{"proto3"} 566 567 test := &clientTest{ 568 name: "ALPN-NoMatch", 569 // Note that this needs OpenSSL 1.0.2 because that is the first 570 // version that supports the -alpn flag. 571 command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, 572 config: &config, 573 validate: func(state ConnectionState) error { 574 // There's no overlap so OpenSSL will not select a protocol. 575 if state.NegotiatedProtocol != "" { 576 return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol) 577 } 578 return nil 579 }, 580 } 581 runClientTestTLS12(t, test) 582 } 583 584 // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` 585 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" 586 587 func TestHandshakClientSCTs(t *testing.T) { 588 config := *testConfig 589 590 scts, err := base64.StdEncoding.DecodeString(sctsBase64) 591 if err != nil { 592 t.Fatal(err) 593 } 594 595 test := &clientTest{ 596 name: "SCT", 597 // Note that this needs OpenSSL 1.0.2 because that is the first 598 // version that supports the -serverinfo flag. 599 command: []string{"openssl", "s_server"}, 600 config: &config, 601 extensions: [][]byte{scts}, 602 validate: func(state ConnectionState) error { 603 expectedSCTs := [][]byte{ 604 scts[8:125], 605 scts[127:245], 606 scts[247:], 607 } 608 if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { 609 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) 610 } 611 for i, expected := range expectedSCTs { 612 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { 613 return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) 614 } 615 } 616 return nil 617 }, 618 } 619 runClientTestTLS12(t, test) 620 } 621 622 var hostnameInSNITests = []struct { 623 in, out string 624 }{ 625 // Opaque string 626 {"", ""}, 627 {"localhost", "localhost"}, 628 {"foo, bar, baz and qux", "foo, bar, baz and qux"}, 629 630 // DNS hostname 631 {"golang.org", "golang.org"}, 632 {"golang.org.", "golang.org"}, 633 634 // Literal IPv4 address 635 {"1.2.3.4", ""}, 636 637 // Literal IPv6 address 638 {"::1", ""}, 639 {"::1%lo0", ""}, // with zone identifier 640 {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal 641 {"[::1%lo0]", ""}, 642 } 643 644 func TestHostnameInSNI(t *testing.T) { 645 for _, tt := range hostnameInSNITests { 646 c, s := net.Pipe() 647 648 go func(host string) { 649 Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() 650 }(tt.in) 651 652 var header [5]byte 653 if _, err := io.ReadFull(s, header[:]); err != nil { 654 t.Fatal(err) 655 } 656 recordLen := int(header[3])<<8 | int(header[4]) 657 658 record := make([]byte, recordLen) 659 if _, err := io.ReadFull(s, record[:]); err != nil { 660 t.Fatal(err) 661 } 662 663 c.Close() 664 s.Close() 665 666 var m clientHelloMsg 667 if !m.unmarshal(record) { 668 t.Errorf("unmarshaling ClientHello for %q failed", tt.in) 669 continue 670 } 671 if tt.in != tt.out && m.serverName == tt.in { 672 t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record) 673 } 674 if m.serverName != tt.out { 675 t.Errorf("expected %q not found in ClientHello: %x", tt.out, record) 676 } 677 } 678 } 679 680 func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { 681 // This checks that the server can't select a cipher suite that the 682 // client didn't offer. See #13174. 683 684 c, s := net.Pipe() 685 errChan := make(chan error, 1) 686 687 go func() { 688 client := Client(c, &Config{ 689 ServerName: "foo", 690 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, 691 }) 692 errChan <- client.Handshake() 693 }() 694 695 var header [5]byte 696 if _, err := io.ReadFull(s, header[:]); err != nil { 697 t.Fatal(err) 698 } 699 recordLen := int(header[3])<<8 | int(header[4]) 700 701 record := make([]byte, recordLen) 702 if _, err := io.ReadFull(s, record); err != nil { 703 t.Fatal(err) 704 } 705 706 // Create a ServerHello that selects a different cipher suite than the 707 // sole one that the client offered. 708 serverHello := &serverHelloMsg{ 709 vers: VersionTLS12, 710 random: make([]byte, 32), 711 cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, 712 } 713 serverHelloBytes := serverHello.marshal() 714 715 s.Write([]byte{ 716 byte(recordTypeHandshake), 717 byte(VersionTLS12 >> 8), 718 byte(VersionTLS12 & 0xff), 719 byte(len(serverHelloBytes) >> 8), 720 byte(len(serverHelloBytes)), 721 }) 722 s.Write(serverHelloBytes) 723 s.Close() 724 725 if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") { 726 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) 727 } 728 } 729 730 // brokenConn wraps a net.Conn and causes all Writes after a certain number to 731 // fail with brokenConnErr. 732 type brokenConn struct { 733 net.Conn 734 735 // breakAfter is the number of successful writes that will be allowed 736 // before all subsequent writes fail. 737 breakAfter int 738 739 // numWrites is the number of writes that have been done. 740 numWrites int 741 } 742 743 // brokenConnErr is the error that brokenConn returns once exhausted. 744 var brokenConnErr = errors.New("too many writes to brokenConn") 745 746 func (b *brokenConn) Write(data []byte) (int, error) { 747 if b.numWrites >= b.breakAfter { 748 return 0, brokenConnErr 749 } 750 751 b.numWrites++ 752 return b.Conn.Write(data) 753 } 754 755 func TestFailedWrite(t *testing.T) { 756 // Test that a write error during the handshake is returned. 757 for _, breakAfter := range []int{0, 1, 2, 3} { 758 c, s := net.Pipe() 759 done := make(chan bool) 760 761 go func() { 762 Server(s, testConfig).Handshake() 763 s.Close() 764 done <- true 765 }() 766 767 brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} 768 err := Client(brokenC, testConfig).Handshake() 769 if err != brokenConnErr { 770 t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) 771 } 772 brokenC.Close() 773 774 <-done 775 } 776 }