github.com/gidoBOSSftw5731/go/src@v0.0.0-20210226122457-d24b0edbf019/crypto/tls/handshake_client_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rsa"
    10  	"crypto/x509"
    11  	"encoding/base64"
    12  	"encoding/binary"
    13  	"encoding/pem"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"math/big"
    18  	"net"
    19  	"os"
    20  	"os/exec"
    21  	"path/filepath"
    22  	"reflect"
    23  	"strconv"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  // Note: see comment in handshake_test.go for details of how the reference
    30  // tests work.
    31  
    32  // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
    33  // s_client` process.
    34  type opensslInputEvent int
    35  
    36  const (
    37  	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
    38  	// connection.
    39  	opensslRenegotiate opensslInputEvent = iota
    40  
    41  	// opensslSendBanner causes OpenSSL to send the contents of
    42  	// opensslSentinel on the connection.
    43  	opensslSendSentinel
    44  
    45  	// opensslKeyUpdate causes OpenSSL to send send a key update message to the
    46  	// client and request one back.
    47  	opensslKeyUpdate
    48  )
    49  
    50  const opensslSentinel = "SENTINEL\n"
    51  
    52  type opensslInput chan opensslInputEvent
    53  
    54  func (i opensslInput) Read(buf []byte) (n int, err error) {
    55  	for event := range i {
    56  		switch event {
    57  		case opensslRenegotiate:
    58  			return copy(buf, []byte("R\n")), nil
    59  		case opensslKeyUpdate:
    60  			return copy(buf, []byte("K\n")), nil
    61  		case opensslSendSentinel:
    62  			return copy(buf, []byte(opensslSentinel)), nil
    63  		default:
    64  			panic("unknown event")
    65  		}
    66  	}
    67  
    68  	return 0, io.EOF
    69  }
    70  
    71  // opensslOutputSink is an io.Writer that receives the stdout and stderr from an
    72  // `openssl` process and sends a value to handshakeComplete or readKeyUpdate
    73  // when certain messages are seen.
    74  type opensslOutputSink struct {
    75  	handshakeComplete chan struct{}
    76  	readKeyUpdate     chan struct{}
    77  	all               []byte
    78  	line              []byte
    79  }
    80  
    81  func newOpensslOutputSink() *opensslOutputSink {
    82  	return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
    83  }
    84  
    85  // opensslEndOfHandshake is a message that the “openssl s_server” tool will
    86  // print when a handshake completes if run with “-state”.
    87  const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
    88  
    89  // opensslReadKeyUpdate is a message that the “openssl s_server” tool will
    90  // print when a KeyUpdate message is received if run with “-state”.
    91  const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
    92  
    93  func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
    94  	o.line = append(o.line, data...)
    95  	o.all = append(o.all, data...)
    96  
    97  	for {
    98  		i := bytes.IndexByte(o.line, '\n')
    99  		if i < 0 {
   100  			break
   101  		}
   102  
   103  		if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
   104  			o.handshakeComplete <- struct{}{}
   105  		}
   106  		if bytes.Equal([]byte(opensslReadKeyUpdate), o.line[:i]) {
   107  			o.readKeyUpdate <- struct{}{}
   108  		}
   109  		o.line = o.line[i+1:]
   110  	}
   111  
   112  	return len(data), nil
   113  }
   114  
   115  func (o *opensslOutputSink) String() string {
   116  	return string(o.all)
   117  }
   118  
   119  // clientTest represents a test of the TLS client handshake against a reference
   120  // implementation.
   121  type clientTest struct {
   122  	// name is a freeform string identifying the test and the file in which
   123  	// the expected results will be stored.
   124  	name string
   125  	// args, if not empty, contains a series of arguments for the
   126  	// command to run for the reference server.
   127  	args []string
   128  	// config, if not nil, contains a custom Config to use for this test.
   129  	config *Config
   130  	// cert, if not empty, contains a DER-encoded certificate for the
   131  	// reference server.
   132  	cert []byte
   133  	// key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
   134  	// *ecdsa.PrivateKey which is the private key for the reference server.
   135  	key interface{}
   136  	// extensions, if not nil, contains a list of extension data to be returned
   137  	// from the ServerHello. The data should be in standard TLS format with
   138  	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
   139  	extensions [][]byte
   140  	// validate, if not nil, is a function that will be called with the
   141  	// ConnectionState of the resulting connection. It returns a non-nil
   142  	// error if the ConnectionState is unacceptable.
   143  	validate func(ConnectionState) error
   144  	// numRenegotiations is the number of times that the connection will be
   145  	// renegotiated.
   146  	numRenegotiations int
   147  	// renegotiationExpectedToFail, if not zero, is the number of the
   148  	// renegotiation attempt that is expected to fail.
   149  	renegotiationExpectedToFail int
   150  	// checkRenegotiationError, if not nil, is called with any error
   151  	// arising from renegotiation. It can map expected errors to nil to
   152  	// ignore them.
   153  	checkRenegotiationError func(renegotiationNum int, err error) error
   154  	// sendKeyUpdate will cause the server to send a KeyUpdate message.
   155  	sendKeyUpdate bool
   156  }
   157  
   158  var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
   159  
   160  // connFromCommand starts the reference server process, connects to it and
   161  // returns a recordingConn for the connection. The stdin return value is an
   162  // opensslInput for the stdin of the child process. It must be closed before
   163  // Waiting for child.
   164  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
   165  	cert := testRSACertificate
   166  	if len(test.cert) > 0 {
   167  		cert = test.cert
   168  	}
   169  	certPath := tempFile(string(cert))
   170  	defer os.Remove(certPath)
   171  
   172  	var key interface{} = testRSAPrivateKey
   173  	if test.key != nil {
   174  		key = test.key
   175  	}
   176  	derBytes, err := x509.MarshalPKCS8PrivateKey(key)
   177  	if err != nil {
   178  		panic(err)
   179  	}
   180  
   181  	var pemOut bytes.Buffer
   182  	pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
   183  
   184  	keyPath := tempFile(pemOut.String())
   185  	defer os.Remove(keyPath)
   186  
   187  	var command []string
   188  	command = append(command, serverCommand...)
   189  	command = append(command, test.args...)
   190  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   191  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   192  	// can't take "0" as an argument here so we have to pick a number and
   193  	// hope that it's not in use on the machine. Since this only occurs
   194  	// when -update is given and thus when there's a human watching the
   195  	// test, this isn't too bad.
   196  	const serverPort = 24323
   197  	command = append(command, "-accept", strconv.Itoa(serverPort))
   198  
   199  	if len(test.extensions) > 0 {
   200  		var serverInfo bytes.Buffer
   201  		for _, ext := range test.extensions {
   202  			pem.Encode(&serverInfo, &pem.Block{
   203  				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
   204  				Bytes: ext,
   205  			})
   206  		}
   207  		serverInfoPath := tempFile(serverInfo.String())
   208  		defer os.Remove(serverInfoPath)
   209  		command = append(command, "-serverinfo", serverInfoPath)
   210  	}
   211  
   212  	if test.numRenegotiations > 0 || test.sendKeyUpdate {
   213  		found := false
   214  		for _, flag := range command[1:] {
   215  			if flag == "-state" {
   216  				found = true
   217  				break
   218  			}
   219  		}
   220  
   221  		if !found {
   222  			panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
   223  		}
   224  	}
   225  
   226  	cmd := exec.Command(command[0], command[1:]...)
   227  	stdin = opensslInput(make(chan opensslInputEvent))
   228  	cmd.Stdin = stdin
   229  	out := newOpensslOutputSink()
   230  	cmd.Stdout = out
   231  	cmd.Stderr = out
   232  	if err := cmd.Start(); err != nil {
   233  		return nil, nil, nil, nil, err
   234  	}
   235  
   236  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   237  	// opening the listening socket, so we can't use that to wait until it
   238  	// has started listening. Thus we are forced to poll until we get a
   239  	// connection.
   240  	var tcpConn net.Conn
   241  	for i := uint(0); i < 5; i++ {
   242  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   243  			IP:   net.IPv4(127, 0, 0, 1),
   244  			Port: serverPort,
   245  		})
   246  		if err == nil {
   247  			break
   248  		}
   249  		time.Sleep((1 << i) * 5 * time.Millisecond)
   250  	}
   251  	if err != nil {
   252  		close(stdin)
   253  		cmd.Process.Kill()
   254  		err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
   255  		return nil, nil, nil, nil, err
   256  	}
   257  
   258  	record := &recordingConn{
   259  		Conn: tcpConn,
   260  	}
   261  
   262  	return record, cmd, stdin, out, nil
   263  }
   264  
   265  func (test *clientTest) dataPath() string {
   266  	return filepath.Join("testdata", "Client-"+test.name)
   267  }
   268  
   269  func (test *clientTest) loadData() (flows [][]byte, err error) {
   270  	in, err := os.Open(test.dataPath())
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	defer in.Close()
   275  	return parseTestData(in)
   276  }
   277  
   278  func (test *clientTest) run(t *testing.T, write bool) {
   279  	var clientConn, serverConn net.Conn
   280  	var recordingConn *recordingConn
   281  	var childProcess *exec.Cmd
   282  	var stdin opensslInput
   283  	var stdout *opensslOutputSink
   284  
   285  	if write {
   286  		var err error
   287  		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
   288  		if err != nil {
   289  			t.Fatalf("Failed to start subcommand: %s", err)
   290  		}
   291  		clientConn = recordingConn
   292  		defer func() {
   293  			if t.Failed() {
   294  				t.Logf("OpenSSL output:\n\n%s", stdout.all)
   295  			}
   296  		}()
   297  	} else {
   298  		clientConn, serverConn = localPipe(t)
   299  	}
   300  
   301  	doneChan := make(chan bool)
   302  	defer func() {
   303  		clientConn.Close()
   304  		<-doneChan
   305  	}()
   306  	go func() {
   307  		defer close(doneChan)
   308  
   309  		config := test.config
   310  		if config == nil {
   311  			config = testConfig
   312  		}
   313  		client := Client(clientConn, config)
   314  		defer client.Close()
   315  
   316  		if _, err := client.Write([]byte("hello\n")); err != nil {
   317  			t.Errorf("Client.Write failed: %s", err)
   318  			return
   319  		}
   320  
   321  		for i := 1; i <= test.numRenegotiations; i++ {
   322  			// The initial handshake will generate a
   323  			// handshakeComplete signal which needs to be quashed.
   324  			if i == 1 && write {
   325  				<-stdout.handshakeComplete
   326  			}
   327  
   328  			// OpenSSL will try to interleave application data and
   329  			// a renegotiation if we send both concurrently.
   330  			// Therefore: ask OpensSSL to start a renegotiation, run
   331  			// a goroutine to call client.Read and thus process the
   332  			// renegotiation request, watch for OpenSSL's stdout to
   333  			// indicate that the handshake is complete and,
   334  			// finally, have OpenSSL write something to cause
   335  			// client.Read to complete.
   336  			if write {
   337  				stdin <- opensslRenegotiate
   338  			}
   339  
   340  			signalChan := make(chan struct{})
   341  
   342  			go func() {
   343  				defer close(signalChan)
   344  
   345  				buf := make([]byte, 256)
   346  				n, err := client.Read(buf)
   347  
   348  				if test.checkRenegotiationError != nil {
   349  					newErr := test.checkRenegotiationError(i, err)
   350  					if err != nil && newErr == nil {
   351  						return
   352  					}
   353  					err = newErr
   354  				}
   355  
   356  				if err != nil {
   357  					t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
   358  					return
   359  				}
   360  
   361  				buf = buf[:n]
   362  				if !bytes.Equal([]byte(opensslSentinel), buf) {
   363  					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
   364  				}
   365  
   366  				if expected := i + 1; client.handshakes != expected {
   367  					t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
   368  				}
   369  			}()
   370  
   371  			if write && test.renegotiationExpectedToFail != i {
   372  				<-stdout.handshakeComplete
   373  				stdin <- opensslSendSentinel
   374  			}
   375  			<-signalChan
   376  		}
   377  
   378  		if test.sendKeyUpdate {
   379  			if write {
   380  				<-stdout.handshakeComplete
   381  				stdin <- opensslKeyUpdate
   382  			}
   383  
   384  			doneRead := make(chan struct{})
   385  
   386  			go func() {
   387  				defer close(doneRead)
   388  
   389  				buf := make([]byte, 256)
   390  				n, err := client.Read(buf)
   391  
   392  				if err != nil {
   393  					t.Errorf("Client.Read failed after KeyUpdate: %s", err)
   394  					return
   395  				}
   396  
   397  				buf = buf[:n]
   398  				if !bytes.Equal([]byte(opensslSentinel), buf) {
   399  					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
   400  				}
   401  			}()
   402  
   403  			if write {
   404  				// There's no real reason to wait for the client KeyUpdate to
   405  				// send data with the new server keys, except that s_server
   406  				// drops writes if they are sent at the wrong time.
   407  				<-stdout.readKeyUpdate
   408  				stdin <- opensslSendSentinel
   409  			}
   410  			<-doneRead
   411  
   412  			if _, err := client.Write([]byte("hello again\n")); err != nil {
   413  				t.Errorf("Client.Write failed: %s", err)
   414  				return
   415  			}
   416  		}
   417  
   418  		if test.validate != nil {
   419  			if err := test.validate(client.ConnectionState()); err != nil {
   420  				t.Errorf("validate callback returned error: %s", err)
   421  			}
   422  		}
   423  
   424  		// If the server sent us an alert after our last flight, give it a
   425  		// chance to arrive.
   426  		if write && test.renegotiationExpectedToFail == 0 {
   427  			if err := peekError(client); err != nil {
   428  				t.Errorf("final Read returned an error: %s", err)
   429  			}
   430  		}
   431  	}()
   432  
   433  	if !write {
   434  		flows, err := test.loadData()
   435  		if err != nil {
   436  			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
   437  		}
   438  		for i, b := range flows {
   439  			if i%2 == 1 {
   440  				if *fast {
   441  					serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
   442  				} else {
   443  					serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
   444  				}
   445  				serverConn.Write(b)
   446  				continue
   447  			}
   448  			bb := make([]byte, len(b))
   449  			if *fast {
   450  				serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
   451  			} else {
   452  				serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
   453  			}
   454  			_, err := io.ReadFull(serverConn, bb)
   455  			if err != nil {
   456  				t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
   457  			}
   458  			if !bytes.Equal(b, bb) {
   459  				t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
   460  			}
   461  		}
   462  	}
   463  
   464  	<-doneChan
   465  	if !write {
   466  		serverConn.Close()
   467  	}
   468  
   469  	if write {
   470  		path := test.dataPath()
   471  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   472  		if err != nil {
   473  			t.Fatalf("Failed to create output file: %s", err)
   474  		}
   475  		defer out.Close()
   476  		recordingConn.Close()
   477  		close(stdin)
   478  		childProcess.Process.Kill()
   479  		childProcess.Wait()
   480  		if len(recordingConn.flows) < 3 {
   481  			t.Fatalf("Client connection didn't work")
   482  		}
   483  		recordingConn.WriteTo(out)
   484  		t.Logf("Wrote %s\n", path)
   485  	}
   486  }
   487  
   488  // peekError does a read with a short timeout to check if the next read would
   489  // cause an error, for example if there is an alert waiting on the wire.
   490  func peekError(conn net.Conn) error {
   491  	conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
   492  	if n, err := conn.Read(make([]byte, 1)); n != 0 {
   493  		return errors.New("unexpectedly read data")
   494  	} else if err != nil {
   495  		if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
   496  			return err
   497  		}
   498  	}
   499  	return nil
   500  }
   501  
   502  func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
   503  	// Make a deep copy of the template before going parallel.
   504  	test := *template
   505  	if template.config != nil {
   506  		test.config = template.config.Clone()
   507  	}
   508  	test.name = version + "-" + test.name
   509  	test.args = append([]string{option}, test.args...)
   510  
   511  	runTestAndUpdateIfNeeded(t, version, test.run, false)
   512  }
   513  
   514  func runClientTestTLS10(t *testing.T, template *clientTest) {
   515  	runClientTestForVersion(t, template, "TLSv10", "-tls1")
   516  }
   517  
   518  func runClientTestTLS11(t *testing.T, template *clientTest) {
   519  	runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
   520  }
   521  
   522  func runClientTestTLS12(t *testing.T, template *clientTest) {
   523  	runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
   524  }
   525  
   526  func runClientTestTLS13(t *testing.T, template *clientTest) {
   527  	runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
   528  }
   529  
   530  func TestHandshakeClientRSARC4(t *testing.T) {
   531  	test := &clientTest{
   532  		name: "RSA-RC4",
   533  		args: []string{"-cipher", "RC4-SHA"},
   534  	}
   535  	runClientTestTLS10(t, test)
   536  	runClientTestTLS11(t, test)
   537  	runClientTestTLS12(t, test)
   538  }
   539  
   540  func TestHandshakeClientRSAAES128GCM(t *testing.T) {
   541  	test := &clientTest{
   542  		name: "AES128-GCM-SHA256",
   543  		args: []string{"-cipher", "AES128-GCM-SHA256"},
   544  	}
   545  	runClientTestTLS12(t, test)
   546  }
   547  
   548  func TestHandshakeClientRSAAES256GCM(t *testing.T) {
   549  	test := &clientTest{
   550  		name: "AES256-GCM-SHA384",
   551  		args: []string{"-cipher", "AES256-GCM-SHA384"},
   552  	}
   553  	runClientTestTLS12(t, test)
   554  }
   555  
   556  func TestHandshakeClientECDHERSAAES(t *testing.T) {
   557  	test := &clientTest{
   558  		name: "ECDHE-RSA-AES",
   559  		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
   560  	}
   561  	runClientTestTLS10(t, test)
   562  	runClientTestTLS11(t, test)
   563  	runClientTestTLS12(t, test)
   564  }
   565  
   566  func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   567  	test := &clientTest{
   568  		name: "ECDHE-ECDSA-AES",
   569  		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
   570  		cert: testECDSACertificate,
   571  		key:  testECDSAPrivateKey,
   572  	}
   573  	runClientTestTLS10(t, test)
   574  	runClientTestTLS11(t, test)
   575  	runClientTestTLS12(t, test)
   576  }
   577  
   578  func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   579  	test := &clientTest{
   580  		name: "ECDHE-ECDSA-AES-GCM",
   581  		args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   582  		cert: testECDSACertificate,
   583  		key:  testECDSAPrivateKey,
   584  	}
   585  	runClientTestTLS12(t, test)
   586  }
   587  
   588  func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
   589  	test := &clientTest{
   590  		name: "ECDHE-ECDSA-AES256-GCM-SHA384",
   591  		args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
   592  		cert: testECDSACertificate,
   593  		key:  testECDSAPrivateKey,
   594  	}
   595  	runClientTestTLS12(t, test)
   596  }
   597  
   598  func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
   599  	test := &clientTest{
   600  		name: "AES128-SHA256",
   601  		args: []string{"-cipher", "AES128-SHA256"},
   602  	}
   603  	runClientTestTLS12(t, test)
   604  }
   605  
   606  func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
   607  	test := &clientTest{
   608  		name: "ECDHE-RSA-AES128-SHA256",
   609  		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
   610  	}
   611  	runClientTestTLS12(t, test)
   612  }
   613  
   614  func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
   615  	test := &clientTest{
   616  		name: "ECDHE-ECDSA-AES128-SHA256",
   617  		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
   618  		cert: testECDSACertificate,
   619  		key:  testECDSAPrivateKey,
   620  	}
   621  	runClientTestTLS12(t, test)
   622  }
   623  
   624  func TestHandshakeClientX25519(t *testing.T) {
   625  	config := testConfig.Clone()
   626  	config.CurvePreferences = []CurveID{X25519}
   627  
   628  	test := &clientTest{
   629  		name:   "X25519-ECDHE",
   630  		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
   631  		config: config,
   632  	}
   633  
   634  	runClientTestTLS12(t, test)
   635  	runClientTestTLS13(t, test)
   636  }
   637  
   638  func TestHandshakeClientP256(t *testing.T) {
   639  	config := testConfig.Clone()
   640  	config.CurvePreferences = []CurveID{CurveP256}
   641  
   642  	test := &clientTest{
   643  		name:   "P256-ECDHE",
   644  		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
   645  		config: config,
   646  	}
   647  
   648  	runClientTestTLS12(t, test)
   649  	runClientTestTLS13(t, test)
   650  }
   651  
   652  func TestHandshakeClientHelloRetryRequest(t *testing.T) {
   653  	config := testConfig.Clone()
   654  	config.CurvePreferences = []CurveID{X25519, CurveP256}
   655  
   656  	test := &clientTest{
   657  		name:   "HelloRetryRequest",
   658  		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
   659  		config: config,
   660  	}
   661  
   662  	runClientTestTLS13(t, test)
   663  }
   664  
   665  func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
   666  	config := testConfig.Clone()
   667  	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
   668  
   669  	test := &clientTest{
   670  		name:   "ECDHE-RSA-CHACHA20-POLY1305",
   671  		args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
   672  		config: config,
   673  	}
   674  
   675  	runClientTestTLS12(t, test)
   676  }
   677  
   678  func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
   679  	config := testConfig.Clone()
   680  	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
   681  
   682  	test := &clientTest{
   683  		name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
   684  		args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
   685  		config: config,
   686  		cert:   testECDSACertificate,
   687  		key:    testECDSAPrivateKey,
   688  	}
   689  
   690  	runClientTestTLS12(t, test)
   691  }
   692  
   693  func TestHandshakeClientAES128SHA256(t *testing.T) {
   694  	test := &clientTest{
   695  		name: "AES128-SHA256",
   696  		args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
   697  	}
   698  	runClientTestTLS13(t, test)
   699  }
   700  func TestHandshakeClientAES256SHA384(t *testing.T) {
   701  	test := &clientTest{
   702  		name: "AES256-SHA384",
   703  		args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
   704  	}
   705  	runClientTestTLS13(t, test)
   706  }
   707  func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
   708  	test := &clientTest{
   709  		name: "CHACHA20-SHA256",
   710  		args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
   711  	}
   712  	runClientTestTLS13(t, test)
   713  }
   714  
   715  func TestHandshakeClientECDSATLS13(t *testing.T) {
   716  	test := &clientTest{
   717  		name: "ECDSA",
   718  		cert: testECDSACertificate,
   719  		key:  testECDSAPrivateKey,
   720  	}
   721  	runClientTestTLS13(t, test)
   722  }
   723  
   724  func TestHandshakeClientEd25519(t *testing.T) {
   725  	test := &clientTest{
   726  		name: "Ed25519",
   727  		cert: testEd25519Certificate,
   728  		key:  testEd25519PrivateKey,
   729  	}
   730  	runClientTestTLS12(t, test)
   731  	runClientTestTLS13(t, test)
   732  
   733  	config := testConfig.Clone()
   734  	cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
   735  	config.Certificates = []Certificate{cert}
   736  
   737  	test = &clientTest{
   738  		name:   "ClientCert-Ed25519",
   739  		args:   []string{"-Verify", "1"},
   740  		config: config,
   741  	}
   742  
   743  	runClientTestTLS12(t, test)
   744  	runClientTestTLS13(t, test)
   745  }
   746  
   747  func TestHandshakeClientCertRSA(t *testing.T) {
   748  	config := testConfig.Clone()
   749  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   750  	config.Certificates = []Certificate{cert}
   751  
   752  	test := &clientTest{
   753  		name:   "ClientCert-RSA-RSA",
   754  		args:   []string{"-cipher", "AES128", "-Verify", "1"},
   755  		config: config,
   756  	}
   757  
   758  	runClientTestTLS10(t, test)
   759  	runClientTestTLS12(t, test)
   760  
   761  	test = &clientTest{
   762  		name:   "ClientCert-RSA-ECDSA",
   763  		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
   764  		config: config,
   765  		cert:   testECDSACertificate,
   766  		key:    testECDSAPrivateKey,
   767  	}
   768  
   769  	runClientTestTLS10(t, test)
   770  	runClientTestTLS12(t, test)
   771  	runClientTestTLS13(t, test)
   772  
   773  	test = &clientTest{
   774  		name:   "ClientCert-RSA-AES256-GCM-SHA384",
   775  		args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
   776  		config: config,
   777  		cert:   testRSACertificate,
   778  		key:    testRSAPrivateKey,
   779  	}
   780  
   781  	runClientTestTLS12(t, test)
   782  }
   783  
   784  func TestHandshakeClientCertECDSA(t *testing.T) {
   785  	config := testConfig.Clone()
   786  	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   787  	config.Certificates = []Certificate{cert}
   788  
   789  	test := &clientTest{
   790  		name:   "ClientCert-ECDSA-RSA",
   791  		args:   []string{"-cipher", "AES128", "-Verify", "1"},
   792  		config: config,
   793  	}
   794  
   795  	runClientTestTLS10(t, test)
   796  	runClientTestTLS12(t, test)
   797  	runClientTestTLS13(t, test)
   798  
   799  	test = &clientTest{
   800  		name:   "ClientCert-ECDSA-ECDSA",
   801  		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
   802  		config: config,
   803  		cert:   testECDSACertificate,
   804  		key:    testECDSAPrivateKey,
   805  	}
   806  
   807  	runClientTestTLS10(t, test)
   808  	runClientTestTLS12(t, test)
   809  }
   810  
   811  // TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
   812  // client and server certificates. It also serves from both sides a certificate
   813  // signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
   814  // works.
   815  func TestHandshakeClientCertRSAPSS(t *testing.T) {
   816  	cert, err := x509.ParseCertificate(testRSAPSSCertificate)
   817  	if err != nil {
   818  		panic(err)
   819  	}
   820  	rootCAs := x509.NewCertPool()
   821  	rootCAs.AddCert(cert)
   822  
   823  	config := testConfig.Clone()
   824  	// Use GetClientCertificate to bypass the client certificate selection logic.
   825  	config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
   826  		return &Certificate{
   827  			Certificate: [][]byte{testRSAPSSCertificate},
   828  			PrivateKey:  testRSAPrivateKey,
   829  		}, nil
   830  	}
   831  	config.RootCAs = rootCAs
   832  
   833  	test := &clientTest{
   834  		name: "ClientCert-RSA-RSAPSS",
   835  		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
   836  			"rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
   837  		config: config,
   838  		cert:   testRSAPSSCertificate,
   839  		key:    testRSAPrivateKey,
   840  	}
   841  	runClientTestTLS12(t, test)
   842  	runClientTestTLS13(t, test)
   843  }
   844  
   845  func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
   846  	config := testConfig.Clone()
   847  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   848  	config.Certificates = []Certificate{cert}
   849  
   850  	test := &clientTest{
   851  		name: "ClientCert-RSA-RSAPKCS1v15",
   852  		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
   853  			"rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
   854  		config: config,
   855  	}
   856  
   857  	runClientTestTLS12(t, test)
   858  }
   859  
   860  func TestClientKeyUpdate(t *testing.T) {
   861  	test := &clientTest{
   862  		name:          "KeyUpdate",
   863  		args:          []string{"-state"},
   864  		sendKeyUpdate: true,
   865  	}
   866  	runClientTestTLS13(t, test)
   867  }
   868  
   869  func TestResumption(t *testing.T) {
   870  	t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
   871  	t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
   872  }
   873  
   874  func testResumption(t *testing.T, version uint16) {
   875  	if testing.Short() {
   876  		t.Skip("skipping in -short mode")
   877  	}
   878  	serverConfig := &Config{
   879  		MaxVersion:   version,
   880  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   881  		Certificates: testConfig.Certificates,
   882  	}
   883  
   884  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
   885  	if err != nil {
   886  		panic(err)
   887  	}
   888  
   889  	rootCAs := x509.NewCertPool()
   890  	rootCAs.AddCert(issuer)
   891  
   892  	clientConfig := &Config{
   893  		MaxVersion:         version,
   894  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   895  		ClientSessionCache: NewLRUClientSessionCache(32),
   896  		RootCAs:            rootCAs,
   897  		ServerName:         "example.golang",
   898  	}
   899  
   900  	testResumeState := func(test string, didResume bool) {
   901  		_, hs, err := testHandshake(t, clientConfig, serverConfig)
   902  		if err != nil {
   903  			t.Fatalf("%s: handshake failed: %s", test, err)
   904  		}
   905  		if hs.DidResume != didResume {
   906  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   907  		}
   908  		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
   909  			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
   910  		}
   911  		if got, want := hs.ServerName, clientConfig.ServerName; got != want {
   912  			t.Errorf("%s: server name %s, want %s", test, got, want)
   913  		}
   914  	}
   915  
   916  	getTicket := func() []byte {
   917  		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
   918  	}
   919  	deleteTicket := func() {
   920  		ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
   921  		clientConfig.ClientSessionCache.Put(ticketKey, nil)
   922  	}
   923  	corruptTicket := func() {
   924  		clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff
   925  	}
   926  	randomKey := func() [32]byte {
   927  		var k [32]byte
   928  		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
   929  			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
   930  		}
   931  		return k
   932  	}
   933  
   934  	testResumeState("Handshake", false)
   935  	ticket := getTicket()
   936  	testResumeState("Resume", true)
   937  	if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
   938  		t.Fatal("first ticket doesn't match ticket after resumption")
   939  	}
   940  	if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
   941  		t.Fatal("ticket didn't change after resumption")
   942  	}
   943  
   944  	// An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key.
   945  	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
   946  	testResumeState("ResumeWithOldTicket", true)
   947  	if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) {
   948  		t.Fatal("old first ticket matches the fresh one")
   949  	}
   950  
   951  	// Now the session tickey key is expired, so a full handshake should occur.
   952  	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
   953  	testResumeState("ResumeWithExpiredTicket", false)
   954  	if bytes.Equal(ticket, getTicket()) {
   955  		t.Fatal("expired first ticket matches the fresh one")
   956  	}
   957  
   958  	serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
   959  	key1 := randomKey()
   960  	serverConfig.SetSessionTicketKeys([][32]byte{key1})
   961  
   962  	testResumeState("InvalidSessionTicketKey", false)
   963  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   964  
   965  	key2 := randomKey()
   966  	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
   967  	ticket = getTicket()
   968  	testResumeState("KeyChange", true)
   969  	if bytes.Equal(ticket, getTicket()) {
   970  		t.Fatal("new ticket wasn't included while resuming")
   971  	}
   972  	testResumeState("KeyChangeFinish", true)
   973  
   974  	// Age the session ticket a bit, but not yet expired.
   975  	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
   976  	testResumeState("OldSessionTicket", true)
   977  	ticket = getTicket()
   978  	// Expire the session ticket, which would force a full handshake.
   979  	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
   980  	testResumeState("ExpiredSessionTicket", false)
   981  	if bytes.Equal(ticket, getTicket()) {
   982  		t.Fatal("new ticket wasn't provided after old ticket expired")
   983  	}
   984  
   985  	// Age the session ticket a bit at a time, but don't expire it.
   986  	d := 0 * time.Hour
   987  	for i := 0; i < 13; i++ {
   988  		d += 12 * time.Hour
   989  		serverConfig.Time = func() time.Time { return time.Now().Add(d) }
   990  		testResumeState("OldSessionTicket", true)
   991  	}
   992  	// Expire it (now a little more than 7 days) and make sure a full
   993  	// handshake occurs for TLS 1.2. Resumption should still occur for
   994  	// TLS 1.3 since the client should be using a fresh ticket sent over
   995  	// by the server.
   996  	d += 12 * time.Hour
   997  	serverConfig.Time = func() time.Time { return time.Now().Add(d) }
   998  	if version == VersionTLS13 {
   999  		testResumeState("ExpiredSessionTicket", true)
  1000  	} else {
  1001  		testResumeState("ExpiredSessionTicket", false)
  1002  	}
  1003  	if bytes.Equal(ticket, getTicket()) {
  1004  		t.Fatal("new ticket wasn't provided after old ticket expired")
  1005  	}
  1006  
  1007  	// Reset serverConfig to ensure that calling SetSessionTicketKeys
  1008  	// before the serverConfig is used works.
  1009  	serverConfig = &Config{
  1010  		MaxVersion:   version,
  1011  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
  1012  		Certificates: testConfig.Certificates,
  1013  	}
  1014  	serverConfig.SetSessionTicketKeys([][32]byte{key2})
  1015  
  1016  	testResumeState("FreshConfig", true)
  1017  
  1018  	// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
  1019  	// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
  1020  	if version != VersionTLS13 {
  1021  		clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
  1022  		testResumeState("DifferentCipherSuite", false)
  1023  		testResumeState("DifferentCipherSuiteRecovers", true)
  1024  	}
  1025  
  1026  	deleteTicket()
  1027  	testResumeState("WithoutSessionTicket", false)
  1028  
  1029  	// Session resumption should work when using client certificates
  1030  	deleteTicket()
  1031  	serverConfig.ClientCAs = rootCAs
  1032  	serverConfig.ClientAuth = RequireAndVerifyClientCert
  1033  	clientConfig.Certificates = serverConfig.Certificates
  1034  	testResumeState("InitialHandshake", false)
  1035  	testResumeState("WithClientCertificates", true)
  1036  	serverConfig.ClientAuth = NoClientCert
  1037  
  1038  	// Tickets should be removed from the session cache on TLS handshake
  1039  	// failure, and the client should recover from a corrupted PSK
  1040  	testResumeState("FetchTicketToCorrupt", false)
  1041  	corruptTicket()
  1042  	_, _, err = testHandshake(t, clientConfig, serverConfig)
  1043  	if err == nil {
  1044  		t.Fatalf("handshake did not fail with a corrupted client secret")
  1045  	}
  1046  	testResumeState("AfterHandshakeFailure", false)
  1047  
  1048  	clientConfig.ClientSessionCache = nil
  1049  	testResumeState("WithoutSessionCache", false)
  1050  }
  1051  
  1052  func TestLRUClientSessionCache(t *testing.T) {
  1053  	// Initialize cache of capacity 4.
  1054  	cache := NewLRUClientSessionCache(4)
  1055  	cs := make([]ClientSessionState, 6)
  1056  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
  1057  
  1058  	// Add 4 entries to the cache and look them up.
  1059  	for i := 0; i < 4; i++ {
  1060  		cache.Put(keys[i], &cs[i])
  1061  	}
  1062  	for i := 0; i < 4; i++ {
  1063  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
  1064  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
  1065  		}
  1066  	}
  1067  
  1068  	// Add 2 more entries to the cache. First 2 should be evicted.
  1069  	for i := 4; i < 6; i++ {
  1070  		cache.Put(keys[i], &cs[i])
  1071  	}
  1072  	for i := 0; i < 2; i++ {
  1073  		if s, ok := cache.Get(keys[i]); ok || s != nil {
  1074  			t.Fatalf("session cache should have evicted key: %s", keys[i])
  1075  		}
  1076  	}
  1077  
  1078  	// Touch entry 2. LRU should evict 3 next.
  1079  	cache.Get(keys[2])
  1080  	cache.Put(keys[0], &cs[0])
  1081  	if s, ok := cache.Get(keys[3]); ok || s != nil {
  1082  		t.Fatalf("session cache should have evicted key 3")
  1083  	}
  1084  
  1085  	// Update entry 0 in place.
  1086  	cache.Put(keys[0], &cs[3])
  1087  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
  1088  		t.Fatalf("session cache failed update for key 0")
  1089  	}
  1090  
  1091  	// Calling Put with a nil entry deletes the key.
  1092  	cache.Put(keys[0], nil)
  1093  	if _, ok := cache.Get(keys[0]); ok {
  1094  		t.Fatalf("session cache failed to delete key 0")
  1095  	}
  1096  
  1097  	// Delete entry 2. LRU should keep 4 and 5
  1098  	cache.Put(keys[2], nil)
  1099  	if _, ok := cache.Get(keys[2]); ok {
  1100  		t.Fatalf("session cache failed to delete key 4")
  1101  	}
  1102  	for i := 4; i < 6; i++ {
  1103  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
  1104  			t.Fatalf("session cache should not have deleted key: %s", keys[i])
  1105  		}
  1106  	}
  1107  }
  1108  
  1109  func TestKeyLogTLS12(t *testing.T) {
  1110  	var serverBuf, clientBuf bytes.Buffer
  1111  
  1112  	clientConfig := testConfig.Clone()
  1113  	clientConfig.KeyLogWriter = &clientBuf
  1114  	clientConfig.MaxVersion = VersionTLS12
  1115  
  1116  	serverConfig := testConfig.Clone()
  1117  	serverConfig.KeyLogWriter = &serverBuf
  1118  	serverConfig.MaxVersion = VersionTLS12
  1119  
  1120  	c, s := localPipe(t)
  1121  	done := make(chan bool)
  1122  
  1123  	go func() {
  1124  		defer close(done)
  1125  
  1126  		if err := Server(s, serverConfig).Handshake(); err != nil {
  1127  			t.Errorf("server: %s", err)
  1128  			return
  1129  		}
  1130  		s.Close()
  1131  	}()
  1132  
  1133  	if err := Client(c, clientConfig).Handshake(); err != nil {
  1134  		t.Fatalf("client: %s", err)
  1135  	}
  1136  
  1137  	c.Close()
  1138  	<-done
  1139  
  1140  	checkKeylogLine := func(side, loggedLine string) {
  1141  		if len(loggedLine) == 0 {
  1142  			t.Fatalf("%s: no keylog line was produced", side)
  1143  		}
  1144  		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
  1145  			1 /* space */ +
  1146  			32*2 /* hex client nonce */ +
  1147  			1 /* space */ +
  1148  			48*2 /* hex master secret */ +
  1149  			1 /* new line */
  1150  		if len(loggedLine) != expectedLen {
  1151  			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
  1152  		}
  1153  		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
  1154  			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
  1155  		}
  1156  	}
  1157  
  1158  	checkKeylogLine("client", clientBuf.String())
  1159  	checkKeylogLine("server", serverBuf.String())
  1160  }
  1161  
  1162  func TestKeyLogTLS13(t *testing.T) {
  1163  	var serverBuf, clientBuf bytes.Buffer
  1164  
  1165  	clientConfig := testConfig.Clone()
  1166  	clientConfig.KeyLogWriter = &clientBuf
  1167  
  1168  	serverConfig := testConfig.Clone()
  1169  	serverConfig.KeyLogWriter = &serverBuf
  1170  
  1171  	c, s := localPipe(t)
  1172  	done := make(chan bool)
  1173  
  1174  	go func() {
  1175  		defer close(done)
  1176  
  1177  		if err := Server(s, serverConfig).Handshake(); err != nil {
  1178  			t.Errorf("server: %s", err)
  1179  			return
  1180  		}
  1181  		s.Close()
  1182  	}()
  1183  
  1184  	if err := Client(c, clientConfig).Handshake(); err != nil {
  1185  		t.Fatalf("client: %s", err)
  1186  	}
  1187  
  1188  	c.Close()
  1189  	<-done
  1190  
  1191  	checkKeylogLines := func(side, loggedLines string) {
  1192  		loggedLines = strings.TrimSpace(loggedLines)
  1193  		lines := strings.Split(loggedLines, "\n")
  1194  		if len(lines) != 4 {
  1195  			t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
  1196  		}
  1197  	}
  1198  
  1199  	checkKeylogLines("client", clientBuf.String())
  1200  	checkKeylogLines("server", serverBuf.String())
  1201  }
  1202  
  1203  func TestHandshakeClientALPNMatch(t *testing.T) {
  1204  	config := testConfig.Clone()
  1205  	config.NextProtos = []string{"proto2", "proto1"}
  1206  
  1207  	test := &clientTest{
  1208  		name: "ALPN",
  1209  		// Note that this needs OpenSSL 1.0.2 because that is the first
  1210  		// version that supports the -alpn flag.
  1211  		args:   []string{"-alpn", "proto1,proto2"},
  1212  		config: config,
  1213  		validate: func(state ConnectionState) error {
  1214  			// The server's preferences should override the client.
  1215  			if state.NegotiatedProtocol != "proto1" {
  1216  				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
  1217  			}
  1218  			return nil
  1219  		},
  1220  	}
  1221  	runClientTestTLS12(t, test)
  1222  	runClientTestTLS13(t, test)
  1223  }
  1224  
  1225  // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
  1226  const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
  1227  
  1228  func TestHandshakClientSCTs(t *testing.T) {
  1229  	config := testConfig.Clone()
  1230  
  1231  	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
  1232  	if err != nil {
  1233  		t.Fatal(err)
  1234  	}
  1235  
  1236  	// Note that this needs OpenSSL 1.0.2 because that is the first
  1237  	// version that supports the -serverinfo flag.
  1238  	test := &clientTest{
  1239  		name:       "SCT",
  1240  		config:     config,
  1241  		extensions: [][]byte{scts},
  1242  		validate: func(state ConnectionState) error {
  1243  			expectedSCTs := [][]byte{
  1244  				scts[8:125],
  1245  				scts[127:245],
  1246  				scts[247:],
  1247  			}
  1248  			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
  1249  				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
  1250  			}
  1251  			for i, expected := range expectedSCTs {
  1252  				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
  1253  					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
  1254  				}
  1255  			}
  1256  			return nil
  1257  		},
  1258  	}
  1259  	runClientTestTLS12(t, test)
  1260  
  1261  	// TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
  1262  	// supports ServerHello extensions.
  1263  }
  1264  
  1265  func TestRenegotiationRejected(t *testing.T) {
  1266  	config := testConfig.Clone()
  1267  	test := &clientTest{
  1268  		name:                        "RenegotiationRejected",
  1269  		args:                        []string{"-state"},
  1270  		config:                      config,
  1271  		numRenegotiations:           1,
  1272  		renegotiationExpectedToFail: 1,
  1273  		checkRenegotiationError: func(renegotiationNum int, err error) error {
  1274  			if err == nil {
  1275  				return errors.New("expected error from renegotiation but got nil")
  1276  			}
  1277  			if !strings.Contains(err.Error(), "no renegotiation") {
  1278  				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  1279  			}
  1280  			return nil
  1281  		},
  1282  	}
  1283  	runClientTestTLS12(t, test)
  1284  }
  1285  
  1286  func TestRenegotiateOnce(t *testing.T) {
  1287  	config := testConfig.Clone()
  1288  	config.Renegotiation = RenegotiateOnceAsClient
  1289  
  1290  	test := &clientTest{
  1291  		name:              "RenegotiateOnce",
  1292  		args:              []string{"-state"},
  1293  		config:            config,
  1294  		numRenegotiations: 1,
  1295  	}
  1296  
  1297  	runClientTestTLS12(t, test)
  1298  }
  1299  
  1300  func TestRenegotiateTwice(t *testing.T) {
  1301  	config := testConfig.Clone()
  1302  	config.Renegotiation = RenegotiateFreelyAsClient
  1303  
  1304  	test := &clientTest{
  1305  		name:              "RenegotiateTwice",
  1306  		args:              []string{"-state"},
  1307  		config:            config,
  1308  		numRenegotiations: 2,
  1309  	}
  1310  
  1311  	runClientTestTLS12(t, test)
  1312  }
  1313  
  1314  func TestRenegotiateTwiceRejected(t *testing.T) {
  1315  	config := testConfig.Clone()
  1316  	config.Renegotiation = RenegotiateOnceAsClient
  1317  
  1318  	test := &clientTest{
  1319  		name:                        "RenegotiateTwiceRejected",
  1320  		args:                        []string{"-state"},
  1321  		config:                      config,
  1322  		numRenegotiations:           2,
  1323  		renegotiationExpectedToFail: 2,
  1324  		checkRenegotiationError: func(renegotiationNum int, err error) error {
  1325  			if renegotiationNum == 1 {
  1326  				return err
  1327  			}
  1328  
  1329  			if err == nil {
  1330  				return errors.New("expected error from renegotiation but got nil")
  1331  			}
  1332  			if !strings.Contains(err.Error(), "no renegotiation") {
  1333  				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  1334  			}
  1335  			return nil
  1336  		},
  1337  	}
  1338  
  1339  	runClientTestTLS12(t, test)
  1340  }
  1341  
  1342  func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
  1343  	test := &clientTest{
  1344  		name:   "ExportKeyingMaterial",
  1345  		config: testConfig.Clone(),
  1346  		validate: func(state ConnectionState) error {
  1347  			if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
  1348  				return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
  1349  			} else if len(km) != 42 {
  1350  				return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
  1351  			}
  1352  			return nil
  1353  		},
  1354  	}
  1355  	runClientTestTLS10(t, test)
  1356  	runClientTestTLS12(t, test)
  1357  	runClientTestTLS13(t, test)
  1358  }
  1359  
  1360  var hostnameInSNITests = []struct {
  1361  	in, out string
  1362  }{
  1363  	// Opaque string
  1364  	{"", ""},
  1365  	{"localhost", "localhost"},
  1366  	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
  1367  
  1368  	// DNS hostname
  1369  	{"golang.org", "golang.org"},
  1370  	{"golang.org.", "golang.org"},
  1371  
  1372  	// Literal IPv4 address
  1373  	{"1.2.3.4", ""},
  1374  
  1375  	// Literal IPv6 address
  1376  	{"::1", ""},
  1377  	{"::1%lo0", ""}, // with zone identifier
  1378  	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
  1379  	{"[::1%lo0]", ""},
  1380  }
  1381  
  1382  func TestHostnameInSNI(t *testing.T) {
  1383  	for _, tt := range hostnameInSNITests {
  1384  		c, s := localPipe(t)
  1385  
  1386  		go func(host string) {
  1387  			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
  1388  		}(tt.in)
  1389  
  1390  		var header [5]byte
  1391  		if _, err := io.ReadFull(s, header[:]); err != nil {
  1392  			t.Fatal(err)
  1393  		}
  1394  		recordLen := int(header[3])<<8 | int(header[4])
  1395  
  1396  		record := make([]byte, recordLen)
  1397  		if _, err := io.ReadFull(s, record[:]); err != nil {
  1398  			t.Fatal(err)
  1399  		}
  1400  
  1401  		c.Close()
  1402  		s.Close()
  1403  
  1404  		var m clientHelloMsg
  1405  		if !m.unmarshal(record) {
  1406  			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
  1407  			continue
  1408  		}
  1409  		if tt.in != tt.out && m.serverName == tt.in {
  1410  			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
  1411  		}
  1412  		if m.serverName != tt.out {
  1413  			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
  1414  		}
  1415  	}
  1416  }
  1417  
  1418  func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
  1419  	// This checks that the server can't select a cipher suite that the
  1420  	// client didn't offer. See #13174.
  1421  
  1422  	c, s := localPipe(t)
  1423  	errChan := make(chan error, 1)
  1424  
  1425  	go func() {
  1426  		client := Client(c, &Config{
  1427  			ServerName:   "foo",
  1428  			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
  1429  		})
  1430  		errChan <- client.Handshake()
  1431  	}()
  1432  
  1433  	var header [5]byte
  1434  	if _, err := io.ReadFull(s, header[:]); err != nil {
  1435  		t.Fatal(err)
  1436  	}
  1437  	recordLen := int(header[3])<<8 | int(header[4])
  1438  
  1439  	record := make([]byte, recordLen)
  1440  	if _, err := io.ReadFull(s, record); err != nil {
  1441  		t.Fatal(err)
  1442  	}
  1443  
  1444  	// Create a ServerHello that selects a different cipher suite than the
  1445  	// sole one that the client offered.
  1446  	serverHello := &serverHelloMsg{
  1447  		vers:        VersionTLS12,
  1448  		random:      make([]byte, 32),
  1449  		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
  1450  	}
  1451  	serverHelloBytes := serverHello.marshal()
  1452  
  1453  	s.Write([]byte{
  1454  		byte(recordTypeHandshake),
  1455  		byte(VersionTLS12 >> 8),
  1456  		byte(VersionTLS12 & 0xff),
  1457  		byte(len(serverHelloBytes) >> 8),
  1458  		byte(len(serverHelloBytes)),
  1459  	})
  1460  	s.Write(serverHelloBytes)
  1461  	s.Close()
  1462  
  1463  	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
  1464  		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
  1465  	}
  1466  }
  1467  
  1468  func TestVerifyConnection(t *testing.T) {
  1469  	t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
  1470  	t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
  1471  }
  1472  
  1473  func testVerifyConnection(t *testing.T, version uint16) {
  1474  	checkFields := func(c ConnectionState, called *int, errorType string) error {
  1475  		if c.Version != version {
  1476  			return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
  1477  		}
  1478  		if c.HandshakeComplete {
  1479  			return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
  1480  		}
  1481  		if c.ServerName != "example.golang" {
  1482  			return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
  1483  		}
  1484  		if c.NegotiatedProtocol != "protocol1" {
  1485  			return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
  1486  		}
  1487  		if c.CipherSuite == 0 {
  1488  			return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
  1489  		}
  1490  		wantDidResume := false
  1491  		if *called == 2 { // if this is the second time, then it should be a resumption
  1492  			wantDidResume = true
  1493  		}
  1494  		if c.DidResume != wantDidResume {
  1495  			return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
  1496  		}
  1497  		return nil
  1498  	}
  1499  
  1500  	tests := []struct {
  1501  		name            string
  1502  		configureServer func(*Config, *int)
  1503  		configureClient func(*Config, *int)
  1504  	}{
  1505  		{
  1506  			name: "RequireAndVerifyClientCert",
  1507  			configureServer: func(config *Config, called *int) {
  1508  				config.ClientAuth = RequireAndVerifyClientCert
  1509  				config.VerifyConnection = func(c ConnectionState) error {
  1510  					*called++
  1511  					if l := len(c.PeerCertificates); l != 1 {
  1512  						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
  1513  					}
  1514  					if len(c.VerifiedChains) == 0 {
  1515  						return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
  1516  					}
  1517  					return checkFields(c, called, "server")
  1518  				}
  1519  			},
  1520  			configureClient: func(config *Config, called *int) {
  1521  				config.VerifyConnection = func(c ConnectionState) error {
  1522  					*called++
  1523  					if l := len(c.PeerCertificates); l != 1 {
  1524  						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
  1525  					}
  1526  					if len(c.VerifiedChains) == 0 {
  1527  						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
  1528  					}
  1529  					if c.DidResume {
  1530  						return nil
  1531  						// The SCTs and OCSP Response are dropped on resumption.
  1532  						// See http://golang.org/issue/39075.
  1533  					}
  1534  					if len(c.OCSPResponse) == 0 {
  1535  						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
  1536  					}
  1537  					if len(c.SignedCertificateTimestamps) == 0 {
  1538  						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
  1539  					}
  1540  					return checkFields(c, called, "client")
  1541  				}
  1542  			},
  1543  		},
  1544  		{
  1545  			name: "InsecureSkipVerify",
  1546  			configureServer: func(config *Config, called *int) {
  1547  				config.ClientAuth = RequireAnyClientCert
  1548  				config.InsecureSkipVerify = true
  1549  				config.VerifyConnection = func(c ConnectionState) error {
  1550  					*called++
  1551  					if l := len(c.PeerCertificates); l != 1 {
  1552  						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
  1553  					}
  1554  					if c.VerifiedChains != nil {
  1555  						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
  1556  					}
  1557  					return checkFields(c, called, "server")
  1558  				}
  1559  			},
  1560  			configureClient: func(config *Config, called *int) {
  1561  				config.InsecureSkipVerify = true
  1562  				config.VerifyConnection = func(c ConnectionState) error {
  1563  					*called++
  1564  					if l := len(c.PeerCertificates); l != 1 {
  1565  						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
  1566  					}
  1567  					if c.VerifiedChains != nil {
  1568  						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
  1569  					}
  1570  					if c.DidResume {
  1571  						return nil
  1572  						// The SCTs and OCSP Response are dropped on resumption.
  1573  						// See http://golang.org/issue/39075.
  1574  					}
  1575  					if len(c.OCSPResponse) == 0 {
  1576  						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
  1577  					}
  1578  					if len(c.SignedCertificateTimestamps) == 0 {
  1579  						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
  1580  					}
  1581  					return checkFields(c, called, "client")
  1582  				}
  1583  			},
  1584  		},
  1585  		{
  1586  			name: "NoClientCert",
  1587  			configureServer: func(config *Config, called *int) {
  1588  				config.ClientAuth = NoClientCert
  1589  				config.VerifyConnection = func(c ConnectionState) error {
  1590  					*called++
  1591  					return checkFields(c, called, "server")
  1592  				}
  1593  			},
  1594  			configureClient: func(config *Config, called *int) {
  1595  				config.VerifyConnection = func(c ConnectionState) error {
  1596  					*called++
  1597  					return checkFields(c, called, "client")
  1598  				}
  1599  			},
  1600  		},
  1601  		{
  1602  			name: "RequestClientCert",
  1603  			configureServer: func(config *Config, called *int) {
  1604  				config.ClientAuth = RequestClientCert
  1605  				config.VerifyConnection = func(c ConnectionState) error {
  1606  					*called++
  1607  					return checkFields(c, called, "server")
  1608  				}
  1609  			},
  1610  			configureClient: func(config *Config, called *int) {
  1611  				config.Certificates = nil // clear the client cert
  1612  				config.VerifyConnection = func(c ConnectionState) error {
  1613  					*called++
  1614  					if l := len(c.PeerCertificates); l != 1 {
  1615  						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
  1616  					}
  1617  					if len(c.VerifiedChains) == 0 {
  1618  						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
  1619  					}
  1620  					if c.DidResume {
  1621  						return nil
  1622  						// The SCTs and OCSP Response are dropped on resumption.
  1623  						// See http://golang.org/issue/39075.
  1624  					}
  1625  					if len(c.OCSPResponse) == 0 {
  1626  						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
  1627  					}
  1628  					if len(c.SignedCertificateTimestamps) == 0 {
  1629  						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
  1630  					}
  1631  					return checkFields(c, called, "client")
  1632  				}
  1633  			},
  1634  		},
  1635  	}
  1636  	for _, test := range tests {
  1637  		issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1638  		if err != nil {
  1639  			panic(err)
  1640  		}
  1641  		rootCAs := x509.NewCertPool()
  1642  		rootCAs.AddCert(issuer)
  1643  
  1644  		var serverCalled, clientCalled int
  1645  
  1646  		serverConfig := &Config{
  1647  			MaxVersion:   version,
  1648  			Certificates: []Certificate{testConfig.Certificates[0]},
  1649  			ClientCAs:    rootCAs,
  1650  			NextProtos:   []string{"protocol1"},
  1651  		}
  1652  		serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
  1653  		serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
  1654  		test.configureServer(serverConfig, &serverCalled)
  1655  
  1656  		clientConfig := &Config{
  1657  			MaxVersion:         version,
  1658  			ClientSessionCache: NewLRUClientSessionCache(32),
  1659  			RootCAs:            rootCAs,
  1660  			ServerName:         "example.golang",
  1661  			Certificates:       []Certificate{testConfig.Certificates[0]},
  1662  			NextProtos:         []string{"protocol1"},
  1663  		}
  1664  		test.configureClient(clientConfig, &clientCalled)
  1665  
  1666  		testHandshakeState := func(name string, didResume bool) {
  1667  			_, hs, err := testHandshake(t, clientConfig, serverConfig)
  1668  			if err != nil {
  1669  				t.Fatalf("%s: handshake failed: %s", name, err)
  1670  			}
  1671  			if hs.DidResume != didResume {
  1672  				t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
  1673  			}
  1674  			wantCalled := 1
  1675  			if didResume {
  1676  				wantCalled = 2 // resumption would mean this is the second time it was called in this test
  1677  			}
  1678  			if clientCalled != wantCalled {
  1679  				t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
  1680  			}
  1681  			if serverCalled != wantCalled {
  1682  				t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
  1683  			}
  1684  		}
  1685  		testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
  1686  		testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
  1687  	}
  1688  }
  1689  
  1690  func TestVerifyPeerCertificate(t *testing.T) {
  1691  	t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
  1692  	t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
  1693  }
  1694  
  1695  func testVerifyPeerCertificate(t *testing.T, version uint16) {
  1696  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1697  	if err != nil {
  1698  		panic(err)
  1699  	}
  1700  
  1701  	rootCAs := x509.NewCertPool()
  1702  	rootCAs.AddCert(issuer)
  1703  
  1704  	now := func() time.Time { return time.Unix(1476984729, 0) }
  1705  
  1706  	sentinelErr := errors.New("TestVerifyPeerCertificate")
  1707  
  1708  	verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1709  		if l := len(rawCerts); l != 1 {
  1710  			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1711  		}
  1712  		if len(validatedChains) == 0 {
  1713  			return errors.New("got len(validatedChains) = 0, wanted non-zero")
  1714  		}
  1715  		*called = true
  1716  		return nil
  1717  	}
  1718  	verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
  1719  		if l := len(c.PeerCertificates); l != 1 {
  1720  			return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
  1721  		}
  1722  		if len(c.VerifiedChains) == 0 {
  1723  			return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
  1724  		}
  1725  		if isClient && len(c.OCSPResponse) == 0 {
  1726  			return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
  1727  		}
  1728  		*called = true
  1729  		return nil
  1730  	}
  1731  
  1732  	tests := []struct {
  1733  		configureServer func(*Config, *bool)
  1734  		configureClient func(*Config, *bool)
  1735  		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
  1736  	}{
  1737  		{
  1738  			configureServer: func(config *Config, called *bool) {
  1739  				config.InsecureSkipVerify = false
  1740  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1741  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1742  				}
  1743  			},
  1744  			configureClient: func(config *Config, called *bool) {
  1745  				config.InsecureSkipVerify = false
  1746  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1747  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1748  				}
  1749  			},
  1750  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1751  				if clientErr != nil {
  1752  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1753  				}
  1754  				if serverErr != nil {
  1755  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1756  				}
  1757  				if !clientCalled {
  1758  					t.Errorf("test[%d]: client did not call callback", testNo)
  1759  				}
  1760  				if !serverCalled {
  1761  					t.Errorf("test[%d]: server did not call callback", testNo)
  1762  				}
  1763  			},
  1764  		},
  1765  		{
  1766  			configureServer: func(config *Config, called *bool) {
  1767  				config.InsecureSkipVerify = false
  1768  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1769  					return sentinelErr
  1770  				}
  1771  			},
  1772  			configureClient: func(config *Config, called *bool) {
  1773  				config.VerifyPeerCertificate = nil
  1774  			},
  1775  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1776  				if serverErr != sentinelErr {
  1777  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1778  				}
  1779  			},
  1780  		},
  1781  		{
  1782  			configureServer: func(config *Config, called *bool) {
  1783  				config.InsecureSkipVerify = false
  1784  			},
  1785  			configureClient: func(config *Config, called *bool) {
  1786  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1787  					return sentinelErr
  1788  				}
  1789  			},
  1790  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1791  				if clientErr != sentinelErr {
  1792  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1793  				}
  1794  			},
  1795  		},
  1796  		{
  1797  			configureServer: func(config *Config, called *bool) {
  1798  				config.InsecureSkipVerify = false
  1799  			},
  1800  			configureClient: func(config *Config, called *bool) {
  1801  				config.InsecureSkipVerify = true
  1802  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1803  					if l := len(rawCerts); l != 1 {
  1804  						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1805  					}
  1806  					// With InsecureSkipVerify set, this
  1807  					// callback should still be called but
  1808  					// validatedChains must be empty.
  1809  					if l := len(validatedChains); l != 0 {
  1810  						return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
  1811  					}
  1812  					*called = true
  1813  					return nil
  1814  				}
  1815  			},
  1816  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1817  				if clientErr != nil {
  1818  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1819  				}
  1820  				if serverErr != nil {
  1821  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1822  				}
  1823  				if !clientCalled {
  1824  					t.Errorf("test[%d]: client did not call callback", testNo)
  1825  				}
  1826  			},
  1827  		},
  1828  		{
  1829  			configureServer: func(config *Config, called *bool) {
  1830  				config.InsecureSkipVerify = false
  1831  				config.VerifyConnection = func(c ConnectionState) error {
  1832  					return verifyConnectionCallback(called, false, c)
  1833  				}
  1834  			},
  1835  			configureClient: func(config *Config, called *bool) {
  1836  				config.InsecureSkipVerify = false
  1837  				config.VerifyConnection = func(c ConnectionState) error {
  1838  					return verifyConnectionCallback(called, true, c)
  1839  				}
  1840  			},
  1841  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1842  				if clientErr != nil {
  1843  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1844  				}
  1845  				if serverErr != nil {
  1846  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1847  				}
  1848  				if !clientCalled {
  1849  					t.Errorf("test[%d]: client did not call callback", testNo)
  1850  				}
  1851  				if !serverCalled {
  1852  					t.Errorf("test[%d]: server did not call callback", testNo)
  1853  				}
  1854  			},
  1855  		},
  1856  		{
  1857  			configureServer: func(config *Config, called *bool) {
  1858  				config.InsecureSkipVerify = false
  1859  				config.VerifyConnection = func(c ConnectionState) error {
  1860  					return sentinelErr
  1861  				}
  1862  			},
  1863  			configureClient: func(config *Config, called *bool) {
  1864  				config.InsecureSkipVerify = false
  1865  				config.VerifyConnection = nil
  1866  			},
  1867  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1868  				if serverErr != sentinelErr {
  1869  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1870  				}
  1871  			},
  1872  		},
  1873  		{
  1874  			configureServer: func(config *Config, called *bool) {
  1875  				config.InsecureSkipVerify = false
  1876  				config.VerifyConnection = nil
  1877  			},
  1878  			configureClient: func(config *Config, called *bool) {
  1879  				config.InsecureSkipVerify = false
  1880  				config.VerifyConnection = func(c ConnectionState) error {
  1881  					return sentinelErr
  1882  				}
  1883  			},
  1884  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1885  				if clientErr != sentinelErr {
  1886  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1887  				}
  1888  			},
  1889  		},
  1890  		{
  1891  			configureServer: func(config *Config, called *bool) {
  1892  				config.InsecureSkipVerify = false
  1893  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1894  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1895  				}
  1896  				config.VerifyConnection = func(c ConnectionState) error {
  1897  					return sentinelErr
  1898  				}
  1899  			},
  1900  			configureClient: func(config *Config, called *bool) {
  1901  				config.InsecureSkipVerify = false
  1902  				config.VerifyPeerCertificate = nil
  1903  				config.VerifyConnection = nil
  1904  			},
  1905  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1906  				if serverErr != sentinelErr {
  1907  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1908  				}
  1909  				if !serverCalled {
  1910  					t.Errorf("test[%d]: server did not call callback", testNo)
  1911  				}
  1912  			},
  1913  		},
  1914  		{
  1915  			configureServer: func(config *Config, called *bool) {
  1916  				config.InsecureSkipVerify = false
  1917  				config.VerifyPeerCertificate = nil
  1918  				config.VerifyConnection = nil
  1919  			},
  1920  			configureClient: func(config *Config, called *bool) {
  1921  				config.InsecureSkipVerify = false
  1922  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1923  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1924  				}
  1925  				config.VerifyConnection = func(c ConnectionState) error {
  1926  					return sentinelErr
  1927  				}
  1928  			},
  1929  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1930  				if clientErr != sentinelErr {
  1931  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1932  				}
  1933  				if !clientCalled {
  1934  					t.Errorf("test[%d]: client did not call callback", testNo)
  1935  				}
  1936  			},
  1937  		},
  1938  	}
  1939  
  1940  	for i, test := range tests {
  1941  		c, s := localPipe(t)
  1942  		done := make(chan error)
  1943  
  1944  		var clientCalled, serverCalled bool
  1945  
  1946  		go func() {
  1947  			config := testConfig.Clone()
  1948  			config.ServerName = "example.golang"
  1949  			config.ClientAuth = RequireAndVerifyClientCert
  1950  			config.ClientCAs = rootCAs
  1951  			config.Time = now
  1952  			config.MaxVersion = version
  1953  			config.Certificates = make([]Certificate, 1)
  1954  			config.Certificates[0].Certificate = [][]byte{testRSACertificate}
  1955  			config.Certificates[0].PrivateKey = testRSAPrivateKey
  1956  			config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
  1957  			config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
  1958  			test.configureServer(config, &serverCalled)
  1959  
  1960  			err = Server(s, config).Handshake()
  1961  			s.Close()
  1962  			done <- err
  1963  		}()
  1964  
  1965  		config := testConfig.Clone()
  1966  		config.ServerName = "example.golang"
  1967  		config.RootCAs = rootCAs
  1968  		config.Time = now
  1969  		config.MaxVersion = version
  1970  		test.configureClient(config, &clientCalled)
  1971  		clientErr := Client(c, config).Handshake()
  1972  		c.Close()
  1973  		serverErr := <-done
  1974  
  1975  		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
  1976  	}
  1977  }
  1978  
  1979  // brokenConn wraps a net.Conn and causes all Writes after a certain number to
  1980  // fail with brokenConnErr.
  1981  type brokenConn struct {
  1982  	net.Conn
  1983  
  1984  	// breakAfter is the number of successful writes that will be allowed
  1985  	// before all subsequent writes fail.
  1986  	breakAfter int
  1987  
  1988  	// numWrites is the number of writes that have been done.
  1989  	numWrites int
  1990  }
  1991  
  1992  // brokenConnErr is the error that brokenConn returns once exhausted.
  1993  var brokenConnErr = errors.New("too many writes to brokenConn")
  1994  
  1995  func (b *brokenConn) Write(data []byte) (int, error) {
  1996  	if b.numWrites >= b.breakAfter {
  1997  		return 0, brokenConnErr
  1998  	}
  1999  
  2000  	b.numWrites++
  2001  	return b.Conn.Write(data)
  2002  }
  2003  
  2004  func TestFailedWrite(t *testing.T) {
  2005  	// Test that a write error during the handshake is returned.
  2006  	for _, breakAfter := range []int{0, 1} {
  2007  		c, s := localPipe(t)
  2008  		done := make(chan bool)
  2009  
  2010  		go func() {
  2011  			Server(s, testConfig).Handshake()
  2012  			s.Close()
  2013  			done <- true
  2014  		}()
  2015  
  2016  		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
  2017  		err := Client(brokenC, testConfig).Handshake()
  2018  		if err != brokenConnErr {
  2019  			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
  2020  		}
  2021  		brokenC.Close()
  2022  
  2023  		<-done
  2024  	}
  2025  }
  2026  
  2027  // writeCountingConn wraps a net.Conn and counts the number of Write calls.
  2028  type writeCountingConn struct {
  2029  	net.Conn
  2030  
  2031  	// numWrites is the number of writes that have been done.
  2032  	numWrites int
  2033  }
  2034  
  2035  func (wcc *writeCountingConn) Write(data []byte) (int, error) {
  2036  	wcc.numWrites++
  2037  	return wcc.Conn.Write(data)
  2038  }
  2039  
  2040  func TestBuffering(t *testing.T) {
  2041  	t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
  2042  	t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
  2043  }
  2044  
  2045  func testBuffering(t *testing.T, version uint16) {
  2046  	c, s := localPipe(t)
  2047  	done := make(chan bool)
  2048  
  2049  	clientWCC := &writeCountingConn{Conn: c}
  2050  	serverWCC := &writeCountingConn{Conn: s}
  2051  
  2052  	go func() {
  2053  		config := testConfig.Clone()
  2054  		config.MaxVersion = version
  2055  		Server(serverWCC, config).Handshake()
  2056  		serverWCC.Close()
  2057  		done <- true
  2058  	}()
  2059  
  2060  	err := Client(clientWCC, testConfig).Handshake()
  2061  	if err != nil {
  2062  		t.Fatal(err)
  2063  	}
  2064  	clientWCC.Close()
  2065  	<-done
  2066  
  2067  	var expectedClient, expectedServer int
  2068  	if version == VersionTLS13 {
  2069  		expectedClient = 2
  2070  		expectedServer = 1
  2071  	} else {
  2072  		expectedClient = 2
  2073  		expectedServer = 2
  2074  	}
  2075  
  2076  	if n := clientWCC.numWrites; n != expectedClient {
  2077  		t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
  2078  	}
  2079  
  2080  	if n := serverWCC.numWrites; n != expectedServer {
  2081  		t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
  2082  	}
  2083  }
  2084  
  2085  func TestAlertFlushing(t *testing.T) {
  2086  	c, s := localPipe(t)
  2087  	done := make(chan bool)
  2088  
  2089  	clientWCC := &writeCountingConn{Conn: c}
  2090  	serverWCC := &writeCountingConn{Conn: s}
  2091  
  2092  	serverConfig := testConfig.Clone()
  2093  
  2094  	// Cause a signature-time error
  2095  	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
  2096  	brokenKey.D = big.NewInt(42)
  2097  	serverConfig.Certificates = []Certificate{{
  2098  		Certificate: [][]byte{testRSACertificate},
  2099  		PrivateKey:  &brokenKey,
  2100  	}}
  2101  
  2102  	go func() {
  2103  		Server(serverWCC, serverConfig).Handshake()
  2104  		serverWCC.Close()
  2105  		done <- true
  2106  	}()
  2107  
  2108  	err := Client(clientWCC, testConfig).Handshake()
  2109  	if err == nil {
  2110  		t.Fatal("client unexpectedly returned no error")
  2111  	}
  2112  
  2113  	const expectedError = "remote error: tls: internal error"
  2114  	if e := err.Error(); !strings.Contains(e, expectedError) {
  2115  		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
  2116  	}
  2117  	clientWCC.Close()
  2118  	<-done
  2119  
  2120  	if n := serverWCC.numWrites; n != 1 {
  2121  		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
  2122  	}
  2123  }
  2124  
  2125  func TestHandshakeRace(t *testing.T) {
  2126  	if testing.Short() {
  2127  		t.Skip("skipping in -short mode")
  2128  	}
  2129  	t.Parallel()
  2130  	// This test races a Read and Write to try and complete a handshake in
  2131  	// order to provide some evidence that there are no races or deadlocks
  2132  	// in the handshake locking.
  2133  	for i := 0; i < 32; i++ {
  2134  		c, s := localPipe(t)
  2135  
  2136  		go func() {
  2137  			server := Server(s, testConfig)
  2138  			if err := server.Handshake(); err != nil {
  2139  				panic(err)
  2140  			}
  2141  
  2142  			var request [1]byte
  2143  			if n, err := server.Read(request[:]); err != nil || n != 1 {
  2144  				panic(err)
  2145  			}
  2146  
  2147  			server.Write(request[:])
  2148  			server.Close()
  2149  		}()
  2150  
  2151  		startWrite := make(chan struct{})
  2152  		startRead := make(chan struct{})
  2153  		readDone := make(chan struct{}, 1)
  2154  
  2155  		client := Client(c, testConfig)
  2156  		go func() {
  2157  			<-startWrite
  2158  			var request [1]byte
  2159  			client.Write(request[:])
  2160  		}()
  2161  
  2162  		go func() {
  2163  			<-startRead
  2164  			var reply [1]byte
  2165  			if _, err := io.ReadFull(client, reply[:]); err != nil {
  2166  				panic(err)
  2167  			}
  2168  			c.Close()
  2169  			readDone <- struct{}{}
  2170  		}()
  2171  
  2172  		if i&1 == 1 {
  2173  			startWrite <- struct{}{}
  2174  			startRead <- struct{}{}
  2175  		} else {
  2176  			startRead <- struct{}{}
  2177  			startWrite <- struct{}{}
  2178  		}
  2179  		<-readDone
  2180  	}
  2181  }
  2182  
  2183  var getClientCertificateTests = []struct {
  2184  	setup               func(*Config, *Config)
  2185  	expectedClientError string
  2186  	verify              func(*testing.T, int, *ConnectionState)
  2187  }{
  2188  	{
  2189  		func(clientConfig, serverConfig *Config) {
  2190  			// Returning a Certificate with no certificate data
  2191  			// should result in an empty message being sent to the
  2192  			// server.
  2193  			serverConfig.ClientCAs = nil
  2194  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2195  				if len(cri.SignatureSchemes) == 0 {
  2196  					panic("empty SignatureSchemes")
  2197  				}
  2198  				if len(cri.AcceptableCAs) != 0 {
  2199  					panic("AcceptableCAs should have been empty")
  2200  				}
  2201  				return new(Certificate), nil
  2202  			}
  2203  		},
  2204  		"",
  2205  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2206  			if l := len(cs.PeerCertificates); l != 0 {
  2207  				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  2208  			}
  2209  		},
  2210  	},
  2211  	{
  2212  		func(clientConfig, serverConfig *Config) {
  2213  			// With TLS 1.1, the SignatureSchemes should be
  2214  			// synthesised from the supported certificate types.
  2215  			clientConfig.MaxVersion = VersionTLS11
  2216  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2217  				if len(cri.SignatureSchemes) == 0 {
  2218  					panic("empty SignatureSchemes")
  2219  				}
  2220  				return new(Certificate), nil
  2221  			}
  2222  		},
  2223  		"",
  2224  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2225  			if l := len(cs.PeerCertificates); l != 0 {
  2226  				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  2227  			}
  2228  		},
  2229  	},
  2230  	{
  2231  		func(clientConfig, serverConfig *Config) {
  2232  			// Returning an error should abort the handshake with
  2233  			// that error.
  2234  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2235  				return nil, errors.New("GetClientCertificate")
  2236  			}
  2237  		},
  2238  		"GetClientCertificate",
  2239  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2240  		},
  2241  	},
  2242  	{
  2243  		func(clientConfig, serverConfig *Config) {
  2244  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2245  				if len(cri.AcceptableCAs) == 0 {
  2246  					panic("empty AcceptableCAs")
  2247  				}
  2248  				cert := &Certificate{
  2249  					Certificate: [][]byte{testRSACertificate},
  2250  					PrivateKey:  testRSAPrivateKey,
  2251  				}
  2252  				return cert, nil
  2253  			}
  2254  		},
  2255  		"",
  2256  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2257  			if len(cs.VerifiedChains) == 0 {
  2258  				t.Errorf("#%d: expected some verified chains, but found none", testNum)
  2259  			}
  2260  		},
  2261  	},
  2262  }
  2263  
  2264  func TestGetClientCertificate(t *testing.T) {
  2265  	t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
  2266  	t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
  2267  }
  2268  
  2269  func testGetClientCertificate(t *testing.T, version uint16) {
  2270  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  2271  	if err != nil {
  2272  		panic(err)
  2273  	}
  2274  
  2275  	for i, test := range getClientCertificateTests {
  2276  		serverConfig := testConfig.Clone()
  2277  		serverConfig.ClientAuth = VerifyClientCertIfGiven
  2278  		serverConfig.RootCAs = x509.NewCertPool()
  2279  		serverConfig.RootCAs.AddCert(issuer)
  2280  		serverConfig.ClientCAs = serverConfig.RootCAs
  2281  		serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
  2282  		serverConfig.MaxVersion = version
  2283  
  2284  		clientConfig := testConfig.Clone()
  2285  		clientConfig.MaxVersion = version
  2286  
  2287  		test.setup(clientConfig, serverConfig)
  2288  
  2289  		type serverResult struct {
  2290  			cs  ConnectionState
  2291  			err error
  2292  		}
  2293  
  2294  		c, s := localPipe(t)
  2295  		done := make(chan serverResult)
  2296  
  2297  		go func() {
  2298  			defer s.Close()
  2299  			server := Server(s, serverConfig)
  2300  			err := server.Handshake()
  2301  
  2302  			var cs ConnectionState
  2303  			if err == nil {
  2304  				cs = server.ConnectionState()
  2305  			}
  2306  			done <- serverResult{cs, err}
  2307  		}()
  2308  
  2309  		clientErr := Client(c, clientConfig).Handshake()
  2310  		c.Close()
  2311  
  2312  		result := <-done
  2313  
  2314  		if clientErr != nil {
  2315  			if len(test.expectedClientError) == 0 {
  2316  				t.Errorf("#%d: client error: %v", i, clientErr)
  2317  			} else if got := clientErr.Error(); got != test.expectedClientError {
  2318  				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
  2319  			} else {
  2320  				test.verify(t, i, &result.cs)
  2321  			}
  2322  		} else if len(test.expectedClientError) > 0 {
  2323  			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
  2324  		} else if err := result.err; err != nil {
  2325  			t.Errorf("#%d: server error: %v", i, err)
  2326  		} else {
  2327  			test.verify(t, i, &result.cs)
  2328  		}
  2329  	}
  2330  }
  2331  
  2332  func TestRSAPSSKeyError(t *testing.T) {
  2333  	// crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
  2334  	// public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
  2335  	// the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
  2336  	// parse, or that they don't carry *rsa.PublicKey keys.
  2337  	b, _ := pem.Decode([]byte(`
  2338  -----BEGIN CERTIFICATE-----
  2339  MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
  2340  MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
  2341  AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
  2342  MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
  2343  ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
  2344  /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
  2345  b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
  2346  QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
  2347  czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
  2348  JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
  2349  AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
  2350  OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
  2351  AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
  2352  sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
  2353  H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
  2354  KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
  2355  bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
  2356  HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
  2357  RwBA9Xk1KBNF
  2358  -----END CERTIFICATE-----`))
  2359  	if b == nil {
  2360  		t.Fatal("Failed to decode certificate")
  2361  	}
  2362  	cert, err := x509.ParseCertificate(b.Bytes)
  2363  	if err != nil {
  2364  		return
  2365  	}
  2366  	if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
  2367  		t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
  2368  	}
  2369  }
  2370  
  2371  func TestCloseClientConnectionOnIdleServer(t *testing.T) {
  2372  	clientConn, serverConn := localPipe(t)
  2373  	client := Client(clientConn, testConfig.Clone())
  2374  	go func() {
  2375  		var b [1]byte
  2376  		serverConn.Read(b[:])
  2377  		client.Close()
  2378  	}()
  2379  	client.SetWriteDeadline(time.Now().Add(time.Minute))
  2380  	err := client.Handshake()
  2381  	if err != nil {
  2382  		if err, ok := err.(net.Error); ok && err.Timeout() {
  2383  			t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
  2384  		}
  2385  	} else {
  2386  		t.Errorf("Error expected, but no error returned")
  2387  	}
  2388  }
  2389  
  2390  func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
  2391  	defer func() { testingOnlyForceDowngradeCanary = false }()
  2392  	testingOnlyForceDowngradeCanary = true
  2393  
  2394  	clientConfig := testConfig.Clone()
  2395  	clientConfig.MaxVersion = clientVersion
  2396  	serverConfig := testConfig.Clone()
  2397  	serverConfig.MaxVersion = serverVersion
  2398  	_, _, err := testHandshake(t, clientConfig, serverConfig)
  2399  	return err
  2400  }
  2401  
  2402  func TestDowngradeCanary(t *testing.T) {
  2403  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
  2404  		t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
  2405  	}
  2406  	if testing.Short() {
  2407  		t.Skip("skipping the rest of the checks in short mode")
  2408  	}
  2409  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
  2410  		t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
  2411  	}
  2412  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
  2413  		t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
  2414  	}
  2415  	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
  2416  		t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
  2417  	}
  2418  	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
  2419  		t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
  2420  	}
  2421  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
  2422  		t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
  2423  	}
  2424  	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
  2425  		t.Errorf("client didn't ignore expected TLS 1.2 canary")
  2426  	}
  2427  	if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
  2428  		t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
  2429  	}
  2430  	if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
  2431  		t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
  2432  	}
  2433  }
  2434  
  2435  func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
  2436  	t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
  2437  	t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
  2438  }
  2439  
  2440  func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
  2441  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  2442  	if err != nil {
  2443  		t.Fatalf("failed to parse test issuer")
  2444  	}
  2445  	roots := x509.NewCertPool()
  2446  	roots.AddCert(issuer)
  2447  	clientConfig := &Config{
  2448  		MaxVersion:         ver,
  2449  		ClientSessionCache: NewLRUClientSessionCache(32),
  2450  		ServerName:         "example.golang",
  2451  		RootCAs:            roots,
  2452  	}
  2453  	serverConfig := testConfig.Clone()
  2454  	serverConfig.MaxVersion = ver
  2455  	serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
  2456  	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
  2457  
  2458  	_, ccs, err := testHandshake(t, clientConfig, serverConfig)
  2459  	if err != nil {
  2460  		t.Fatalf("handshake failed: %s", err)
  2461  	}
  2462  	// after a new session we expect to see OCSPResponse and
  2463  	// SignedCertificateTimestamps populated as usual
  2464  	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
  2465  		t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
  2466  			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
  2467  	}
  2468  	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
  2469  		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
  2470  			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
  2471  	}
  2472  
  2473  	// if the server doesn't send any SCTs, repopulate the old SCTs
  2474  	oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
  2475  	serverConfig.Certificates[0].SignedCertificateTimestamps = nil
  2476  	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
  2477  	if err != nil {
  2478  		t.Fatalf("handshake failed: %s", err)
  2479  	}
  2480  	if !ccs.DidResume {
  2481  		t.Fatalf("expected session to be resumed")
  2482  	}
  2483  	// after a resumed session we also expect to see OCSPResponse
  2484  	// and SignedCertificateTimestamps populated
  2485  	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
  2486  		t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
  2487  			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
  2488  	}
  2489  	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
  2490  		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
  2491  			oldSCTs, ccs.SignedCertificateTimestamps)
  2492  	}
  2493  
  2494  	//  Only test overriding the SCTs for TLS 1.2, since in 1.3
  2495  	// the server won't send the message containing them
  2496  	if ver == VersionTLS13 {
  2497  		return
  2498  	}
  2499  
  2500  	// if the server changes the SCTs it sends, they should override the saved SCTs
  2501  	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
  2502  	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
  2503  	if err != nil {
  2504  		t.Fatalf("handshake failed: %s", err)
  2505  	}
  2506  	if !ccs.DidResume {
  2507  		t.Fatalf("expected session to be resumed")
  2508  	}
  2509  	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
  2510  		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
  2511  			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
  2512  	}
  2513  }