github.com/kdevb0x/go@v0.0.0-20180115030120-39687051e9e7/src/crypto/tls/handshake_messages_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "math/rand" 10 "reflect" 11 "strings" 12 "testing" 13 "testing/quick" 14 ) 15 16 var tests = []interface{}{ 17 &clientHelloMsg{}, 18 &serverHelloMsg{}, 19 &finishedMsg{}, 20 21 &certificateMsg{}, 22 &certificateRequestMsg{}, 23 &certificateVerifyMsg{}, 24 &certificateStatusMsg{}, 25 &clientKeyExchangeMsg{}, 26 &nextProtoMsg{}, 27 &newSessionTicketMsg{}, 28 &sessionState{}, 29 } 30 31 type testMessage interface { 32 marshal() []byte 33 unmarshal([]byte) bool 34 equal(interface{}) bool 35 } 36 37 func TestMarshalUnmarshal(t *testing.T) { 38 rand := rand.New(rand.NewSource(0)) 39 40 for i, iface := range tests { 41 ty := reflect.ValueOf(iface).Type() 42 43 n := 100 44 if testing.Short() { 45 n = 5 46 } 47 for j := 0; j < n; j++ { 48 v, ok := quick.Value(ty, rand) 49 if !ok { 50 t.Errorf("#%d: failed to create value", i) 51 break 52 } 53 54 m1 := v.Interface().(testMessage) 55 marshaled := m1.marshal() 56 m2 := iface.(testMessage) 57 if !m2.unmarshal(marshaled) { 58 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 59 break 60 } 61 m2.marshal() // to fill any marshal cache in the message 62 63 if !m1.equal(m2) { 64 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 65 break 66 } 67 68 if i >= 3 { 69 // The first three message types (ClientHello, 70 // ServerHello and Finished) are allowed to 71 // have parsable prefixes because the extension 72 // data is optional and the length of the 73 // Finished varies across versions. 74 for j := 0; j < len(marshaled); j++ { 75 if m2.unmarshal(marshaled[0:j]) { 76 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 77 break 78 } 79 } 80 } 81 } 82 } 83 } 84 85 func TestFuzz(t *testing.T) { 86 rand := rand.New(rand.NewSource(0)) 87 for _, iface := range tests { 88 m := iface.(testMessage) 89 90 for j := 0; j < 1000; j++ { 91 len := rand.Intn(100) 92 bytes := randomBytes(len, rand) 93 // This just looks for crashes due to bounds errors etc. 94 m.unmarshal(bytes) 95 } 96 } 97 } 98 99 func randomBytes(n int, rand *rand.Rand) []byte { 100 r := make([]byte, n) 101 if _, err := rand.Read(r); err != nil { 102 panic("rand.Read failed: " + err.Error()) 103 } 104 return r 105 } 106 107 func randomString(n int, rand *rand.Rand) string { 108 b := randomBytes(n, rand) 109 return string(b) 110 } 111 112 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 113 m := &clientHelloMsg{} 114 m.vers = uint16(rand.Intn(65536)) 115 m.random = randomBytes(32, rand) 116 m.sessionId = randomBytes(rand.Intn(32), rand) 117 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 118 for i := 0; i < len(m.cipherSuites); i++ { 119 cs := uint16(rand.Int31()) 120 if cs == scsvRenegotiation { 121 cs += 1 122 } 123 m.cipherSuites[i] = cs 124 } 125 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 126 if rand.Intn(10) > 5 { 127 m.nextProtoNeg = true 128 } 129 if rand.Intn(10) > 5 { 130 m.serverName = randomString(rand.Intn(255), rand) 131 for strings.HasSuffix(m.serverName, ".") { 132 m.serverName = m.serverName[:len(m.serverName)-1] 133 } 134 } 135 m.ocspStapling = rand.Intn(10) > 5 136 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 137 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 138 for i := range m.supportedCurves { 139 m.supportedCurves[i] = CurveID(rand.Intn(30000)) 140 } 141 if rand.Intn(10) > 5 { 142 m.ticketSupported = true 143 if rand.Intn(10) > 5 { 144 m.sessionTicket = randomBytes(rand.Intn(300), rand) 145 } 146 } 147 if rand.Intn(10) > 5 { 148 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 149 } 150 m.alpnProtocols = make([]string, rand.Intn(5)) 151 for i := range m.alpnProtocols { 152 m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand) 153 } 154 if rand.Intn(10) > 5 { 155 m.scts = true 156 } 157 158 return reflect.ValueOf(m) 159 } 160 161 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 162 m := &serverHelloMsg{} 163 m.vers = uint16(rand.Intn(65536)) 164 m.random = randomBytes(32, rand) 165 m.sessionId = randomBytes(rand.Intn(32), rand) 166 m.cipherSuite = uint16(rand.Int31()) 167 m.compressionMethod = uint8(rand.Intn(256)) 168 169 if rand.Intn(10) > 5 { 170 m.nextProtoNeg = true 171 172 n := rand.Intn(10) 173 m.nextProtos = make([]string, n) 174 for i := 0; i < n; i++ { 175 m.nextProtos[i] = randomString(20, rand) 176 } 177 } 178 179 if rand.Intn(10) > 5 { 180 m.ocspStapling = true 181 } 182 if rand.Intn(10) > 5 { 183 m.ticketSupported = true 184 } 185 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 186 187 if rand.Intn(10) > 5 { 188 numSCTs := rand.Intn(4) 189 m.scts = make([][]byte, numSCTs) 190 for i := range m.scts { 191 m.scts[i] = randomBytes(rand.Intn(500), rand) 192 } 193 } 194 195 return reflect.ValueOf(m) 196 } 197 198 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 199 m := &certificateMsg{} 200 numCerts := rand.Intn(20) 201 m.certificates = make([][]byte, numCerts) 202 for i := 0; i < numCerts; i++ { 203 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 204 } 205 return reflect.ValueOf(m) 206 } 207 208 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 209 m := &certificateRequestMsg{} 210 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 211 numCAs := rand.Intn(100) 212 m.certificateAuthorities = make([][]byte, numCAs) 213 for i := 0; i < numCAs; i++ { 214 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) 215 } 216 return reflect.ValueOf(m) 217 } 218 219 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 220 m := &certificateVerifyMsg{} 221 m.signature = randomBytes(rand.Intn(15)+1, rand) 222 return reflect.ValueOf(m) 223 } 224 225 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 226 m := &certificateStatusMsg{} 227 if rand.Intn(10) > 5 { 228 m.statusType = statusTypeOCSP 229 m.response = randomBytes(rand.Intn(10)+1, rand) 230 } else { 231 m.statusType = 42 232 } 233 return reflect.ValueOf(m) 234 } 235 236 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 237 m := &clientKeyExchangeMsg{} 238 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 239 return reflect.ValueOf(m) 240 } 241 242 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 243 m := &finishedMsg{} 244 m.verifyData = randomBytes(12, rand) 245 return reflect.ValueOf(m) 246 } 247 248 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { 249 m := &nextProtoMsg{} 250 m.proto = randomString(rand.Intn(255), rand) 251 return reflect.ValueOf(m) 252 } 253 254 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 255 m := &newSessionTicketMsg{} 256 m.ticket = randomBytes(rand.Intn(4), rand) 257 return reflect.ValueOf(m) 258 } 259 260 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 261 s := &sessionState{} 262 s.vers = uint16(rand.Intn(10000)) 263 s.cipherSuite = uint16(rand.Intn(10000)) 264 s.masterSecret = randomBytes(rand.Intn(100), rand) 265 numCerts := rand.Intn(20) 266 s.certificates = make([][]byte, numCerts) 267 for i := 0; i < numCerts; i++ { 268 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 269 } 270 return reflect.ValueOf(s) 271 } 272 273 func TestRejectEmptySCTList(t *testing.T) { 274 // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that 275 // empty SCT lists are invalid. 276 277 var random [32]byte 278 sct := []byte{0x42, 0x42, 0x42, 0x42} 279 serverHello := serverHelloMsg{ 280 vers: VersionTLS12, 281 random: random[:], 282 scts: [][]byte{sct}, 283 } 284 serverHelloBytes := serverHello.marshal() 285 286 var serverHelloCopy serverHelloMsg 287 if !serverHelloCopy.unmarshal(serverHelloBytes) { 288 t.Fatal("Failed to unmarshal initial message") 289 } 290 291 // Change serverHelloBytes so that the SCT list is empty 292 i := bytes.Index(serverHelloBytes, sct) 293 if i < 0 { 294 t.Fatal("Cannot find SCT in ServerHello") 295 } 296 297 var serverHelloEmptySCT []byte 298 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 299 // Append the extension length and SCT list length for an empty list. 300 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 301 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 302 303 // Update the handshake message length. 304 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 305 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 306 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 307 308 // Update the extensions length 309 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 310 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 311 312 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 313 t.Fatal("Unmarshaled ServerHello with empty SCT list") 314 } 315 } 316 317 func TestRejectEmptySCT(t *testing.T) { 318 // Not only must the SCT list be non-empty, but the SCT elements must 319 // not be zero length. 320 321 var random [32]byte 322 serverHello := serverHelloMsg{ 323 vers: VersionTLS12, 324 random: random[:], 325 scts: [][]byte{nil}, 326 } 327 serverHelloBytes := serverHello.marshal() 328 329 var serverHelloCopy serverHelloMsg 330 if serverHelloCopy.unmarshal(serverHelloBytes) { 331 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 332 } 333 }