github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/handshake_client_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "encoding/pem" 12 "fmt" 13 "io" 14 "net" 15 "os" 16 "os/exec" 17 "path/filepath" 18 "strconv" 19 "testing" 20 "time" 21 22 "github.com/zmap/zcrypto/x509" 23 ) 24 25 // Note: see comment in handshake_test.go for details of how the reference 26 // tests work. 27 28 // blockingSource is an io.Reader that blocks a Read call until it's closed. 29 type blockingSource chan bool 30 31 func (b blockingSource) Read([]byte) (n int, err error) { 32 <-b 33 return 0, io.EOF 34 } 35 36 // clientTest represents a test of the TLS client handshake against a reference 37 // implementation. 38 type clientTest struct { 39 // name is a freeform string identifying the test and the file in which 40 // the expected results will be stored. 41 name string 42 // command, if not empty, contains a series of arguments for the 43 // command to run for the reference server. 44 command []string 45 // config, if not nil, contains a custom Config to use for this test. 46 config *Config 47 // cert, if not empty, contains a DER-encoded certificate for the 48 // reference server. 49 cert []byte 50 // key, if not nil, contains either a *rsa.PrivateKey or 51 // *ecdsa.PrivateKey which is the private key for the reference server. 52 key interface{} 53 } 54 55 var defaultServerCommand = []string{"openssl", "s_server"} 56 57 // connFromCommand starts the reference server process, connects to it and 58 // returns a recordingConn for the connection. The stdin return value is a 59 // blockingSource for the stdin of the child process. It must be closed before 60 // Waiting for child. 61 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) { 62 cert := testRSACertificate 63 if len(test.cert) > 0 { 64 cert = test.cert 65 } 66 certPath := tempFile(string(cert)) 67 defer os.Remove(certPath) 68 69 var key interface{} = testRSAPrivateKey 70 if test.key != nil { 71 key = test.key 72 } 73 var pemType string 74 var derBytes []byte 75 switch key := key.(type) { 76 case *rsa.PrivateKey: 77 pemType = "RSA" 78 derBytes = x509.MarshalPKCS1PrivateKey(key) 79 case *ecdsa.PrivateKey: 80 pemType = "EC" 81 var err error 82 derBytes, err = x509.MarshalECPrivateKey(key) 83 if err != nil { 84 panic(err) 85 } 86 default: 87 panic("unknown key type") 88 } 89 90 var pemOut bytes.Buffer 91 pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) 92 93 keyPath := tempFile(string(pemOut.Bytes())) 94 defer os.Remove(keyPath) 95 96 var command []string 97 if len(test.command) > 0 { 98 command = append(command, test.command...) 99 } else { 100 command = append(command, defaultServerCommand...) 101 } 102 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) 103 // serverPort contains the port that OpenSSL will listen on. OpenSSL 104 // can't take "0" as an argument here so we have to pick a number and 105 // hope that it's not in use on the machine. Since this only occurs 106 // when -update is given and thus when there's a human watching the 107 // test, this isn't too bad. 108 const serverPort = 24323 109 command = append(command, "-accept", strconv.Itoa(serverPort)) 110 111 cmd := exec.Command(command[0], command[1:]...) 112 stdin = blockingSource(make(chan bool)) 113 cmd.Stdin = stdin 114 var out bytes.Buffer 115 cmd.Stdout = &out 116 cmd.Stderr = &out 117 if err := cmd.Start(); err != nil { 118 return nil, nil, nil, err 119 } 120 121 // OpenSSL does print an "ACCEPT" banner, but it does so *before* 122 // opening the listening socket, so we can't use that to wait until it 123 // has started listening. Thus we are forced to poll until we get a 124 // connection. 125 var tcpConn net.Conn 126 for i := uint(0); i < 5; i++ { 127 var err error 128 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ 129 IP: net.IPv4(127, 0, 0, 1), 130 Port: serverPort, 131 }) 132 if err == nil { 133 break 134 } 135 time.Sleep((1 << i) * 5 * time.Millisecond) 136 } 137 if tcpConn == nil { 138 close(stdin) 139 out.WriteTo(os.Stdout) 140 cmd.Process.Kill() 141 return nil, nil, nil, cmd.Wait() 142 } 143 144 record := &recordingConn{ 145 Conn: tcpConn, 146 } 147 148 return record, cmd, stdin, nil 149 } 150 151 func (test *clientTest) dataPath() string { 152 return filepath.Join("testdata", "Client-"+test.name) 153 } 154 155 func (test *clientTest) loadData() (flows [][]byte, err error) { 156 in, err := os.Open(test.dataPath()) 157 if err != nil { 158 return nil, err 159 } 160 defer in.Close() 161 return parseTestData(in) 162 } 163 164 func (test *clientTest) run(t *testing.T, write bool) { 165 var clientConn, serverConn net.Conn 166 var recordingConn *recordingConn 167 var childProcess *exec.Cmd 168 var stdin blockingSource 169 170 if write { 171 var err error 172 recordingConn, childProcess, stdin, err = test.connFromCommand() 173 if err != nil { 174 t.Fatalf("Failed to start subcommand: %s", err) 175 } 176 clientConn = recordingConn 177 } else { 178 clientConn, serverConn = net.Pipe() 179 } 180 181 config := test.config 182 if config == nil { 183 config = testConfig 184 } 185 client := Client(clientConn, config) 186 187 doneChan := make(chan bool) 188 go func() { 189 if _, err := client.Write([]byte("hello\n")); err != nil { 190 t.Logf("Client.Write failed: %s", err) 191 } 192 client.Close() 193 clientConn.Close() 194 doneChan <- true 195 }() 196 197 if !write { 198 flows, err := test.loadData() 199 if err != nil { 200 t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) 201 } 202 for i, b := range flows { 203 if i%2 == 1 { 204 serverConn.Write(b) 205 continue 206 } 207 bb := make([]byte, len(b)) 208 _, err := io.ReadFull(serverConn, bb) 209 if err != nil { 210 t.Fatalf("%s #%d: %s", test.name, i, err) 211 } 212 if !bytes.Equal(b, bb) { 213 t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b) 214 } 215 } 216 serverConn.Close() 217 } 218 219 <-doneChan 220 221 if write { 222 path := test.dataPath() 223 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 224 if err != nil { 225 t.Fatalf("Failed to create output file: %s", err) 226 } 227 defer out.Close() 228 recordingConn.Close() 229 close(stdin) 230 childProcess.Process.Kill() 231 childProcess.Wait() 232 if len(recordingConn.flows) < 3 { 233 childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout) 234 t.Fatalf("Client connection didn't work") 235 } 236 recordingConn.WriteTo(out) 237 fmt.Printf("Wrote %s\n", path) 238 } 239 } 240 241 func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) { 242 test := *template 243 test.name = prefix + test.name 244 if len(test.command) == 0 { 245 test.command = defaultClientCommand 246 } 247 test.command = append([]string(nil), test.command...) 248 test.command = append(test.command, option) 249 test.run(t, *update) 250 } 251 252 func runClientTestTLS10(t *testing.T, template *clientTest) { 253 runClientTestForVersion(t, template, "TLSv10-", "-tls1") 254 } 255 256 func runClientTestTLS11(t *testing.T, template *clientTest) { 257 runClientTestForVersion(t, template, "TLSv11-", "-tls1_1") 258 } 259 260 func runClientTestTLS12(t *testing.T, template *clientTest) { 261 runClientTestForVersion(t, template, "TLSv12-", "-tls1_2") 262 } 263 264 //func TestHandshakeClientRSARC4(t *testing.T) { 265 // test := &clientTest{ 266 // name: "RSA-RC4", 267 // command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"}, 268 // } 269 // runClientTestTLS10(t, test) 270 // runClientTestTLS11(t, test) 271 // runClientTestTLS12(t, test) 272 //} 273 // 274 //func TestHandshakeClientECDHERSAAES(t *testing.T) { 275 // test := &clientTest{ 276 // name: "ECDHE-RSA-AES", 277 // command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"}, 278 // } 279 // runClientTestTLS10(t, test) 280 // runClientTestTLS11(t, test) 281 // runClientTestTLS12(t, test) 282 //} 283 // 284 //func TestHandshakeClientECDHEECDSAAES(t *testing.T) { 285 // test := &clientTest{ 286 // name: "ECDHE-ECDSA-AES", 287 // command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"}, 288 // cert: testECDSACertificate, 289 // key: testECDSAPrivateKey, 290 // } 291 // runClientTestTLS10(t, test) 292 // runClientTestTLS11(t, test) 293 // runClientTestTLS12(t, test) 294 //} 295 // 296 //func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { 297 // test := &clientTest{ 298 // name: "ECDHE-ECDSA-AES-GCM", 299 // command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, 300 // cert: testECDSACertificate, 301 // key: testECDSAPrivateKey, 302 // } 303 // runClientTestTLS12(t, test) 304 //} 305 // 306 //func TestHandshakeClientCertRSA(t *testing.T) { 307 // config := *testConfig 308 // cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 309 // config.Certificates = []Certificate{cert} 310 // 311 // test := &clientTest{ 312 // name: "ClientCert-RSA-RSA", 313 // command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 314 // config: &config, 315 // } 316 // 317 // runClientTestTLS10(t, test) 318 // runClientTestTLS12(t, test) 319 // 320 // test = &clientTest{ 321 // name: "ClientCert-RSA-ECDSA", 322 // command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 323 // config: &config, 324 // cert: testECDSACertificate, 325 // key: testECDSAPrivateKey, 326 // } 327 // 328 // runClientTestTLS10(t, test) 329 // runClientTestTLS12(t, test) 330 //} 331 332 // TODO: figure out why this test is failing 333 //func TestHandshakeClientCertECDSA(t *testing.T) { 334 // config := *testConfig 335 // cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) 336 // config.Certificates = []Certificate{cert} 337 // 338 // test := &clientTest{ 339 // name: "ClientCert-ECDSA-RSA", 340 // command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 341 // config: &config, 342 // } 343 // 344 // runClientTestTLS10(t, test) 345 // runClientTestTLS12(t, test) 346 // 347 // test = &clientTest{ 348 // name: "ClientCert-ECDSA-ECDSA", 349 // command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 350 // config: &config, 351 // cert: testECDSACertificate, 352 // key: testECDSAPrivateKey, 353 // } 354 // 355 // runClientTestTLS10(t, test) 356 // runClientTestTLS12(t, test) 357 //} 358 359 func TestClientResumption(t *testing.T) { 360 serverConfig := &Config{ 361 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 362 Certificates: testConfig.Certificates, 363 } 364 clientConfig := &Config{ 365 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, 366 InsecureSkipVerify: true, 367 ClientSessionCache: NewLRUClientSessionCache(32), 368 } 369 370 testResumeState := func(test string, didResume bool) { 371 hs, err := testHandshake(clientConfig, serverConfig) 372 if err != nil { 373 t.Fatalf("%s: handshake failed: %s", test, err) 374 } 375 if hs.DidResume != didResume { 376 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) 377 } 378 } 379 380 testResumeState("Handshake", false) 381 testResumeState("Resume", true) 382 383 if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil { 384 t.Fatalf("Failed to invalidate SessionTicketKey") 385 } 386 testResumeState("InvalidSessionTicketKey", false) 387 testResumeState("ResumeAfterInvalidSessionTicketKey", true) 388 389 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} 390 testResumeState("DifferentCipherSuite", false) 391 testResumeState("DifferentCipherSuiteRecovers", true) 392 393 clientConfig.ClientSessionCache = nil 394 testResumeState("WithoutSessionCache", false) 395 } 396 397 func TestLRUClientSessionCache(t *testing.T) { 398 // Initialize cache of capacity 4. 399 cache := NewLRUClientSessionCache(4) 400 cs := make([]ClientSessionState, 6) 401 keys := []string{"0", "1", "2", "3", "4", "5", "6"} 402 403 // Add 4 entries to the cache and look them up. 404 for i := 0; i < 4; i++ { 405 cache.Put(keys[i], &cs[i]) 406 } 407 for i := 0; i < 4; i++ { 408 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 409 t.Fatalf("session cache failed lookup for added key: %s", keys[i]) 410 } 411 } 412 413 // Add 2 more entries to the cache. First 2 should be evicted. 414 for i := 4; i < 6; i++ { 415 cache.Put(keys[i], &cs[i]) 416 } 417 for i := 0; i < 2; i++ { 418 if s, ok := cache.Get(keys[i]); ok || s != nil { 419 t.Fatalf("session cache should have evicted key: %s", keys[i]) 420 } 421 } 422 423 // Touch entry 2. LRU should evict 3 next. 424 cache.Get(keys[2]) 425 cache.Put(keys[0], &cs[0]) 426 if s, ok := cache.Get(keys[3]); ok || s != nil { 427 t.Fatalf("session cache should have evicted key 3") 428 } 429 430 // Update entry 0 in place. 431 cache.Put(keys[0], &cs[3]) 432 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { 433 t.Fatalf("session cache failed update for key 0") 434 } 435 436 // Adding a nil entry is valid. 437 cache.Put(keys[0], nil) 438 if s, ok := cache.Get(keys[0]); !ok || s != nil { 439 t.Fatalf("failed to add nil entry to cache") 440 } 441 } 442 443 // Test the custom client hello feature by imitating a Firefox ClientHello message 444 func TestHandshakeClientCustomHello(t *testing.T) { 445 hello := ClientFingerprintConfiguration{} 446 hello.HandshakeVersion = 0x0303 447 448 hello.CipherSuites = []uint16{ 449 TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 450 TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 451 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, 452 TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, 453 TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 454 TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 455 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 456 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 457 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 458 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 459 TLS_DHE_RSA_WITH_AES_128_CBC_SHA, 460 TLS_DHE_RSA_WITH_AES_256_CBC_SHA, 461 TLS_RSA_WITH_AES_128_CBC_SHA, 462 TLS_RSA_WITH_AES_256_CBC_SHA, 463 TLS_RSA_WITH_3DES_EDE_CBC_SHA, 464 } 465 hello.CompressionMethods = []uint8{0} 466 sni := SNIExtension{[]string{}, true} 467 ec := SupportedCurvesExtension{[]CurveID{CurveP256, CurveP384, CurveP521}} 468 points := PointFormatExtension{[]uint8{0}} 469 st := SessionTicketExtension{[]byte{}, true} 470 alpn := ALPNExtension{[]string{"h2", "http/1.1"}} 471 sigs := SignatureAlgorithmExtension{[]uint16{0x0401, 472 0x0501, 473 0x0601, 474 0x0201, 475 0x0403, 476 0x0503, 477 0x0603, 478 0x0203, 479 0x0502, 480 0x0402, 481 0x0202, 482 }} 483 484 hello.Extensions = []ClientExtension{&sni, 485 &ExtendedMasterSecretExtension{}, 486 &SecureRenegotiationExtension{}, 487 &ec, 488 &points, 489 &st, 490 &NextProtocolNegotiationExtension{}, 491 &alpn, 492 &StatusRequestExtension{}, 493 &sigs, 494 } 495 config := *testConfig 496 config.ClientFingerprintConfiguration = &hello 497 test := &clientTest{ 498 name: "ClientFingerprint", 499 command: []string{"openssl", "s_server"}, 500 config: &config, 501 } 502 runClientTestTLS12(t, test) 503 } 504 505 // writeCountingConn wraps a net.Conn and counts the number of Write calls. 506 type writeCountingConn struct { 507 net.Conn 508 509 // numWrites is the number of writes that have been done. 510 numWrites int 511 } 512 513 func (wcc *writeCountingConn) Write(data []byte) (int, error) { 514 wcc.numWrites++ 515 return wcc.Conn.Write(data) 516 } 517 518 func TestBuffering(t *testing.T) { 519 c, s := net.Pipe() 520 done := make(chan bool) 521 522 clientWCC := &writeCountingConn{Conn: c} 523 serverWCC := &writeCountingConn{Conn: s} 524 525 go func() { 526 Server(serverWCC, testConfig).Handshake() 527 serverWCC.Close() 528 done <- true 529 }() 530 531 err := Client(clientWCC, testConfig).Handshake() 532 if err != nil { 533 t.Fatal(err) 534 } 535 clientWCC.Close() 536 <-done 537 538 if n := clientWCC.numWrites; n != 2 { 539 t.Errorf("expected client handshake to complete with only two writes, but saw %d", n) 540 } 541 542 if n := serverWCC.numWrites; n != 2 { 543 t.Errorf("expected server handshake to complete with only two writes, but saw %d", n) 544 } 545 } 546 547 func TestDontBuffer(t *testing.T) { 548 c, s := net.Pipe() 549 done := make(chan bool) 550 551 clientWCC := &writeCountingConn{Conn: c} 552 serverWCC := &writeCountingConn{Conn: s} 553 testConfig.DontBufferHandshakes = true 554 defer func() { 555 testConfig.DontBufferHandshakes = false 556 }() 557 go func() { 558 Server(serverWCC, testConfig).Handshake() 559 serverWCC.Close() 560 done <- true 561 }() 562 563 err := Client(clientWCC, testConfig).Handshake() 564 if err != nil { 565 t.Fatal(err) 566 } 567 clientWCC.Close() 568 <-done 569 570 if n := clientWCC.numWrites; n != 4 { 571 t.Errorf("expected client handshake to complete with only two writes, but saw %d", n) 572 } 573 574 if n := serverWCC.numWrites; n != 6 { 575 t.Errorf("expected server handshake to complete with only two writes, but saw %d", n) 576 } 577 }