github.com/miolini/go@v0.0.0-20160405192216-fca68c8cb408/src/crypto/tls/handshake_client_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/base64"
    13  	"encoding/binary"
    14  	"encoding/pem"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"net"
    19  	"os"
    20  	"os/exec"
    21  	"path/filepath"
    22  	"strconv"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  )
    27  
    28  // Note: see comment in handshake_test.go for details of how the reference
    29  // tests work.
    30  
    31  // blockingSource is an io.Reader that blocks a Read call until it's closed.
    32  type blockingSource chan bool
    33  
    34  func (b blockingSource) Read([]byte) (n int, err error) {
    35  	<-b
    36  	return 0, io.EOF
    37  }
    38  
    39  // clientTest represents a test of the TLS client handshake against a reference
    40  // implementation.
    41  type clientTest struct {
    42  	// name is a freeform string identifying the test and the file in which
    43  	// the expected results will be stored.
    44  	name string
    45  	// command, if not empty, contains a series of arguments for the
    46  	// command to run for the reference server.
    47  	command []string
    48  	// config, if not nil, contains a custom Config to use for this test.
    49  	config *Config
    50  	// cert, if not empty, contains a DER-encoded certificate for the
    51  	// reference server.
    52  	cert []byte
    53  	// key, if not nil, contains either a *rsa.PrivateKey or
    54  	// *ecdsa.PrivateKey which is the private key for the reference server.
    55  	key interface{}
    56  	// extensions, if not nil, contains a list of extension data to be returned
    57  	// from the ServerHello. The data should be in standard TLS format with
    58  	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
    59  	extensions [][]byte
    60  	// validate, if not nil, is a function that will be called with the
    61  	// ConnectionState of the resulting connection. It returns a non-nil
    62  	// error if the ConnectionState is unacceptable.
    63  	validate func(ConnectionState) error
    64  }
    65  
    66  var defaultServerCommand = []string{"openssl", "s_server"}
    67  
    68  // connFromCommand starts the reference server process, connects to it and
    69  // returns a recordingConn for the connection. The stdin return value is a
    70  // blockingSource for the stdin of the child process. It must be closed before
    71  // Waiting for child.
    72  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) {
    73  	cert := testRSACertificate
    74  	if len(test.cert) > 0 {
    75  		cert = test.cert
    76  	}
    77  	certPath := tempFile(string(cert))
    78  	defer os.Remove(certPath)
    79  
    80  	var key interface{} = testRSAPrivateKey
    81  	if test.key != nil {
    82  		key = test.key
    83  	}
    84  	var pemType string
    85  	var derBytes []byte
    86  	switch key := key.(type) {
    87  	case *rsa.PrivateKey:
    88  		pemType = "RSA"
    89  		derBytes = x509.MarshalPKCS1PrivateKey(key)
    90  	case *ecdsa.PrivateKey:
    91  		pemType = "EC"
    92  		var err error
    93  		derBytes, err = x509.MarshalECPrivateKey(key)
    94  		if err != nil {
    95  			panic(err)
    96  		}
    97  	default:
    98  		panic("unknown key type")
    99  	}
   100  
   101  	var pemOut bytes.Buffer
   102  	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
   103  
   104  	keyPath := tempFile(string(pemOut.Bytes()))
   105  	defer os.Remove(keyPath)
   106  
   107  	var command []string
   108  	if len(test.command) > 0 {
   109  		command = append(command, test.command...)
   110  	} else {
   111  		command = append(command, defaultServerCommand...)
   112  	}
   113  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   114  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   115  	// can't take "0" as an argument here so we have to pick a number and
   116  	// hope that it's not in use on the machine. Since this only occurs
   117  	// when -update is given and thus when there's a human watching the
   118  	// test, this isn't too bad.
   119  	const serverPort = 24323
   120  	command = append(command, "-accept", strconv.Itoa(serverPort))
   121  
   122  	if len(test.extensions) > 0 {
   123  		var serverInfo bytes.Buffer
   124  		for _, ext := range test.extensions {
   125  			pem.Encode(&serverInfo, &pem.Block{
   126  				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
   127  				Bytes: ext,
   128  			})
   129  		}
   130  		serverInfoPath := tempFile(serverInfo.String())
   131  		defer os.Remove(serverInfoPath)
   132  		command = append(command, "-serverinfo", serverInfoPath)
   133  	}
   134  
   135  	cmd := exec.Command(command[0], command[1:]...)
   136  	stdin = blockingSource(make(chan bool))
   137  	cmd.Stdin = stdin
   138  	var out bytes.Buffer
   139  	cmd.Stdout = &out
   140  	cmd.Stderr = &out
   141  	if err := cmd.Start(); err != nil {
   142  		return nil, nil, nil, err
   143  	}
   144  
   145  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   146  	// opening the listening socket, so we can't use that to wait until it
   147  	// has started listening. Thus we are forced to poll until we get a
   148  	// connection.
   149  	var tcpConn net.Conn
   150  	for i := uint(0); i < 5; i++ {
   151  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   152  			IP:   net.IPv4(127, 0, 0, 1),
   153  			Port: serverPort,
   154  		})
   155  		if err == nil {
   156  			break
   157  		}
   158  		time.Sleep((1 << i) * 5 * time.Millisecond)
   159  	}
   160  	if err != nil {
   161  		close(stdin)
   162  		out.WriteTo(os.Stdout)
   163  		cmd.Process.Kill()
   164  		return nil, nil, nil, cmd.Wait()
   165  	}
   166  
   167  	record := &recordingConn{
   168  		Conn: tcpConn,
   169  	}
   170  
   171  	return record, cmd, stdin, nil
   172  }
   173  
   174  func (test *clientTest) dataPath() string {
   175  	return filepath.Join("testdata", "Client-"+test.name)
   176  }
   177  
   178  func (test *clientTest) loadData() (flows [][]byte, err error) {
   179  	in, err := os.Open(test.dataPath())
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	defer in.Close()
   184  	return parseTestData(in)
   185  }
   186  
   187  func (test *clientTest) run(t *testing.T, write bool) {
   188  	var clientConn, serverConn net.Conn
   189  	var recordingConn *recordingConn
   190  	var childProcess *exec.Cmd
   191  	var stdin blockingSource
   192  
   193  	if write {
   194  		var err error
   195  		recordingConn, childProcess, stdin, err = test.connFromCommand()
   196  		if err != nil {
   197  			t.Fatalf("Failed to start subcommand: %s", err)
   198  		}
   199  		clientConn = recordingConn
   200  	} else {
   201  		clientConn, serverConn = net.Pipe()
   202  	}
   203  
   204  	config := test.config
   205  	if config == nil {
   206  		config = testConfig
   207  	}
   208  	client := Client(clientConn, config)
   209  
   210  	doneChan := make(chan bool)
   211  	go func() {
   212  		if _, err := client.Write([]byte("hello\n")); err != nil {
   213  			t.Errorf("Client.Write failed: %s", err)
   214  		}
   215  		if test.validate != nil {
   216  			if err := test.validate(client.ConnectionState()); err != nil {
   217  				t.Errorf("validate callback returned error: %s", err)
   218  			}
   219  		}
   220  		client.Close()
   221  		clientConn.Close()
   222  		doneChan <- true
   223  	}()
   224  
   225  	if !write {
   226  		flows, err := test.loadData()
   227  		if err != nil {
   228  			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
   229  		}
   230  		for i, b := range flows {
   231  			if i%2 == 1 {
   232  				serverConn.Write(b)
   233  				continue
   234  			}
   235  			bb := make([]byte, len(b))
   236  			_, err := io.ReadFull(serverConn, bb)
   237  			if err != nil {
   238  				t.Fatalf("%s #%d: %s", test.name, i, err)
   239  			}
   240  			if !bytes.Equal(b, bb) {
   241  				t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
   242  			}
   243  		}
   244  		serverConn.Close()
   245  	}
   246  
   247  	<-doneChan
   248  
   249  	if write {
   250  		path := test.dataPath()
   251  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   252  		if err != nil {
   253  			t.Fatalf("Failed to create output file: %s", err)
   254  		}
   255  		defer out.Close()
   256  		recordingConn.Close()
   257  		close(stdin)
   258  		childProcess.Process.Kill()
   259  		childProcess.Wait()
   260  		if len(recordingConn.flows) < 3 {
   261  			childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout)
   262  			t.Fatalf("Client connection didn't work")
   263  		}
   264  		recordingConn.WriteTo(out)
   265  		fmt.Printf("Wrote %s\n", path)
   266  	}
   267  }
   268  
   269  func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
   270  	test := *template
   271  	test.name = prefix + test.name
   272  	if len(test.command) == 0 {
   273  		test.command = defaultClientCommand
   274  	}
   275  	test.command = append([]string(nil), test.command...)
   276  	test.command = append(test.command, option)
   277  	test.run(t, *update)
   278  }
   279  
   280  func runClientTestTLS10(t *testing.T, template *clientTest) {
   281  	runClientTestForVersion(t, template, "TLSv10-", "-tls1")
   282  }
   283  
   284  func runClientTestTLS11(t *testing.T, template *clientTest) {
   285  	runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
   286  }
   287  
   288  func runClientTestTLS12(t *testing.T, template *clientTest) {
   289  	runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
   290  }
   291  
   292  func TestHandshakeClientRSARC4(t *testing.T) {
   293  	test := &clientTest{
   294  		name:    "RSA-RC4",
   295  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
   296  	}
   297  	runClientTestTLS10(t, test)
   298  	runClientTestTLS11(t, test)
   299  	runClientTestTLS12(t, test)
   300  }
   301  
   302  func TestHandshakeClientRSAAES128GCM(t *testing.T) {
   303  	test := &clientTest{
   304  		name:    "AES128-GCM-SHA256",
   305  		command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"},
   306  	}
   307  	runClientTestTLS12(t, test)
   308  }
   309  
   310  func TestHandshakeClientRSAAES256GCM(t *testing.T) {
   311  	test := &clientTest{
   312  		name:    "AES256-GCM-SHA384",
   313  		command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"},
   314  	}
   315  	runClientTestTLS12(t, test)
   316  }
   317  
   318  func TestHandshakeClientECDHERSAAES(t *testing.T) {
   319  	test := &clientTest{
   320  		name:    "ECDHE-RSA-AES",
   321  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
   322  	}
   323  	runClientTestTLS10(t, test)
   324  	runClientTestTLS11(t, test)
   325  	runClientTestTLS12(t, test)
   326  }
   327  
   328  func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   329  	test := &clientTest{
   330  		name:    "ECDHE-ECDSA-AES",
   331  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
   332  		cert:    testECDSACertificate,
   333  		key:     testECDSAPrivateKey,
   334  	}
   335  	runClientTestTLS10(t, test)
   336  	runClientTestTLS11(t, test)
   337  	runClientTestTLS12(t, test)
   338  }
   339  
   340  func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   341  	test := &clientTest{
   342  		name:    "ECDHE-ECDSA-AES-GCM",
   343  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   344  		cert:    testECDSACertificate,
   345  		key:     testECDSAPrivateKey,
   346  	}
   347  	runClientTestTLS12(t, test)
   348  }
   349  
   350  func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
   351  	test := &clientTest{
   352  		name:    "ECDHE-ECDSA-AES256-GCM-SHA384",
   353  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
   354  		cert:    testECDSACertificate,
   355  		key:     testECDSAPrivateKey,
   356  	}
   357  	runClientTestTLS12(t, test)
   358  }
   359  
   360  func TestHandshakeClientCertRSA(t *testing.T) {
   361  	config := *testConfig
   362  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   363  	config.Certificates = []Certificate{cert}
   364  
   365  	test := &clientTest{
   366  		name:    "ClientCert-RSA-RSA",
   367  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   368  		config:  &config,
   369  	}
   370  
   371  	runClientTestTLS10(t, test)
   372  	runClientTestTLS12(t, test)
   373  
   374  	test = &clientTest{
   375  		name:    "ClientCert-RSA-ECDSA",
   376  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   377  		config:  &config,
   378  		cert:    testECDSACertificate,
   379  		key:     testECDSAPrivateKey,
   380  	}
   381  
   382  	runClientTestTLS10(t, test)
   383  	runClientTestTLS12(t, test)
   384  
   385  	test = &clientTest{
   386  		name:    "ClientCert-RSA-AES256-GCM-SHA384",
   387  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"},
   388  		config:  &config,
   389  		cert:    testRSACertificate,
   390  		key:     testRSAPrivateKey,
   391  	}
   392  
   393  	runClientTestTLS12(t, test)
   394  }
   395  
   396  func TestHandshakeClientCertECDSA(t *testing.T) {
   397  	config := *testConfig
   398  	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   399  	config.Certificates = []Certificate{cert}
   400  
   401  	test := &clientTest{
   402  		name:    "ClientCert-ECDSA-RSA",
   403  		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   404  		config:  &config,
   405  	}
   406  
   407  	runClientTestTLS10(t, test)
   408  	runClientTestTLS12(t, test)
   409  
   410  	test = &clientTest{
   411  		name:    "ClientCert-ECDSA-ECDSA",
   412  		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   413  		config:  &config,
   414  		cert:    testECDSACertificate,
   415  		key:     testECDSAPrivateKey,
   416  	}
   417  
   418  	runClientTestTLS10(t, test)
   419  	runClientTestTLS12(t, test)
   420  }
   421  
   422  func TestClientResumption(t *testing.T) {
   423  	serverConfig := &Config{
   424  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   425  		Certificates: testConfig.Certificates,
   426  	}
   427  
   428  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
   429  	if err != nil {
   430  		panic(err)
   431  	}
   432  
   433  	rootCAs := x509.NewCertPool()
   434  	rootCAs.AddCert(issuer)
   435  
   436  	clientConfig := &Config{
   437  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   438  		ClientSessionCache: NewLRUClientSessionCache(32),
   439  		RootCAs:            rootCAs,
   440  		ServerName:         "example.golang",
   441  	}
   442  
   443  	testResumeState := func(test string, didResume bool) {
   444  		_, hs, err := testHandshake(clientConfig, serverConfig)
   445  		if err != nil {
   446  			t.Fatalf("%s: handshake failed: %s", test, err)
   447  		}
   448  		if hs.DidResume != didResume {
   449  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   450  		}
   451  		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
   452  			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
   453  		}
   454  	}
   455  
   456  	getTicket := func() []byte {
   457  		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
   458  	}
   459  	randomKey := func() [32]byte {
   460  		var k [32]byte
   461  		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
   462  			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
   463  		}
   464  		return k
   465  	}
   466  
   467  	testResumeState("Handshake", false)
   468  	ticket := getTicket()
   469  	testResumeState("Resume", true)
   470  	if !bytes.Equal(ticket, getTicket()) {
   471  		t.Fatal("first ticket doesn't match ticket after resumption")
   472  	}
   473  
   474  	key2 := randomKey()
   475  	serverConfig.SetSessionTicketKeys([][32]byte{key2})
   476  
   477  	testResumeState("InvalidSessionTicketKey", false)
   478  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   479  
   480  	serverConfig.SetSessionTicketKeys([][32]byte{randomKey(), key2})
   481  	ticket = getTicket()
   482  	testResumeState("KeyChange", true)
   483  	if bytes.Equal(ticket, getTicket()) {
   484  		t.Fatal("new ticket wasn't included while resuming")
   485  	}
   486  	testResumeState("KeyChangeFinish", true)
   487  
   488  	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
   489  	testResumeState("DifferentCipherSuite", false)
   490  	testResumeState("DifferentCipherSuiteRecovers", true)
   491  
   492  	clientConfig.ClientSessionCache = nil
   493  	testResumeState("WithoutSessionCache", false)
   494  }
   495  
   496  func TestLRUClientSessionCache(t *testing.T) {
   497  	// Initialize cache of capacity 4.
   498  	cache := NewLRUClientSessionCache(4)
   499  	cs := make([]ClientSessionState, 6)
   500  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
   501  
   502  	// Add 4 entries to the cache and look them up.
   503  	for i := 0; i < 4; i++ {
   504  		cache.Put(keys[i], &cs[i])
   505  	}
   506  	for i := 0; i < 4; i++ {
   507  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
   508  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
   509  		}
   510  	}
   511  
   512  	// Add 2 more entries to the cache. First 2 should be evicted.
   513  	for i := 4; i < 6; i++ {
   514  		cache.Put(keys[i], &cs[i])
   515  	}
   516  	for i := 0; i < 2; i++ {
   517  		if s, ok := cache.Get(keys[i]); ok || s != nil {
   518  			t.Fatalf("session cache should have evicted key: %s", keys[i])
   519  		}
   520  	}
   521  
   522  	// Touch entry 2. LRU should evict 3 next.
   523  	cache.Get(keys[2])
   524  	cache.Put(keys[0], &cs[0])
   525  	if s, ok := cache.Get(keys[3]); ok || s != nil {
   526  		t.Fatalf("session cache should have evicted key 3")
   527  	}
   528  
   529  	// Update entry 0 in place.
   530  	cache.Put(keys[0], &cs[3])
   531  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
   532  		t.Fatalf("session cache failed update for key 0")
   533  	}
   534  
   535  	// Adding a nil entry is valid.
   536  	cache.Put(keys[0], nil)
   537  	if s, ok := cache.Get(keys[0]); !ok || s != nil {
   538  		t.Fatalf("failed to add nil entry to cache")
   539  	}
   540  }
   541  
   542  func TestHandshakeClientALPNMatch(t *testing.T) {
   543  	config := *testConfig
   544  	config.NextProtos = []string{"proto2", "proto1"}
   545  
   546  	test := &clientTest{
   547  		name: "ALPN",
   548  		// Note that this needs OpenSSL 1.0.2 because that is the first
   549  		// version that supports the -alpn flag.
   550  		command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
   551  		config:  &config,
   552  		validate: func(state ConnectionState) error {
   553  			// The server's preferences should override the client.
   554  			if state.NegotiatedProtocol != "proto1" {
   555  				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
   556  			}
   557  			return nil
   558  		},
   559  	}
   560  	runClientTestTLS12(t, test)
   561  }
   562  
   563  func TestHandshakeClientALPNNoMatch(t *testing.T) {
   564  	config := *testConfig
   565  	config.NextProtos = []string{"proto3"}
   566  
   567  	test := &clientTest{
   568  		name: "ALPN-NoMatch",
   569  		// Note that this needs OpenSSL 1.0.2 because that is the first
   570  		// version that supports the -alpn flag.
   571  		command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
   572  		config:  &config,
   573  		validate: func(state ConnectionState) error {
   574  			// There's no overlap so OpenSSL will not select a protocol.
   575  			if state.NegotiatedProtocol != "" {
   576  				return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol)
   577  			}
   578  			return nil
   579  		},
   580  	}
   581  	runClientTestTLS12(t, test)
   582  }
   583  
   584  // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
   585  const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
   586  
   587  func TestHandshakClientSCTs(t *testing.T) {
   588  	config := *testConfig
   589  
   590  	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
   591  	if err != nil {
   592  		t.Fatal(err)
   593  	}
   594  
   595  	test := &clientTest{
   596  		name: "SCT",
   597  		// Note that this needs OpenSSL 1.0.2 because that is the first
   598  		// version that supports the -serverinfo flag.
   599  		command:    []string{"openssl", "s_server"},
   600  		config:     &config,
   601  		extensions: [][]byte{scts},
   602  		validate: func(state ConnectionState) error {
   603  			expectedSCTs := [][]byte{
   604  				scts[8:125],
   605  				scts[127:245],
   606  				scts[247:],
   607  			}
   608  			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
   609  				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
   610  			}
   611  			for i, expected := range expectedSCTs {
   612  				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
   613  					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
   614  				}
   615  			}
   616  			return nil
   617  		},
   618  	}
   619  	runClientTestTLS12(t, test)
   620  }
   621  
   622  var hostnameInSNITests = []struct {
   623  	in, out string
   624  }{
   625  	// Opaque string
   626  	{"", ""},
   627  	{"localhost", "localhost"},
   628  	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
   629  
   630  	// DNS hostname
   631  	{"golang.org", "golang.org"},
   632  	{"golang.org.", "golang.org"},
   633  
   634  	// Literal IPv4 address
   635  	{"1.2.3.4", ""},
   636  
   637  	// Literal IPv6 address
   638  	{"::1", ""},
   639  	{"::1%lo0", ""}, // with zone identifier
   640  	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
   641  	{"[::1%lo0]", ""},
   642  }
   643  
   644  func TestHostnameInSNI(t *testing.T) {
   645  	for _, tt := range hostnameInSNITests {
   646  		c, s := net.Pipe()
   647  
   648  		go func(host string) {
   649  			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
   650  		}(tt.in)
   651  
   652  		var header [5]byte
   653  		if _, err := io.ReadFull(s, header[:]); err != nil {
   654  			t.Fatal(err)
   655  		}
   656  		recordLen := int(header[3])<<8 | int(header[4])
   657  
   658  		record := make([]byte, recordLen)
   659  		if _, err := io.ReadFull(s, record[:]); err != nil {
   660  			t.Fatal(err)
   661  		}
   662  
   663  		c.Close()
   664  		s.Close()
   665  
   666  		var m clientHelloMsg
   667  		if !m.unmarshal(record) {
   668  			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
   669  			continue
   670  		}
   671  		if tt.in != tt.out && m.serverName == tt.in {
   672  			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
   673  		}
   674  		if m.serverName != tt.out {
   675  			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
   676  		}
   677  	}
   678  }
   679  
   680  func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
   681  	// This checks that the server can't select a cipher suite that the
   682  	// client didn't offer. See #13174.
   683  
   684  	c, s := net.Pipe()
   685  	errChan := make(chan error, 1)
   686  
   687  	go func() {
   688  		client := Client(c, &Config{
   689  			ServerName:   "foo",
   690  			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
   691  		})
   692  		errChan <- client.Handshake()
   693  	}()
   694  
   695  	var header [5]byte
   696  	if _, err := io.ReadFull(s, header[:]); err != nil {
   697  		t.Fatal(err)
   698  	}
   699  	recordLen := int(header[3])<<8 | int(header[4])
   700  
   701  	record := make([]byte, recordLen)
   702  	if _, err := io.ReadFull(s, record); err != nil {
   703  		t.Fatal(err)
   704  	}
   705  
   706  	// Create a ServerHello that selects a different cipher suite than the
   707  	// sole one that the client offered.
   708  	serverHello := &serverHelloMsg{
   709  		vers:        VersionTLS12,
   710  		random:      make([]byte, 32),
   711  		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
   712  	}
   713  	serverHelloBytes := serverHello.marshal()
   714  
   715  	s.Write([]byte{
   716  		byte(recordTypeHandshake),
   717  		byte(VersionTLS12 >> 8),
   718  		byte(VersionTLS12 & 0xff),
   719  		byte(len(serverHelloBytes) >> 8),
   720  		byte(len(serverHelloBytes)),
   721  	})
   722  	s.Write(serverHelloBytes)
   723  	s.Close()
   724  
   725  	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
   726  		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
   727  	}
   728  }
   729  
   730  // brokenConn wraps a net.Conn and causes all Writes after a certain number to
   731  // fail with brokenConnErr.
   732  type brokenConn struct {
   733  	net.Conn
   734  
   735  	// breakAfter is the number of successful writes that will be allowed
   736  	// before all subsequent writes fail.
   737  	breakAfter int
   738  
   739  	// numWrites is the number of writes that have been done.
   740  	numWrites int
   741  }
   742  
   743  // brokenConnErr is the error that brokenConn returns once exhausted.
   744  var brokenConnErr = errors.New("too many writes to brokenConn")
   745  
   746  func (b *brokenConn) Write(data []byte) (int, error) {
   747  	if b.numWrites >= b.breakAfter {
   748  		return 0, brokenConnErr
   749  	}
   750  
   751  	b.numWrites++
   752  	return b.Conn.Write(data)
   753  }
   754  
   755  func TestFailedWrite(t *testing.T) {
   756  	// Test that a write error during the handshake is returned.
   757  	for _, breakAfter := range []int{0, 1, 2, 3} {
   758  		c, s := net.Pipe()
   759  		done := make(chan bool)
   760  
   761  		go func() {
   762  			Server(s, testConfig).Handshake()
   763  			s.Close()
   764  			done <- true
   765  		}()
   766  
   767  		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
   768  		err := Client(brokenC, testConfig).Handshake()
   769  		if err != brokenConnErr {
   770  			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
   771  		}
   772  		brokenC.Close()
   773  
   774  		<-done
   775  	}
   776  }