github.com/lovishpuri/go-40569/src@v0.0.0-20230519171745-f8623e7c56cf/crypto/tls/handshake_messages_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "encoding/hex" 10 "math/rand" 11 "reflect" 12 "strings" 13 "testing" 14 "testing/quick" 15 "time" 16 ) 17 18 var tests = []any{ 19 &clientHelloMsg{}, 20 &serverHelloMsg{}, 21 &finishedMsg{}, 22 23 &certificateMsg{}, 24 &certificateRequestMsg{}, 25 &certificateVerifyMsg{ 26 hasSignatureAlgorithm: true, 27 }, 28 &certificateStatusMsg{}, 29 &clientKeyExchangeMsg{}, 30 &newSessionTicketMsg{}, 31 &sessionState{}, 32 &sessionStateTLS13{}, 33 &encryptedExtensionsMsg{}, 34 &endOfEarlyDataMsg{}, 35 &keyUpdateMsg{}, 36 &newSessionTicketMsgTLS13{}, 37 &certificateRequestMsgTLS13{}, 38 &certificateMsgTLS13{}, 39 } 40 41 func mustMarshal(t *testing.T, msg handshakeMessage) []byte { 42 t.Helper() 43 b, err := msg.marshal() 44 if err != nil { 45 t.Fatal(err) 46 } 47 return b 48 } 49 50 func TestMarshalUnmarshal(t *testing.T) { 51 rand := rand.New(rand.NewSource(time.Now().UnixNano())) 52 53 for i, iface := range tests { 54 ty := reflect.ValueOf(iface).Type() 55 56 n := 100 57 if testing.Short() { 58 n = 5 59 } 60 for j := 0; j < n; j++ { 61 v, ok := quick.Value(ty, rand) 62 if !ok { 63 t.Errorf("#%d: failed to create value", i) 64 break 65 } 66 67 m1 := v.Interface().(handshakeMessage) 68 marshaled := mustMarshal(t, m1) 69 m2 := iface.(handshakeMessage) 70 if !m2.unmarshal(marshaled) { 71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 72 break 73 } 74 m2.marshal() // to fill any marshal cache in the message 75 76 if !reflect.DeepEqual(m1, m2) { 77 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 78 break 79 } 80 81 if i >= 3 { 82 // The first three message types (ClientHello, 83 // ServerHello and Finished) are allowed to 84 // have parsable prefixes because the extension 85 // data is optional and the length of the 86 // Finished varies across versions. 87 for j := 0; j < len(marshaled); j++ { 88 if m2.unmarshal(marshaled[0:j]) { 89 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 90 break 91 } 92 } 93 } 94 } 95 } 96 } 97 98 func TestFuzz(t *testing.T) { 99 rand := rand.New(rand.NewSource(0)) 100 for _, iface := range tests { 101 m := iface.(handshakeMessage) 102 103 for j := 0; j < 1000; j++ { 104 len := rand.Intn(100) 105 bytes := randomBytes(len, rand) 106 // This just looks for crashes due to bounds errors etc. 107 m.unmarshal(bytes) 108 } 109 } 110 } 111 112 func randomBytes(n int, rand *rand.Rand) []byte { 113 r := make([]byte, n) 114 if _, err := rand.Read(r); err != nil { 115 panic("rand.Read failed: " + err.Error()) 116 } 117 return r 118 } 119 120 func randomString(n int, rand *rand.Rand) string { 121 b := randomBytes(n, rand) 122 return string(b) 123 } 124 125 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 126 m := &clientHelloMsg{} 127 m.vers = uint16(rand.Intn(65536)) 128 m.random = randomBytes(32, rand) 129 m.sessionId = randomBytes(rand.Intn(32), rand) 130 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 131 for i := 0; i < len(m.cipherSuites); i++ { 132 cs := uint16(rand.Int31()) 133 if cs == scsvRenegotiation { 134 cs += 1 135 } 136 m.cipherSuites[i] = cs 137 } 138 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 139 if rand.Intn(10) > 5 { 140 m.serverName = randomString(rand.Intn(255), rand) 141 for strings.HasSuffix(m.serverName, ".") { 142 m.serverName = m.serverName[:len(m.serverName)-1] 143 } 144 } 145 m.ocspStapling = rand.Intn(10) > 5 146 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 147 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 148 for i := range m.supportedCurves { 149 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 150 } 151 if rand.Intn(10) > 5 { 152 m.ticketSupported = true 153 if rand.Intn(10) > 5 { 154 m.sessionTicket = randomBytes(rand.Intn(300), rand) 155 } else { 156 m.sessionTicket = make([]byte, 0) 157 } 158 } 159 if rand.Intn(10) > 5 { 160 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 161 } 162 if rand.Intn(10) > 5 { 163 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 164 } 165 for i := 0; i < rand.Intn(5); i++ { 166 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 167 } 168 if rand.Intn(10) > 5 { 169 m.scts = true 170 } 171 if rand.Intn(10) > 5 { 172 m.secureRenegotiationSupported = true 173 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 174 } 175 for i := 0; i < rand.Intn(5); i++ { 176 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 177 } 178 if rand.Intn(10) > 5 { 179 m.cookie = randomBytes(rand.Intn(500)+1, rand) 180 } 181 for i := 0; i < rand.Intn(5); i++ { 182 var ks keyShare 183 ks.group = CurveID(rand.Intn(30000) + 1) 184 ks.data = randomBytes(rand.Intn(200)+1, rand) 185 m.keyShares = append(m.keyShares, ks) 186 } 187 switch rand.Intn(3) { 188 case 1: 189 m.pskModes = []uint8{pskModeDHE} 190 case 2: 191 m.pskModes = []uint8{pskModeDHE, pskModePlain} 192 } 193 for i := 0; i < rand.Intn(5); i++ { 194 var psk pskIdentity 195 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 196 psk.label = randomBytes(rand.Intn(500)+1, rand) 197 m.pskIdentities = append(m.pskIdentities, psk) 198 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 199 } 200 if rand.Intn(10) > 5 { 201 m.earlyData = true 202 } 203 204 return reflect.ValueOf(m) 205 } 206 207 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 208 m := &serverHelloMsg{} 209 m.vers = uint16(rand.Intn(65536)) 210 m.random = randomBytes(32, rand) 211 m.sessionId = randomBytes(rand.Intn(32), rand) 212 m.cipherSuite = uint16(rand.Int31()) 213 m.compressionMethod = uint8(rand.Intn(256)) 214 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 215 216 if rand.Intn(10) > 5 { 217 m.ocspStapling = true 218 } 219 if rand.Intn(10) > 5 { 220 m.ticketSupported = true 221 } 222 if rand.Intn(10) > 5 { 223 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 224 } 225 226 for i := 0; i < rand.Intn(4); i++ { 227 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 228 } 229 230 if rand.Intn(10) > 5 { 231 m.secureRenegotiationSupported = true 232 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 233 } 234 if rand.Intn(10) > 5 { 235 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 236 } 237 if rand.Intn(10) > 5 { 238 m.cookie = randomBytes(rand.Intn(500)+1, rand) 239 } 240 if rand.Intn(10) > 5 { 241 for i := 0; i < rand.Intn(5); i++ { 242 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 243 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 244 } 245 } else if rand.Intn(10) > 5 { 246 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 247 } 248 if rand.Intn(10) > 5 { 249 m.selectedIdentityPresent = true 250 m.selectedIdentity = uint16(rand.Intn(0xffff)) 251 } 252 253 return reflect.ValueOf(m) 254 } 255 256 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 257 m := &encryptedExtensionsMsg{} 258 259 if rand.Intn(10) > 5 { 260 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 261 } 262 263 return reflect.ValueOf(m) 264 } 265 266 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 267 m := &certificateMsg{} 268 numCerts := rand.Intn(20) 269 m.certificates = make([][]byte, numCerts) 270 for i := 0; i < numCerts; i++ { 271 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 272 } 273 return reflect.ValueOf(m) 274 } 275 276 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 277 m := &certificateRequestMsg{} 278 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 279 for i := 0; i < rand.Intn(100); i++ { 280 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 281 } 282 return reflect.ValueOf(m) 283 } 284 285 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 286 m := &certificateVerifyMsg{} 287 m.hasSignatureAlgorithm = true 288 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 289 m.signature = randomBytes(rand.Intn(15)+1, rand) 290 return reflect.ValueOf(m) 291 } 292 293 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 294 m := &certificateStatusMsg{} 295 m.response = randomBytes(rand.Intn(10)+1, rand) 296 return reflect.ValueOf(m) 297 } 298 299 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 300 m := &clientKeyExchangeMsg{} 301 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 302 return reflect.ValueOf(m) 303 } 304 305 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 306 m := &finishedMsg{} 307 m.verifyData = randomBytes(12, rand) 308 return reflect.ValueOf(m) 309 } 310 311 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 312 m := &newSessionTicketMsg{} 313 m.ticket = randomBytes(rand.Intn(4), rand) 314 return reflect.ValueOf(m) 315 } 316 317 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 318 s := &sessionState{} 319 s.vers = uint16(rand.Intn(10000)) 320 s.cipherSuite = uint16(rand.Intn(10000)) 321 s.masterSecret = randomBytes(rand.Intn(100)+1, rand) 322 s.createdAt = uint64(rand.Int63()) 323 for i := 0; i < rand.Intn(20); i++ { 324 s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) 325 } 326 return reflect.ValueOf(s) 327 } 328 329 func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 330 s := &sessionStateTLS13{} 331 s.cipherSuite = uint16(rand.Intn(10000)) 332 s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) 333 s.createdAt = uint64(rand.Int63()) 334 for i := 0; i < rand.Intn(2)+1; i++ { 335 s.certificate.Certificate = append( 336 s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 337 } 338 if rand.Intn(10) > 5 { 339 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 340 } 341 if rand.Intn(10) > 5 { 342 for i := 0; i < rand.Intn(2)+1; i++ { 343 s.certificate.SignedCertificateTimestamps = append( 344 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 345 } 346 } 347 return reflect.ValueOf(s) 348 } 349 350 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 351 m := &endOfEarlyDataMsg{} 352 return reflect.ValueOf(m) 353 } 354 355 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 356 m := &keyUpdateMsg{} 357 m.updateRequested = rand.Intn(10) > 5 358 return reflect.ValueOf(m) 359 } 360 361 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 362 m := &newSessionTicketMsgTLS13{} 363 m.lifetime = uint32(rand.Intn(500000)) 364 m.ageAdd = uint32(rand.Intn(500000)) 365 m.nonce = randomBytes(rand.Intn(100), rand) 366 m.label = randomBytes(rand.Intn(1000), rand) 367 if rand.Intn(10) > 5 { 368 m.maxEarlyData = uint32(rand.Intn(500000)) 369 } 370 return reflect.ValueOf(m) 371 } 372 373 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 374 m := &certificateRequestMsgTLS13{} 375 if rand.Intn(10) > 5 { 376 m.ocspStapling = true 377 } 378 if rand.Intn(10) > 5 { 379 m.scts = true 380 } 381 if rand.Intn(10) > 5 { 382 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 383 } 384 if rand.Intn(10) > 5 { 385 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 386 } 387 if rand.Intn(10) > 5 { 388 m.certificateAuthorities = make([][]byte, 3) 389 for i := 0; i < 3; i++ { 390 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 391 } 392 } 393 return reflect.ValueOf(m) 394 } 395 396 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 397 m := &certificateMsgTLS13{} 398 for i := 0; i < rand.Intn(2)+1; i++ { 399 m.certificate.Certificate = append( 400 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 401 } 402 if rand.Intn(10) > 5 { 403 m.ocspStapling = true 404 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 405 } 406 if rand.Intn(10) > 5 { 407 m.scts = true 408 for i := 0; i < rand.Intn(2)+1; i++ { 409 m.certificate.SignedCertificateTimestamps = append( 410 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 411 } 412 } 413 return reflect.ValueOf(m) 414 } 415 416 func TestRejectEmptySCTList(t *testing.T) { 417 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 418 419 var random [32]byte 420 sct := []byte{0x42, 0x42, 0x42, 0x42} 421 serverHello := &serverHelloMsg{ 422 vers: VersionTLS12, 423 random: random[:], 424 scts: [][]byte{sct}, 425 } 426 serverHelloBytes := mustMarshal(t, serverHello) 427 428 var serverHelloCopy serverHelloMsg 429 if !serverHelloCopy.unmarshal(serverHelloBytes) { 430 t.Fatal("Failed to unmarshal initial message") 431 } 432 433 // Change serverHelloBytes so that the SCT list is empty 434 i := bytes.Index(serverHelloBytes, sct) 435 if i < 0 { 436 t.Fatal("Cannot find SCT in ServerHello") 437 } 438 439 var serverHelloEmptySCT []byte 440 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 441 // Append the extension length and SCT list length for an empty list. 442 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 443 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 444 445 // Update the handshake message length. 446 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 447 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 448 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 449 450 // Update the extensions length 451 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 452 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 453 454 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 455 t.Fatal("Unmarshaled ServerHello with empty SCT list") 456 } 457 } 458 459 func TestRejectEmptySCT(t *testing.T) { 460 // Not only must the SCT list be non-empty, but the SCT elements must 461 // not be zero length. 462 463 var random [32]byte 464 serverHello := &serverHelloMsg{ 465 vers: VersionTLS12, 466 random: random[:], 467 scts: [][]byte{nil}, 468 } 469 serverHelloBytes := mustMarshal(t, serverHello) 470 471 var serverHelloCopy serverHelloMsg 472 if serverHelloCopy.unmarshal(serverHelloBytes) { 473 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 474 } 475 } 476 477 func TestRejectDuplicateExtensions(t *testing.T) { 478 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f") 479 if err != nil { 480 t.Fatalf("failed to decode test ClientHello: %s", err) 481 } 482 var clientHelloCopy clientHelloMsg 483 if clientHelloCopy.unmarshal(clientHelloBytes) { 484 t.Error("Unmarshaled ClientHello with duplicate extensions") 485 } 486 487 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000") 488 if err != nil { 489 t.Fatalf("failed to decode test ServerHello: %s", err) 490 } 491 var serverHelloCopy serverHelloMsg 492 if serverHelloCopy.unmarshal(serverHelloBytes) { 493 t.Fatal("Unmarshaled ServerHello with duplicate extensions") 494 } 495 }