github.com/ader1990/go@v0.0.0-20140630135419-8c24447fa791/src/pkg/crypto/tls/handshake_client_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "crypto/x509" 12 "encoding/pem" 13 "fmt" 14 "io" 15 "net" 16 "os" 17 "os/exec" 18 "path/filepath" 19 "strconv" 20 "testing" 21 "time" 22 ) 23 24 // Note: see comment in handshake_test.go for details of how the reference 25 // tests work. 26 27 // blockingSource is an io.Reader that blocks a Read call until it's closed. 28 type blockingSource chan bool 29 30 func (b blockingSource) Read([]byte) (n int, err error) { 31 <-b 32 return 0, io.EOF 33 } 34 35 // clientTest represents a test of the TLS client handshake against a reference 36 // implementation. 37 type clientTest struct { 38 // name is a freeform string identifying the test and the file in which 39 // the expected results will be stored. 40 name string 41 // command, if not empty, contains a series of arguments for the 42 // command to run for the reference server. 43 command []string 44 // config, if not nil, contains a custom Config to use for this test. 45 config *Config 46 // cert, if not empty, contains a DER-encoded certificate for the 47 // reference server. 48 cert []byte 49 // key, if not nil, contains either a *rsa.PrivateKey or 50 // *ecdsa.PrivateKey which is the private key for the reference server. 51 key interface{} 52 } 53 54 var defaultServerCommand = []string{"openssl", "s_server"} 55 56 // connFromCommand starts the reference server process, connects to it and 57 // returns a recordingConn for the connection. The stdin return value is a 58 // blockingSource for the stdin of the child process. It must be closed before 59 // Waiting for child. 60 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) { 61 cert := testRSACertificate 62 if len(test.cert) > 0 { 63 cert = test.cert 64 } 65 certPath := tempFile(string(cert)) 66 defer os.Remove(certPath) 67 68 var key interface{} = testRSAPrivateKey 69 if test.key != nil { 70 key = test.key 71 } 72 var pemType string 73 var derBytes []byte 74 switch key := key.(type) { 75 case *rsa.PrivateKey: 76 pemType = "RSA" 77 derBytes = x509.MarshalPKCS1PrivateKey(key) 78 case *ecdsa.PrivateKey: 79 pemType = "EC" 80 var err error 81 derBytes, err = x509.MarshalECPrivateKey(key) 82 if err != nil { 83 panic(err) 84 } 85 default: 86 panic("unknown key type") 87 } 88 89 var pemOut bytes.Buffer 90 pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) 91 92 keyPath := tempFile(string(pemOut.Bytes())) 93 defer os.Remove(keyPath) 94 95 var command []string 96 if len(test.command) > 0 { 97 command = append(command, test.command...) 98 } else { 99 command = append(command, defaultServerCommand...) 100 } 101 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) 102 // serverPort contains the port that OpenSSL will listen on. OpenSSL 103 // can't take "0" as an argument here so we have to pick a number and 104 // hope that it's not in use on the machine. Since this only occurs 105 // when -update is given and thus when there's a human watching the 106 // test, this isn't too bad. 107 const serverPort = 24323 108 command = append(command, "-accept", strconv.Itoa(serverPort)) 109 110 cmd := exec.Command(command[0], command[1:]...) 111 stdin = blockingSource(make(chan bool)) 112 cmd.Stdin = stdin 113 var out bytes.Buffer 114 cmd.Stdout = &out 115 cmd.Stderr = &out 116 if err := cmd.Start(); err != nil { 117 return nil, nil, nil, err 118 } 119 120 // OpenSSL does print an "ACCEPT" banner, but it does so *before* 121 // opening the listening socket, so we can't use that to wait until it 122 // has started listening. Thus we are forced to poll until we get a 123 // connection. 124 var tcpConn net.Conn 125 for i := uint(0); i < 5; i++ { 126 var err error 127 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ 128 IP: net.IPv4(127, 0, 0, 1), 129 Port: serverPort, 130 }) 131 if err == nil { 132 break 133 } 134 time.Sleep((1 << i) * 5 * time.Millisecond) 135 } 136 if tcpConn == nil { 137 close(stdin) 138 out.WriteTo(os.Stdout) 139 cmd.Process.Kill() 140 return nil, nil, nil, cmd.Wait() 141 } 142 143 record := &recordingConn{ 144 Conn: tcpConn, 145 } 146 147 return record, cmd, stdin, nil 148 } 149 150 func (test *clientTest) dataPath() string { 151 return filepath.Join("testdata", "Client-"+test.name) 152 } 153 154 func (test *clientTest) loadData() (flows [][]byte, err error) { 155 in, err := os.Open(test.dataPath()) 156 if err != nil { 157 return nil, err 158 } 159 defer in.Close() 160 return parseTestData(in) 161 } 162 163 func (test *clientTest) run(t *testing.T, write bool) { 164 var clientConn, serverConn net.Conn 165 var recordingConn *recordingConn 166 var childProcess *exec.Cmd 167 var stdin blockingSource 168 169 if write { 170 var err error 171 recordingConn, childProcess, stdin, err = test.connFromCommand() 172 if err != nil { 173 t.Fatalf("Failed to start subcommand: %s", err) 174 } 175 clientConn = recordingConn 176 } else { 177 clientConn, serverConn = net.Pipe() 178 } 179 180 config := test.config 181 if config == nil { 182 config = testConfig 183 } 184 client := Client(clientConn, config) 185 186 doneChan := make(chan bool) 187 go func() { 188 if _, err := client.Write([]byte("hello\n")); err != nil { 189 t.Logf("Client.Write failed: %s", err) 190 } 191 client.Close() 192 clientConn.Close() 193 doneChan <- true 194 }() 195 196 if !write { 197 flows, err := test.loadData() 198 if err != nil { 199 t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) 200 } 201 for i, b := range flows { 202 if i%2 == 1 { 203 serverConn.Write(b) 204 continue 205 } 206 bb := make([]byte, len(b)) 207 _, err := io.ReadFull(serverConn, bb) 208 if err != nil { 209 t.Fatalf("%s #%d: %s", test.name, i, err) 210 } 211 if !bytes.Equal(b, bb) { 212 t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b) 213 } 214 } 215 serverConn.Close() 216 } 217 218 <-doneChan 219 220 if write { 221 path := test.dataPath() 222 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 223 if err != nil { 224 t.Fatalf("Failed to create output file: %s", err) 225 } 226 defer out.Close() 227 recordingConn.Close() 228 close(stdin) 229 childProcess.Process.Kill() 230 childProcess.Wait() 231 if len(recordingConn.flows) < 3 { 232 childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout) 233 t.Fatalf("Client connection didn't work") 234 } 235 recordingConn.WriteTo(out) 236 fmt.Printf("Wrote %s\n", path) 237 } 238 } 239 240 func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) { 241 test := *template 242 test.name = prefix + test.name 243 if len(test.command) == 0 { 244 test.command = defaultClientCommand 245 } 246 test.command = append([]string(nil), test.command...) 247 test.command = append(test.command, option) 248 test.run(t, *update) 249 } 250 251 func runClientTestTLS10(t *testing.T, template *clientTest) { 252 runClientTestForVersion(t, template, "TLSv10-", "-tls1") 253 } 254 255 func runClientTestTLS11(t *testing.T, template *clientTest) { 256 runClientTestForVersion(t, template, "TLSv11-", "-tls1_1") 257 } 258 259 func runClientTestTLS12(t *testing.T, template *clientTest) { 260 runClientTestForVersion(t, template, "TLSv12-", "-tls1_2") 261 } 262 263 func TestHandshakeClientRSARC4(t *testing.T) { 264 test := &clientTest{ 265 name: "RSA-RC4", 266 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"}, 267 } 268 runClientTestTLS10(t, test) 269 runClientTestTLS11(t, test) 270 runClientTestTLS12(t, test) 271 } 272 273 func TestHandshakeClientECDHERSAAES(t *testing.T) { 274 test := &clientTest{ 275 name: "ECDHE-RSA-AES", 276 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"}, 277 } 278 runClientTestTLS10(t, test) 279 runClientTestTLS11(t, test) 280 runClientTestTLS12(t, test) 281 } 282 283 func TestHandshakeClientECDHEECDSAAES(t *testing.T) { 284 test := &clientTest{ 285 name: "ECDHE-ECDSA-AES", 286 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"}, 287 cert: testECDSACertificate, 288 key: testECDSAPrivateKey, 289 } 290 runClientTestTLS10(t, test) 291 runClientTestTLS11(t, test) 292 runClientTestTLS12(t, test) 293 } 294 295 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { 296 test := &clientTest{ 297 name: "ECDHE-ECDSA-AES-GCM", 298 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, 299 cert: testECDSACertificate, 300 key: testECDSAPrivateKey, 301 } 302 runClientTestTLS12(t, test) 303 } 304 305 func TestHandshakeClientCertRSA(t *testing.T) { 306 config := *testConfig 307 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 308 config.Certificates = []Certificate{cert} 309 310 test := &clientTest{ 311 name: "ClientCert-RSA-RSA", 312 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 313 config: &config, 314 } 315 316 runClientTestTLS10(t, test) 317 runClientTestTLS12(t, test) 318 319 test = &clientTest{ 320 name: "ClientCert-RSA-ECDSA", 321 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 322 config: &config, 323 cert: testECDSACertificate, 324 key: testECDSAPrivateKey, 325 } 326 327 runClientTestTLS10(t, test) 328 runClientTestTLS12(t, test) 329 } 330 331 func TestHandshakeClientCertECDSA(t *testing.T) { 332 config := *testConfig 333 cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) 334 config.Certificates = []Certificate{cert} 335 336 test := &clientTest{ 337 name: "ClientCert-ECDSA-RSA", 338 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 339 config: &config, 340 } 341 342 runClientTestTLS10(t, test) 343 runClientTestTLS12(t, test) 344 345 test = &clientTest{ 346 name: "ClientCert-ECDSA-ECDSA", 347 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 348 config: &config, 349 cert: testECDSACertificate, 350 key: testECDSAPrivateKey, 351 } 352 353 runClientTestTLS10(t, test) 354 runClientTestTLS12(t, test) 355 } 356 357 func TestClientResumption(t *testing.T) { 358 serverConfig := &Config{ 359 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 360 Certificates: testConfig.Certificates, 361 } 362 clientConfig := &Config{ 363 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, 364 InsecureSkipVerify: true, 365 ClientSessionCache: NewLRUClientSessionCache(32), 366 } 367 368 testResumeState := func(test string, didResume bool) { 369 hs, err := testHandshake(clientConfig, serverConfig) 370 if err != nil { 371 t.Fatalf("%s: handshake failed: %s", test, err) 372 } 373 if hs.DidResume != didResume { 374 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) 375 } 376 } 377 378 testResumeState("Handshake", false) 379 testResumeState("Resume", true) 380 381 if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil { 382 t.Fatalf("Failed to invalidate SessionTicketKey") 383 } 384 testResumeState("InvalidSessionTicketKey", false) 385 testResumeState("ResumeAfterInvalidSessionTicketKey", true) 386 387 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} 388 testResumeState("DifferentCipherSuite", false) 389 testResumeState("DifferentCipherSuiteRecovers", true) 390 391 clientConfig.ClientSessionCache = nil 392 testResumeState("WithoutSessionCache", false) 393 } 394 395 func TestLRUClientSessionCache(t *testing.T) { 396 // Initialize cache of capacity 4. 397 cache := NewLRUClientSessionCache(4) 398 cs := make([]ClientSessionState, 6) 399 keys := []string{"0", "1", "2", "3", "4", "5", "6"} 400 401 // Add 4 entries to the cache and look them up. 402 for i := 0; i < 4; i++ { 403 cache.Put(keys[i], &cs[i]) 404 } 405 for i := 0; i < 4; i++ { 406 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 407 t.Fatalf("session cache failed lookup for added key: %s", keys[i]) 408 } 409 } 410 411 // Add 2 more entries to the cache. First 2 should be evicted. 412 for i := 4; i < 6; i++ { 413 cache.Put(keys[i], &cs[i]) 414 } 415 for i := 0; i < 2; i++ { 416 if s, ok := cache.Get(keys[i]); ok || s != nil { 417 t.Fatalf("session cache should have evicted key: %s", keys[i]) 418 } 419 } 420 421 // Touch entry 2. LRU should evict 3 next. 422 cache.Get(keys[2]) 423 cache.Put(keys[0], &cs[0]) 424 if s, ok := cache.Get(keys[3]); ok || s != nil { 425 t.Fatalf("session cache should have evicted key 3") 426 } 427 428 // Update entry 0 in place. 429 cache.Put(keys[0], &cs[3]) 430 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { 431 t.Fatalf("session cache failed update for key 0") 432 } 433 434 // Adding a nil entry is valid. 435 cache.Put(keys[0], nil) 436 if s, ok := cache.Get(keys[0]); !ok || s != nil { 437 t.Fatalf("failed to add nil entry to cache") 438 } 439 }