github.com/icodeface/tls@v0.0.0-20230910023335-34df9250cd12/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bufio"
     9  	"encoding/hex"
    10  	"errors"
    11  	"flag"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"testing"
    22  )
    23  
    24  // TLS reference tests run a connection against a reference implementation
    25  // (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
    26  // code, during a test, is configured with deterministic randomness and so the
    27  // reference test can be reproduced exactly in the future.
    28  //
    29  // In order to save everyone who wishes to run the tests from needing the
    30  // reference implementation installed, the reference connections are saved in
    31  // files in the testdata directory. Thus running the tests involves nothing
    32  // external, but creating and updating them requires the reference
    33  // implementation.
    34  //
    35  // Tests can be updated by running them with the -update flag. This will cause
    36  // the test files to be regenerated. Generally one should combine the -update
    37  // flag with -test.run to updated a specific test. Since the reference
    38  // implementation will always generate fresh random numbers, large parts of
    39  // the reference connection will always change.
    40  
    41  var (
    42  	update = flag.Bool("update", false, "update golden files on disk")
    43  
    44  	opensslVersionTestOnce sync.Once
    45  	opensslVersionTestErr  error
    46  )
    47  
    48  func checkOpenSSLVersion(t *testing.T) {
    49  	opensslVersionTestOnce.Do(testOpenSSLVersion)
    50  	if opensslVersionTestErr != nil {
    51  		t.Fatal(opensslVersionTestErr)
    52  	}
    53  }
    54  
    55  func testOpenSSLVersion() {
    56  	// This test ensures that the version of OpenSSL looks reasonable
    57  	// before updating the test data.
    58  
    59  	if !*update {
    60  		return
    61  	}
    62  
    63  	openssl := exec.Command("openssl", "version")
    64  	output, err := openssl.CombinedOutput()
    65  	if err != nil {
    66  		opensslVersionTestErr = err
    67  		return
    68  	}
    69  
    70  	version := string(output)
    71  	if strings.HasPrefix(version, "OpenSSL 1.1.1") {
    72  		return
    73  	}
    74  
    75  	println("***********************************************")
    76  	println("")
    77  	println("You need to build OpenSSL 1.1.1 from source in order")
    78  	println("to update the test data.")
    79  	println("")
    80  	println("Configure it with:")
    81  	println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method")
    82  	println("and then add the apps/ directory at the front of your PATH.")
    83  	println("***********************************************")
    84  
    85  	opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data")
    86  }
    87  
    88  // recordingConn is a net.Conn that records the traffic that passes through it.
    89  // WriteTo can be used to produce output that can be later be loaded with
    90  // ParseTestData.
    91  type recordingConn struct {
    92  	net.Conn
    93  	sync.Mutex
    94  	flows   [][]byte
    95  	reading bool
    96  }
    97  
    98  func (r *recordingConn) Read(b []byte) (n int, err error) {
    99  	if n, err = r.Conn.Read(b); n == 0 {
   100  		return
   101  	}
   102  	b = b[:n]
   103  
   104  	r.Lock()
   105  	defer r.Unlock()
   106  
   107  	if l := len(r.flows); l == 0 || !r.reading {
   108  		buf := make([]byte, len(b))
   109  		copy(buf, b)
   110  		r.flows = append(r.flows, buf)
   111  	} else {
   112  		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
   113  	}
   114  	r.reading = true
   115  	return
   116  }
   117  
   118  func (r *recordingConn) Write(b []byte) (n int, err error) {
   119  	if n, err = r.Conn.Write(b); n == 0 {
   120  		return
   121  	}
   122  	b = b[:n]
   123  
   124  	r.Lock()
   125  	defer r.Unlock()
   126  
   127  	if l := len(r.flows); l == 0 || r.reading {
   128  		buf := make([]byte, len(b))
   129  		copy(buf, b)
   130  		r.flows = append(r.flows, buf)
   131  	} else {
   132  		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
   133  	}
   134  	r.reading = false
   135  	return
   136  }
   137  
   138  // WriteTo writes Go source code to w that contains the recorded traffic.
   139  func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
   140  	// TLS always starts with a client to server flow.
   141  	clientToServer := true
   142  	var written int64
   143  	for i, flow := range r.flows {
   144  		source, dest := "client", "server"
   145  		if !clientToServer {
   146  			source, dest = dest, source
   147  		}
   148  		n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
   149  		written += int64(n)
   150  		if err != nil {
   151  			return written, err
   152  		}
   153  		dumper := hex.Dumper(w)
   154  		n, err = dumper.Write(flow)
   155  		written += int64(n)
   156  		if err != nil {
   157  			return written, err
   158  		}
   159  		err = dumper.Close()
   160  		if err != nil {
   161  			return written, err
   162  		}
   163  		clientToServer = !clientToServer
   164  	}
   165  	return written, nil
   166  }
   167  
   168  func parseTestData(r io.Reader) (flows [][]byte, err error) {
   169  	var currentFlow []byte
   170  
   171  	scanner := bufio.NewScanner(r)
   172  	for scanner.Scan() {
   173  		line := scanner.Text()
   174  		// If the line starts with ">>> " then it marks the beginning
   175  		// of a new flow.
   176  		if strings.HasPrefix(line, ">>> ") {
   177  			if len(currentFlow) > 0 || len(flows) > 0 {
   178  				flows = append(flows, currentFlow)
   179  				currentFlow = nil
   180  			}
   181  			continue
   182  		}
   183  
   184  		// Otherwise the line is a line of hex dump that looks like:
   185  		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
   186  		// (Some bytes have been omitted from the middle section.)
   187  
   188  		if i := strings.IndexByte(line, ' '); i >= 0 {
   189  			line = line[i:]
   190  		} else {
   191  			return nil, errors.New("invalid test data")
   192  		}
   193  
   194  		if i := strings.IndexByte(line, '|'); i >= 0 {
   195  			line = line[:i]
   196  		} else {
   197  			return nil, errors.New("invalid test data")
   198  		}
   199  
   200  		hexBytes := strings.Fields(line)
   201  		for _, hexByte := range hexBytes {
   202  			val, err := strconv.ParseUint(hexByte, 16, 8)
   203  			if err != nil {
   204  				return nil, errors.New("invalid hex byte in test data: " + err.Error())
   205  			}
   206  			currentFlow = append(currentFlow, byte(val))
   207  		}
   208  	}
   209  
   210  	if len(currentFlow) > 0 {
   211  		flows = append(flows, currentFlow)
   212  	}
   213  
   214  	return flows, nil
   215  }
   216  
   217  // tempFile creates a temp file containing contents and returns its path.
   218  func tempFile(contents string) string {
   219  	file, err := ioutil.TempFile("", "go-tls-test")
   220  	if err != nil {
   221  		panic("failed to create temp file: " + err.Error())
   222  	}
   223  	path := file.Name()
   224  	file.WriteString(contents)
   225  	file.Close()
   226  	return path
   227  }
   228  
   229  // localListener is set up by TestMain and used by localPipe to create Conn
   230  // pairs like net.Pipe, but connected by an actual buffered TCP connection.
   231  var localListener struct {
   232  	sync.Mutex
   233  	net.Listener
   234  }
   235  
   236  func localPipe(t testing.TB) (net.Conn, net.Conn) {
   237  	localListener.Lock()
   238  	defer localListener.Unlock()
   239  	c := make(chan net.Conn)
   240  	go func() {
   241  		conn, err := localListener.Accept()
   242  		if err != nil {
   243  			t.Errorf("Failed to accept local connection: %v", err)
   244  		}
   245  		c <- conn
   246  	}()
   247  	addr := localListener.Addr()
   248  	c1, err := net.Dial(addr.Network(), addr.String())
   249  	if err != nil {
   250  		t.Fatalf("Failed to dial local connection: %v", err)
   251  	}
   252  	c2 := <-c
   253  	return c1, c2
   254  }
   255  
   256  func TestMain(m *testing.M) {
   257  	l, err := net.Listen("tcp", "127.0.0.1:0")
   258  	if err != nil {
   259  		l, err = net.Listen("tcp6", "[::1]:0")
   260  	}
   261  	if err != nil {
   262  		fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err)
   263  		os.Exit(1)
   264  	}
   265  	localListener.Listener = l
   266  	exitCode := m.Run()
   267  	localListener.Close()
   268  	os.Exit(exitCode)
   269  }