github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/conn_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"io"
    10  	"net"
    11  	"testing"
    12  )
    13  
    14  func TestRoundUp(t *testing.T) {
    15  	if roundUp(0, 16) != 0 ||
    16  		roundUp(1, 16) != 16 ||
    17  		roundUp(15, 16) != 16 ||
    18  		roundUp(16, 16) != 16 ||
    19  		roundUp(17, 16) != 32 {
    20  		t.Error("roundUp broken")
    21  	}
    22  }
    23  
    24  // will be initialized with {0, 255, 255, ..., 255}
    25  var padding255Bad = [256]byte{}
    26  
    27  // will be initialized with {255, 255, 255, ..., 255}
    28  var padding255Good = [256]byte{255}
    29  
    30  var paddingTests = []struct {
    31  	in          []byte
    32  	good        bool
    33  	expectedLen int
    34  }{
    35  	{[]byte{1, 2, 3, 4, 0}, true, 4},
    36  	{[]byte{1, 2, 3, 4, 0, 1}, false, 0},
    37  	{[]byte{1, 2, 3, 4, 99, 99}, false, 0},
    38  	{[]byte{1, 2, 3, 4, 1, 1}, true, 4},
    39  	{[]byte{1, 2, 3, 2, 2, 2}, true, 3},
    40  	{[]byte{1, 2, 3, 3, 3, 3}, true, 2},
    41  	{[]byte{1, 2, 3, 4, 3, 3}, false, 0},
    42  	{[]byte{1, 4, 4, 4, 4, 4}, true, 1},
    43  	{[]byte{5, 5, 5, 5, 5, 5}, true, 0},
    44  	{[]byte{6, 6, 6, 6, 6, 6}, false, 0},
    45  	{padding255Bad[:], false, 0},
    46  	{padding255Good[:], true, 0},
    47  }
    48  
    49  func TestRemovePadding(t *testing.T) {
    50  	for i := 1; i < len(padding255Bad); i++ {
    51  		padding255Bad[i] = 255
    52  		padding255Good[i] = 255
    53  	}
    54  	for i, test := range paddingTests {
    55  		paddingLen, good := extractPadding(test.in)
    56  		expectedGood := byte(255)
    57  		if !test.good {
    58  			expectedGood = 0
    59  		}
    60  		if good != expectedGood {
    61  			t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good)
    62  		}
    63  		if good == 255 && len(test.in)-paddingLen != test.expectedLen {
    64  			t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen)
    65  		}
    66  	}
    67  }
    68  
    69  var certExampleCom = `308201713082011ba003020102021005a75ddf21014d5f417083b7a010ba2e300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343135335a170d3137303831373231343135335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b37f0fdd67e715bf532046ac34acbd8fdc4dabe2b598588f3f58b1f12e6219a16cbfe54d2b4b665396013589262360b6721efa27d546854f17cc9aeec6751db10203010001a34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300d06092a864886f70d01010b050003410059fc487866d3d855503c8e064ca32aac5e9babcece89ec597f8b2b24c17867f4a5d3b4ece06e795bfc5448ccbd2ffca1b3433171ebf3557a4737b020565350a0`
    70  
    71  var certWildcardExampleCom = `308201743082011ea003020102021100a7aa6297c9416a4633af8bec2958c607300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343231395a170d3137303831373231343231395a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b105afc859a711ee864114e7d2d46c2dcbe392d3506249f6c2285b0eb342cc4bf2d803677c61c0abde443f084745c1a6d62080e5664ef2cc8f50ad8a0ab8870b0203010001a34f304d300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030180603551d110411300f820d2a2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100af26088584d266e3f6566360cf862c7fecc441484b098b107439543144a2b93f20781988281e108c6d7656934e56950e1e5f2bcf38796b814ccb729445856c34`
    72  
    73  var certFooExampleCom = `308201753082011fa00302010202101bbdb6070b0aeffc49008cde74deef29300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343234345a170d3137303831373231343234345a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100f00ac69d8ca2829f26216c7b50f1d4bbabad58d447706476cd89a2f3e1859943748aa42c15eedc93ac7c49e40d3b05ed645cb6b81c4efba60d961f44211a54eb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f666f6f2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100a0957fca6d1e0f1ef4b247348c7a8ca092c29c9c0ecc1898ea6b8065d23af6d922a410dd2335a0ea15edd1394cef9f62c9e876a21e35250a0b4fe1ddceba0f36`
    74  
    75  var certDoubleWildcardExampleCom = `308201753082011fa003020102021039d262d8538db8ffba30d204e02ddeb5300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343331335a170d3137303831373231343331335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100abb6bd84b8b9be3fb9415d00f22b4ddcaec7c99855b9d818c09003e084578430e5cfd2e35faa3561f036d496aa43a9ca6e6cf23c72a763c04ae324004f6cbdbb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f2a2e2a2e6578616d706c652e636f6d300d06092a864886f70d01010b05000341004837521004a5b6bc7ad5d6c0dae60bb7ee0fa5e4825be35e2bb6ef07ee29396ca30ceb289431bcfd363888ba2207139933ac7c6369fa8810c819b2e2966abb4b`
    76  
    77  func TestCertificateSelection(t *testing.T) {
    78  	config := Config{
    79  		Certificates: []Certificate{
    80  			{
    81  				Certificate: [][]byte{fromHex(certExampleCom)},
    82  			},
    83  			{
    84  				Certificate: [][]byte{fromHex(certWildcardExampleCom)},
    85  			},
    86  			{
    87  				Certificate: [][]byte{fromHex(certFooExampleCom)},
    88  			},
    89  			{
    90  				Certificate: [][]byte{fromHex(certDoubleWildcardExampleCom)},
    91  			},
    92  		},
    93  	}
    94  
    95  	config.BuildNameToCertificate()
    96  
    97  	pointerToIndex := func(c *Certificate) int {
    98  		for i := range config.Certificates {
    99  			if c == &config.Certificates[i] {
   100  				return i
   101  			}
   102  		}
   103  		return -1
   104  	}
   105  
   106  	certificateForName := func(name string) *Certificate {
   107  		clientHello := &ClientHelloInfo{
   108  			ServerName: name,
   109  		}
   110  		if cert, err := config.getCertificate(clientHello); err != nil {
   111  			t.Errorf("unable to get certificate for name '%s': %s", name, err)
   112  			return nil
   113  		} else {
   114  			return cert
   115  		}
   116  	}
   117  
   118  	if n := pointerToIndex(certificateForName("example.com")); n != 0 {
   119  		t.Errorf("example.com returned certificate %d, not 0", n)
   120  	}
   121  	if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
   122  		t.Errorf("bar.example.com returned certificate %d, not 1", n)
   123  	}
   124  	if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
   125  		t.Errorf("foo.example.com returned certificate %d, not 2", n)
   126  	}
   127  	if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 {
   128  		t.Errorf("foo.bar.example.com returned certificate %d, not 3", n)
   129  	}
   130  	if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 {
   131  		t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n)
   132  	}
   133  }
   134  
   135  // Run with multiple crypto configs to test the logic for computing TLS record overheads.
   136  func runDynamicRecordSizingTest(t *testing.T, config *Config) {
   137  	clientConn, serverConn := net.Pipe()
   138  
   139  	serverConfig := config.Clone()
   140  	serverConfig.DynamicRecordSizingDisabled = false
   141  	tlsConn := Server(serverConn, serverConfig)
   142  
   143  	recordSizesChan := make(chan []int, 1)
   144  	go func() {
   145  		// This goroutine performs a TLS handshake over clientConn and
   146  		// then reads TLS records until EOF. It writes a slice that
   147  		// contains all the record sizes to recordSizesChan.
   148  		defer close(recordSizesChan)
   149  		defer clientConn.Close()
   150  
   151  		tlsConn := Client(clientConn, config)
   152  		if err := tlsConn.Handshake(); err != nil {
   153  			t.Errorf("Error from client handshake: %v", err)
   154  			return
   155  		}
   156  
   157  		var recordHeader [recordHeaderLen]byte
   158  		var record []byte
   159  		var recordSizes []int
   160  
   161  		for {
   162  			n, err := io.ReadFull(clientConn, recordHeader[:])
   163  			if err == io.EOF {
   164  				break
   165  			}
   166  			if err != nil || n != len(recordHeader) {
   167  				t.Errorf("io.ReadFull = %d, %v", n, err)
   168  				return
   169  			}
   170  
   171  			length := int(recordHeader[3])<<8 | int(recordHeader[4])
   172  			if len(record) < length {
   173  				record = make([]byte, length)
   174  			}
   175  
   176  			n, err = io.ReadFull(clientConn, record[:length])
   177  			if err != nil || n != length {
   178  				t.Errorf("io.ReadFull = %d, %v", n, err)
   179  				return
   180  			}
   181  
   182  			// The last record will be a close_notify alert, which
   183  			// we don't wish to record.
   184  			if recordType(recordHeader[0]) == recordTypeApplicationData {
   185  				recordSizes = append(recordSizes, recordHeaderLen+length)
   186  			}
   187  		}
   188  
   189  		recordSizesChan <- recordSizes
   190  	}()
   191  
   192  	if err := tlsConn.Handshake(); err != nil {
   193  		t.Fatalf("Error from server handshake: %s", err)
   194  	}
   195  
   196  	// The server writes these plaintexts in order.
   197  	plaintext := bytes.Join([][]byte{
   198  		bytes.Repeat([]byte("x"), recordSizeBoostThreshold),
   199  		bytes.Repeat([]byte("y"), maxPlaintext*2),
   200  		bytes.Repeat([]byte("z"), maxPlaintext),
   201  	}, nil)
   202  
   203  	if _, err := tlsConn.Write(plaintext); err != nil {
   204  		t.Fatalf("Error from server write: %s", err)
   205  	}
   206  	if err := tlsConn.Close(); err != nil {
   207  		t.Fatalf("Error from server close: %s", err)
   208  	}
   209  
   210  	recordSizes := <-recordSizesChan
   211  	if recordSizes == nil {
   212  		t.Fatalf("Client encountered an error")
   213  	}
   214  
   215  	// Drop the size of last record, which is likely to be truncated.
   216  	recordSizes = recordSizes[:len(recordSizes)-1]
   217  
   218  	// recordSizes should contain a series of records smaller than
   219  	// tcpMSSEstimate followed by some larger than maxPlaintext.
   220  	seenLargeRecord := false
   221  	for i, size := range recordSizes {
   222  		if !seenLargeRecord {
   223  			if size > (i+1)*tcpMSSEstimate {
   224  				t.Fatalf("Record #%d has size %d, which is too large too soon", i, size)
   225  			}
   226  			if size >= maxPlaintext {
   227  				seenLargeRecord = true
   228  			}
   229  		} else if size <= maxPlaintext {
   230  			t.Fatalf("Record #%d has size %d but should be full sized", i, size)
   231  		}
   232  	}
   233  
   234  	if !seenLargeRecord {
   235  		t.Fatalf("No large records observed")
   236  	}
   237  }
   238  
   239  func TestDynamicRecordSizingWithStreamCipher(t *testing.T) {
   240  	config := testConfig.Clone()
   241  	config.CipherSuites = []uint16{TLS_RSA_WITH_RC4_128_SHA}
   242  	runDynamicRecordSizingTest(t, config)
   243  }
   244  
   245  func TestDynamicRecordSizingWithCBC(t *testing.T) {
   246  	config := testConfig.Clone()
   247  	config.CipherSuites = []uint16{TLS_RSA_WITH_AES_256_CBC_SHA}
   248  	runDynamicRecordSizingTest(t, config)
   249  }
   250  
   251  func TestDynamicRecordSizingWithAEAD(t *testing.T) {
   252  	config := testConfig.Clone()
   253  	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
   254  	runDynamicRecordSizingTest(t, config)
   255  }
   256  
   257  // hairpinConn is a net.Conn that makes a “hairpin” call when closed, back into
   258  // the tls.Conn which is calling it.
   259  type hairpinConn struct {
   260  	net.Conn
   261  	tlsConn *Conn
   262  }
   263  
   264  func (conn *hairpinConn) Close() error {
   265  	conn.tlsConn.ConnectionState()
   266  	return nil
   267  }
   268  
   269  func TestHairpinInClose(t *testing.T) {
   270  	// This tests that the underlying net.Conn can call back into the
   271  	// tls.Conn when being closed without deadlocking.
   272  	client, server := net.Pipe()
   273  	defer server.Close()
   274  	defer client.Close()
   275  
   276  	conn := &hairpinConn{client, nil}
   277  	tlsConn := Server(conn, &Config{
   278  		GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
   279  			panic("unreachable")
   280  		},
   281  	})
   282  	conn.tlsConn = tlsConn
   283  
   284  	// This call should not deadlock.
   285  	tlsConn.Close()
   286  }