github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bufio"
     9  	"encoding/hex"
    10  	"errors"
    11  	"flag"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"os/exec"
    17  	"strconv"
    18  	"strings"
    19  	"sync"
    20  	"testing"
    21  )
    22  
    23  // TLS reference tests run a connection against a reference implementation
    24  // (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
    25  // code, during a test, is configured with deterministic randomness and so the
    26  // reference test can be reproduced exactly in the future.
    27  //
    28  // In order to save everyone who wishes to run the tests from needing the
    29  // reference implementation installed, the reference connections are saved in
    30  // files in the testdata directory. Thus running the tests involves nothing
    31  // external, but creating and updating them requires the reference
    32  // implementation.
    33  //
    34  // Tests can be updated by running them with the -update flag. This will cause
    35  // the test files to be regenerated. Generally one should combine the -update
    36  // flag with -test.run to updated a specific test. Since the reference
    37  // implementation will always generate fresh random numbers, large parts of
    38  // the reference connection will always change.
    39  
    40  var (
    41  	update = flag.Bool("update", false, "update golden files on disk")
    42  
    43  	opensslVersionTestOnce sync.Once
    44  	opensslVersionTestErr  error
    45  )
    46  
    47  func checkOpenSSLVersion(t *testing.T) {
    48  	opensslVersionTestOnce.Do(testOpenSSLVersion)
    49  	if opensslVersionTestErr != nil {
    50  		t.Fatal(opensslVersionTestErr)
    51  	}
    52  }
    53  
    54  func testOpenSSLVersion() {
    55  	// This test ensures that the version of OpenSSL looks reasonable
    56  	// before updating the test data.
    57  
    58  	if !*update {
    59  		return
    60  	}
    61  
    62  	openssl := exec.Command("openssl", "version")
    63  	output, err := openssl.CombinedOutput()
    64  	if err != nil {
    65  		opensslVersionTestErr = err
    66  		return
    67  	}
    68  
    69  	version := string(output)
    70  	if strings.HasPrefix(version, "OpenSSL 1.1.0") {
    71  		return
    72  	}
    73  
    74  	println("***********************************************")
    75  	println("")
    76  	println("You need to build OpenSSL 1.1.0 from source in order")
    77  	println("to update the test data.")
    78  	println("")
    79  	println("Configure it with:")
    80  	println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method -static linux-x86_64")
    81  	println("and then add the apps/ directory at the front of your PATH.")
    82  	println("***********************************************")
    83  
    84  	opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data")
    85  }
    86  
    87  // recordingConn is a net.Conn that records the traffic that passes through it.
    88  // WriteTo can be used to produce output that can be later be loaded with
    89  // ParseTestData.
    90  type recordingConn struct {
    91  	net.Conn
    92  	sync.Mutex
    93  	flows   [][]byte
    94  	reading bool
    95  }
    96  
    97  func (r *recordingConn) Read(b []byte) (n int, err error) {
    98  	if n, err = r.Conn.Read(b); n == 0 {
    99  		return
   100  	}
   101  	b = b[:n]
   102  
   103  	r.Lock()
   104  	defer r.Unlock()
   105  
   106  	if l := len(r.flows); l == 0 || !r.reading {
   107  		buf := make([]byte, len(b))
   108  		copy(buf, b)
   109  		r.flows = append(r.flows, buf)
   110  	} else {
   111  		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
   112  	}
   113  	r.reading = true
   114  	return
   115  }
   116  
   117  func (r *recordingConn) Write(b []byte) (n int, err error) {
   118  	if n, err = r.Conn.Write(b); n == 0 {
   119  		return
   120  	}
   121  	b = b[:n]
   122  
   123  	r.Lock()
   124  	defer r.Unlock()
   125  
   126  	if l := len(r.flows); l == 0 || r.reading {
   127  		buf := make([]byte, len(b))
   128  		copy(buf, b)
   129  		r.flows = append(r.flows, buf)
   130  	} else {
   131  		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
   132  	}
   133  	r.reading = false
   134  	return
   135  }
   136  
   137  // WriteTo writes Go source code to w that contains the recorded traffic.
   138  func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
   139  	// TLS always starts with a client to server flow.
   140  	clientToServer := true
   141  	var written int64
   142  	for i, flow := range r.flows {
   143  		source, dest := "client", "server"
   144  		if !clientToServer {
   145  			source, dest = dest, source
   146  		}
   147  		n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
   148  		written += int64(n)
   149  		if err != nil {
   150  			return written, err
   151  		}
   152  		dumper := hex.Dumper(w)
   153  		n, err = dumper.Write(flow)
   154  		written += int64(n)
   155  		if err != nil {
   156  			return written, err
   157  		}
   158  		err = dumper.Close()
   159  		if err != nil {
   160  			return written, err
   161  		}
   162  		clientToServer = !clientToServer
   163  	}
   164  	return written, nil
   165  }
   166  
   167  func parseTestData(r io.Reader) (flows [][]byte, err error) {
   168  	var currentFlow []byte
   169  
   170  	scanner := bufio.NewScanner(r)
   171  	for scanner.Scan() {
   172  		line := scanner.Text()
   173  		// If the line starts with ">>> " then it marks the beginning
   174  		// of a new flow.
   175  		if strings.HasPrefix(line, ">>> ") {
   176  			if len(currentFlow) > 0 || len(flows) > 0 {
   177  				flows = append(flows, currentFlow)
   178  				currentFlow = nil
   179  			}
   180  			continue
   181  		}
   182  
   183  		// Otherwise the line is a line of hex dump that looks like:
   184  		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
   185  		// (Some bytes have been omitted from the middle section.)
   186  
   187  		if i := strings.IndexByte(line, ' '); i >= 0 {
   188  			line = line[i:]
   189  		} else {
   190  			return nil, errors.New("invalid test data")
   191  		}
   192  
   193  		if i := strings.IndexByte(line, '|'); i >= 0 {
   194  			line = line[:i]
   195  		} else {
   196  			return nil, errors.New("invalid test data")
   197  		}
   198  
   199  		hexBytes := strings.Fields(line)
   200  		for _, hexByte := range hexBytes {
   201  			val, err := strconv.ParseUint(hexByte, 16, 8)
   202  			if err != nil {
   203  				return nil, errors.New("invalid hex byte in test data: " + err.Error())
   204  			}
   205  			currentFlow = append(currentFlow, byte(val))
   206  		}
   207  	}
   208  
   209  	if len(currentFlow) > 0 {
   210  		flows = append(flows, currentFlow)
   211  	}
   212  
   213  	return flows, nil
   214  }
   215  
   216  // tempFile creates a temp file containing contents and returns its path.
   217  func tempFile(contents string) string {
   218  	file, err := ioutil.TempFile("", "go-tls-test")
   219  	if err != nil {
   220  		panic("failed to create temp file: " + err.Error())
   221  	}
   222  	path := file.Name()
   223  	file.WriteString(contents)
   224  	file.Close()
   225  	return path
   226  }