github.com/rsc/go@v0.0.0-20150416155037-e040fd465409/src/crypto/tls/handshake_client_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/pem"
    13  	"fmt"
    14  	"io"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"strconv"
    20  	"testing"
    21  	"time"
    22  )
    23  
    24  // Note: see comment in handshake_test.go for details of how the reference
    25  // tests work.
    26  
    27  // blockingSource is an io.Reader that blocks a Read call until it's closed.
    28  type blockingSource chan bool
    29  
    30  func (b blockingSource) Read([]byte) (n int, err error) {
    31  	<-b
    32  	return 0, io.EOF
    33  }
    34  
    35  // clientTest represents a test of the TLS client handshake against a reference
    36  // implementation.
    37  type clientTest struct {
    38  	// name is a freeform string identifying the test and the file in which
    39  	// the expected results will be stored.
    40  	name string
    41  	// command, if not empty, contains a series of arguments for the
    42  	// command to run for the reference server.
    43  	command []string
    44  	// config, if not nil, contains a custom Config to use for this test.
    45  	config *Config
    46  	// cert, if not empty, contains a DER-encoded certificate for the
    47  	// reference server.
    48  	cert []byte
    49  	// key, if not nil, contains either a *rsa.PrivateKey or
    50  	// *ecdsa.PrivateKey which is the private key for the reference server.
    51  	key interface{}
    52  	// validate, if not nil, is a function that will be called with the
    53  	// ConnectionState of the resulting connection. It returns a non-nil
    54  	// error if the ConnectionState is unacceptable.
    55  	validate func(ConnectionState) error
    56  }
    57  
    58  var defaultServerCommand = []string{"openssl", "s_server"}
    59  
    60  // connFromCommand starts the reference server process, connects to it and
    61  // returns a recordingConn for the connection. The stdin return value is a
    62  // blockingSource for the stdin of the child process. It must be closed before
    63  // Waiting for child.
    64  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) {
    65  	cert := testRSACertificate
    66  	if len(test.cert) > 0 {
    67  		cert = test.cert
    68  	}
    69  	certPath := tempFile(string(cert))
    70  	defer os.Remove(certPath)
    71  
    72  	var key interface{} = testRSAPrivateKey
    73  	if test.key != nil {
    74  		key = test.key
    75  	}
    76  	var pemType string
    77  	var derBytes []byte
    78  	switch key := key.(type) {
    79  	case *rsa.PrivateKey:
    80  		pemType = "RSA"
    81  		derBytes = x509.MarshalPKCS1PrivateKey(key)
    82  	case *ecdsa.PrivateKey:
    83  		pemType = "EC"
    84  		var err error
    85  		derBytes, err = x509.MarshalECPrivateKey(key)
    86  		if err != nil {
    87  			panic(err)
    88  		}
    89  	default:
    90  		panic("unknown key type")
    91  	}
    92  
    93  	var pemOut bytes.Buffer
    94  	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
    95  
    96  	keyPath := tempFile(string(pemOut.Bytes()))
    97  	defer os.Remove(keyPath)
    98  
    99  	var command []string
   100  	if len(test.command) > 0 {
   101  		command = append(command, test.command...)
   102  	} else {
   103  		command = append(command, defaultServerCommand...)
   104  	}
   105  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   106  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   107  	// can't take "0" as an argument here so we have to pick a number and
   108  	// hope that it's not in use on the machine. Since this only occurs
   109  	// when -update is given and thus when there's a human watching the
   110  	// test, this isn't too bad.
   111  	const serverPort = 24323
   112  	command = append(command, "-accept", strconv.Itoa(serverPort))
   113  
   114  	cmd := exec.Command(command[0], command[1:]...)
   115  	stdin = blockingSource(make(chan bool))
   116  	cmd.Stdin = stdin
   117  	var out bytes.Buffer
   118  	cmd.Stdout = &out
   119  	cmd.Stderr = &out
   120  	if err := cmd.Start(); err != nil {
   121  		return nil, nil, nil, err
   122  	}
   123  
   124  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   125  	// opening the listening socket, so we can't use that to wait until it
   126  	// has started listening. Thus we are forced to poll until we get a
   127  	// connection.
   128  	var tcpConn net.Conn
   129  	for i := uint(0); i < 5; i++ {
   130  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   131  			IP:   net.IPv4(127, 0, 0, 1),
   132  			Port: serverPort,
   133  		})
   134  		if err == nil {
   135  			break
   136  		}
   137  		time.Sleep((1 << i) * 5 * time.Millisecond)
   138  	}
   139  	if err != nil {
   140  		close(stdin)
   141  		out.WriteTo(os.Stdout)
   142  		cmd.Process.Kill()
   143  		return nil, nil, nil, cmd.Wait()
   144  	}
   145  
   146  	record := &recordingConn{
   147  		Conn: tcpConn,
   148  	}
   149  
   150  	return record, cmd, stdin, nil
   151  }
   152  
   153  func (test *clientTest) dataPath() string {
   154  	return filepath.Join("testdata", "Client-"+test.name)
   155  }
   156  
   157  func (test *clientTest) loadData() (flows [][]byte, err error) {
   158  	in, err := os.Open(test.dataPath())
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	defer in.Close()
   163  	return parseTestData(in)
   164  }
   165  
   166  func (test *clientTest) run(t *testing.T, write bool) {
   167  	var clientConn, serverConn net.Conn
   168  	var recordingConn *recordingConn
   169  	var childProcess *exec.Cmd
   170  	var stdin blockingSource
   171  
   172  	if write {
   173  		var err error
   174  		recordingConn, childProcess, stdin, err = test.connFromCommand()
   175  		if err != nil {
   176  			t.Fatalf("Failed to start subcommand: %s", err)
   177  		}
   178  		clientConn = recordingConn
   179  	} else {
   180  		clientConn, serverConn = net.Pipe()
   181  	}
   182  
   183  	config := test.config
   184  	if config == nil {
   185  		config = testConfig
   186  	}
   187  	client := Client(clientConn, config)
   188  
   189  	doneChan := make(chan bool)
   190  	go func() {
   191  		if _, err := client.Write([]byte("hello\n")); err != nil {
   192  			t.Errorf("Client.Write failed: %s", err)
   193  		}
   194  		if test.validate != nil {
   195  			if err := test.validate(client.ConnectionState()); err != nil {
   196  				t.Logf("validate callback returned error: %s", err)
   197  			}
   198  		}
   199  		client.Close()
   200  		clientConn.Close()
   201  		doneChan <- true
   202  	}()
   203  
   204  	if !write {
   205  		flows, err := test.loadData()
   206  		if err != nil {
   207  			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
   208  		}
   209  		for i, b := range flows {
   210  			if i%2 == 1 {
   211  				serverConn.Write(b)
   212  				continue
   213  			}
   214  			bb := make([]byte, len(b))
   215  			_, err := io.ReadFull(serverConn, bb)
   216  			if err != nil {
   217  				t.Fatalf("%s #%d: %s", test.name, i, err)
   218  			}
   219  			if !bytes.Equal(b, bb) {
   220  				t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
   221  			}
   222  		}
   223  		serverConn.Close()
   224  	}
   225  
   226  	<-doneChan
   227  
   228  	if write {
   229  		path := test.dataPath()
   230  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   231  		if err != nil {
   232  			t.Fatalf("Failed to create output file: %s", err)
   233  		}
   234  		defer out.Close()
   235  		recordingConn.Close()
   236  		close(stdin)
   237  		childProcess.Process.Kill()
   238  		childProcess.Wait()
   239  		if len(recordingConn.flows) < 3 {
   240  			childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout)
   241  			t.Fatalf("Client connection didn't work")
   242  		}
   243  		recordingConn.WriteTo(out)
   244  		fmt.Printf("Wrote %s\n", path)
   245  	}
   246  }
   247  
   248  func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
   249  	test := *template
   250  	test.name = prefix + test.name
   251  	if len(test.command) == 0 {
   252  		test.command = defaultClientCommand
   253  	}
   254  	test.command = append([]string(nil), test.command...)
   255  	test.command = append(test.command, option)
   256  	test.run(t, *update)
   257  }
   258  
   259  func runClientTestTLS10(t *testing.T, template *clientTest) {
   260  	runClientTestForVersion(t, template, "TLSv10-", "-tls1")
   261  }
   262  
   263  func runClientTestTLS11(t *testing.T, template *clientTest) {
   264  	runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
   265  }
   266  
   267  func runClientTestTLS12(t *testing.T, template *clientTest) {
   268  	runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
   269  }
   270  
   271  func TestHandshakeClientRSARC4(t *testing.T) {
   272  	test := &clientTest{
   273  		name:    "RSA-RC4",
   274  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
   275  	}
   276  	runClientTestTLS10(t, test)
   277  	runClientTestTLS11(t, test)
   278  	runClientTestTLS12(t, test)
   279  }
   280  
   281  func TestHandshakeClientECDHERSAAES(t *testing.T) {
   282  	test := &clientTest{
   283  		name:    "ECDHE-RSA-AES",
   284  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
   285  	}
   286  	runClientTestTLS10(t, test)
   287  	runClientTestTLS11(t, test)
   288  	runClientTestTLS12(t, test)
   289  }
   290  
   291  func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   292  	test := &clientTest{
   293  		name:    "ECDHE-ECDSA-AES",
   294  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
   295  		cert:    testECDSACertificate,
   296  		key:     testECDSAPrivateKey,
   297  	}
   298  	runClientTestTLS10(t, test)
   299  	runClientTestTLS11(t, test)
   300  	runClientTestTLS12(t, test)
   301  }
   302  
   303  func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   304  	test := &clientTest{
   305  		name:    "ECDHE-ECDSA-AES-GCM",
   306  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   307  		cert:    testECDSACertificate,
   308  		key:     testECDSAPrivateKey,
   309  	}
   310  	runClientTestTLS12(t, test)
   311  }
   312  
   313  func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
   314  	test := &clientTest{
   315  		name:    "ECDHE-ECDSA-AES256-GCM-SHA384",
   316  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
   317  		cert:    testECDSACertificate,
   318  		key:     testECDSAPrivateKey,
   319  	}
   320  	runClientTestTLS12(t, test)
   321  }
   322  
   323  func TestHandshakeClientCertRSA(t *testing.T) {
   324  	config := *testConfig
   325  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   326  	config.Certificates = []Certificate{cert}
   327  
   328  	test := &clientTest{
   329  		name:    "ClientCert-RSA-RSA",
   330  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   331  		config:  &config,
   332  	}
   333  
   334  	runClientTestTLS10(t, test)
   335  	runClientTestTLS12(t, test)
   336  
   337  	test = &clientTest{
   338  		name:    "ClientCert-RSA-ECDSA",
   339  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   340  		config:  &config,
   341  		cert:    testECDSACertificate,
   342  		key:     testECDSAPrivateKey,
   343  	}
   344  
   345  	runClientTestTLS10(t, test)
   346  	runClientTestTLS12(t, test)
   347  
   348  	test = &clientTest{
   349  		name:    "ClientCert-RSA-AES256-GCM-SHA384",
   350  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"},
   351  		config:  &config,
   352  		cert:    testRSACertificate,
   353  		key:     testRSAPrivateKey,
   354  	}
   355  
   356  	runClientTestTLS12(t, test)
   357  }
   358  
   359  func TestHandshakeClientCertECDSA(t *testing.T) {
   360  	config := *testConfig
   361  	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   362  	config.Certificates = []Certificate{cert}
   363  
   364  	test := &clientTest{
   365  		name:    "ClientCert-ECDSA-RSA",
   366  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   367  		config:  &config,
   368  	}
   369  
   370  	runClientTestTLS10(t, test)
   371  	runClientTestTLS12(t, test)
   372  
   373  	test = &clientTest{
   374  		name:    "ClientCert-ECDSA-ECDSA",
   375  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   376  		config:  &config,
   377  		cert:    testECDSACertificate,
   378  		key:     testECDSAPrivateKey,
   379  	}
   380  
   381  	runClientTestTLS10(t, test)
   382  	runClientTestTLS12(t, test)
   383  }
   384  
   385  func TestClientResumption(t *testing.T) {
   386  	serverConfig := &Config{
   387  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   388  		Certificates: testConfig.Certificates,
   389  	}
   390  	clientConfig := &Config{
   391  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   392  		InsecureSkipVerify: true,
   393  		ClientSessionCache: NewLRUClientSessionCache(32),
   394  	}
   395  
   396  	testResumeState := func(test string, didResume bool) {
   397  		hs, err := testHandshake(clientConfig, serverConfig)
   398  		if err != nil {
   399  			t.Fatalf("%s: handshake failed: %s", test, err)
   400  		}
   401  		if hs.DidResume != didResume {
   402  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   403  		}
   404  	}
   405  
   406  	testResumeState("Handshake", false)
   407  	testResumeState("Resume", true)
   408  
   409  	if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil {
   410  		t.Fatalf("Failed to invalidate SessionTicketKey")
   411  	}
   412  	testResumeState("InvalidSessionTicketKey", false)
   413  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   414  
   415  	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
   416  	testResumeState("DifferentCipherSuite", false)
   417  	testResumeState("DifferentCipherSuiteRecovers", true)
   418  
   419  	clientConfig.ClientSessionCache = nil
   420  	testResumeState("WithoutSessionCache", false)
   421  }
   422  
   423  func TestLRUClientSessionCache(t *testing.T) {
   424  	// Initialize cache of capacity 4.
   425  	cache := NewLRUClientSessionCache(4)
   426  	cs := make([]ClientSessionState, 6)
   427  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
   428  
   429  	// Add 4 entries to the cache and look them up.
   430  	for i := 0; i < 4; i++ {
   431  		cache.Put(keys[i], &cs[i])
   432  	}
   433  	for i := 0; i < 4; i++ {
   434  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
   435  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
   436  		}
   437  	}
   438  
   439  	// Add 2 more entries to the cache. First 2 should be evicted.
   440  	for i := 4; i < 6; i++ {
   441  		cache.Put(keys[i], &cs[i])
   442  	}
   443  	for i := 0; i < 2; i++ {
   444  		if s, ok := cache.Get(keys[i]); ok || s != nil {
   445  			t.Fatalf("session cache should have evicted key: %s", keys[i])
   446  		}
   447  	}
   448  
   449  	// Touch entry 2. LRU should evict 3 next.
   450  	cache.Get(keys[2])
   451  	cache.Put(keys[0], &cs[0])
   452  	if s, ok := cache.Get(keys[3]); ok || s != nil {
   453  		t.Fatalf("session cache should have evicted key 3")
   454  	}
   455  
   456  	// Update entry 0 in place.
   457  	cache.Put(keys[0], &cs[3])
   458  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
   459  		t.Fatalf("session cache failed update for key 0")
   460  	}
   461  
   462  	// Adding a nil entry is valid.
   463  	cache.Put(keys[0], nil)
   464  	if s, ok := cache.Get(keys[0]); !ok || s != nil {
   465  		t.Fatalf("failed to add nil entry to cache")
   466  	}
   467  }
   468  
   469  func TestHandshakeClientALPNMatch(t *testing.T) {
   470  	config := *testConfig
   471  	config.NextProtos = []string{"proto2", "proto1"}
   472  
   473  	test := &clientTest{
   474  		name: "ALPN",
   475  		// Note that this needs OpenSSL 1.0.2 because that is the first
   476  		// version that supports the -alpn flag.
   477  		command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
   478  		config:  &config,
   479  		validate: func(state ConnectionState) error {
   480  			// The server's preferences should override the client.
   481  			if state.NegotiatedProtocol != "proto1" {
   482  				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
   483  			}
   484  			return nil
   485  		},
   486  	}
   487  	runClientTestTLS12(t, test)
   488  }
   489  
   490  func TestHandshakeClientALPNNoMatch(t *testing.T) {
   491  	config := *testConfig
   492  	config.NextProtos = []string{"proto3"}
   493  
   494  	test := &clientTest{
   495  		name: "ALPN-NoMatch",
   496  		// Note that this needs OpenSSL 1.0.2 because that is the first
   497  		// version that supports the -alpn flag.
   498  		command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
   499  		config:  &config,
   500  		validate: func(state ConnectionState) error {
   501  			// There's no overlap so OpenSSL will not select a protocol.
   502  			if state.NegotiatedProtocol != "" {
   503  				return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol)
   504  			}
   505  			return nil
   506  		},
   507  	}
   508  	runClientTestTLS12(t, test)
   509  }