github.com/icodeface/tls@v0.0.0-20230910023335-34df9250cd12/handshake_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bufio" 9 "encoding/hex" 10 "errors" 11 "flag" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "net" 16 "os" 17 "os/exec" 18 "strconv" 19 "strings" 20 "sync" 21 "testing" 22 ) 23 24 // TLS reference tests run a connection against a reference implementation 25 // (OpenSSL) of TLS and record the bytes of the resulting connection. The Go 26 // code, during a test, is configured with deterministic randomness and so the 27 // reference test can be reproduced exactly in the future. 28 // 29 // In order to save everyone who wishes to run the tests from needing the 30 // reference implementation installed, the reference connections are saved in 31 // files in the testdata directory. Thus running the tests involves nothing 32 // external, but creating and updating them requires the reference 33 // implementation. 34 // 35 // Tests can be updated by running them with the -update flag. This will cause 36 // the test files to be regenerated. Generally one should combine the -update 37 // flag with -test.run to updated a specific test. Since the reference 38 // implementation will always generate fresh random numbers, large parts of 39 // the reference connection will always change. 40 41 var ( 42 update = flag.Bool("update", false, "update golden files on disk") 43 44 opensslVersionTestOnce sync.Once 45 opensslVersionTestErr error 46 ) 47 48 func checkOpenSSLVersion(t *testing.T) { 49 opensslVersionTestOnce.Do(testOpenSSLVersion) 50 if opensslVersionTestErr != nil { 51 t.Fatal(opensslVersionTestErr) 52 } 53 } 54 55 func testOpenSSLVersion() { 56 // This test ensures that the version of OpenSSL looks reasonable 57 // before updating the test data. 58 59 if !*update { 60 return 61 } 62 63 openssl := exec.Command("openssl", "version") 64 output, err := openssl.CombinedOutput() 65 if err != nil { 66 opensslVersionTestErr = err 67 return 68 } 69 70 version := string(output) 71 if strings.HasPrefix(version, "OpenSSL 1.1.1") { 72 return 73 } 74 75 println("***********************************************") 76 println("") 77 println("You need to build OpenSSL 1.1.1 from source in order") 78 println("to update the test data.") 79 println("") 80 println("Configure it with:") 81 println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method") 82 println("and then add the apps/ directory at the front of your PATH.") 83 println("***********************************************") 84 85 opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data") 86 } 87 88 // recordingConn is a net.Conn that records the traffic that passes through it. 89 // WriteTo can be used to produce output that can be later be loaded with 90 // ParseTestData. 91 type recordingConn struct { 92 net.Conn 93 sync.Mutex 94 flows [][]byte 95 reading bool 96 } 97 98 func (r *recordingConn) Read(b []byte) (n int, err error) { 99 if n, err = r.Conn.Read(b); n == 0 { 100 return 101 } 102 b = b[:n] 103 104 r.Lock() 105 defer r.Unlock() 106 107 if l := len(r.flows); l == 0 || !r.reading { 108 buf := make([]byte, len(b)) 109 copy(buf, b) 110 r.flows = append(r.flows, buf) 111 } else { 112 r.flows[l-1] = append(r.flows[l-1], b[:n]...) 113 } 114 r.reading = true 115 return 116 } 117 118 func (r *recordingConn) Write(b []byte) (n int, err error) { 119 if n, err = r.Conn.Write(b); n == 0 { 120 return 121 } 122 b = b[:n] 123 124 r.Lock() 125 defer r.Unlock() 126 127 if l := len(r.flows); l == 0 || r.reading { 128 buf := make([]byte, len(b)) 129 copy(buf, b) 130 r.flows = append(r.flows, buf) 131 } else { 132 r.flows[l-1] = append(r.flows[l-1], b[:n]...) 133 } 134 r.reading = false 135 return 136 } 137 138 // WriteTo writes Go source code to w that contains the recorded traffic. 139 func (r *recordingConn) WriteTo(w io.Writer) (int64, error) { 140 // TLS always starts with a client to server flow. 141 clientToServer := true 142 var written int64 143 for i, flow := range r.flows { 144 source, dest := "client", "server" 145 if !clientToServer { 146 source, dest = dest, source 147 } 148 n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest) 149 written += int64(n) 150 if err != nil { 151 return written, err 152 } 153 dumper := hex.Dumper(w) 154 n, err = dumper.Write(flow) 155 written += int64(n) 156 if err != nil { 157 return written, err 158 } 159 err = dumper.Close() 160 if err != nil { 161 return written, err 162 } 163 clientToServer = !clientToServer 164 } 165 return written, nil 166 } 167 168 func parseTestData(r io.Reader) (flows [][]byte, err error) { 169 var currentFlow []byte 170 171 scanner := bufio.NewScanner(r) 172 for scanner.Scan() { 173 line := scanner.Text() 174 // If the line starts with ">>> " then it marks the beginning 175 // of a new flow. 176 if strings.HasPrefix(line, ">>> ") { 177 if len(currentFlow) > 0 || len(flows) > 0 { 178 flows = append(flows, currentFlow) 179 currentFlow = nil 180 } 181 continue 182 } 183 184 // Otherwise the line is a line of hex dump that looks like: 185 // 00000170 fc f5 06 bf (...) |.....X{&?......!| 186 // (Some bytes have been omitted from the middle section.) 187 188 if i := strings.IndexByte(line, ' '); i >= 0 { 189 line = line[i:] 190 } else { 191 return nil, errors.New("invalid test data") 192 } 193 194 if i := strings.IndexByte(line, '|'); i >= 0 { 195 line = line[:i] 196 } else { 197 return nil, errors.New("invalid test data") 198 } 199 200 hexBytes := strings.Fields(line) 201 for _, hexByte := range hexBytes { 202 val, err := strconv.ParseUint(hexByte, 16, 8) 203 if err != nil { 204 return nil, errors.New("invalid hex byte in test data: " + err.Error()) 205 } 206 currentFlow = append(currentFlow, byte(val)) 207 } 208 } 209 210 if len(currentFlow) > 0 { 211 flows = append(flows, currentFlow) 212 } 213 214 return flows, nil 215 } 216 217 // tempFile creates a temp file containing contents and returns its path. 218 func tempFile(contents string) string { 219 file, err := ioutil.TempFile("", "go-tls-test") 220 if err != nil { 221 panic("failed to create temp file: " + err.Error()) 222 } 223 path := file.Name() 224 file.WriteString(contents) 225 file.Close() 226 return path 227 } 228 229 // localListener is set up by TestMain and used by localPipe to create Conn 230 // pairs like net.Pipe, but connected by an actual buffered TCP connection. 231 var localListener struct { 232 sync.Mutex 233 net.Listener 234 } 235 236 func localPipe(t testing.TB) (net.Conn, net.Conn) { 237 localListener.Lock() 238 defer localListener.Unlock() 239 c := make(chan net.Conn) 240 go func() { 241 conn, err := localListener.Accept() 242 if err != nil { 243 t.Errorf("Failed to accept local connection: %v", err) 244 } 245 c <- conn 246 }() 247 addr := localListener.Addr() 248 c1, err := net.Dial(addr.Network(), addr.String()) 249 if err != nil { 250 t.Fatalf("Failed to dial local connection: %v", err) 251 } 252 c2 := <-c 253 return c1, c2 254 } 255 256 func TestMain(m *testing.M) { 257 l, err := net.Listen("tcp", "127.0.0.1:0") 258 if err != nil { 259 l, err = net.Listen("tcp6", "[::1]:0") 260 } 261 if err != nil { 262 fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err) 263 os.Exit(1) 264 } 265 localListener.Listener = l 266 exitCode := m.Run() 267 localListener.Close() 268 os.Exit(exitCode) 269 }