github.com/euank/go@v0.0.0-20160829210321-495514729181/src/crypto/tls/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bufio"
     9  	"encoding/hex"
    10  	"errors"
    11  	"flag"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  )
    20  
    21  // TLS reference tests run a connection against a reference implementation
    22  // (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
    23  // code, during a test, is configured with deterministic randomness and so the
    24  // reference test can be reproduced exactly in the future.
    25  //
    26  // In order to save everyone who wishes to run the tests from needing the
    27  // reference implementation installed, the reference connections are saved in
    28  // files in the testdata directory. Thus running the tests involves nothing
    29  // external, but creating and updating them requires the reference
    30  // implementation.
    31  //
    32  // Tests can be updated by running them with the -update flag. This will cause
    33  // the test files. Generally one should combine the -update flag with -test.run
    34  // to updated a specific test. Since the reference implementation will always
    35  // generate fresh random numbers, large parts of the reference connection will
    36  // always change.
    37  
    38  var update = flag.Bool("update", false, "update golden files on disk")
    39  
    40  // recordingConn is a net.Conn that records the traffic that passes through it.
    41  // WriteTo can be used to produce output that can be later be loaded with
    42  // ParseTestData.
    43  type recordingConn struct {
    44  	net.Conn
    45  	sync.Mutex
    46  	flows   [][]byte
    47  	reading bool
    48  }
    49  
    50  func (r *recordingConn) Read(b []byte) (n int, err error) {
    51  	if n, err = r.Conn.Read(b); n == 0 {
    52  		return
    53  	}
    54  	b = b[:n]
    55  
    56  	r.Lock()
    57  	defer r.Unlock()
    58  
    59  	if l := len(r.flows); l == 0 || !r.reading {
    60  		buf := make([]byte, len(b))
    61  		copy(buf, b)
    62  		r.flows = append(r.flows, buf)
    63  	} else {
    64  		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
    65  	}
    66  	r.reading = true
    67  	return
    68  }
    69  
    70  func (r *recordingConn) Write(b []byte) (n int, err error) {
    71  	if n, err = r.Conn.Write(b); n == 0 {
    72  		return
    73  	}
    74  	b = b[:n]
    75  
    76  	r.Lock()
    77  	defer r.Unlock()
    78  
    79  	if l := len(r.flows); l == 0 || r.reading {
    80  		buf := make([]byte, len(b))
    81  		copy(buf, b)
    82  		r.flows = append(r.flows, buf)
    83  	} else {
    84  		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
    85  	}
    86  	r.reading = false
    87  	return
    88  }
    89  
    90  // WriteTo writes Go source code to w that contains the recorded traffic.
    91  func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
    92  	// TLS always starts with a client to server flow.
    93  	clientToServer := true
    94  	var written int64
    95  	for i, flow := range r.flows {
    96  		source, dest := "client", "server"
    97  		if !clientToServer {
    98  			source, dest = dest, source
    99  		}
   100  		n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
   101  		written += int64(n)
   102  		if err != nil {
   103  			return written, err
   104  		}
   105  		dumper := hex.Dumper(w)
   106  		n, err = dumper.Write(flow)
   107  		written += int64(n)
   108  		if err != nil {
   109  			return written, err
   110  		}
   111  		err = dumper.Close()
   112  		if err != nil {
   113  			return written, err
   114  		}
   115  		clientToServer = !clientToServer
   116  	}
   117  	return written, nil
   118  }
   119  
   120  func parseTestData(r io.Reader) (flows [][]byte, err error) {
   121  	var currentFlow []byte
   122  
   123  	scanner := bufio.NewScanner(r)
   124  	for scanner.Scan() {
   125  		line := scanner.Text()
   126  		// If the line starts with ">>> " then it marks the beginning
   127  		// of a new flow.
   128  		if strings.HasPrefix(line, ">>> ") {
   129  			if len(currentFlow) > 0 || len(flows) > 0 {
   130  				flows = append(flows, currentFlow)
   131  				currentFlow = nil
   132  			}
   133  			continue
   134  		}
   135  
   136  		// Otherwise the line is a line of hex dump that looks like:
   137  		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
   138  		// (Some bytes have been omitted from the middle section.)
   139  
   140  		if i := strings.IndexByte(line, ' '); i >= 0 {
   141  			line = line[i:]
   142  		} else {
   143  			return nil, errors.New("invalid test data")
   144  		}
   145  
   146  		if i := strings.IndexByte(line, '|'); i >= 0 {
   147  			line = line[:i]
   148  		} else {
   149  			return nil, errors.New("invalid test data")
   150  		}
   151  
   152  		hexBytes := strings.Fields(line)
   153  		for _, hexByte := range hexBytes {
   154  			val, err := strconv.ParseUint(hexByte, 16, 8)
   155  			if err != nil {
   156  				return nil, errors.New("invalid hex byte in test data: " + err.Error())
   157  			}
   158  			currentFlow = append(currentFlow, byte(val))
   159  		}
   160  	}
   161  
   162  	if len(currentFlow) > 0 {
   163  		flows = append(flows, currentFlow)
   164  	}
   165  
   166  	return flows, nil
   167  }
   168  
   169  // tempFile creates a temp file containing contents and returns its path.
   170  func tempFile(contents string) string {
   171  	file, err := ioutil.TempFile("", "go-tls-test")
   172  	if err != nil {
   173  		panic("failed to create temp file: " + err.Error())
   174  	}
   175  	path := file.Name()
   176  	file.WriteString(contents)
   177  	file.Close()
   178  	return path
   179  }