github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/handshake_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bufio" 9 "encoding/hex" 10 "errors" 11 "flag" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "net" 16 "os/exec" 17 "strconv" 18 "strings" 19 "sync" 20 "testing" 21 ) 22 23 // TLS reference tests run a connection against a reference implementation 24 // (OpenSSL) of TLS and record the bytes of the resulting connection. The Go 25 // code, during a test, is configured with deterministic randomness and so the 26 // reference test can be reproduced exactly in the future. 27 // 28 // In order to save everyone who wishes to run the tests from needing the 29 // reference implementation installed, the reference connections are saved in 30 // files in the testdata directory. Thus running the tests involves nothing 31 // external, but creating and updating them requires the reference 32 // implementation. 33 // 34 // Tests can be updated by running them with the -update flag. This will cause 35 // the test files to be regenerated. Generally one should combine the -update 36 // flag with -test.run to updated a specific test. Since the reference 37 // implementation will always generate fresh random numbers, large parts of 38 // the reference connection will always change. 39 40 var ( 41 update = flag.Bool("update", false, "update golden files on disk") 42 43 opensslVersionTestOnce sync.Once 44 opensslVersionTestErr error 45 ) 46 47 func checkOpenSSLVersion(t *testing.T) { 48 opensslVersionTestOnce.Do(testOpenSSLVersion) 49 if opensslVersionTestErr != nil { 50 t.Fatal(opensslVersionTestErr) 51 } 52 } 53 54 func testOpenSSLVersion() { 55 // This test ensures that the version of OpenSSL looks reasonable 56 // before updating the test data. 57 58 if !*update { 59 return 60 } 61 62 openssl := exec.Command("openssl", "version") 63 output, err := openssl.CombinedOutput() 64 if err != nil { 65 opensslVersionTestErr = err 66 return 67 } 68 69 version := string(output) 70 if strings.HasPrefix(version, "OpenSSL 1.1.0") { 71 return 72 } 73 74 println("***********************************************") 75 println("") 76 println("You need to build OpenSSL 1.1.0 from source in order") 77 println("to update the test data.") 78 println("") 79 println("Configure it with:") 80 println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method -static linux-x86_64") 81 println("and then add the apps/ directory at the front of your PATH.") 82 println("***********************************************") 83 84 opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data") 85 } 86 87 // recordingConn is a net.Conn that records the traffic that passes through it. 88 // WriteTo can be used to produce output that can be later be loaded with 89 // ParseTestData. 90 type recordingConn struct { 91 net.Conn 92 sync.Mutex 93 flows [][]byte 94 reading bool 95 } 96 97 func (r *recordingConn) Read(b []byte) (n int, err error) { 98 if n, err = r.Conn.Read(b); n == 0 { 99 return 100 } 101 b = b[:n] 102 103 r.Lock() 104 defer r.Unlock() 105 106 if l := len(r.flows); l == 0 || !r.reading { 107 buf := make([]byte, len(b)) 108 copy(buf, b) 109 r.flows = append(r.flows, buf) 110 } else { 111 r.flows[l-1] = append(r.flows[l-1], b[:n]...) 112 } 113 r.reading = true 114 return 115 } 116 117 func (r *recordingConn) Write(b []byte) (n int, err error) { 118 if n, err = r.Conn.Write(b); n == 0 { 119 return 120 } 121 b = b[:n] 122 123 r.Lock() 124 defer r.Unlock() 125 126 if l := len(r.flows); l == 0 || r.reading { 127 buf := make([]byte, len(b)) 128 copy(buf, b) 129 r.flows = append(r.flows, buf) 130 } else { 131 r.flows[l-1] = append(r.flows[l-1], b[:n]...) 132 } 133 r.reading = false 134 return 135 } 136 137 // WriteTo writes Go source code to w that contains the recorded traffic. 138 func (r *recordingConn) WriteTo(w io.Writer) (int64, error) { 139 // TLS always starts with a client to server flow. 140 clientToServer := true 141 var written int64 142 for i, flow := range r.flows { 143 source, dest := "client", "server" 144 if !clientToServer { 145 source, dest = dest, source 146 } 147 n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest) 148 written += int64(n) 149 if err != nil { 150 return written, err 151 } 152 dumper := hex.Dumper(w) 153 n, err = dumper.Write(flow) 154 written += int64(n) 155 if err != nil { 156 return written, err 157 } 158 err = dumper.Close() 159 if err != nil { 160 return written, err 161 } 162 clientToServer = !clientToServer 163 } 164 return written, nil 165 } 166 167 func parseTestData(r io.Reader) (flows [][]byte, err error) { 168 var currentFlow []byte 169 170 scanner := bufio.NewScanner(r) 171 for scanner.Scan() { 172 line := scanner.Text() 173 // If the line starts with ">>> " then it marks the beginning 174 // of a new flow. 175 if strings.HasPrefix(line, ">>> ") { 176 if len(currentFlow) > 0 || len(flows) > 0 { 177 flows = append(flows, currentFlow) 178 currentFlow = nil 179 } 180 continue 181 } 182 183 // Otherwise the line is a line of hex dump that looks like: 184 // 00000170 fc f5 06 bf (...) |.....X{&?......!| 185 // (Some bytes have been omitted from the middle section.) 186 187 if i := strings.IndexByte(line, ' '); i >= 0 { 188 line = line[i:] 189 } else { 190 return nil, errors.New("invalid test data") 191 } 192 193 if i := strings.IndexByte(line, '|'); i >= 0 { 194 line = line[:i] 195 } else { 196 return nil, errors.New("invalid test data") 197 } 198 199 hexBytes := strings.Fields(line) 200 for _, hexByte := range hexBytes { 201 val, err := strconv.ParseUint(hexByte, 16, 8) 202 if err != nil { 203 return nil, errors.New("invalid hex byte in test data: " + err.Error()) 204 } 205 currentFlow = append(currentFlow, byte(val)) 206 } 207 } 208 209 if len(currentFlow) > 0 { 210 flows = append(flows, currentFlow) 211 } 212 213 return flows, nil 214 } 215 216 // tempFile creates a temp file containing contents and returns its path. 217 func tempFile(contents string) string { 218 file, err := ioutil.TempFile("", "go-tls-test") 219 if err != nil { 220 panic("failed to create temp file: " + err.Error()) 221 } 222 path := file.Name() 223 file.WriteString(contents) 224 file.Close() 225 return path 226 }