github.com/ader1990/go@v0.0.0-20140630135419-8c24447fa791/src/pkg/crypto/tls/handshake_client_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/pem"
    13  	"fmt"
    14  	"io"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"strconv"
    20  	"testing"
    21  	"time"
    22  )
    23  
    24  // Note: see comment in handshake_test.go for details of how the reference
    25  // tests work.
    26  
    27  // blockingSource is an io.Reader that blocks a Read call until it's closed.
    28  type blockingSource chan bool
    29  
    30  func (b blockingSource) Read([]byte) (n int, err error) {
    31  	<-b
    32  	return 0, io.EOF
    33  }
    34  
    35  // clientTest represents a test of the TLS client handshake against a reference
    36  // implementation.
    37  type clientTest struct {
    38  	// name is a freeform string identifying the test and the file in which
    39  	// the expected results will be stored.
    40  	name string
    41  	// command, if not empty, contains a series of arguments for the
    42  	// command to run for the reference server.
    43  	command []string
    44  	// config, if not nil, contains a custom Config to use for this test.
    45  	config *Config
    46  	// cert, if not empty, contains a DER-encoded certificate for the
    47  	// reference server.
    48  	cert []byte
    49  	// key, if not nil, contains either a *rsa.PrivateKey or
    50  	// *ecdsa.PrivateKey which is the private key for the reference server.
    51  	key interface{}
    52  }
    53  
    54  var defaultServerCommand = []string{"openssl", "s_server"}
    55  
    56  // connFromCommand starts the reference server process, connects to it and
    57  // returns a recordingConn for the connection. The stdin return value is a
    58  // blockingSource for the stdin of the child process. It must be closed before
    59  // Waiting for child.
    60  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) {
    61  	cert := testRSACertificate
    62  	if len(test.cert) > 0 {
    63  		cert = test.cert
    64  	}
    65  	certPath := tempFile(string(cert))
    66  	defer os.Remove(certPath)
    67  
    68  	var key interface{} = testRSAPrivateKey
    69  	if test.key != nil {
    70  		key = test.key
    71  	}
    72  	var pemType string
    73  	var derBytes []byte
    74  	switch key := key.(type) {
    75  	case *rsa.PrivateKey:
    76  		pemType = "RSA"
    77  		derBytes = x509.MarshalPKCS1PrivateKey(key)
    78  	case *ecdsa.PrivateKey:
    79  		pemType = "EC"
    80  		var err error
    81  		derBytes, err = x509.MarshalECPrivateKey(key)
    82  		if err != nil {
    83  			panic(err)
    84  		}
    85  	default:
    86  		panic("unknown key type")
    87  	}
    88  
    89  	var pemOut bytes.Buffer
    90  	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
    91  
    92  	keyPath := tempFile(string(pemOut.Bytes()))
    93  	defer os.Remove(keyPath)
    94  
    95  	var command []string
    96  	if len(test.command) > 0 {
    97  		command = append(command, test.command...)
    98  	} else {
    99  		command = append(command, defaultServerCommand...)
   100  	}
   101  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   102  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   103  	// can't take "0" as an argument here so we have to pick a number and
   104  	// hope that it's not in use on the machine. Since this only occurs
   105  	// when -update is given and thus when there's a human watching the
   106  	// test, this isn't too bad.
   107  	const serverPort = 24323
   108  	command = append(command, "-accept", strconv.Itoa(serverPort))
   109  
   110  	cmd := exec.Command(command[0], command[1:]...)
   111  	stdin = blockingSource(make(chan bool))
   112  	cmd.Stdin = stdin
   113  	var out bytes.Buffer
   114  	cmd.Stdout = &out
   115  	cmd.Stderr = &out
   116  	if err := cmd.Start(); err != nil {
   117  		return nil, nil, nil, err
   118  	}
   119  
   120  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   121  	// opening the listening socket, so we can't use that to wait until it
   122  	// has started listening. Thus we are forced to poll until we get a
   123  	// connection.
   124  	var tcpConn net.Conn
   125  	for i := uint(0); i < 5; i++ {
   126  		var err error
   127  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   128  			IP:   net.IPv4(127, 0, 0, 1),
   129  			Port: serverPort,
   130  		})
   131  		if err == nil {
   132  			break
   133  		}
   134  		time.Sleep((1 << i) * 5 * time.Millisecond)
   135  	}
   136  	if tcpConn == nil {
   137  		close(stdin)
   138  		out.WriteTo(os.Stdout)
   139  		cmd.Process.Kill()
   140  		return nil, nil, nil, cmd.Wait()
   141  	}
   142  
   143  	record := &recordingConn{
   144  		Conn: tcpConn,
   145  	}
   146  
   147  	return record, cmd, stdin, nil
   148  }
   149  
   150  func (test *clientTest) dataPath() string {
   151  	return filepath.Join("testdata", "Client-"+test.name)
   152  }
   153  
   154  func (test *clientTest) loadData() (flows [][]byte, err error) {
   155  	in, err := os.Open(test.dataPath())
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	defer in.Close()
   160  	return parseTestData(in)
   161  }
   162  
   163  func (test *clientTest) run(t *testing.T, write bool) {
   164  	var clientConn, serverConn net.Conn
   165  	var recordingConn *recordingConn
   166  	var childProcess *exec.Cmd
   167  	var stdin blockingSource
   168  
   169  	if write {
   170  		var err error
   171  		recordingConn, childProcess, stdin, err = test.connFromCommand()
   172  		if err != nil {
   173  			t.Fatalf("Failed to start subcommand: %s", err)
   174  		}
   175  		clientConn = recordingConn
   176  	} else {
   177  		clientConn, serverConn = net.Pipe()
   178  	}
   179  
   180  	config := test.config
   181  	if config == nil {
   182  		config = testConfig
   183  	}
   184  	client := Client(clientConn, config)
   185  
   186  	doneChan := make(chan bool)
   187  	go func() {
   188  		if _, err := client.Write([]byte("hello\n")); err != nil {
   189  			t.Logf("Client.Write failed: %s", err)
   190  		}
   191  		client.Close()
   192  		clientConn.Close()
   193  		doneChan <- true
   194  	}()
   195  
   196  	if !write {
   197  		flows, err := test.loadData()
   198  		if err != nil {
   199  			t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
   200  		}
   201  		for i, b := range flows {
   202  			if i%2 == 1 {
   203  				serverConn.Write(b)
   204  				continue
   205  			}
   206  			bb := make([]byte, len(b))
   207  			_, err := io.ReadFull(serverConn, bb)
   208  			if err != nil {
   209  				t.Fatalf("%s #%d: %s", test.name, i, err)
   210  			}
   211  			if !bytes.Equal(b, bb) {
   212  				t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
   213  			}
   214  		}
   215  		serverConn.Close()
   216  	}
   217  
   218  	<-doneChan
   219  
   220  	if write {
   221  		path := test.dataPath()
   222  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   223  		if err != nil {
   224  			t.Fatalf("Failed to create output file: %s", err)
   225  		}
   226  		defer out.Close()
   227  		recordingConn.Close()
   228  		close(stdin)
   229  		childProcess.Process.Kill()
   230  		childProcess.Wait()
   231  		if len(recordingConn.flows) < 3 {
   232  			childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout)
   233  			t.Fatalf("Client connection didn't work")
   234  		}
   235  		recordingConn.WriteTo(out)
   236  		fmt.Printf("Wrote %s\n", path)
   237  	}
   238  }
   239  
   240  func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
   241  	test := *template
   242  	test.name = prefix + test.name
   243  	if len(test.command) == 0 {
   244  		test.command = defaultClientCommand
   245  	}
   246  	test.command = append([]string(nil), test.command...)
   247  	test.command = append(test.command, option)
   248  	test.run(t, *update)
   249  }
   250  
   251  func runClientTestTLS10(t *testing.T, template *clientTest) {
   252  	runClientTestForVersion(t, template, "TLSv10-", "-tls1")
   253  }
   254  
   255  func runClientTestTLS11(t *testing.T, template *clientTest) {
   256  	runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
   257  }
   258  
   259  func runClientTestTLS12(t *testing.T, template *clientTest) {
   260  	runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
   261  }
   262  
   263  func TestHandshakeClientRSARC4(t *testing.T) {
   264  	test := &clientTest{
   265  		name:    "RSA-RC4",
   266  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
   267  	}
   268  	runClientTestTLS10(t, test)
   269  	runClientTestTLS11(t, test)
   270  	runClientTestTLS12(t, test)
   271  }
   272  
   273  func TestHandshakeClientECDHERSAAES(t *testing.T) {
   274  	test := &clientTest{
   275  		name:    "ECDHE-RSA-AES",
   276  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
   277  	}
   278  	runClientTestTLS10(t, test)
   279  	runClientTestTLS11(t, test)
   280  	runClientTestTLS12(t, test)
   281  }
   282  
   283  func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   284  	test := &clientTest{
   285  		name:    "ECDHE-ECDSA-AES",
   286  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
   287  		cert:    testECDSACertificate,
   288  		key:     testECDSAPrivateKey,
   289  	}
   290  	runClientTestTLS10(t, test)
   291  	runClientTestTLS11(t, test)
   292  	runClientTestTLS12(t, test)
   293  }
   294  
   295  func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   296  	test := &clientTest{
   297  		name:    "ECDHE-ECDSA-AES-GCM",
   298  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   299  		cert:    testECDSACertificate,
   300  		key:     testECDSAPrivateKey,
   301  	}
   302  	runClientTestTLS12(t, test)
   303  }
   304  
   305  func TestHandshakeClientCertRSA(t *testing.T) {
   306  	config := *testConfig
   307  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   308  	config.Certificates = []Certificate{cert}
   309  
   310  	test := &clientTest{
   311  		name:    "ClientCert-RSA-RSA",
   312  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   313  		config:  &config,
   314  	}
   315  
   316  	runClientTestTLS10(t, test)
   317  	runClientTestTLS12(t, test)
   318  
   319  	test = &clientTest{
   320  		name:    "ClientCert-RSA-ECDSA",
   321  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   322  		config:  &config,
   323  		cert:    testECDSACertificate,
   324  		key:     testECDSAPrivateKey,
   325  	}
   326  
   327  	runClientTestTLS10(t, test)
   328  	runClientTestTLS12(t, test)
   329  }
   330  
   331  func TestHandshakeClientCertECDSA(t *testing.T) {
   332  	config := *testConfig
   333  	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   334  	config.Certificates = []Certificate{cert}
   335  
   336  	test := &clientTest{
   337  		name:    "ClientCert-ECDSA-RSA",
   338  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   339  		config:  &config,
   340  	}
   341  
   342  	runClientTestTLS10(t, test)
   343  	runClientTestTLS12(t, test)
   344  
   345  	test = &clientTest{
   346  		name:    "ClientCert-ECDSA-ECDSA",
   347  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   348  		config:  &config,
   349  		cert:    testECDSACertificate,
   350  		key:     testECDSAPrivateKey,
   351  	}
   352  
   353  	runClientTestTLS10(t, test)
   354  	runClientTestTLS12(t, test)
   355  }
   356  
   357  func TestClientResumption(t *testing.T) {
   358  	serverConfig := &Config{
   359  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   360  		Certificates: testConfig.Certificates,
   361  	}
   362  	clientConfig := &Config{
   363  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   364  		InsecureSkipVerify: true,
   365  		ClientSessionCache: NewLRUClientSessionCache(32),
   366  	}
   367  
   368  	testResumeState := func(test string, didResume bool) {
   369  		hs, err := testHandshake(clientConfig, serverConfig)
   370  		if err != nil {
   371  			t.Fatalf("%s: handshake failed: %s", test, err)
   372  		}
   373  		if hs.DidResume != didResume {
   374  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   375  		}
   376  	}
   377  
   378  	testResumeState("Handshake", false)
   379  	testResumeState("Resume", true)
   380  
   381  	if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil {
   382  		t.Fatalf("Failed to invalidate SessionTicketKey")
   383  	}
   384  	testResumeState("InvalidSessionTicketKey", false)
   385  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   386  
   387  	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
   388  	testResumeState("DifferentCipherSuite", false)
   389  	testResumeState("DifferentCipherSuiteRecovers", true)
   390  
   391  	clientConfig.ClientSessionCache = nil
   392  	testResumeState("WithoutSessionCache", false)
   393  }
   394  
   395  func TestLRUClientSessionCache(t *testing.T) {
   396  	// Initialize cache of capacity 4.
   397  	cache := NewLRUClientSessionCache(4)
   398  	cs := make([]ClientSessionState, 6)
   399  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
   400  
   401  	// Add 4 entries to the cache and look them up.
   402  	for i := 0; i < 4; i++ {
   403  		cache.Put(keys[i], &cs[i])
   404  	}
   405  	for i := 0; i < 4; i++ {
   406  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
   407  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
   408  		}
   409  	}
   410  
   411  	// Add 2 more entries to the cache. First 2 should be evicted.
   412  	for i := 4; i < 6; i++ {
   413  		cache.Put(keys[i], &cs[i])
   414  	}
   415  	for i := 0; i < 2; i++ {
   416  		if s, ok := cache.Get(keys[i]); ok || s != nil {
   417  			t.Fatalf("session cache should have evicted key: %s", keys[i])
   418  		}
   419  	}
   420  
   421  	// Touch entry 2. LRU should evict 3 next.
   422  	cache.Get(keys[2])
   423  	cache.Put(keys[0], &cs[0])
   424  	if s, ok := cache.Get(keys[3]); ok || s != nil {
   425  		t.Fatalf("session cache should have evicted key 3")
   426  	}
   427  
   428  	// Update entry 0 in place.
   429  	cache.Put(keys[0], &cs[3])
   430  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
   431  		t.Fatalf("session cache failed update for key 0")
   432  	}
   433  
   434  	// Adding a nil entry is valid.
   435  	cache.Put(keys[0], nil)
   436  	if s, ok := cache.Get(keys[0]); !ok || s != nil {
   437  		t.Fatalf("failed to add nil entry to cache")
   438  	}
   439  }