github.com/panjjo/go@v0.0.0-20161104043856-d62b31386338/src/crypto/tls/handshake_client_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/base64"
    13  	"encoding/binary"
    14  	"encoding/pem"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"math/big"
    19  	"net"
    20  	"os"
    21  	"os/exec"
    22  	"path/filepath"
    23  	"strconv"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  // Note: see comment in handshake_test.go for details of how the reference
    30  // tests work.
    31  
    32  // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
    33  // s_client` process.
    34  type opensslInputEvent int
    35  
    36  const (
    37  	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
    38  	// connection.
    39  	opensslRenegotiate opensslInputEvent = iota
    40  
    41  	// opensslSendBanner causes OpenSSL to send the contents of
    42  	// opensslSentinel on the connection.
    43  	opensslSendSentinel
    44  )
    45  
    46  const opensslSentinel = "SENTINEL\n"
    47  
    48  type opensslInput chan opensslInputEvent
    49  
    50  func (i opensslInput) Read(buf []byte) (n int, err error) {
    51  	for event := range i {
    52  		switch event {
    53  		case opensslRenegotiate:
    54  			return copy(buf, []byte("R\n")), nil
    55  		case opensslSendSentinel:
    56  			return copy(buf, []byte(opensslSentinel)), nil
    57  		default:
    58  			panic("unknown event")
    59  		}
    60  	}
    61  
    62  	return 0, io.EOF
    63  }
    64  
    65  // opensslOutputSink is an io.Writer that receives the stdout and stderr from
    66  // an `openssl` process and sends a value to handshakeComplete when it sees a
    67  // log message from a completed server handshake.
    68  type opensslOutputSink struct {
    69  	handshakeComplete chan struct{}
    70  	all               []byte
    71  	line              []byte
    72  }
    73  
    74  func newOpensslOutputSink() *opensslOutputSink {
    75  	return &opensslOutputSink{make(chan struct{}), nil, nil}
    76  }
    77  
    78  // opensslEndOfHandshake is a message that the “openssl s_server” tool will
    79  // print when a handshake completes if run with “-state”.
    80  const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
    81  
    82  func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
    83  	o.line = append(o.line, data...)
    84  	o.all = append(o.all, data...)
    85  
    86  	for {
    87  		i := bytes.Index(o.line, []byte{'\n'})
    88  		if i < 0 {
    89  			break
    90  		}
    91  
    92  		if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
    93  			o.handshakeComplete <- struct{}{}
    94  		}
    95  		o.line = o.line[i+1:]
    96  	}
    97  
    98  	return len(data), nil
    99  }
   100  
   101  func (o *opensslOutputSink) WriteTo(w io.Writer) (int64, error) {
   102  	n, err := w.Write(o.all)
   103  	return int64(n), err
   104  }
   105  
   106  // clientTest represents a test of the TLS client handshake against a reference
   107  // implementation.
   108  type clientTest struct {
   109  	// name is a freeform string identifying the test and the file in which
   110  	// the expected results will be stored.
   111  	name string
   112  	// command, if not empty, contains a series of arguments for the
   113  	// command to run for the reference server.
   114  	command []string
   115  	// config, if not nil, contains a custom Config to use for this test.
   116  	config *Config
   117  	// cert, if not empty, contains a DER-encoded certificate for the
   118  	// reference server.
   119  	cert []byte
   120  	// key, if not nil, contains either a *rsa.PrivateKey or
   121  	// *ecdsa.PrivateKey which is the private key for the reference server.
   122  	key interface{}
   123  	// extensions, if not nil, contains a list of extension data to be returned
   124  	// from the ServerHello. The data should be in standard TLS format with
   125  	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
   126  	extensions [][]byte
   127  	// validate, if not nil, is a function that will be called with the
   128  	// ConnectionState of the resulting connection. It returns a non-nil
   129  	// error if the ConnectionState is unacceptable.
   130  	validate func(ConnectionState) error
   131  	// numRenegotiations is the number of times that the connection will be
   132  	// renegotiated.
   133  	numRenegotiations int
   134  	// renegotiationExpectedToFail, if not zero, is the number of the
   135  	// renegotiation attempt that is expected to fail.
   136  	renegotiationExpectedToFail int
   137  	// checkRenegotiationError, if not nil, is called with any error
   138  	// arising from renegotiation. It can map expected errors to nil to
   139  	// ignore them.
   140  	checkRenegotiationError func(renegotiationNum int, err error) error
   141  }
   142  
   143  var defaultServerCommand = []string{"openssl", "s_server"}
   144  
   145  // connFromCommand starts the reference server process, connects to it and
   146  // returns a recordingConn for the connection. The stdin return value is an
   147  // opensslInput for the stdin of the child process. It must be closed before
   148  // Waiting for child.
   149  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
   150  	cert := testRSACertificate
   151  	if len(test.cert) > 0 {
   152  		cert = test.cert
   153  	}
   154  	certPath := tempFile(string(cert))
   155  	defer os.Remove(certPath)
   156  
   157  	var key interface{} = testRSAPrivateKey
   158  	if test.key != nil {
   159  		key = test.key
   160  	}
   161  	var pemType string
   162  	var derBytes []byte
   163  	switch key := key.(type) {
   164  	case *rsa.PrivateKey:
   165  		pemType = "RSA"
   166  		derBytes = x509.MarshalPKCS1PrivateKey(key)
   167  	case *ecdsa.PrivateKey:
   168  		pemType = "EC"
   169  		var err error
   170  		derBytes, err = x509.MarshalECPrivateKey(key)
   171  		if err != nil {
   172  			panic(err)
   173  		}
   174  	default:
   175  		panic("unknown key type")
   176  	}
   177  
   178  	var pemOut bytes.Buffer
   179  	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
   180  
   181  	keyPath := tempFile(string(pemOut.Bytes()))
   182  	defer os.Remove(keyPath)
   183  
   184  	var command []string
   185  	if len(test.command) > 0 {
   186  		command = append(command, test.command...)
   187  	} else {
   188  		command = append(command, defaultServerCommand...)
   189  	}
   190  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   191  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   192  	// can't take "0" as an argument here so we have to pick a number and
   193  	// hope that it's not in use on the machine. Since this only occurs
   194  	// when -update is given and thus when there's a human watching the
   195  	// test, this isn't too bad.
   196  	const serverPort = 24323
   197  	command = append(command, "-accept", strconv.Itoa(serverPort))
   198  
   199  	if len(test.extensions) > 0 {
   200  		var serverInfo bytes.Buffer
   201  		for _, ext := range test.extensions {
   202  			pem.Encode(&serverInfo, &pem.Block{
   203  				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
   204  				Bytes: ext,
   205  			})
   206  		}
   207  		serverInfoPath := tempFile(serverInfo.String())
   208  		defer os.Remove(serverInfoPath)
   209  		command = append(command, "-serverinfo", serverInfoPath)
   210  	}
   211  
   212  	if test.numRenegotiations > 0 {
   213  		found := false
   214  		for _, flag := range command[1:] {
   215  			if flag == "-state" {
   216  				found = true
   217  				break
   218  			}
   219  		}
   220  
   221  		if !found {
   222  			panic("-state flag missing to OpenSSL. You need this if testing renegotiation")
   223  		}
   224  	}
   225  
   226  	cmd := exec.Command(command[0], command[1:]...)
   227  	stdin = opensslInput(make(chan opensslInputEvent))
   228  	cmd.Stdin = stdin
   229  	out := newOpensslOutputSink()
   230  	cmd.Stdout = out
   231  	cmd.Stderr = out
   232  	if err := cmd.Start(); err != nil {
   233  		return nil, nil, nil, nil, err
   234  	}
   235  
   236  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   237  	// opening the listening socket, so we can't use that to wait until it
   238  	// has started listening. Thus we are forced to poll until we get a
   239  	// connection.
   240  	var tcpConn net.Conn
   241  	for i := uint(0); i < 5; i++ {
   242  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   243  			IP:   net.IPv4(127, 0, 0, 1),
   244  			Port: serverPort,
   245  		})
   246  		if err == nil {
   247  			break
   248  		}
   249  		time.Sleep((1 << i) * 5 * time.Millisecond)
   250  	}
   251  	if err != nil {
   252  		close(stdin)
   253  		out.WriteTo(os.Stdout)
   254  		cmd.Process.Kill()
   255  		return nil, nil, nil, nil, cmd.Wait()
   256  	}
   257  
   258  	record := &recordingConn{
   259  		Conn: tcpConn,
   260  	}
   261  
   262  	return record, cmd, stdin, out, nil
   263  }
   264  
   265  func (test *clientTest) dataPath() string {
   266  	return filepath.Join("testdata", "Client-"+test.name)
   267  }
   268  
   269  func (test *clientTest) loadData() (flows [][]byte, err error) {
   270  	in, err := os.Open(test.dataPath())
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	defer in.Close()
   275  	return parseTestData(in)
   276  }
   277  
   278  func (test *clientTest) run(t *testing.T, write bool) {
   279  	checkOpenSSLVersion(t)
   280  
   281  	var clientConn, serverConn net.Conn
   282  	var recordingConn *recordingConn
   283  	var childProcess *exec.Cmd
   284  	var stdin opensslInput
   285  	var stdout *opensslOutputSink
   286  
   287  	if write {
   288  		var err error
   289  		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
   290  		if err != nil {
   291  			t.Fatalf("Failed to start subcommand: %s", err)
   292  		}
   293  		clientConn = recordingConn
   294  	} else {
   295  		clientConn, serverConn = net.Pipe()
   296  	}
   297  
   298  	config := test.config
   299  	if config == nil {
   300  		config = testConfig
   301  	}
   302  	client := Client(clientConn, config)
   303  
   304  	doneChan := make(chan bool)
   305  	go func() {
   306  		defer func() { doneChan <- true }()
   307  		defer clientConn.Close()
   308  		defer client.Close()
   309  
   310  		if _, err := client.Write([]byte("hello\n")); err != nil {
   311  			t.Errorf("Client.Write failed: %s", err)
   312  			return
   313  		}
   314  
   315  		for i := 1; i <= test.numRenegotiations; i++ {
   316  			// The initial handshake will generate a
   317  			// handshakeComplete signal which needs to be quashed.
   318  			if i == 1 && write {
   319  				<-stdout.handshakeComplete
   320  			}
   321  
   322  			// OpenSSL will try to interleave application data and
   323  			// a renegotiation if we send both concurrently.
   324  			// Therefore: ask OpensSSL to start a renegotiation, run
   325  			// a goroutine to call client.Read and thus process the
   326  			// renegotiation request, watch for OpenSSL's stdout to
   327  			// indicate that the handshake is complete and,
   328  			// finally, have OpenSSL write something to cause
   329  			// client.Read to complete.
   330  			if write {
   331  				stdin <- opensslRenegotiate
   332  			}
   333  
   334  			signalChan := make(chan struct{})
   335  
   336  			go func() {
   337  				defer func() { signalChan <- struct{}{} }()
   338  
   339  				buf := make([]byte, 256)
   340  				n, err := client.Read(buf)
   341  
   342  				if test.checkRenegotiationError != nil {
   343  					newErr := test.checkRenegotiationError(i, err)
   344  					if err != nil && newErr == nil {
   345  						return
   346  					}
   347  					err = newErr
   348  				}
   349  
   350  				if err != nil {
   351  					t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
   352  					return
   353  				}
   354  
   355  				buf = buf[:n]
   356  				if !bytes.Equal([]byte(opensslSentinel), buf) {
   357  					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
   358  				}
   359  
   360  				if expected := i + 1; client.handshakes != expected {
   361  					t.Errorf("client should have recorded %d handshakes, but believes that %d have occured", expected, client.handshakes)
   362  				}
   363  			}()
   364  
   365  			if write && test.renegotiationExpectedToFail != i {
   366  				<-stdout.handshakeComplete
   367  				stdin <- opensslSendSentinel
   368  			}
   369  			<-signalChan
   370  		}
   371  
   372  		if test.validate != nil {
   373  			if err := test.validate(client.ConnectionState()); err != nil {
   374  				t.Errorf("validate callback returned error: %s", err)
   375  			}
   376  		}
   377  	}()
   378  
   379  	if !write {
   380  		flows, err := test.loadData()
   381  		if err != nil {
   382  			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
   383  		}
   384  		for i, b := range flows {
   385  			if i%2 == 1 {
   386  				serverConn.Write(b)
   387  				continue
   388  			}
   389  			bb := make([]byte, len(b))
   390  			_, err := io.ReadFull(serverConn, bb)
   391  			if err != nil {
   392  				t.Fatalf("%s #%d: %s", test.name, i, err)
   393  			}
   394  			if !bytes.Equal(b, bb) {
   395  				t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
   396  			}
   397  		}
   398  		serverConn.Close()
   399  	}
   400  
   401  	<-doneChan
   402  
   403  	if write {
   404  		path := test.dataPath()
   405  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   406  		if err != nil {
   407  			t.Fatalf("Failed to create output file: %s", err)
   408  		}
   409  		defer out.Close()
   410  		recordingConn.Close()
   411  		close(stdin)
   412  		childProcess.Process.Kill()
   413  		childProcess.Wait()
   414  		if len(recordingConn.flows) < 3 {
   415  			os.Stdout.Write(childProcess.Stdout.(*opensslOutputSink).all)
   416  			t.Fatalf("Client connection didn't work")
   417  		}
   418  		recordingConn.WriteTo(out)
   419  		fmt.Printf("Wrote %s\n", path)
   420  	}
   421  }
   422  
   423  func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
   424  	test := *template
   425  	test.name = prefix + test.name
   426  	if len(test.command) == 0 {
   427  		test.command = defaultClientCommand
   428  	}
   429  	test.command = append([]string(nil), test.command...)
   430  	test.command = append(test.command, option)
   431  	test.run(t, *update)
   432  }
   433  
   434  func runClientTestTLS10(t *testing.T, template *clientTest) {
   435  	runClientTestForVersion(t, template, "TLSv10-", "-tls1")
   436  }
   437  
   438  func runClientTestTLS11(t *testing.T, template *clientTest) {
   439  	runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
   440  }
   441  
   442  func runClientTestTLS12(t *testing.T, template *clientTest) {
   443  	runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
   444  }
   445  
   446  func TestHandshakeClientRSARC4(t *testing.T) {
   447  	test := &clientTest{
   448  		name:    "RSA-RC4",
   449  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
   450  	}
   451  	runClientTestTLS10(t, test)
   452  	runClientTestTLS11(t, test)
   453  	runClientTestTLS12(t, test)
   454  }
   455  
   456  func TestHandshakeClientRSAAES128GCM(t *testing.T) {
   457  	test := &clientTest{
   458  		name:    "AES128-GCM-SHA256",
   459  		command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"},
   460  	}
   461  	runClientTestTLS12(t, test)
   462  }
   463  
   464  func TestHandshakeClientRSAAES256GCM(t *testing.T) {
   465  	test := &clientTest{
   466  		name:    "AES256-GCM-SHA384",
   467  		command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"},
   468  	}
   469  	runClientTestTLS12(t, test)
   470  }
   471  
   472  func TestHandshakeClientECDHERSAAES(t *testing.T) {
   473  	test := &clientTest{
   474  		name:    "ECDHE-RSA-AES",
   475  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
   476  	}
   477  	runClientTestTLS10(t, test)
   478  	runClientTestTLS11(t, test)
   479  	runClientTestTLS12(t, test)
   480  }
   481  
   482  func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   483  	test := &clientTest{
   484  		name:    "ECDHE-ECDSA-AES",
   485  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
   486  		cert:    testECDSACertificate,
   487  		key:     testECDSAPrivateKey,
   488  	}
   489  	runClientTestTLS10(t, test)
   490  	runClientTestTLS11(t, test)
   491  	runClientTestTLS12(t, test)
   492  }
   493  
   494  func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   495  	test := &clientTest{
   496  		name:    "ECDHE-ECDSA-AES-GCM",
   497  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   498  		cert:    testECDSACertificate,
   499  		key:     testECDSAPrivateKey,
   500  	}
   501  	runClientTestTLS12(t, test)
   502  }
   503  
   504  func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
   505  	test := &clientTest{
   506  		name:    "ECDHE-ECDSA-AES256-GCM-SHA384",
   507  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
   508  		cert:    testECDSACertificate,
   509  		key:     testECDSAPrivateKey,
   510  	}
   511  	runClientTestTLS12(t, test)
   512  }
   513  
   514  func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
   515  	test := &clientTest{
   516  		name:    "AES128-SHA256",
   517  		command: []string{"openssl", "s_server", "-cipher", "AES128-SHA256"},
   518  	}
   519  	runClientTestTLS12(t, test)
   520  }
   521  
   522  func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
   523  	test := &clientTest{
   524  		name:    "ECDHE-RSA-AES128-SHA256",
   525  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA256"},
   526  	}
   527  	runClientTestTLS12(t, test)
   528  }
   529  
   530  func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
   531  	test := &clientTest{
   532  		name:    "ECDHE-ECDSA-AES128-SHA256",
   533  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA256"},
   534  		cert:    testECDSACertificate,
   535  		key:     testECDSAPrivateKey,
   536  	}
   537  	runClientTestTLS12(t, test)
   538  }
   539  
   540  func TestHandshakeClientX25519(t *testing.T) {
   541  	config := testConfig.Clone()
   542  	config.CurvePreferences = []CurveID{X25519}
   543  
   544  	test := &clientTest{
   545  		name:    "X25519-ECDHE-RSA-AES-GCM",
   546  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-GCM-SHA256"},
   547  		config:  config,
   548  	}
   549  
   550  	runClientTestTLS12(t, test)
   551  }
   552  
   553  func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
   554  	config := testConfig.Clone()
   555  	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
   556  
   557  	test := &clientTest{
   558  		name:    "ECDHE-RSA-CHACHA20-POLY1305",
   559  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
   560  		config:  config,
   561  	}
   562  
   563  	runClientTestTLS12(t, test)
   564  }
   565  
   566  func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
   567  	config := testConfig.Clone()
   568  	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
   569  
   570  	test := &clientTest{
   571  		name:    "ECDHE-ECDSA-CHACHA20-POLY1305",
   572  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
   573  		config:  config,
   574  		cert:    testECDSACertificate,
   575  		key:     testECDSAPrivateKey,
   576  	}
   577  
   578  	runClientTestTLS12(t, test)
   579  }
   580  
   581  func TestHandshakeClientCertRSA(t *testing.T) {
   582  	config := testConfig.Clone()
   583  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   584  	config.Certificates = []Certificate{cert}
   585  
   586  	test := &clientTest{
   587  		name:    "ClientCert-RSA-RSA",
   588  		command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
   589  		config:  config,
   590  	}
   591  
   592  	runClientTestTLS10(t, test)
   593  	runClientTestTLS12(t, test)
   594  
   595  	test = &clientTest{
   596  		name:    "ClientCert-RSA-ECDSA",
   597  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   598  		config:  config,
   599  		cert:    testECDSACertificate,
   600  		key:     testECDSAPrivateKey,
   601  	}
   602  
   603  	runClientTestTLS10(t, test)
   604  	runClientTestTLS12(t, test)
   605  
   606  	test = &clientTest{
   607  		name:    "ClientCert-RSA-AES256-GCM-SHA384",
   608  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"},
   609  		config:  config,
   610  		cert:    testRSACertificate,
   611  		key:     testRSAPrivateKey,
   612  	}
   613  
   614  	runClientTestTLS12(t, test)
   615  }
   616  
   617  func TestHandshakeClientCertECDSA(t *testing.T) {
   618  	config := testConfig.Clone()
   619  	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   620  	config.Certificates = []Certificate{cert}
   621  
   622  	test := &clientTest{
   623  		name:    "ClientCert-ECDSA-RSA",
   624  		command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
   625  		config:  config,
   626  	}
   627  
   628  	runClientTestTLS10(t, test)
   629  	runClientTestTLS12(t, test)
   630  
   631  	test = &clientTest{
   632  		name:    "ClientCert-ECDSA-ECDSA",
   633  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   634  		config:  config,
   635  		cert:    testECDSACertificate,
   636  		key:     testECDSAPrivateKey,
   637  	}
   638  
   639  	runClientTestTLS10(t, test)
   640  	runClientTestTLS12(t, test)
   641  }
   642  
   643  func TestClientResumption(t *testing.T) {
   644  	serverConfig := &Config{
   645  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   646  		Certificates: testConfig.Certificates,
   647  	}
   648  
   649  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
   650  	if err != nil {
   651  		panic(err)
   652  	}
   653  
   654  	rootCAs := x509.NewCertPool()
   655  	rootCAs.AddCert(issuer)
   656  
   657  	clientConfig := &Config{
   658  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   659  		ClientSessionCache: NewLRUClientSessionCache(32),
   660  		RootCAs:            rootCAs,
   661  		ServerName:         "example.golang",
   662  	}
   663  
   664  	testResumeState := func(test string, didResume bool) {
   665  		_, hs, err := testHandshake(clientConfig, serverConfig)
   666  		if err != nil {
   667  			t.Fatalf("%s: handshake failed: %s", test, err)
   668  		}
   669  		if hs.DidResume != didResume {
   670  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   671  		}
   672  		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
   673  			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
   674  		}
   675  	}
   676  
   677  	getTicket := func() []byte {
   678  		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
   679  	}
   680  	randomKey := func() [32]byte {
   681  		var k [32]byte
   682  		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
   683  			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
   684  		}
   685  		return k
   686  	}
   687  
   688  	testResumeState("Handshake", false)
   689  	ticket := getTicket()
   690  	testResumeState("Resume", true)
   691  	if !bytes.Equal(ticket, getTicket()) {
   692  		t.Fatal("first ticket doesn't match ticket after resumption")
   693  	}
   694  
   695  	key1 := randomKey()
   696  	serverConfig.SetSessionTicketKeys([][32]byte{key1})
   697  
   698  	testResumeState("InvalidSessionTicketKey", false)
   699  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   700  
   701  	key2 := randomKey()
   702  	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
   703  	ticket = getTicket()
   704  	testResumeState("KeyChange", true)
   705  	if bytes.Equal(ticket, getTicket()) {
   706  		t.Fatal("new ticket wasn't included while resuming")
   707  	}
   708  	testResumeState("KeyChangeFinish", true)
   709  
   710  	// Reset serverConfig to ensure that calling SetSessionTicketKeys
   711  	// before the serverConfig is used works.
   712  	serverConfig = &Config{
   713  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   714  		Certificates: testConfig.Certificates,
   715  	}
   716  	serverConfig.SetSessionTicketKeys([][32]byte{key2})
   717  
   718  	testResumeState("FreshConfig", true)
   719  
   720  	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
   721  	testResumeState("DifferentCipherSuite", false)
   722  	testResumeState("DifferentCipherSuiteRecovers", true)
   723  
   724  	clientConfig.ClientSessionCache = nil
   725  	testResumeState("WithoutSessionCache", false)
   726  }
   727  
   728  func TestLRUClientSessionCache(t *testing.T) {
   729  	// Initialize cache of capacity 4.
   730  	cache := NewLRUClientSessionCache(4)
   731  	cs := make([]ClientSessionState, 6)
   732  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
   733  
   734  	// Add 4 entries to the cache and look them up.
   735  	for i := 0; i < 4; i++ {
   736  		cache.Put(keys[i], &cs[i])
   737  	}
   738  	for i := 0; i < 4; i++ {
   739  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
   740  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
   741  		}
   742  	}
   743  
   744  	// Add 2 more entries to the cache. First 2 should be evicted.
   745  	for i := 4; i < 6; i++ {
   746  		cache.Put(keys[i], &cs[i])
   747  	}
   748  	for i := 0; i < 2; i++ {
   749  		if s, ok := cache.Get(keys[i]); ok || s != nil {
   750  			t.Fatalf("session cache should have evicted key: %s", keys[i])
   751  		}
   752  	}
   753  
   754  	// Touch entry 2. LRU should evict 3 next.
   755  	cache.Get(keys[2])
   756  	cache.Put(keys[0], &cs[0])
   757  	if s, ok := cache.Get(keys[3]); ok || s != nil {
   758  		t.Fatalf("session cache should have evicted key 3")
   759  	}
   760  
   761  	// Update entry 0 in place.
   762  	cache.Put(keys[0], &cs[3])
   763  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
   764  		t.Fatalf("session cache failed update for key 0")
   765  	}
   766  
   767  	// Adding a nil entry is valid.
   768  	cache.Put(keys[0], nil)
   769  	if s, ok := cache.Get(keys[0]); !ok || s != nil {
   770  		t.Fatalf("failed to add nil entry to cache")
   771  	}
   772  }
   773  
   774  func TestKeyLog(t *testing.T) {
   775  	var serverBuf, clientBuf bytes.Buffer
   776  
   777  	clientConfig := testConfig.Clone()
   778  	clientConfig.KeyLogWriter = &clientBuf
   779  
   780  	serverConfig := testConfig.Clone()
   781  	serverConfig.KeyLogWriter = &serverBuf
   782  
   783  	c, s := net.Pipe()
   784  	done := make(chan bool)
   785  
   786  	go func() {
   787  		defer close(done)
   788  
   789  		if err := Server(s, serverConfig).Handshake(); err != nil {
   790  			t.Errorf("server: %s", err)
   791  			return
   792  		}
   793  		s.Close()
   794  	}()
   795  
   796  	if err := Client(c, clientConfig).Handshake(); err != nil {
   797  		t.Fatalf("client: %s", err)
   798  	}
   799  
   800  	c.Close()
   801  	<-done
   802  
   803  	checkKeylogLine := func(side, loggedLine string) {
   804  		if len(loggedLine) == 0 {
   805  			t.Fatalf("%s: no keylog line was produced", side)
   806  		}
   807  		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
   808  			1 /* space */ +
   809  			32*2 /* hex client nonce */ +
   810  			1 /* space */ +
   811  			48*2 /* hex master secret */ +
   812  			1 /* new line */
   813  		if len(loggedLine) != expectedLen {
   814  			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
   815  		}
   816  		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
   817  			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
   818  		}
   819  	}
   820  
   821  	checkKeylogLine("client", string(clientBuf.Bytes()))
   822  	checkKeylogLine("server", string(serverBuf.Bytes()))
   823  }
   824  
   825  func TestHandshakeClientALPNMatch(t *testing.T) {
   826  	config := testConfig.Clone()
   827  	config.NextProtos = []string{"proto2", "proto1"}
   828  
   829  	test := &clientTest{
   830  		name: "ALPN",
   831  		// Note that this needs OpenSSL 1.0.2 because that is the first
   832  		// version that supports the -alpn flag.
   833  		command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
   834  		config:  config,
   835  		validate: func(state ConnectionState) error {
   836  			// The server's preferences should override the client.
   837  			if state.NegotiatedProtocol != "proto1" {
   838  				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
   839  			}
   840  			return nil
   841  		},
   842  	}
   843  	runClientTestTLS12(t, test)
   844  }
   845  
   846  // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
   847  const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
   848  
   849  func TestHandshakClientSCTs(t *testing.T) {
   850  	config := testConfig.Clone()
   851  
   852  	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
   853  	if err != nil {
   854  		t.Fatal(err)
   855  	}
   856  
   857  	test := &clientTest{
   858  		name: "SCT",
   859  		// Note that this needs OpenSSL 1.0.2 because that is the first
   860  		// version that supports the -serverinfo flag.
   861  		command:    []string{"openssl", "s_server"},
   862  		config:     config,
   863  		extensions: [][]byte{scts},
   864  		validate: func(state ConnectionState) error {
   865  			expectedSCTs := [][]byte{
   866  				scts[8:125],
   867  				scts[127:245],
   868  				scts[247:],
   869  			}
   870  			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
   871  				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
   872  			}
   873  			for i, expected := range expectedSCTs {
   874  				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
   875  					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
   876  				}
   877  			}
   878  			return nil
   879  		},
   880  	}
   881  	runClientTestTLS12(t, test)
   882  }
   883  
   884  func TestRenegotiationRejected(t *testing.T) {
   885  	config := testConfig.Clone()
   886  	test := &clientTest{
   887  		name:                        "RenegotiationRejected",
   888  		command:                     []string{"openssl", "s_server", "-state"},
   889  		config:                      config,
   890  		numRenegotiations:           1,
   891  		renegotiationExpectedToFail: 1,
   892  		checkRenegotiationError: func(renegotiationNum int, err error) error {
   893  			if err == nil {
   894  				return errors.New("expected error from renegotiation but got nil")
   895  			}
   896  			if !strings.Contains(err.Error(), "no renegotiation") {
   897  				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
   898  			}
   899  			return nil
   900  		},
   901  	}
   902  
   903  	runClientTestTLS12(t, test)
   904  }
   905  
   906  func TestRenegotiateOnce(t *testing.T) {
   907  	config := testConfig.Clone()
   908  	config.Renegotiation = RenegotiateOnceAsClient
   909  
   910  	test := &clientTest{
   911  		name:              "RenegotiateOnce",
   912  		command:           []string{"openssl", "s_server", "-state"},
   913  		config:            config,
   914  		numRenegotiations: 1,
   915  	}
   916  
   917  	runClientTestTLS12(t, test)
   918  }
   919  
   920  func TestRenegotiateTwice(t *testing.T) {
   921  	config := testConfig.Clone()
   922  	config.Renegotiation = RenegotiateFreelyAsClient
   923  
   924  	test := &clientTest{
   925  		name:              "RenegotiateTwice",
   926  		command:           []string{"openssl", "s_server", "-state"},
   927  		config:            config,
   928  		numRenegotiations: 2,
   929  	}
   930  
   931  	runClientTestTLS12(t, test)
   932  }
   933  
   934  func TestRenegotiateTwiceRejected(t *testing.T) {
   935  	config := testConfig.Clone()
   936  	config.Renegotiation = RenegotiateOnceAsClient
   937  
   938  	test := &clientTest{
   939  		name:                        "RenegotiateTwiceRejected",
   940  		command:                     []string{"openssl", "s_server", "-state"},
   941  		config:                      config,
   942  		numRenegotiations:           2,
   943  		renegotiationExpectedToFail: 2,
   944  		checkRenegotiationError: func(renegotiationNum int, err error) error {
   945  			if renegotiationNum == 1 {
   946  				return err
   947  			}
   948  
   949  			if err == nil {
   950  				return errors.New("expected error from renegotiation but got nil")
   951  			}
   952  			if !strings.Contains(err.Error(), "no renegotiation") {
   953  				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
   954  			}
   955  			return nil
   956  		},
   957  	}
   958  
   959  	runClientTestTLS12(t, test)
   960  }
   961  
   962  var hostnameInSNITests = []struct {
   963  	in, out string
   964  }{
   965  	// Opaque string
   966  	{"", ""},
   967  	{"localhost", "localhost"},
   968  	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
   969  
   970  	// DNS hostname
   971  	{"golang.org", "golang.org"},
   972  	{"golang.org.", "golang.org"},
   973  
   974  	// Literal IPv4 address
   975  	{"1.2.3.4", ""},
   976  
   977  	// Literal IPv6 address
   978  	{"::1", ""},
   979  	{"::1%lo0", ""}, // with zone identifier
   980  	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
   981  	{"[::1%lo0]", ""},
   982  }
   983  
   984  func TestHostnameInSNI(t *testing.T) {
   985  	for _, tt := range hostnameInSNITests {
   986  		c, s := net.Pipe()
   987  
   988  		go func(host string) {
   989  			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
   990  		}(tt.in)
   991  
   992  		var header [5]byte
   993  		if _, err := io.ReadFull(s, header[:]); err != nil {
   994  			t.Fatal(err)
   995  		}
   996  		recordLen := int(header[3])<<8 | int(header[4])
   997  
   998  		record := make([]byte, recordLen)
   999  		if _, err := io.ReadFull(s, record[:]); err != nil {
  1000  			t.Fatal(err)
  1001  		}
  1002  
  1003  		c.Close()
  1004  		s.Close()
  1005  
  1006  		var m clientHelloMsg
  1007  		if !m.unmarshal(record) {
  1008  			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
  1009  			continue
  1010  		}
  1011  		if tt.in != tt.out && m.serverName == tt.in {
  1012  			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
  1013  		}
  1014  		if m.serverName != tt.out {
  1015  			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
  1016  		}
  1017  	}
  1018  }
  1019  
  1020  func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
  1021  	// This checks that the server can't select a cipher suite that the
  1022  	// client didn't offer. See #13174.
  1023  
  1024  	c, s := net.Pipe()
  1025  	errChan := make(chan error, 1)
  1026  
  1027  	go func() {
  1028  		client := Client(c, &Config{
  1029  			ServerName:   "foo",
  1030  			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
  1031  		})
  1032  		errChan <- client.Handshake()
  1033  	}()
  1034  
  1035  	var header [5]byte
  1036  	if _, err := io.ReadFull(s, header[:]); err != nil {
  1037  		t.Fatal(err)
  1038  	}
  1039  	recordLen := int(header[3])<<8 | int(header[4])
  1040  
  1041  	record := make([]byte, recordLen)
  1042  	if _, err := io.ReadFull(s, record); err != nil {
  1043  		t.Fatal(err)
  1044  	}
  1045  
  1046  	// Create a ServerHello that selects a different cipher suite than the
  1047  	// sole one that the client offered.
  1048  	serverHello := &serverHelloMsg{
  1049  		vers:        VersionTLS12,
  1050  		random:      make([]byte, 32),
  1051  		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
  1052  	}
  1053  	serverHelloBytes := serverHello.marshal()
  1054  
  1055  	s.Write([]byte{
  1056  		byte(recordTypeHandshake),
  1057  		byte(VersionTLS12 >> 8),
  1058  		byte(VersionTLS12 & 0xff),
  1059  		byte(len(serverHelloBytes) >> 8),
  1060  		byte(len(serverHelloBytes)),
  1061  	})
  1062  	s.Write(serverHelloBytes)
  1063  	s.Close()
  1064  
  1065  	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
  1066  		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
  1067  	}
  1068  }
  1069  
  1070  func TestVerifyPeerCertificate(t *testing.T) {
  1071  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1072  	if err != nil {
  1073  		panic(err)
  1074  	}
  1075  
  1076  	rootCAs := x509.NewCertPool()
  1077  	rootCAs.AddCert(issuer)
  1078  
  1079  	now := func() time.Time { return time.Unix(1476984729, 0) }
  1080  
  1081  	sentinelErr := errors.New("TestVerifyPeerCertificate")
  1082  
  1083  	verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1084  		if l := len(rawCerts); l != 1 {
  1085  			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1086  		}
  1087  		if len(validatedChains) == 0 {
  1088  			return errors.New("got len(validatedChains) = 0, wanted non-zero")
  1089  		}
  1090  		*called = true
  1091  		return nil
  1092  	}
  1093  
  1094  	tests := []struct {
  1095  		configureServer func(*Config, *bool)
  1096  		configureClient func(*Config, *bool)
  1097  		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
  1098  	}{
  1099  		{
  1100  			configureServer: func(config *Config, called *bool) {
  1101  				config.InsecureSkipVerify = false
  1102  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1103  					return verifyCallback(called, rawCerts, validatedChains)
  1104  				}
  1105  			},
  1106  			configureClient: func(config *Config, called *bool) {
  1107  				config.InsecureSkipVerify = false
  1108  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1109  					return verifyCallback(called, rawCerts, validatedChains)
  1110  				}
  1111  			},
  1112  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1113  				if clientErr != nil {
  1114  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1115  				}
  1116  				if serverErr != nil {
  1117  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1118  				}
  1119  				if !clientCalled {
  1120  					t.Errorf("test[%d]: client did not call callback", testNo)
  1121  				}
  1122  				if !serverCalled {
  1123  					t.Errorf("test[%d]: server did not call callback", testNo)
  1124  				}
  1125  			},
  1126  		},
  1127  		{
  1128  			configureServer: func(config *Config, called *bool) {
  1129  				config.InsecureSkipVerify = false
  1130  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1131  					return sentinelErr
  1132  				}
  1133  			},
  1134  			configureClient: func(config *Config, called *bool) {
  1135  				config.VerifyPeerCertificate = nil
  1136  			},
  1137  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1138  				if serverErr != sentinelErr {
  1139  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1140  				}
  1141  			},
  1142  		},
  1143  		{
  1144  			configureServer: func(config *Config, called *bool) {
  1145  				config.InsecureSkipVerify = false
  1146  			},
  1147  			configureClient: func(config *Config, called *bool) {
  1148  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1149  					return sentinelErr
  1150  				}
  1151  			},
  1152  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1153  				if clientErr != sentinelErr {
  1154  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1155  				}
  1156  			},
  1157  		},
  1158  		{
  1159  			configureServer: func(config *Config, called *bool) {
  1160  				config.InsecureSkipVerify = false
  1161  			},
  1162  			configureClient: func(config *Config, called *bool) {
  1163  				config.InsecureSkipVerify = true
  1164  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1165  					if l := len(rawCerts); l != 1 {
  1166  						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1167  					}
  1168  					// With InsecureSkipVerify set, this
  1169  					// callback should still be called but
  1170  					// validatedChains must be empty.
  1171  					if l := len(validatedChains); l != 0 {
  1172  						return errors.New("got len(validatedChains) = 0, wanted zero")
  1173  					}
  1174  					*called = true
  1175  					return nil
  1176  				}
  1177  			},
  1178  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1179  				if clientErr != nil {
  1180  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1181  				}
  1182  				if serverErr != nil {
  1183  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1184  				}
  1185  				if !clientCalled {
  1186  					t.Errorf("test[%d]: client did not call callback", testNo)
  1187  				}
  1188  			},
  1189  		},
  1190  	}
  1191  
  1192  	for i, test := range tests {
  1193  		c, s := net.Pipe()
  1194  		done := make(chan error)
  1195  
  1196  		var clientCalled, serverCalled bool
  1197  
  1198  		go func() {
  1199  			config := testConfig.Clone()
  1200  			config.ServerName = "example.golang"
  1201  			config.ClientAuth = RequireAndVerifyClientCert
  1202  			config.ClientCAs = rootCAs
  1203  			config.Time = now
  1204  			test.configureServer(config, &serverCalled)
  1205  
  1206  			err = Server(s, config).Handshake()
  1207  			s.Close()
  1208  			done <- err
  1209  		}()
  1210  
  1211  		config := testConfig.Clone()
  1212  		config.ServerName = "example.golang"
  1213  		config.RootCAs = rootCAs
  1214  		config.Time = now
  1215  		test.configureClient(config, &clientCalled)
  1216  		clientErr := Client(c, config).Handshake()
  1217  		c.Close()
  1218  		serverErr := <-done
  1219  
  1220  		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
  1221  	}
  1222  }
  1223  
  1224  // brokenConn wraps a net.Conn and causes all Writes after a certain number to
  1225  // fail with brokenConnErr.
  1226  type brokenConn struct {
  1227  	net.Conn
  1228  
  1229  	// breakAfter is the number of successful writes that will be allowed
  1230  	// before all subsequent writes fail.
  1231  	breakAfter int
  1232  
  1233  	// numWrites is the number of writes that have been done.
  1234  	numWrites int
  1235  }
  1236  
  1237  // brokenConnErr is the error that brokenConn returns once exhausted.
  1238  var brokenConnErr = errors.New("too many writes to brokenConn")
  1239  
  1240  func (b *brokenConn) Write(data []byte) (int, error) {
  1241  	if b.numWrites >= b.breakAfter {
  1242  		return 0, brokenConnErr
  1243  	}
  1244  
  1245  	b.numWrites++
  1246  	return b.Conn.Write(data)
  1247  }
  1248  
  1249  func TestFailedWrite(t *testing.T) {
  1250  	// Test that a write error during the handshake is returned.
  1251  	for _, breakAfter := range []int{0, 1} {
  1252  		c, s := net.Pipe()
  1253  		done := make(chan bool)
  1254  
  1255  		go func() {
  1256  			Server(s, testConfig).Handshake()
  1257  			s.Close()
  1258  			done <- true
  1259  		}()
  1260  
  1261  		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
  1262  		err := Client(brokenC, testConfig).Handshake()
  1263  		if err != brokenConnErr {
  1264  			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
  1265  		}
  1266  		brokenC.Close()
  1267  
  1268  		<-done
  1269  	}
  1270  }
  1271  
  1272  // writeCountingConn wraps a net.Conn and counts the number of Write calls.
  1273  type writeCountingConn struct {
  1274  	net.Conn
  1275  
  1276  	// numWrites is the number of writes that have been done.
  1277  	numWrites int
  1278  }
  1279  
  1280  func (wcc *writeCountingConn) Write(data []byte) (int, error) {
  1281  	wcc.numWrites++
  1282  	return wcc.Conn.Write(data)
  1283  }
  1284  
  1285  func TestBuffering(t *testing.T) {
  1286  	c, s := net.Pipe()
  1287  	done := make(chan bool)
  1288  
  1289  	clientWCC := &writeCountingConn{Conn: c}
  1290  	serverWCC := &writeCountingConn{Conn: s}
  1291  
  1292  	go func() {
  1293  		Server(serverWCC, testConfig).Handshake()
  1294  		serverWCC.Close()
  1295  		done <- true
  1296  	}()
  1297  
  1298  	err := Client(clientWCC, testConfig).Handshake()
  1299  	if err != nil {
  1300  		t.Fatal(err)
  1301  	}
  1302  	clientWCC.Close()
  1303  	<-done
  1304  
  1305  	if n := clientWCC.numWrites; n != 2 {
  1306  		t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
  1307  	}
  1308  
  1309  	if n := serverWCC.numWrites; n != 2 {
  1310  		t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
  1311  	}
  1312  }
  1313  
  1314  func TestAlertFlushing(t *testing.T) {
  1315  	c, s := net.Pipe()
  1316  	done := make(chan bool)
  1317  
  1318  	clientWCC := &writeCountingConn{Conn: c}
  1319  	serverWCC := &writeCountingConn{Conn: s}
  1320  
  1321  	serverConfig := testConfig.Clone()
  1322  
  1323  	// Cause a signature-time error
  1324  	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
  1325  	brokenKey.D = big.NewInt(42)
  1326  	serverConfig.Certificates = []Certificate{{
  1327  		Certificate: [][]byte{testRSACertificate},
  1328  		PrivateKey:  &brokenKey,
  1329  	}}
  1330  
  1331  	go func() {
  1332  		Server(serverWCC, serverConfig).Handshake()
  1333  		serverWCC.Close()
  1334  		done <- true
  1335  	}()
  1336  
  1337  	err := Client(clientWCC, testConfig).Handshake()
  1338  	if err == nil {
  1339  		t.Fatal("client unexpectedly returned no error")
  1340  	}
  1341  
  1342  	const expectedError = "remote error: tls: handshake failure"
  1343  	if e := err.Error(); !strings.Contains(e, expectedError) {
  1344  		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
  1345  	}
  1346  	clientWCC.Close()
  1347  	<-done
  1348  
  1349  	if n := clientWCC.numWrites; n != 1 {
  1350  		t.Errorf("expected client handshake to complete with one write, but saw %d", n)
  1351  	}
  1352  
  1353  	if n := serverWCC.numWrites; n != 1 {
  1354  		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
  1355  	}
  1356  }
  1357  
  1358  func TestHandshakeRace(t *testing.T) {
  1359  	// This test races a Read and Write to try and complete a handshake in
  1360  	// order to provide some evidence that there are no races or deadlocks
  1361  	// in the handshake locking.
  1362  	for i := 0; i < 32; i++ {
  1363  		c, s := net.Pipe()
  1364  
  1365  		go func() {
  1366  			server := Server(s, testConfig)
  1367  			if err := server.Handshake(); err != nil {
  1368  				panic(err)
  1369  			}
  1370  
  1371  			var request [1]byte
  1372  			if n, err := server.Read(request[:]); err != nil || n != 1 {
  1373  				panic(err)
  1374  			}
  1375  
  1376  			server.Write(request[:])
  1377  			server.Close()
  1378  		}()
  1379  
  1380  		startWrite := make(chan struct{})
  1381  		startRead := make(chan struct{})
  1382  		readDone := make(chan struct{})
  1383  
  1384  		client := Client(c, testConfig)
  1385  		go func() {
  1386  			<-startWrite
  1387  			var request [1]byte
  1388  			client.Write(request[:])
  1389  		}()
  1390  
  1391  		go func() {
  1392  			<-startRead
  1393  			var reply [1]byte
  1394  			if n, err := client.Read(reply[:]); err != nil || n != 1 {
  1395  				panic(err)
  1396  			}
  1397  			c.Close()
  1398  			readDone <- struct{}{}
  1399  		}()
  1400  
  1401  		if i&1 == 1 {
  1402  			startWrite <- struct{}{}
  1403  			startRead <- struct{}{}
  1404  		} else {
  1405  			startRead <- struct{}{}
  1406  			startWrite <- struct{}{}
  1407  		}
  1408  		<-readDone
  1409  	}
  1410  }
  1411  
  1412  func TestTLS11SignatureSchemes(t *testing.T) {
  1413  	expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
  1414  	if expected != len(tls11SignatureSchemes) {
  1415  		t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
  1416  	}
  1417  }
  1418  
  1419  var getClientCertificateTests = []struct {
  1420  	setup               func(*Config)
  1421  	expectedClientError string
  1422  	verify              func(*testing.T, int, *ConnectionState)
  1423  }{
  1424  	{
  1425  		func(clientConfig *Config) {
  1426  			// Returning a Certificate with no certificate data
  1427  			// should result in an empty message being sent to the
  1428  			// server.
  1429  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1430  				if len(cri.SignatureSchemes) == 0 {
  1431  					panic("empty SignatureSchemes")
  1432  				}
  1433  				return new(Certificate), nil
  1434  			}
  1435  		},
  1436  		"",
  1437  		func(t *testing.T, testNum int, cs *ConnectionState) {
  1438  			if l := len(cs.PeerCertificates); l != 0 {
  1439  				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  1440  			}
  1441  		},
  1442  	},
  1443  	{
  1444  		func(clientConfig *Config) {
  1445  			// With TLS 1.1, the SignatureSchemes should be
  1446  			// synthesised from the supported certificate types.
  1447  			clientConfig.MaxVersion = VersionTLS11
  1448  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1449  				if len(cri.SignatureSchemes) == 0 {
  1450  					panic("empty SignatureSchemes")
  1451  				}
  1452  				return new(Certificate), nil
  1453  			}
  1454  		},
  1455  		"",
  1456  		func(t *testing.T, testNum int, cs *ConnectionState) {
  1457  			if l := len(cs.PeerCertificates); l != 0 {
  1458  				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  1459  			}
  1460  		},
  1461  	},
  1462  	{
  1463  		func(clientConfig *Config) {
  1464  			// Returning an error should abort the handshake with
  1465  			// that error.
  1466  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1467  				return nil, errors.New("GetClientCertificate")
  1468  			}
  1469  		},
  1470  		"GetClientCertificate",
  1471  		func(t *testing.T, testNum int, cs *ConnectionState) {
  1472  		},
  1473  	},
  1474  	{
  1475  		func(clientConfig *Config) {
  1476  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  1477  				return &testConfig.Certificates[0], nil
  1478  			}
  1479  		},
  1480  		"",
  1481  		func(t *testing.T, testNum int, cs *ConnectionState) {
  1482  			if l := len(cs.VerifiedChains); l != 0 {
  1483  				t.Errorf("#%d: expected some verified chains, but found none", testNum)
  1484  			}
  1485  		},
  1486  	},
  1487  }
  1488  
  1489  func TestGetClientCertificate(t *testing.T) {
  1490  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1491  	if err != nil {
  1492  		panic(err)
  1493  	}
  1494  
  1495  	for i, test := range getClientCertificateTests {
  1496  		serverConfig := testConfig.Clone()
  1497  		serverConfig.ClientAuth = RequestClientCert
  1498  		serverConfig.RootCAs = x509.NewCertPool()
  1499  		serverConfig.RootCAs.AddCert(issuer)
  1500  
  1501  		clientConfig := testConfig.Clone()
  1502  
  1503  		test.setup(clientConfig)
  1504  
  1505  		type serverResult struct {
  1506  			cs  ConnectionState
  1507  			err error
  1508  		}
  1509  
  1510  		c, s := net.Pipe()
  1511  		done := make(chan serverResult)
  1512  
  1513  		go func() {
  1514  			defer s.Close()
  1515  			server := Server(s, serverConfig)
  1516  			err := server.Handshake()
  1517  
  1518  			var cs ConnectionState
  1519  			if err == nil {
  1520  				cs = server.ConnectionState()
  1521  			}
  1522  			done <- serverResult{cs, err}
  1523  		}()
  1524  
  1525  		clientErr := Client(c, clientConfig).Handshake()
  1526  		c.Close()
  1527  
  1528  		result := <-done
  1529  
  1530  		if clientErr != nil {
  1531  			if len(test.expectedClientError) == 0 {
  1532  				t.Errorf("#%d: client error: %v", i, clientErr)
  1533  			} else if got := clientErr.Error(); got != test.expectedClientError {
  1534  				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
  1535  			}
  1536  		} else if len(test.expectedClientError) > 0 {
  1537  			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
  1538  		} else if err := result.err; err != nil {
  1539  			t.Errorf("#%d: server error: %v", i, err)
  1540  		} else {
  1541  			test.verify(t, i, &result.cs)
  1542  		}
  1543  	}
  1544  }