github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/handshake_messages_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "math/rand" 10 "reflect" 11 "strings" 12 "testing" 13 "testing/quick" 14 ) 15 16 var tests = []interface{}{ 17 &clientHelloMsg{}, 18 &serverHelloMsg{}, 19 &finishedMsg{}, 20 21 &certificateMsg{}, 22 &certificateRequestMsg{}, 23 &certificateRequestMsg13{}, 24 &certificateVerifyMsg{}, 25 &certificateStatusMsg{}, 26 &clientKeyExchangeMsg{}, 27 &nextProtoMsg{}, 28 &newSessionTicketMsg{}, 29 &sessionState{}, 30 &encryptedExtensionsMsg{}, 31 &certificateMsg13{}, 32 &newSessionTicketMsg13{}, 33 &sessionState13{}, 34 } 35 36 type testMessage interface { 37 marshal() []byte 38 unmarshal([]byte) alert 39 equal(interface{}) bool 40 } 41 42 func TestMarshalUnmarshal(t *testing.T) { 43 rand := rand.New(rand.NewSource(0)) 44 45 for i, iface := range tests { 46 ty := reflect.ValueOf(iface).Type() 47 48 n := 100 49 if testing.Short() { 50 n = 5 51 } 52 for j := 0; j < n; j++ { 53 v, ok := quick.Value(ty, rand) 54 if !ok { 55 t.Errorf("#%d: failed to create value", i) 56 break 57 } 58 59 m1 := v.Interface().(testMessage) 60 marshaled := m1.marshal() 61 m2 := iface.(testMessage) 62 if m2.unmarshal(marshaled) != alertSuccess { 63 t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled) 64 break 65 } 66 m2.marshal() // to fill any marshal cache in the message 67 68 if !m1.equal(m2) { 69 t.Errorf("#%d.%d got:%#v want:%#v %x", i, j, m2, m1, marshaled) 70 break 71 } 72 73 if i >= 3 { 74 // The first three message types (ClientHello, 75 // ServerHello and Finished) are allowed to 76 // have parsable prefixes because the extension 77 // data is optional and the length of the 78 // Finished varies across versions. 79 for j := 0; j < len(marshaled); j++ { 80 if m2.unmarshal(marshaled[0:j]) == alertSuccess { 81 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 82 break 83 } 84 } 85 } 86 } 87 } 88 } 89 90 func TestFuzz(t *testing.T) { 91 rand := rand.New(rand.NewSource(0)) 92 for _, iface := range tests { 93 m := iface.(testMessage) 94 95 for j := 0; j < 1000; j++ { 96 len := rand.Intn(100) 97 bytes := randomBytes(len, rand) 98 // This just looks for crashes due to bounds errors etc. 99 m.unmarshal(bytes) 100 } 101 } 102 } 103 104 func randomBytes(n int, rand *rand.Rand) []byte { 105 r := make([]byte, n) 106 if _, err := rand.Read(r); err != nil { 107 panic("rand.Read failed: " + err.Error()) 108 } 109 return r 110 } 111 112 func randomString(n int, rand *rand.Rand) string { 113 b := randomBytes(n, rand) 114 return string(b) 115 } 116 117 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 118 m := &clientHelloMsg{} 119 m.vers = uint16(rand.Intn(65536)) 120 m.random = randomBytes(32, rand) 121 m.sessionId = randomBytes(rand.Intn(32), rand) 122 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 123 for i := 0; i < len(m.cipherSuites); i++ { 124 cs := uint16(rand.Int31()) 125 if cs == scsvRenegotiation { 126 cs += 1 127 } 128 m.cipherSuites[i] = cs 129 } 130 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 131 if rand.Intn(10) > 5 { 132 m.nextProtoNeg = true 133 } 134 if rand.Intn(10) > 5 { 135 m.serverName = randomString(rand.Intn(255), rand) 136 for strings.HasSuffix(m.serverName, ".") { 137 m.serverName = m.serverName[:len(m.serverName)-1] 138 } 139 } 140 m.ocspStapling = rand.Intn(10) > 5 141 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 142 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 143 for i := range m.supportedCurves { 144 m.supportedCurves[i] = CurveID(rand.Intn(30000)) 145 } 146 if rand.Intn(10) > 5 { 147 m.ticketSupported = true 148 if rand.Intn(10) > 5 { 149 m.sessionTicket = randomBytes(rand.Intn(300), rand) 150 } 151 } 152 if rand.Intn(10) > 5 { 153 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 154 } 155 m.alpnProtocols = make([]string, rand.Intn(5)) 156 for i := range m.alpnProtocols { 157 m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand) 158 } 159 if rand.Intn(10) > 5 { 160 m.scts = true 161 } 162 m.keyShares = make([]keyShare, rand.Intn(4)) 163 for i := range m.keyShares { 164 m.keyShares[i].group = CurveID(rand.Intn(30000)) 165 m.keyShares[i].data = randomBytes(rand.Intn(300)+1, rand) 166 } 167 m.supportedVersions = make([]uint16, rand.Intn(5)) 168 for i := range m.supportedVersions { 169 m.supportedVersions[i] = uint16(rand.Intn(30000)) 170 } 171 if rand.Intn(10) > 5 { 172 m.earlyData = true 173 } 174 if rand.Intn(10) > 5 { 175 m.delegatedCredential = true 176 } 177 if rand.Intn(10) > 5 { 178 m.extendedMSSupported = true 179 } 180 return reflect.ValueOf(m) 181 } 182 183 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 184 m := &serverHelloMsg{} 185 m.vers = uint16(rand.Intn(65536)) 186 m.random = randomBytes(32, rand) 187 m.sessionId = randomBytes(rand.Intn(32), rand) 188 m.cipherSuite = uint16(rand.Int31()) 189 m.compressionMethod = uint8(rand.Intn(256)) 190 191 if rand.Intn(10) > 5 { 192 m.nextProtoNeg = true 193 194 n := rand.Intn(10) 195 m.nextProtos = make([]string, n) 196 for i := 0; i < n; i++ { 197 m.nextProtos[i] = randomString(20, rand) 198 } 199 } 200 201 if rand.Intn(10) > 5 { 202 m.ocspStapling = true 203 } 204 if rand.Intn(10) > 5 { 205 m.ticketSupported = true 206 } 207 if rand.Intn(10) > 5 { 208 m.extendedMSSupported = true 209 } 210 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 211 212 if rand.Intn(10) > 5 { 213 numSCTs := rand.Intn(4) 214 m.scts = make([][]byte, numSCTs) 215 for i := range m.scts { 216 m.scts[i] = randomBytes(rand.Intn(500)+1, rand) 217 } 218 } 219 220 if rand.Intn(10) > 5 { 221 m.keyShare.group = CurveID(rand.Intn(30000) + 1) 222 m.keyShare.data = randomBytes(rand.Intn(300)+1, rand) 223 } 224 if rand.Intn(10) > 5 { 225 m.psk = true 226 m.pskIdentity = uint16(rand.Int31()) 227 } 228 229 return reflect.ValueOf(m) 230 } 231 232 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 233 m := &encryptedExtensionsMsg{} 234 if rand.Intn(10) > 5 { 235 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 236 } 237 if rand.Intn(10) > 5 { 238 m.earlyData = true 239 } 240 241 return reflect.ValueOf(m) 242 } 243 244 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 245 m := &certificateMsg{} 246 numCerts := rand.Intn(20) 247 m.certificates = make([][]byte, numCerts) 248 for i := 0; i < numCerts; i++ { 249 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 250 } 251 return reflect.ValueOf(m) 252 } 253 254 func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value { 255 m := &certificateMsg13{} 256 numCerts := rand.Intn(20) 257 m.certificates = make([]certificateEntry, numCerts) 258 for i := 0; i < numCerts; i++ { 259 m.certificates[i].data = randomBytes(rand.Intn(10)+1, rand) 260 if rand.Intn(2) == 1 { 261 m.certificates[i].ocspStaple = randomBytes(rand.Intn(10)+1, rand) 262 } 263 264 numScts := rand.Intn(3) 265 for j := 0; j < numScts; j++ { 266 m.certificates[i].sctList = append(m.certificates[i].sctList, randomBytes(rand.Intn(10)+1, rand)) 267 } 268 } 269 m.requestContext = randomBytes(rand.Intn(5), rand) 270 return reflect.ValueOf(m) 271 } 272 273 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 274 m := &certificateRequestMsg{} 275 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 276 numCAs := rand.Intn(100) 277 m.certificateAuthorities = make([][]byte, numCAs) 278 for i := 0; i < numCAs; i++ { 279 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) 280 } 281 return reflect.ValueOf(m) 282 } 283 284 func (*certificateRequestMsg13) Generate(rand *rand.Rand, size int) reflect.Value { 285 m := &certificateRequestMsg13{} 286 m.requestContext = randomBytes(rand.Intn(5), rand) 287 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 288 numCAs := rand.Intn(100) 289 m.certificateAuthorities = make([][]byte, numCAs) 290 for i := 0; i < numCAs; i++ { 291 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) 292 } 293 return reflect.ValueOf(m) 294 } 295 296 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 297 m := &certificateVerifyMsg{} 298 m.signature = randomBytes(rand.Intn(15)+1, rand) 299 return reflect.ValueOf(m) 300 } 301 302 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 303 m := &certificateStatusMsg{} 304 if rand.Intn(10) > 5 { 305 m.statusType = statusTypeOCSP 306 m.response = randomBytes(rand.Intn(10)+1, rand) 307 } else { 308 m.statusType = 42 309 } 310 return reflect.ValueOf(m) 311 } 312 313 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 314 m := &clientKeyExchangeMsg{} 315 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 316 return reflect.ValueOf(m) 317 } 318 319 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 320 m := &finishedMsg{} 321 m.verifyData = randomBytes(12, rand) 322 return reflect.ValueOf(m) 323 } 324 325 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { 326 m := &nextProtoMsg{} 327 m.proto = randomString(rand.Intn(255), rand) 328 return reflect.ValueOf(m) 329 } 330 331 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 332 m := &newSessionTicketMsg{} 333 m.ticket = randomBytes(rand.Intn(4), rand) 334 return reflect.ValueOf(m) 335 } 336 337 func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value { 338 m := &newSessionTicketMsg13{} 339 m.ageAdd = uint32(rand.Intn(0xffffffff)) 340 m.lifetime = uint32(rand.Intn(0xffffffff)) 341 m.nonce = randomBytes(1+rand.Intn(255), rand) 342 m.ticket = randomBytes(1+rand.Intn(40), rand) 343 if rand.Intn(10) > 5 { 344 m.withEarlyDataInfo = true 345 m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff)) 346 } 347 return reflect.ValueOf(m) 348 } 349 350 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 351 s := &sessionState{} 352 s.vers = uint16(rand.Intn(10000)) 353 s.cipherSuite = uint16(rand.Intn(10000)) 354 s.masterSecret = randomBytes(rand.Intn(100), rand) 355 if rand.Intn(10) > 5 { 356 s.usedEMS = true 357 } 358 numCerts := rand.Intn(20) 359 s.certificates = make([][]byte, numCerts) 360 for i := 0; i < numCerts; i++ { 361 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 362 } 363 return reflect.ValueOf(s) 364 } 365 366 func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { 367 s := &sessionState13{} 368 s.vers = uint16(rand.Intn(10000)) 369 s.suite = uint16(rand.Intn(10000)) 370 s.ageAdd = uint32(rand.Intn(0xffffffff)) 371 s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff)) 372 s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) 373 s.pskSecret = randomBytes(rand.Intn(100), rand) 374 s.alpnProtocol = randomString(rand.Intn(100), rand) 375 s.SNI = randomString(rand.Intn(100), rand) 376 return reflect.ValueOf(s) 377 } 378 379 func TestRejectEmptySCTList(t *testing.T) { 380 // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that 381 // empty SCT lists are invalid. 382 383 var random [32]byte 384 sct := []byte{0x42, 0x42, 0x42, 0x42} 385 serverHello := serverHelloMsg{ 386 vers: VersionTLS12, 387 random: random[:], 388 scts: [][]byte{sct}, 389 } 390 serverHelloBytes := serverHello.marshal() 391 392 var serverHelloCopy serverHelloMsg 393 if serverHelloCopy.unmarshal(serverHelloBytes) != alertSuccess { 394 t.Fatal("Failed to unmarshal initial message") 395 } 396 397 // Change serverHelloBytes so that the SCT list is empty 398 i := bytes.Index(serverHelloBytes, sct) 399 if i < 0 { 400 t.Fatal("Cannot find SCT in ServerHello") 401 } 402 403 var serverHelloEmptySCT []byte 404 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 405 // Append the extension length and SCT list length for an empty list. 406 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 407 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 408 409 // Update the handshake message length. 410 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 411 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 412 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 413 414 // Update the extensions length 415 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 416 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 417 418 if serverHelloCopy.unmarshal(serverHelloEmptySCT) == alertSuccess { 419 t.Fatal("Unmarshaled ServerHello with empty SCT list") 420 } 421 } 422 423 func TestRejectEmptySCT(t *testing.T) { 424 // Not only must the SCT list be non-empty, but the SCT elements must 425 // not be zero length. 426 427 var random [32]byte 428 serverHello := serverHelloMsg{ 429 vers: VersionTLS12, 430 random: random[:], 431 scts: [][]byte{nil}, 432 } 433 serverHelloBytes := serverHello.marshal() 434 435 var serverHelloCopy serverHelloMsg 436 if serverHelloCopy.unmarshal(serverHelloBytes) == alertSuccess { 437 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 438 } 439 }