github.com/JimmyHuang454/JLS-go@v0.0.0-20230831150107-90d536585ba0/tls/handshake_messages_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/x509" 10 "encoding/hex" 11 "math" 12 "math/rand" 13 "reflect" 14 "strings" 15 "testing" 16 "testing/quick" 17 "time" 18 ) 19 20 var tests = []handshakeMessage{ 21 &clientHelloMsg{}, 22 &serverHelloMsg{}, 23 &finishedMsg{}, 24 25 &certificateMsg{}, 26 &certificateRequestMsg{}, 27 &certificateVerifyMsg{ 28 hasSignatureAlgorithm: true, 29 }, 30 &certificateStatusMsg{}, 31 &clientKeyExchangeMsg{}, 32 &newSessionTicketMsg{}, 33 &encryptedExtensionsMsg{}, 34 &endOfEarlyDataMsg{}, 35 &keyUpdateMsg{}, 36 &newSessionTicketMsgTLS13{}, 37 &certificateRequestMsgTLS13{}, 38 &certificateMsgTLS13{}, 39 &SessionState{}, 40 } 41 42 func mustMarshal(t *testing.T, msg handshakeMessage) []byte { 43 t.Helper() 44 b, err := msg.marshal() 45 if err != nil { 46 t.Fatal(err) 47 } 48 return b 49 } 50 51 func TestMarshalUnmarshal(t *testing.T) { 52 rand := rand.New(rand.NewSource(time.Now().UnixNano())) 53 54 for i, m := range tests { 55 ty := reflect.ValueOf(m).Type() 56 57 n := 100 58 if testing.Short() { 59 n = 5 60 } 61 for j := 0; j < n; j++ { 62 v, ok := quick.Value(ty, rand) 63 if !ok { 64 t.Errorf("#%d: failed to create value", i) 65 break 66 } 67 68 m1 := v.Interface().(handshakeMessage) 69 marshaled := mustMarshal(t, m1) 70 if !m.unmarshal(marshaled) { 71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 72 break 73 } 74 m.marshal() // to fill any marshal cache in the message 75 76 if m, ok := m.(*SessionState); ok { 77 m.activeCertHandles = nil 78 } 79 80 if !reflect.DeepEqual(m1, m) { 81 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled) 82 break 83 } 84 85 if i >= 3 { 86 // The first three message types (ClientHello, 87 // ServerHello and Finished) are allowed to 88 // have parsable prefixes because the extension 89 // data is optional and the length of the 90 // Finished varies across versions. 91 for j := 0; j < len(marshaled); j++ { 92 if m.unmarshal(marshaled[0:j]) { 93 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 94 break 95 } 96 } 97 } 98 } 99 } 100 } 101 102 func TestFuzz(t *testing.T) { 103 rand := rand.New(rand.NewSource(0)) 104 for _, m := range tests { 105 for j := 0; j < 1000; j++ { 106 len := rand.Intn(1000) 107 bytes := randomBytes(len, rand) 108 // This just looks for crashes due to bounds errors etc. 109 m.unmarshal(bytes) 110 } 111 } 112 } 113 114 func randomBytes(n int, rand *rand.Rand) []byte { 115 r := make([]byte, n) 116 if _, err := rand.Read(r); err != nil { 117 panic("rand.Read failed: " + err.Error()) 118 } 119 return r 120 } 121 122 func randomString(n int, rand *rand.Rand) string { 123 b := randomBytes(n, rand) 124 return string(b) 125 } 126 127 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 128 m := &clientHelloMsg{} 129 m.vers = uint16(rand.Intn(65536)) 130 m.random = randomBytes(32, rand) 131 m.sessionId = randomBytes(rand.Intn(32), rand) 132 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 133 for i := 0; i < len(m.cipherSuites); i++ { 134 cs := uint16(rand.Int31()) 135 if cs == scsvRenegotiation { 136 cs += 1 137 } 138 m.cipherSuites[i] = cs 139 } 140 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 141 if rand.Intn(10) > 5 { 142 m.serverName = randomString(rand.Intn(255), rand) 143 for strings.HasSuffix(m.serverName, ".") { 144 m.serverName = m.serverName[:len(m.serverName)-1] 145 } 146 } 147 m.ocspStapling = rand.Intn(10) > 5 148 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 149 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 150 for i := range m.supportedCurves { 151 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 152 } 153 if rand.Intn(10) > 5 { 154 m.ticketSupported = true 155 if rand.Intn(10) > 5 { 156 m.sessionTicket = randomBytes(rand.Intn(300), rand) 157 } else { 158 m.sessionTicket = make([]byte, 0) 159 } 160 } 161 if rand.Intn(10) > 5 { 162 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 163 } 164 if rand.Intn(10) > 5 { 165 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 166 } 167 for i := 0; i < rand.Intn(5); i++ { 168 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 169 } 170 if rand.Intn(10) > 5 { 171 m.scts = true 172 } 173 if rand.Intn(10) > 5 { 174 m.secureRenegotiationSupported = true 175 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 176 } 177 if rand.Intn(10) > 5 { 178 m.extendedMasterSecret = true 179 } 180 for i := 0; i < rand.Intn(5); i++ { 181 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 182 } 183 if rand.Intn(10) > 5 { 184 m.cookie = randomBytes(rand.Intn(500)+1, rand) 185 } 186 for i := 0; i < rand.Intn(5); i++ { 187 var ks keyShare 188 ks.group = CurveID(rand.Intn(30000) + 1) 189 ks.data = randomBytes(rand.Intn(200)+1, rand) 190 m.keyShares = append(m.keyShares, ks) 191 } 192 switch rand.Intn(3) { 193 case 1: 194 m.pskModes = []uint8{pskModeDHE} 195 case 2: 196 m.pskModes = []uint8{pskModeDHE, pskModePlain} 197 } 198 for i := 0; i < rand.Intn(5); i++ { 199 var psk pskIdentity 200 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 201 psk.label = randomBytes(rand.Intn(500)+1, rand) 202 m.pskIdentities = append(m.pskIdentities, psk) 203 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 204 } 205 if rand.Intn(10) > 5 { 206 m.quicTransportParameters = randomBytes(rand.Intn(500), rand) 207 } 208 if rand.Intn(10) > 5 { 209 m.earlyData = true 210 } 211 212 return reflect.ValueOf(m) 213 } 214 215 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 216 m := &serverHelloMsg{} 217 m.vers = uint16(rand.Intn(65536)) 218 m.random = randomBytes(32, rand) 219 m.sessionId = randomBytes(rand.Intn(32), rand) 220 m.cipherSuite = uint16(rand.Int31()) 221 m.compressionMethod = uint8(rand.Intn(256)) 222 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 223 224 if rand.Intn(10) > 5 { 225 m.ocspStapling = true 226 } 227 if rand.Intn(10) > 5 { 228 m.ticketSupported = true 229 } 230 if rand.Intn(10) > 5 { 231 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 232 } 233 234 for i := 0; i < rand.Intn(4); i++ { 235 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 236 } 237 238 if rand.Intn(10) > 5 { 239 m.secureRenegotiationSupported = true 240 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 241 } 242 if rand.Intn(10) > 5 { 243 m.extendedMasterSecret = true 244 } 245 if rand.Intn(10) > 5 { 246 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 247 } 248 if rand.Intn(10) > 5 { 249 m.cookie = randomBytes(rand.Intn(500)+1, rand) 250 } 251 if rand.Intn(10) > 5 { 252 for i := 0; i < rand.Intn(5); i++ { 253 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 254 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 255 } 256 } else if rand.Intn(10) > 5 { 257 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 258 } 259 if rand.Intn(10) > 5 { 260 m.selectedIdentityPresent = true 261 m.selectedIdentity = uint16(rand.Intn(0xffff)) 262 } 263 264 return reflect.ValueOf(m) 265 } 266 267 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 268 m := &encryptedExtensionsMsg{} 269 270 if rand.Intn(10) > 5 { 271 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 272 } 273 if rand.Intn(10) > 5 { 274 m.earlyData = true 275 } 276 277 return reflect.ValueOf(m) 278 } 279 280 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 281 m := &certificateMsg{} 282 numCerts := rand.Intn(20) 283 m.certificates = make([][]byte, numCerts) 284 for i := 0; i < numCerts; i++ { 285 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 286 } 287 return reflect.ValueOf(m) 288 } 289 290 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 291 m := &certificateRequestMsg{} 292 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 293 for i := 0; i < rand.Intn(100); i++ { 294 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 295 } 296 return reflect.ValueOf(m) 297 } 298 299 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 300 m := &certificateVerifyMsg{} 301 m.hasSignatureAlgorithm = true 302 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 303 m.signature = randomBytes(rand.Intn(15)+1, rand) 304 return reflect.ValueOf(m) 305 } 306 307 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 308 m := &certificateStatusMsg{} 309 m.response = randomBytes(rand.Intn(10)+1, rand) 310 return reflect.ValueOf(m) 311 } 312 313 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 314 m := &clientKeyExchangeMsg{} 315 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 316 return reflect.ValueOf(m) 317 } 318 319 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 320 m := &finishedMsg{} 321 m.verifyData = randomBytes(12, rand) 322 return reflect.ValueOf(m) 323 } 324 325 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 326 m := &newSessionTicketMsg{} 327 m.ticket = randomBytes(rand.Intn(4), rand) 328 return reflect.ValueOf(m) 329 } 330 331 var sessionTestCerts []*x509.Certificate 332 333 func init() { 334 cert, err := x509.ParseCertificate(testRSACertificate) 335 if err != nil { 336 panic(err) 337 } 338 sessionTestCerts = append(sessionTestCerts, cert) 339 cert, err = x509.ParseCertificate(testRSACertificateIssuer) 340 if err != nil { 341 panic(err) 342 } 343 sessionTestCerts = append(sessionTestCerts, cert) 344 } 345 346 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value { 347 s := &SessionState{} 348 isTLS13 := rand.Intn(10) > 5 349 if isTLS13 { 350 s.version = VersionTLS13 351 } else { 352 s.version = uint16(rand.Intn(VersionTLS13)) 353 } 354 s.isClient = rand.Intn(10) > 5 355 s.cipherSuite = uint16(rand.Intn(math.MaxUint16)) 356 s.createdAt = uint64(rand.Int63()) 357 s.secret = randomBytes(rand.Intn(100)+1, rand) 358 for n, i := rand.Intn(3), 0; i < n; i++ { 359 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand)) 360 } 361 if rand.Intn(10) > 5 { 362 s.EarlyData = true 363 } 364 if rand.Intn(10) > 5 { 365 s.extMasterSecret = true 366 } 367 if s.isClient || rand.Intn(10) > 5 { 368 if rand.Intn(10) > 5 { 369 s.peerCertificates = sessionTestCerts 370 } else { 371 s.peerCertificates = sessionTestCerts[:1] 372 } 373 } 374 if rand.Intn(10) > 5 && s.peerCertificates != nil { 375 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand) 376 } 377 if rand.Intn(10) > 5 && s.peerCertificates != nil { 378 for i := 0; i < rand.Intn(2)+1; i++ { 379 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand)) 380 } 381 } 382 if len(s.peerCertificates) > 0 { 383 for i := 0; i < rand.Intn(3); i++ { 384 if rand.Intn(10) > 5 { 385 s.verifiedChains = append(s.verifiedChains, s.peerCertificates) 386 } else { 387 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1]) 388 } 389 } 390 } 391 if rand.Intn(10) > 5 && s.EarlyData { 392 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand)) 393 } 394 if s.isClient { 395 if isTLS13 { 396 s.useBy = uint64(rand.Int63()) 397 s.ageAdd = uint32(rand.Int63() & math.MaxUint32) 398 } 399 } 400 return reflect.ValueOf(s) 401 } 402 403 func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() } 404 func (s *SessionState) unmarshal(b []byte) bool { 405 ss, err := ParseSessionState(b) 406 if err != nil { 407 return false 408 } 409 *s = *ss 410 return true 411 } 412 413 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 414 m := &endOfEarlyDataMsg{} 415 return reflect.ValueOf(m) 416 } 417 418 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 419 m := &keyUpdateMsg{} 420 m.updateRequested = rand.Intn(10) > 5 421 return reflect.ValueOf(m) 422 } 423 424 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 425 m := &newSessionTicketMsgTLS13{} 426 m.lifetime = uint32(rand.Intn(500000)) 427 m.ageAdd = uint32(rand.Intn(500000)) 428 m.nonce = randomBytes(rand.Intn(100), rand) 429 m.label = randomBytes(rand.Intn(1000), rand) 430 if rand.Intn(10) > 5 { 431 m.maxEarlyData = uint32(rand.Intn(500000)) 432 } 433 return reflect.ValueOf(m) 434 } 435 436 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 437 m := &certificateRequestMsgTLS13{} 438 if rand.Intn(10) > 5 { 439 m.ocspStapling = true 440 } 441 if rand.Intn(10) > 5 { 442 m.scts = true 443 } 444 if rand.Intn(10) > 5 { 445 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 446 } 447 if rand.Intn(10) > 5 { 448 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 449 } 450 if rand.Intn(10) > 5 { 451 m.certificateAuthorities = make([][]byte, 3) 452 for i := 0; i < 3; i++ { 453 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 454 } 455 } 456 return reflect.ValueOf(m) 457 } 458 459 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 460 m := &certificateMsgTLS13{} 461 for i := 0; i < rand.Intn(2)+1; i++ { 462 m.certificate.Certificate = append( 463 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 464 } 465 if rand.Intn(10) > 5 { 466 m.ocspStapling = true 467 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 468 } 469 if rand.Intn(10) > 5 { 470 m.scts = true 471 for i := 0; i < rand.Intn(2)+1; i++ { 472 m.certificate.SignedCertificateTimestamps = append( 473 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 474 } 475 } 476 return reflect.ValueOf(m) 477 } 478 479 func TestRejectEmptySCTList(t *testing.T) { 480 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 481 482 var random [32]byte 483 sct := []byte{0x42, 0x42, 0x42, 0x42} 484 serverHello := &serverHelloMsg{ 485 vers: VersionTLS12, 486 random: random[:], 487 scts: [][]byte{sct}, 488 } 489 serverHelloBytes := mustMarshal(t, serverHello) 490 491 var serverHelloCopy serverHelloMsg 492 if !serverHelloCopy.unmarshal(serverHelloBytes) { 493 t.Fatal("Failed to unmarshal initial message") 494 } 495 496 // Change serverHelloBytes so that the SCT list is empty 497 i := bytes.Index(serverHelloBytes, sct) 498 if i < 0 { 499 t.Fatal("Cannot find SCT in ServerHello") 500 } 501 502 var serverHelloEmptySCT []byte 503 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 504 // Append the extension length and SCT list length for an empty list. 505 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 506 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 507 508 // Update the handshake message length. 509 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 510 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 511 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 512 513 // Update the extensions length 514 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 515 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 516 517 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 518 t.Fatal("Unmarshaled ServerHello with empty SCT list") 519 } 520 } 521 522 func TestRejectEmptySCT(t *testing.T) { 523 // Not only must the SCT list be non-empty, but the SCT elements must 524 // not be zero length. 525 526 var random [32]byte 527 serverHello := &serverHelloMsg{ 528 vers: VersionTLS12, 529 random: random[:], 530 scts: [][]byte{nil}, 531 } 532 serverHelloBytes := mustMarshal(t, serverHello) 533 534 var serverHelloCopy serverHelloMsg 535 if serverHelloCopy.unmarshal(serverHelloBytes) { 536 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 537 } 538 } 539 540 func TestRejectDuplicateExtensions(t *testing.T) { 541 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f") 542 if err != nil { 543 t.Fatalf("failed to decode test ClientHello: %s", err) 544 } 545 var clientHelloCopy clientHelloMsg 546 if clientHelloCopy.unmarshal(clientHelloBytes) { 547 t.Error("Unmarshaled ClientHello with duplicate extensions") 548 } 549 550 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000") 551 if err != nil { 552 t.Fatalf("failed to decode test ServerHello: %s", err) 553 } 554 var serverHelloCopy serverHelloMsg 555 if serverHelloCopy.unmarshal(serverHelloBytes) { 556 t.Fatal("Unmarshaled ServerHello with duplicate extensions") 557 } 558 }