github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/handshake_client_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"encoding/pem"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"os"
    16  	"os/exec"
    17  	"path/filepath"
    18  	"strconv"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/zmap/zcrypto/x509"
    23  )
    24  
    25  // Note: see comment in handshake_test.go for details of how the reference
    26  // tests work.
    27  
    28  // blockingSource is an io.Reader that blocks a Read call until it's closed.
    29  type blockingSource chan bool
    30  
    31  func (b blockingSource) Read([]byte) (n int, err error) {
    32  	<-b
    33  	return 0, io.EOF
    34  }
    35  
    36  // clientTest represents a test of the TLS client handshake against a reference
    37  // implementation.
    38  type clientTest struct {
    39  	// name is a freeform string identifying the test and the file in which
    40  	// the expected results will be stored.
    41  	name string
    42  	// command, if not empty, contains a series of arguments for the
    43  	// command to run for the reference server.
    44  	command []string
    45  	// config, if not nil, contains a custom Config to use for this test.
    46  	config *Config
    47  	// cert, if not empty, contains a DER-encoded certificate for the
    48  	// reference server.
    49  	cert []byte
    50  	// key, if not nil, contains either a *rsa.PrivateKey or
    51  	// *ecdsa.PrivateKey which is the private key for the reference server.
    52  	key interface{}
    53  }
    54  
    55  var defaultServerCommand = []string{"openssl", "s_server"}
    56  
    57  // connFromCommand starts the reference server process, connects to it and
    58  // returns a recordingConn for the connection. The stdin return value is a
    59  // blockingSource for the stdin of the child process. It must be closed before
    60  // Waiting for child.
    61  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) {
    62  	cert := testRSACertificate
    63  	if len(test.cert) > 0 {
    64  		cert = test.cert
    65  	}
    66  	certPath := tempFile(string(cert))
    67  	defer os.Remove(certPath)
    68  
    69  	var key interface{} = testRSAPrivateKey
    70  	if test.key != nil {
    71  		key = test.key
    72  	}
    73  	var pemType string
    74  	var derBytes []byte
    75  	switch key := key.(type) {
    76  	case *rsa.PrivateKey:
    77  		pemType = "RSA"
    78  		derBytes = x509.MarshalPKCS1PrivateKey(key)
    79  	case *ecdsa.PrivateKey:
    80  		pemType = "EC"
    81  		var err error
    82  		derBytes, err = x509.MarshalECPrivateKey(key)
    83  		if err != nil {
    84  			panic(err)
    85  		}
    86  	default:
    87  		panic("unknown key type")
    88  	}
    89  
    90  	var pemOut bytes.Buffer
    91  	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
    92  
    93  	keyPath := tempFile(string(pemOut.Bytes()))
    94  	defer os.Remove(keyPath)
    95  
    96  	var command []string
    97  	if len(test.command) > 0 {
    98  		command = append(command, test.command...)
    99  	} else {
   100  		command = append(command, defaultServerCommand...)
   101  	}
   102  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   103  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   104  	// can't take "0" as an argument here so we have to pick a number and
   105  	// hope that it's not in use on the machine. Since this only occurs
   106  	// when -update is given and thus when there's a human watching the
   107  	// test, this isn't too bad.
   108  	const serverPort = 24323
   109  	command = append(command, "-accept", strconv.Itoa(serverPort))
   110  
   111  	cmd := exec.Command(command[0], command[1:]...)
   112  	stdin = blockingSource(make(chan bool))
   113  	cmd.Stdin = stdin
   114  	var out bytes.Buffer
   115  	cmd.Stdout = &out
   116  	cmd.Stderr = &out
   117  	if err := cmd.Start(); err != nil {
   118  		return nil, nil, nil, err
   119  	}
   120  
   121  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   122  	// opening the listening socket, so we can't use that to wait until it
   123  	// has started listening. Thus we are forced to poll until we get a
   124  	// connection.
   125  	var tcpConn net.Conn
   126  	for i := uint(0); i < 5; i++ {
   127  		var err error
   128  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   129  			IP:   net.IPv4(127, 0, 0, 1),
   130  			Port: serverPort,
   131  		})
   132  		if err == nil {
   133  			break
   134  		}
   135  		time.Sleep((1 << i) * 5 * time.Millisecond)
   136  	}
   137  	if tcpConn == nil {
   138  		close(stdin)
   139  		out.WriteTo(os.Stdout)
   140  		cmd.Process.Kill()
   141  		return nil, nil, nil, cmd.Wait()
   142  	}
   143  
   144  	record := &recordingConn{
   145  		Conn: tcpConn,
   146  	}
   147  
   148  	return record, cmd, stdin, nil
   149  }
   150  
   151  func (test *clientTest) dataPath() string {
   152  	return filepath.Join("testdata", "Client-"+test.name)
   153  }
   154  
   155  func (test *clientTest) loadData() (flows [][]byte, err error) {
   156  	in, err := os.Open(test.dataPath())
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	defer in.Close()
   161  	return parseTestData(in)
   162  }
   163  
   164  func (test *clientTest) run(t *testing.T, write bool) {
   165  	var clientConn, serverConn net.Conn
   166  	var recordingConn *recordingConn
   167  	var childProcess *exec.Cmd
   168  	var stdin blockingSource
   169  
   170  	if write {
   171  		var err error
   172  		recordingConn, childProcess, stdin, err = test.connFromCommand()
   173  		if err != nil {
   174  			t.Fatalf("Failed to start subcommand: %s", err)
   175  		}
   176  		clientConn = recordingConn
   177  	} else {
   178  		clientConn, serverConn = net.Pipe()
   179  	}
   180  
   181  	config := test.config
   182  	if config == nil {
   183  		config = testConfig
   184  	}
   185  	client := Client(clientConn, config)
   186  
   187  	doneChan := make(chan bool)
   188  	go func() {
   189  		if _, err := client.Write([]byte("hello\n")); err != nil {
   190  			t.Logf("Client.Write failed: %s", err)
   191  		}
   192  		client.Close()
   193  		clientConn.Close()
   194  		doneChan <- true
   195  	}()
   196  
   197  	if !write {
   198  		flows, err := test.loadData()
   199  		if err != nil {
   200  			t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
   201  		}
   202  		for i, b := range flows {
   203  			if i%2 == 1 {
   204  				serverConn.Write(b)
   205  				continue
   206  			}
   207  			bb := make([]byte, len(b))
   208  			_, err := io.ReadFull(serverConn, bb)
   209  			if err != nil {
   210  				t.Fatalf("%s #%d: %s", test.name, i, err)
   211  			}
   212  			if !bytes.Equal(b, bb) {
   213  				t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
   214  			}
   215  		}
   216  		serverConn.Close()
   217  	}
   218  
   219  	<-doneChan
   220  
   221  	if write {
   222  		path := test.dataPath()
   223  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   224  		if err != nil {
   225  			t.Fatalf("Failed to create output file: %s", err)
   226  		}
   227  		defer out.Close()
   228  		recordingConn.Close()
   229  		close(stdin)
   230  		childProcess.Process.Kill()
   231  		childProcess.Wait()
   232  		if len(recordingConn.flows) < 3 {
   233  			childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout)
   234  			t.Fatalf("Client connection didn't work")
   235  		}
   236  		recordingConn.WriteTo(out)
   237  		fmt.Printf("Wrote %s\n", path)
   238  	}
   239  }
   240  
   241  func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
   242  	test := *template
   243  	test.name = prefix + test.name
   244  	if len(test.command) == 0 {
   245  		test.command = defaultClientCommand
   246  	}
   247  	test.command = append([]string(nil), test.command...)
   248  	test.command = append(test.command, option)
   249  	test.run(t, *update)
   250  }
   251  
   252  func runClientTestTLS10(t *testing.T, template *clientTest) {
   253  	runClientTestForVersion(t, template, "TLSv10-", "-tls1")
   254  }
   255  
   256  func runClientTestTLS11(t *testing.T, template *clientTest) {
   257  	runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
   258  }
   259  
   260  func runClientTestTLS12(t *testing.T, template *clientTest) {
   261  	runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
   262  }
   263  
   264  //func TestHandshakeClientRSARC4(t *testing.T) {
   265  //	test := &clientTest{
   266  //		name:    "RSA-RC4",
   267  //		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
   268  //	}
   269  //	runClientTestTLS10(t, test)
   270  //	runClientTestTLS11(t, test)
   271  //	runClientTestTLS12(t, test)
   272  //}
   273  //
   274  //func TestHandshakeClientECDHERSAAES(t *testing.T) {
   275  //	test := &clientTest{
   276  //		name:    "ECDHE-RSA-AES",
   277  //		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
   278  //	}
   279  //	runClientTestTLS10(t, test)
   280  //	runClientTestTLS11(t, test)
   281  //	runClientTestTLS12(t, test)
   282  //}
   283  //
   284  //func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   285  //	test := &clientTest{
   286  //		name:    "ECDHE-ECDSA-AES",
   287  //		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
   288  //		cert:    testECDSACertificate,
   289  //		key:     testECDSAPrivateKey,
   290  //	}
   291  //	runClientTestTLS10(t, test)
   292  //	runClientTestTLS11(t, test)
   293  //	runClientTestTLS12(t, test)
   294  //}
   295  //
   296  //func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   297  //	test := &clientTest{
   298  //		name:    "ECDHE-ECDSA-AES-GCM",
   299  //		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   300  //		cert:    testECDSACertificate,
   301  //		key:     testECDSAPrivateKey,
   302  //	}
   303  //	runClientTestTLS12(t, test)
   304  //}
   305  //
   306  //func TestHandshakeClientCertRSA(t *testing.T) {
   307  //	config := *testConfig
   308  //	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   309  //	config.Certificates = []Certificate{cert}
   310  //
   311  //	test := &clientTest{
   312  //		name:    "ClientCert-RSA-RSA",
   313  //		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   314  //		config:  &config,
   315  //	}
   316  //
   317  //	runClientTestTLS10(t, test)
   318  //	runClientTestTLS12(t, test)
   319  //
   320  //	test = &clientTest{
   321  //		name:    "ClientCert-RSA-ECDSA",
   322  //		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   323  //		config:  &config,
   324  //		cert:    testECDSACertificate,
   325  //		key:     testECDSAPrivateKey,
   326  //	}
   327  //
   328  //	runClientTestTLS10(t, test)
   329  //	runClientTestTLS12(t, test)
   330  //}
   331  
   332  // TODO: figure out why this test is failing
   333  //func TestHandshakeClientCertECDSA(t *testing.T) {
   334  //	config := *testConfig
   335  //	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   336  //	config.Certificates = []Certificate{cert}
   337  //
   338  //	test := &clientTest{
   339  //		name:    "ClientCert-ECDSA-RSA",
   340  //		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"},
   341  //		config:  &config,
   342  //	}
   343  //
   344  //	runClientTestTLS10(t, test)
   345  //	runClientTestTLS12(t, test)
   346  //
   347  //	test = &clientTest{
   348  //		name:    "ClientCert-ECDSA-ECDSA",
   349  //		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
   350  //		config:  &config,
   351  //		cert:    testECDSACertificate,
   352  //		key:     testECDSAPrivateKey,
   353  //	}
   354  //
   355  //	runClientTestTLS10(t, test)
   356  //	runClientTestTLS12(t, test)
   357  //}
   358  
   359  func TestClientResumption(t *testing.T) {
   360  	serverConfig := &Config{
   361  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   362  		Certificates: testConfig.Certificates,
   363  	}
   364  	clientConfig := &Config{
   365  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   366  		InsecureSkipVerify: true,
   367  		ClientSessionCache: NewLRUClientSessionCache(32),
   368  	}
   369  
   370  	testResumeState := func(test string, didResume bool) {
   371  		hs, err := testHandshake(clientConfig, serverConfig)
   372  		if err != nil {
   373  			t.Fatalf("%s: handshake failed: %s", test, err)
   374  		}
   375  		if hs.DidResume != didResume {
   376  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   377  		}
   378  	}
   379  
   380  	testResumeState("Handshake", false)
   381  	testResumeState("Resume", true)
   382  
   383  	if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil {
   384  		t.Fatalf("Failed to invalidate SessionTicketKey")
   385  	}
   386  	testResumeState("InvalidSessionTicketKey", false)
   387  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   388  
   389  	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
   390  	testResumeState("DifferentCipherSuite", false)
   391  	testResumeState("DifferentCipherSuiteRecovers", true)
   392  
   393  	clientConfig.ClientSessionCache = nil
   394  	testResumeState("WithoutSessionCache", false)
   395  }
   396  
   397  func TestLRUClientSessionCache(t *testing.T) {
   398  	// Initialize cache of capacity 4.
   399  	cache := NewLRUClientSessionCache(4)
   400  	cs := make([]ClientSessionState, 6)
   401  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
   402  
   403  	// Add 4 entries to the cache and look them up.
   404  	for i := 0; i < 4; i++ {
   405  		cache.Put(keys[i], &cs[i])
   406  	}
   407  	for i := 0; i < 4; i++ {
   408  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
   409  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
   410  		}
   411  	}
   412  
   413  	// Add 2 more entries to the cache. First 2 should be evicted.
   414  	for i := 4; i < 6; i++ {
   415  		cache.Put(keys[i], &cs[i])
   416  	}
   417  	for i := 0; i < 2; i++ {
   418  		if s, ok := cache.Get(keys[i]); ok || s != nil {
   419  			t.Fatalf("session cache should have evicted key: %s", keys[i])
   420  		}
   421  	}
   422  
   423  	// Touch entry 2. LRU should evict 3 next.
   424  	cache.Get(keys[2])
   425  	cache.Put(keys[0], &cs[0])
   426  	if s, ok := cache.Get(keys[3]); ok || s != nil {
   427  		t.Fatalf("session cache should have evicted key 3")
   428  	}
   429  
   430  	// Update entry 0 in place.
   431  	cache.Put(keys[0], &cs[3])
   432  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
   433  		t.Fatalf("session cache failed update for key 0")
   434  	}
   435  
   436  	// Adding a nil entry is valid.
   437  	cache.Put(keys[0], nil)
   438  	if s, ok := cache.Get(keys[0]); !ok || s != nil {
   439  		t.Fatalf("failed to add nil entry to cache")
   440  	}
   441  }
   442  
   443  // Test the custom client hello feature by imitating a Firefox ClientHello message
   444  func TestHandshakeClientCustomHello(t *testing.T) {
   445  	hello := ClientFingerprintConfiguration{}
   446  	hello.HandshakeVersion = 0x0303
   447  
   448  	hello.CipherSuites = []uint16{
   449  		TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   450  		TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   451  		TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
   452  		TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
   453  		TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   454  		TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   455  		TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
   456  		TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
   457  		TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
   458  		TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
   459  		TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
   460  		TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
   461  		TLS_RSA_WITH_AES_128_CBC_SHA,
   462  		TLS_RSA_WITH_AES_256_CBC_SHA,
   463  		TLS_RSA_WITH_3DES_EDE_CBC_SHA,
   464  	}
   465  	hello.CompressionMethods = []uint8{0}
   466  	sni := SNIExtension{[]string{}, true}
   467  	ec := SupportedCurvesExtension{[]CurveID{CurveP256, CurveP384, CurveP521}}
   468  	points := PointFormatExtension{[]uint8{0}}
   469  	st := SessionTicketExtension{[]byte{}, true}
   470  	alpn := ALPNExtension{[]string{"h2", "http/1.1"}}
   471  	sigs := SignatureAlgorithmExtension{[]uint16{0x0401,
   472  		0x0501,
   473  		0x0601,
   474  		0x0201,
   475  		0x0403,
   476  		0x0503,
   477  		0x0603,
   478  		0x0203,
   479  		0x0502,
   480  		0x0402,
   481  		0x0202,
   482  	}}
   483  
   484  	hello.Extensions = []ClientExtension{&sni,
   485  		&ExtendedMasterSecretExtension{},
   486  		&SecureRenegotiationExtension{},
   487  		&ec,
   488  		&points,
   489  		&st,
   490  		&NextProtocolNegotiationExtension{},
   491  		&alpn,
   492  		&StatusRequestExtension{},
   493  		&sigs,
   494  	}
   495  	config := *testConfig
   496  	config.ClientFingerprintConfiguration = &hello
   497  	test := &clientTest{
   498  		name:    "ClientFingerprint",
   499  		command: []string{"openssl", "s_server"},
   500  		config:  &config,
   501  	}
   502  	runClientTestTLS12(t, test)
   503  }
   504  
   505  // writeCountingConn wraps a net.Conn and counts the number of Write calls.
   506  type writeCountingConn struct {
   507  	net.Conn
   508  
   509  	// numWrites is the number of writes that have been done.
   510  	numWrites int
   511  }
   512  
   513  func (wcc *writeCountingConn) Write(data []byte) (int, error) {
   514  	wcc.numWrites++
   515  	return wcc.Conn.Write(data)
   516  }
   517  
   518  func TestBuffering(t *testing.T) {
   519  	c, s := net.Pipe()
   520  	done := make(chan bool)
   521  
   522  	clientWCC := &writeCountingConn{Conn: c}
   523  	serverWCC := &writeCountingConn{Conn: s}
   524  
   525  	go func() {
   526  		Server(serverWCC, testConfig).Handshake()
   527  		serverWCC.Close()
   528  		done <- true
   529  	}()
   530  
   531  	err := Client(clientWCC, testConfig).Handshake()
   532  	if err != nil {
   533  		t.Fatal(err)
   534  	}
   535  	clientWCC.Close()
   536  	<-done
   537  
   538  	if n := clientWCC.numWrites; n != 2 {
   539  		t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
   540  	}
   541  
   542  	if n := serverWCC.numWrites; n != 2 {
   543  		t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
   544  	}
   545  }
   546  
   547  func TestDontBuffer(t *testing.T) {
   548  	c, s := net.Pipe()
   549  	done := make(chan bool)
   550  
   551  	clientWCC := &writeCountingConn{Conn: c}
   552  	serverWCC := &writeCountingConn{Conn: s}
   553  	testConfig.DontBufferHandshakes = true
   554  	defer func() {
   555  		testConfig.DontBufferHandshakes = false
   556  	}()
   557  	go func() {
   558  		Server(serverWCC, testConfig).Handshake()
   559  		serverWCC.Close()
   560  		done <- true
   561  	}()
   562  
   563  	err := Client(clientWCC, testConfig).Handshake()
   564  	if err != nil {
   565  		t.Fatal(err)
   566  	}
   567  	clientWCC.Close()
   568  	<-done
   569  
   570  	if n := clientWCC.numWrites; n != 4 {
   571  		t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
   572  	}
   573  
   574  	if n := serverWCC.numWrites; n != 6 {
   575  		t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
   576  	}
   577  }