github.com/icodeface/tls@v0.0.0-20230910023335-34df9250cd12/handshake_messages_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "math/rand" 10 "reflect" 11 "strings" 12 "testing" 13 "testing/quick" 14 "time" 15 ) 16 17 var tests = []interface{}{ 18 &clientHelloMsg{}, 19 &serverHelloMsg{}, 20 &finishedMsg{}, 21 22 &certificateMsg{}, 23 &certificateRequestMsg{}, 24 &certificateVerifyMsg{ 25 hasSignatureAlgorithm: true, 26 }, 27 &certificateStatusMsg{}, 28 &clientKeyExchangeMsg{}, 29 &nextProtoMsg{}, 30 &newSessionTicketMsg{}, 31 &sessionState{}, 32 &sessionStateTLS13{}, 33 &encryptedExtensionsMsg{}, 34 &endOfEarlyDataMsg{}, 35 &keyUpdateMsg{}, 36 &newSessionTicketMsgTLS13{}, 37 &certificateRequestMsgTLS13{}, 38 &certificateMsgTLS13{}, 39 } 40 41 func TestMarshalUnmarshal(t *testing.T) { 42 rand := rand.New(rand.NewSource(time.Now().UnixNano())) 43 44 for i, iface := range tests { 45 ty := reflect.ValueOf(iface).Type() 46 47 n := 100 48 if testing.Short() { 49 n = 5 50 } 51 for j := 0; j < n; j++ { 52 v, ok := quick.Value(ty, rand) 53 if !ok { 54 t.Errorf("#%d: failed to create value", i) 55 break 56 } 57 58 m1 := v.Interface().(handshakeMessage) 59 marshaled := m1.marshal() 60 m2 := iface.(handshakeMessage) 61 if !m2.unmarshal(marshaled) { 62 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 63 break 64 } 65 m2.marshal() // to fill any marshal cache in the message 66 67 if !reflect.DeepEqual(m1, m2) { 68 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 69 break 70 } 71 72 if i >= 3 { 73 // The first three message types (ClientHello, 74 // ServerHello and Finished) are allowed to 75 // have parsable prefixes because the extension 76 // data is optional and the length of the 77 // Finished varies across versions. 78 for j := 0; j < len(marshaled); j++ { 79 if m2.unmarshal(marshaled[0:j]) { 80 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 81 break 82 } 83 } 84 } 85 } 86 } 87 } 88 89 func TestFuzz(t *testing.T) { 90 rand := rand.New(rand.NewSource(0)) 91 for _, iface := range tests { 92 m := iface.(handshakeMessage) 93 94 for j := 0; j < 1000; j++ { 95 len := rand.Intn(100) 96 bytes := randomBytes(len, rand) 97 // This just looks for crashes due to bounds errors etc. 98 m.unmarshal(bytes) 99 } 100 } 101 } 102 103 func randomBytes(n int, rand *rand.Rand) []byte { 104 r := make([]byte, n) 105 if _, err := rand.Read(r); err != nil { 106 panic("rand.Read failed: " + err.Error()) 107 } 108 return r 109 } 110 111 func randomString(n int, rand *rand.Rand) string { 112 b := randomBytes(n, rand) 113 return string(b) 114 } 115 116 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 117 m := &clientHelloMsg{} 118 m.vers = uint16(rand.Intn(65536)) 119 m.random = randomBytes(32, rand) 120 m.sessionId = randomBytes(rand.Intn(32), rand) 121 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 122 for i := 0; i < len(m.cipherSuites); i++ { 123 cs := uint16(rand.Int31()) 124 if cs == scsvRenegotiation { 125 cs += 1 126 } 127 m.cipherSuites[i] = cs 128 } 129 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 130 if rand.Intn(10) > 5 { 131 m.nextProtoNeg = true 132 } 133 if rand.Intn(10) > 5 { 134 m.serverName = randomString(rand.Intn(255), rand) 135 for strings.HasSuffix(m.serverName, ".") { 136 m.serverName = m.serverName[:len(m.serverName)-1] 137 } 138 } 139 m.ocspStapling = rand.Intn(10) > 5 140 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 141 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 142 for i := range m.supportedCurves { 143 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 144 } 145 if rand.Intn(10) > 5 { 146 m.ticketSupported = true 147 if rand.Intn(10) > 5 { 148 m.sessionTicket = randomBytes(rand.Intn(300), rand) 149 } else { 150 m.sessionTicket = make([]byte, 0) 151 } 152 } 153 if rand.Intn(10) > 5 { 154 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 155 } 156 if rand.Intn(10) > 5 { 157 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 158 } 159 for i := 0; i < rand.Intn(5); i++ { 160 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 161 } 162 if rand.Intn(10) > 5 { 163 m.scts = true 164 } 165 if rand.Intn(10) > 5 { 166 m.secureRenegotiationSupported = true 167 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 168 } 169 for i := 0; i < rand.Intn(5); i++ { 170 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 171 } 172 if rand.Intn(10) > 5 { 173 m.cookie = randomBytes(rand.Intn(500)+1, rand) 174 } 175 for i := 0; i < rand.Intn(5); i++ { 176 var ks keyShare 177 ks.group = CurveID(rand.Intn(30000) + 1) 178 ks.data = randomBytes(rand.Intn(200)+1, rand) 179 m.keyShares = append(m.keyShares, ks) 180 } 181 switch rand.Intn(3) { 182 case 1: 183 m.pskModes = []uint8{pskModeDHE} 184 case 2: 185 m.pskModes = []uint8{pskModeDHE, pskModePlain} 186 } 187 for i := 0; i < rand.Intn(5); i++ { 188 var psk pskIdentity 189 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 190 psk.label = randomBytes(rand.Intn(500)+1, rand) 191 m.pskIdentities = append(m.pskIdentities, psk) 192 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 193 } 194 if rand.Intn(10) > 5 { 195 m.earlyData = true 196 } 197 198 return reflect.ValueOf(m) 199 } 200 201 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 202 m := &serverHelloMsg{} 203 m.vers = uint16(rand.Intn(65536)) 204 m.random = randomBytes(32, rand) 205 m.sessionId = randomBytes(rand.Intn(32), rand) 206 m.cipherSuite = uint16(rand.Int31()) 207 m.compressionMethod = uint8(rand.Intn(256)) 208 209 if rand.Intn(10) > 5 { 210 m.nextProtoNeg = true 211 for i := 0; i < rand.Intn(10); i++ { 212 m.nextProtos = append(m.nextProtos, randomString(20, rand)) 213 } 214 } 215 216 if rand.Intn(10) > 5 { 217 m.ocspStapling = true 218 } 219 if rand.Intn(10) > 5 { 220 m.ticketSupported = true 221 } 222 if rand.Intn(10) > 5 { 223 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 224 } 225 226 for i := 0; i < rand.Intn(4); i++ { 227 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 228 } 229 230 if rand.Intn(10) > 5 { 231 m.secureRenegotiationSupported = true 232 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 233 } 234 if rand.Intn(10) > 5 { 235 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 236 } 237 if rand.Intn(10) > 5 { 238 m.cookie = randomBytes(rand.Intn(500)+1, rand) 239 } 240 if rand.Intn(10) > 5 { 241 for i := 0; i < rand.Intn(5); i++ { 242 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 243 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 244 } 245 } else if rand.Intn(10) > 5 { 246 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 247 } 248 if rand.Intn(10) > 5 { 249 m.selectedIdentityPresent = true 250 m.selectedIdentity = uint16(rand.Intn(0xffff)) 251 } 252 253 return reflect.ValueOf(m) 254 } 255 256 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 257 m := &encryptedExtensionsMsg{} 258 259 if rand.Intn(10) > 5 { 260 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 261 } 262 263 return reflect.ValueOf(m) 264 } 265 266 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 267 m := &certificateMsg{} 268 numCerts := rand.Intn(20) 269 m.certificates = make([][]byte, numCerts) 270 for i := 0; i < numCerts; i++ { 271 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 272 } 273 return reflect.ValueOf(m) 274 } 275 276 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 277 m := &certificateRequestMsg{} 278 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 279 for i := 0; i < rand.Intn(100); i++ { 280 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 281 } 282 return reflect.ValueOf(m) 283 } 284 285 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 286 m := &certificateVerifyMsg{} 287 m.hasSignatureAlgorithm = true 288 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 289 m.signature = randomBytes(rand.Intn(15)+1, rand) 290 return reflect.ValueOf(m) 291 } 292 293 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 294 m := &certificateStatusMsg{} 295 m.response = randomBytes(rand.Intn(10)+1, rand) 296 return reflect.ValueOf(m) 297 } 298 299 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 300 m := &clientKeyExchangeMsg{} 301 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 302 return reflect.ValueOf(m) 303 } 304 305 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 306 m := &finishedMsg{} 307 m.verifyData = randomBytes(12, rand) 308 return reflect.ValueOf(m) 309 } 310 311 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { 312 m := &nextProtoMsg{} 313 m.proto = randomString(rand.Intn(255), rand) 314 return reflect.ValueOf(m) 315 } 316 317 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 318 m := &newSessionTicketMsg{} 319 m.ticket = randomBytes(rand.Intn(4), rand) 320 return reflect.ValueOf(m) 321 } 322 323 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 324 s := &sessionState{} 325 s.vers = uint16(rand.Intn(10000)) 326 s.cipherSuite = uint16(rand.Intn(10000)) 327 s.masterSecret = randomBytes(rand.Intn(100), rand) 328 numCerts := rand.Intn(20) 329 s.certificates = make([][]byte, numCerts) 330 for i := 0; i < numCerts; i++ { 331 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 332 } 333 return reflect.ValueOf(s) 334 } 335 336 func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 337 s := &sessionStateTLS13{} 338 s.cipherSuite = uint16(rand.Intn(10000)) 339 s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) 340 s.createdAt = uint64(rand.Int63()) 341 for i := 0; i < rand.Intn(2)+1; i++ { 342 s.certificate.Certificate = append( 343 s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 344 } 345 if rand.Intn(10) > 5 { 346 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 347 } 348 if rand.Intn(10) > 5 { 349 for i := 0; i < rand.Intn(2)+1; i++ { 350 s.certificate.SignedCertificateTimestamps = append( 351 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 352 } 353 } 354 return reflect.ValueOf(s) 355 } 356 357 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 358 m := &endOfEarlyDataMsg{} 359 return reflect.ValueOf(m) 360 } 361 362 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 363 m := &keyUpdateMsg{} 364 m.updateRequested = rand.Intn(10) > 5 365 return reflect.ValueOf(m) 366 } 367 368 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 369 m := &newSessionTicketMsgTLS13{} 370 m.lifetime = uint32(rand.Intn(500000)) 371 m.ageAdd = uint32(rand.Intn(500000)) 372 m.nonce = randomBytes(rand.Intn(100), rand) 373 m.label = randomBytes(rand.Intn(1000), rand) 374 if rand.Intn(10) > 5 { 375 m.maxEarlyData = uint32(rand.Intn(500000)) 376 } 377 return reflect.ValueOf(m) 378 } 379 380 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 381 m := &certificateRequestMsgTLS13{} 382 if rand.Intn(10) > 5 { 383 m.ocspStapling = true 384 } 385 if rand.Intn(10) > 5 { 386 m.scts = true 387 } 388 if rand.Intn(10) > 5 { 389 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 390 } 391 if rand.Intn(10) > 5 { 392 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 393 } 394 if rand.Intn(10) > 5 { 395 m.certificateAuthorities = make([][]byte, 3) 396 for i := 0; i < 3; i++ { 397 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 398 } 399 } 400 return reflect.ValueOf(m) 401 } 402 403 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 404 m := &certificateMsgTLS13{} 405 for i := 0; i < rand.Intn(2)+1; i++ { 406 m.certificate.Certificate = append( 407 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 408 } 409 if rand.Intn(10) > 5 { 410 m.ocspStapling = true 411 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 412 } 413 if rand.Intn(10) > 5 { 414 m.scts = true 415 for i := 0; i < rand.Intn(2)+1; i++ { 416 m.certificate.SignedCertificateTimestamps = append( 417 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 418 } 419 } 420 return reflect.ValueOf(m) 421 } 422 423 func TestRejectEmptySCTList(t *testing.T) { 424 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 425 426 var random [32]byte 427 sct := []byte{0x42, 0x42, 0x42, 0x42} 428 serverHello := serverHelloMsg{ 429 vers: VersionTLS12, 430 random: random[:], 431 scts: [][]byte{sct}, 432 } 433 serverHelloBytes := serverHello.marshal() 434 435 var serverHelloCopy serverHelloMsg 436 if !serverHelloCopy.unmarshal(serverHelloBytes) { 437 t.Fatal("Failed to unmarshal initial message") 438 } 439 440 // Change serverHelloBytes so that the SCT list is empty 441 i := bytes.Index(serverHelloBytes, sct) 442 if i < 0 { 443 t.Fatal("Cannot find SCT in ServerHello") 444 } 445 446 var serverHelloEmptySCT []byte 447 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 448 // Append the extension length and SCT list length for an empty list. 449 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 450 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 451 452 // Update the handshake message length. 453 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 454 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 455 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 456 457 // Update the extensions length 458 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 459 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 460 461 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 462 t.Fatal("Unmarshaled ServerHello with empty SCT list") 463 } 464 } 465 466 func TestRejectEmptySCT(t *testing.T) { 467 // Not only must the SCT list be non-empty, but the SCT elements must 468 // not be zero length. 469 470 var random [32]byte 471 serverHello := serverHelloMsg{ 472 vers: VersionTLS12, 473 random: random[:], 474 scts: [][]byte{nil}, 475 } 476 serverHelloBytes := serverHello.marshal() 477 478 var serverHelloCopy serverHelloMsg 479 if serverHelloCopy.unmarshal(serverHelloBytes) { 480 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 481 } 482 }