github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/crypto/tls/handshake_messages_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "encoding/hex" 10 "math/rand" 11 "reflect" 12 "strings" 13 "testing" 14 "testing/quick" 15 "time" 16 ) 17 18 var tests = []any{ 19 &clientHelloMsg{}, 20 &serverHelloMsg{}, 21 &finishedMsg{}, 22 23 &certificateMsg{}, 24 &certificateRequestMsg{}, 25 &certificateVerifyMsg{ 26 hasSignatureAlgorithm: true, 27 }, 28 &certificateStatusMsg{}, 29 &clientKeyExchangeMsg{}, 30 &newSessionTicketMsg{}, 31 &sessionState{}, 32 &sessionStateTLS13{}, 33 &encryptedExtensionsMsg{}, 34 &endOfEarlyDataMsg{}, 35 &keyUpdateMsg{}, 36 &newSessionTicketMsgTLS13{}, 37 &certificateRequestMsgTLS13{}, 38 &certificateMsgTLS13{}, 39 } 40 41 func TestMarshalUnmarshal(t *testing.T) { 42 rand := rand.New(rand.NewSource(time.Now().UnixNano())) 43 44 for i, iface := range tests { 45 ty := reflect.ValueOf(iface).Type() 46 47 n := 100 48 if testing.Short() { 49 n = 5 50 } 51 for j := 0; j < n; j++ { 52 v, ok := quick.Value(ty, rand) 53 if !ok { 54 t.Errorf("#%d: failed to create value", i) 55 break 56 } 57 58 m1 := v.Interface().(handshakeMessage) 59 marshaled := m1.marshal() 60 m2 := iface.(handshakeMessage) 61 if !m2.unmarshal(marshaled) { 62 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 63 break 64 } 65 m2.marshal() // to fill any marshal cache in the message 66 67 if !reflect.DeepEqual(m1, m2) { 68 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 69 break 70 } 71 72 if i >= 3 { 73 // The first three message types (ClientHello, 74 // ServerHello and Finished) are allowed to 75 // have parsable prefixes because the extension 76 // data is optional and the length of the 77 // Finished varies across versions. 78 for j := 0; j < len(marshaled); j++ { 79 if m2.unmarshal(marshaled[0:j]) { 80 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 81 break 82 } 83 } 84 } 85 } 86 } 87 } 88 89 func TestFuzz(t *testing.T) { 90 rand := rand.New(rand.NewSource(0)) 91 for _, iface := range tests { 92 m := iface.(handshakeMessage) 93 94 for j := 0; j < 1000; j++ { 95 len := rand.Intn(100) 96 bytes := randomBytes(len, rand) 97 // This just looks for crashes due to bounds errors etc. 98 m.unmarshal(bytes) 99 } 100 } 101 } 102 103 func randomBytes(n int, rand *rand.Rand) []byte { 104 r := make([]byte, n) 105 if _, err := rand.Read(r); err != nil { 106 panic("rand.Read failed: " + err.Error()) 107 } 108 return r 109 } 110 111 func randomString(n int, rand *rand.Rand) string { 112 b := randomBytes(n, rand) 113 return string(b) 114 } 115 116 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 117 m := &clientHelloMsg{} 118 m.vers = uint16(rand.Intn(65536)) 119 m.random = randomBytes(32, rand) 120 m.sessionId = randomBytes(rand.Intn(32), rand) 121 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 122 for i := 0; i < len(m.cipherSuites); i++ { 123 cs := uint16(rand.Int31()) 124 if cs == scsvRenegotiation { 125 cs += 1 126 } 127 m.cipherSuites[i] = cs 128 } 129 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 130 if rand.Intn(10) > 5 { 131 m.serverName = randomString(rand.Intn(255), rand) 132 for strings.HasSuffix(m.serverName, ".") { 133 m.serverName = m.serverName[:len(m.serverName)-1] 134 } 135 } 136 m.ocspStapling = rand.Intn(10) > 5 137 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 138 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 139 for i := range m.supportedCurves { 140 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 141 } 142 if rand.Intn(10) > 5 { 143 m.ticketSupported = true 144 if rand.Intn(10) > 5 { 145 m.sessionTicket = randomBytes(rand.Intn(300), rand) 146 } else { 147 m.sessionTicket = make([]byte, 0) 148 } 149 } 150 if rand.Intn(10) > 5 { 151 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 152 } 153 if rand.Intn(10) > 5 { 154 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 155 } 156 for i := 0; i < rand.Intn(5); i++ { 157 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 158 } 159 if rand.Intn(10) > 5 { 160 m.scts = true 161 } 162 if rand.Intn(10) > 5 { 163 m.secureRenegotiationSupported = true 164 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 165 } 166 for i := 0; i < rand.Intn(5); i++ { 167 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 168 } 169 if rand.Intn(10) > 5 { 170 m.cookie = randomBytes(rand.Intn(500)+1, rand) 171 } 172 for i := 0; i < rand.Intn(5); i++ { 173 var ks keyShare 174 ks.group = CurveID(rand.Intn(30000) + 1) 175 ks.data = randomBytes(rand.Intn(200)+1, rand) 176 m.keyShares = append(m.keyShares, ks) 177 } 178 switch rand.Intn(3) { 179 case 1: 180 m.pskModes = []uint8{pskModeDHE} 181 case 2: 182 m.pskModes = []uint8{pskModeDHE, pskModePlain} 183 } 184 for i := 0; i < rand.Intn(5); i++ { 185 var psk pskIdentity 186 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 187 psk.label = randomBytes(rand.Intn(500)+1, rand) 188 m.pskIdentities = append(m.pskIdentities, psk) 189 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 190 } 191 if rand.Intn(10) > 5 { 192 m.earlyData = true 193 } 194 195 return reflect.ValueOf(m) 196 } 197 198 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 199 m := &serverHelloMsg{} 200 m.vers = uint16(rand.Intn(65536)) 201 m.random = randomBytes(32, rand) 202 m.sessionId = randomBytes(rand.Intn(32), rand) 203 m.cipherSuite = uint16(rand.Int31()) 204 m.compressionMethod = uint8(rand.Intn(256)) 205 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 206 207 if rand.Intn(10) > 5 { 208 m.ocspStapling = true 209 } 210 if rand.Intn(10) > 5 { 211 m.ticketSupported = true 212 } 213 if rand.Intn(10) > 5 { 214 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 215 } 216 217 for i := 0; i < rand.Intn(4); i++ { 218 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 219 } 220 221 if rand.Intn(10) > 5 { 222 m.secureRenegotiationSupported = true 223 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 224 } 225 if rand.Intn(10) > 5 { 226 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 227 } 228 if rand.Intn(10) > 5 { 229 m.cookie = randomBytes(rand.Intn(500)+1, rand) 230 } 231 if rand.Intn(10) > 5 { 232 for i := 0; i < rand.Intn(5); i++ { 233 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 234 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 235 } 236 } else if rand.Intn(10) > 5 { 237 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 238 } 239 if rand.Intn(10) > 5 { 240 m.selectedIdentityPresent = true 241 m.selectedIdentity = uint16(rand.Intn(0xffff)) 242 } 243 244 return reflect.ValueOf(m) 245 } 246 247 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 248 m := &encryptedExtensionsMsg{} 249 250 if rand.Intn(10) > 5 { 251 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 252 } 253 254 return reflect.ValueOf(m) 255 } 256 257 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 258 m := &certificateMsg{} 259 numCerts := rand.Intn(20) 260 m.certificates = make([][]byte, numCerts) 261 for i := 0; i < numCerts; i++ { 262 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 263 } 264 return reflect.ValueOf(m) 265 } 266 267 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 268 m := &certificateRequestMsg{} 269 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 270 for i := 0; i < rand.Intn(100); i++ { 271 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 272 } 273 return reflect.ValueOf(m) 274 } 275 276 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 277 m := &certificateVerifyMsg{} 278 m.hasSignatureAlgorithm = true 279 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 280 m.signature = randomBytes(rand.Intn(15)+1, rand) 281 return reflect.ValueOf(m) 282 } 283 284 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 285 m := &certificateStatusMsg{} 286 m.response = randomBytes(rand.Intn(10)+1, rand) 287 return reflect.ValueOf(m) 288 } 289 290 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 291 m := &clientKeyExchangeMsg{} 292 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 293 return reflect.ValueOf(m) 294 } 295 296 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 297 m := &finishedMsg{} 298 m.verifyData = randomBytes(12, rand) 299 return reflect.ValueOf(m) 300 } 301 302 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 303 m := &newSessionTicketMsg{} 304 m.ticket = randomBytes(rand.Intn(4), rand) 305 return reflect.ValueOf(m) 306 } 307 308 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 309 s := &sessionState{} 310 s.vers = uint16(rand.Intn(10000)) 311 s.cipherSuite = uint16(rand.Intn(10000)) 312 s.masterSecret = randomBytes(rand.Intn(100)+1, rand) 313 s.createdAt = uint64(rand.Int63()) 314 for i := 0; i < rand.Intn(20); i++ { 315 s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) 316 } 317 return reflect.ValueOf(s) 318 } 319 320 func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 321 s := &sessionStateTLS13{} 322 s.cipherSuite = uint16(rand.Intn(10000)) 323 s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) 324 s.createdAt = uint64(rand.Int63()) 325 for i := 0; i < rand.Intn(2)+1; i++ { 326 s.certificate.Certificate = append( 327 s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 328 } 329 if rand.Intn(10) > 5 { 330 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 331 } 332 if rand.Intn(10) > 5 { 333 for i := 0; i < rand.Intn(2)+1; i++ { 334 s.certificate.SignedCertificateTimestamps = append( 335 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 336 } 337 } 338 return reflect.ValueOf(s) 339 } 340 341 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 342 m := &endOfEarlyDataMsg{} 343 return reflect.ValueOf(m) 344 } 345 346 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 347 m := &keyUpdateMsg{} 348 m.updateRequested = rand.Intn(10) > 5 349 return reflect.ValueOf(m) 350 } 351 352 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 353 m := &newSessionTicketMsgTLS13{} 354 m.lifetime = uint32(rand.Intn(500000)) 355 m.ageAdd = uint32(rand.Intn(500000)) 356 m.nonce = randomBytes(rand.Intn(100), rand) 357 m.label = randomBytes(rand.Intn(1000), rand) 358 if rand.Intn(10) > 5 { 359 m.maxEarlyData = uint32(rand.Intn(500000)) 360 } 361 return reflect.ValueOf(m) 362 } 363 364 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 365 m := &certificateRequestMsgTLS13{} 366 if rand.Intn(10) > 5 { 367 m.ocspStapling = true 368 } 369 if rand.Intn(10) > 5 { 370 m.scts = true 371 } 372 if rand.Intn(10) > 5 { 373 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 374 } 375 if rand.Intn(10) > 5 { 376 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 377 } 378 if rand.Intn(10) > 5 { 379 m.certificateAuthorities = make([][]byte, 3) 380 for i := 0; i < 3; i++ { 381 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 382 } 383 } 384 return reflect.ValueOf(m) 385 } 386 387 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 388 m := &certificateMsgTLS13{} 389 for i := 0; i < rand.Intn(2)+1; i++ { 390 m.certificate.Certificate = append( 391 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 392 } 393 if rand.Intn(10) > 5 { 394 m.ocspStapling = true 395 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 396 } 397 if rand.Intn(10) > 5 { 398 m.scts = true 399 for i := 0; i < rand.Intn(2)+1; i++ { 400 m.certificate.SignedCertificateTimestamps = append( 401 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 402 } 403 } 404 return reflect.ValueOf(m) 405 } 406 407 func TestRejectEmptySCTList(t *testing.T) { 408 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 409 410 var random [32]byte 411 sct := []byte{0x42, 0x42, 0x42, 0x42} 412 serverHello := serverHelloMsg{ 413 vers: VersionTLS12, 414 random: random[:], 415 scts: [][]byte{sct}, 416 } 417 serverHelloBytes := serverHello.marshal() 418 419 var serverHelloCopy serverHelloMsg 420 if !serverHelloCopy.unmarshal(serverHelloBytes) { 421 t.Fatal("Failed to unmarshal initial message") 422 } 423 424 // Change serverHelloBytes so that the SCT list is empty 425 i := bytes.Index(serverHelloBytes, sct) 426 if i < 0 { 427 t.Fatal("Cannot find SCT in ServerHello") 428 } 429 430 var serverHelloEmptySCT []byte 431 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 432 // Append the extension length and SCT list length for an empty list. 433 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 434 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 435 436 // Update the handshake message length. 437 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 438 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 439 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 440 441 // Update the extensions length 442 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 443 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 444 445 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 446 t.Fatal("Unmarshaled ServerHello with empty SCT list") 447 } 448 } 449 450 func TestRejectEmptySCT(t *testing.T) { 451 // Not only must the SCT list be non-empty, but the SCT elements must 452 // not be zero length. 453 454 var random [32]byte 455 serverHello := serverHelloMsg{ 456 vers: VersionTLS12, 457 random: random[:], 458 scts: [][]byte{nil}, 459 } 460 serverHelloBytes := serverHello.marshal() 461 462 var serverHelloCopy serverHelloMsg 463 if serverHelloCopy.unmarshal(serverHelloBytes) { 464 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 465 } 466 } 467 468 func TestRejectDuplicateExtensions(t *testing.T) { 469 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f") 470 if err != nil { 471 t.Fatalf("failed to decode test ClientHello: %s", err) 472 } 473 var clientHelloCopy clientHelloMsg 474 if clientHelloCopy.unmarshal(clientHelloBytes) { 475 t.Error("Unmarshaled ClientHello with duplicate extensions") 476 } 477 478 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000") 479 if err != nil { 480 t.Fatalf("failed to decode test ServerHello: %s", err) 481 } 482 var serverHelloCopy serverHelloMsg 483 if serverHelloCopy.unmarshal(serverHelloBytes) { 484 t.Fatal("Unmarshaled ServerHello with duplicate extensions") 485 } 486 }