github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmtls/handshake_messages_test.go (about) 1 // Copyright 2022 s1ren@github.com/hxx258456. 2 3 /* 4 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 5 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 6 */ 7 8 package gmtls 9 10 import ( 11 "bytes" 12 "math/rand" 13 "reflect" 14 "strings" 15 "testing" 16 "testing/quick" 17 "time" 18 ) 19 20 var tests = []interface{}{ 21 &clientHelloMsg{}, 22 &serverHelloMsg{}, 23 &finishedMsg{}, 24 25 &certificateMsg{}, 26 &certificateRequestMsg{}, 27 &certificateVerifyMsg{ 28 hasSignatureAlgorithm: true, 29 }, 30 &certificateStatusMsg{}, 31 &clientKeyExchangeMsg{}, 32 &newSessionTicketMsg{}, 33 &sessionState{}, 34 &sessionStateTLS13{}, 35 &encryptedExtensionsMsg{}, 36 &endOfEarlyDataMsg{}, 37 &keyUpdateMsg{}, 38 &newSessionTicketMsgTLS13{}, 39 &certificateRequestMsgTLS13{}, 40 &certificateMsgTLS13{}, 41 } 42 43 func TestMarshalUnmarshal(t *testing.T) { 44 randNew := rand.New(rand.NewSource(time.Now().UnixNano())) 45 46 for i, iface := range tests { 47 ty := reflect.ValueOf(iface).Type() 48 49 n := 100 50 if testing.Short() { 51 n = 5 52 } 53 for j := 0; j < n; j++ { 54 v, ok := quick.Value(ty, randNew) 55 if !ok { 56 t.Errorf("#%d: failed to create value", i) 57 break 58 } 59 60 m1 := v.Interface().(handshakeMessage) 61 marshaled := m1.marshal() 62 m2 := iface.(handshakeMessage) 63 if !m2.unmarshal(marshaled) { 64 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 65 break 66 } 67 m2.marshal() // to fill any marshal cache in the message 68 69 if !reflect.DeepEqual(m1, m2) { 70 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 71 break 72 } 73 74 if i >= 3 { 75 // The first three message types (ClientHello, 76 // ServerHello and Finished) are allowed to 77 // have parsable prefixes because the extension 78 // data is optional and the length of the 79 // Finished varies across versions. 80 for j := 0; j < len(marshaled); j++ { 81 if m2.unmarshal(marshaled[0:j]) { 82 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 83 break 84 } 85 } 86 } 87 } 88 } 89 } 90 91 func TestFuzz(t *testing.T) { 92 randNew := rand.New(rand.NewSource(0)) 93 for _, iface := range tests { 94 m := iface.(handshakeMessage) 95 96 for j := 0; j < 1000; j++ { 97 length := randNew.Intn(100) 98 ranBytes := randomBytes(length, randNew) 99 // This just looks for crashes due to bounds errors etc. 100 m.unmarshal(ranBytes) 101 } 102 } 103 } 104 func randomBytes(n int, rand *rand.Rand) []byte { 105 r := make([]byte, n) 106 if _, err := rand.Read(r); err != nil { 107 panic("rand.Read failed: " + err.Error()) 108 } 109 return r 110 } 111 112 func randomString(n int, rand *rand.Rand) string { 113 b := randomBytes(n, rand) 114 return string(b) 115 } 116 117 //goland:noinspection GoUnusedParameter 118 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 119 m := &clientHelloMsg{} 120 m.vers = uint16(rand.Intn(65536)) 121 m.random = randomBytes(32, rand) 122 m.sessionId = randomBytes(rand.Intn(32), rand) 123 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 124 for i := 0; i < len(m.cipherSuites); i++ { 125 cs := uint16(rand.Int31()) 126 if cs == scsvRenegotiation { 127 cs += 1 128 } 129 m.cipherSuites[i] = cs 130 } 131 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 132 if rand.Intn(10) > 5 { 133 m.serverName = randomString(rand.Intn(255), rand) 134 for strings.HasSuffix(m.serverName, ".") { 135 m.serverName = m.serverName[:len(m.serverName)-1] 136 } 137 } 138 m.ocspStapling = rand.Intn(10) > 5 139 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 140 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 141 for i := range m.supportedCurves { 142 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 143 } 144 if rand.Intn(10) > 5 { 145 m.ticketSupported = true 146 if rand.Intn(10) > 5 { 147 m.sessionTicket = randomBytes(rand.Intn(300), rand) 148 } else { 149 m.sessionTicket = make([]byte, 0) 150 } 151 } 152 if rand.Intn(10) > 5 { 153 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 154 } 155 if rand.Intn(10) > 5 { 156 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 157 } 158 for i := 0; i < rand.Intn(5); i++ { 159 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 160 } 161 if rand.Intn(10) > 5 { 162 m.scts = true 163 } 164 if rand.Intn(10) > 5 { 165 m.secureRenegotiationSupported = true 166 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 167 } 168 for i := 0; i < rand.Intn(5); i++ { 169 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 170 } 171 if rand.Intn(10) > 5 { 172 m.cookie = randomBytes(rand.Intn(500)+1, rand) 173 } 174 for i := 0; i < rand.Intn(5); i++ { 175 var ks keyShare 176 ks.group = CurveID(rand.Intn(30000) + 1) 177 ks.data = randomBytes(rand.Intn(200)+1, rand) 178 m.keyShares = append(m.keyShares, ks) 179 } 180 switch rand.Intn(3) { 181 case 1: 182 m.pskModes = []uint8{pskModeDHE} 183 case 2: 184 m.pskModes = []uint8{pskModeDHE, pskModePlain} 185 } 186 for i := 0; i < rand.Intn(5); i++ { 187 var psk pskIdentity 188 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 189 psk.label = randomBytes(rand.Intn(500)+1, rand) 190 m.pskIdentities = append(m.pskIdentities, psk) 191 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 192 } 193 if rand.Intn(10) > 5 { 194 m.earlyData = true 195 } 196 197 return reflect.ValueOf(m) 198 } 199 200 //goland:noinspection GoUnusedParameter 201 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 202 m := &serverHelloMsg{} 203 m.vers = uint16(rand.Intn(65536)) 204 m.random = randomBytes(32, rand) 205 m.sessionId = randomBytes(rand.Intn(32), rand) 206 m.cipherSuite = uint16(rand.Int31()) 207 m.compressionMethod = uint8(rand.Intn(256)) 208 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 209 210 if rand.Intn(10) > 5 { 211 m.ocspStapling = true 212 } 213 if rand.Intn(10) > 5 { 214 m.ticketSupported = true 215 } 216 if rand.Intn(10) > 5 { 217 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 218 } 219 220 for i := 0; i < rand.Intn(4); i++ { 221 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 222 } 223 224 if rand.Intn(10) > 5 { 225 m.secureRenegotiationSupported = true 226 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 227 } 228 if rand.Intn(10) > 5 { 229 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 230 } 231 if rand.Intn(10) > 5 { 232 m.cookie = randomBytes(rand.Intn(500)+1, rand) 233 } 234 if rand.Intn(10) > 5 { 235 for i := 0; i < rand.Intn(5); i++ { 236 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 237 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 238 } 239 } else if rand.Intn(10) > 5 { 240 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 241 } 242 if rand.Intn(10) > 5 { 243 m.selectedIdentityPresent = true 244 m.selectedIdentity = uint16(rand.Intn(0xffff)) 245 } 246 247 return reflect.ValueOf(m) 248 } 249 250 //goland:noinspection GoUnusedParameter 251 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 252 m := &encryptedExtensionsMsg{} 253 254 if rand.Intn(10) > 5 { 255 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 256 } 257 258 return reflect.ValueOf(m) 259 } 260 261 //goland:noinspection GoUnusedParameter 262 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 263 m := &certificateMsg{} 264 numCerts := rand.Intn(20) 265 m.certificates = make([][]byte, numCerts) 266 for i := 0; i < numCerts; i++ { 267 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 268 } 269 return reflect.ValueOf(m) 270 } 271 272 //goland:noinspection GoUnusedParameter 273 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 274 m := &certificateRequestMsg{} 275 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 276 for i := 0; i < rand.Intn(100); i++ { 277 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 278 } 279 return reflect.ValueOf(m) 280 } 281 282 //goland:noinspection GoUnusedParameter 283 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 284 m := &certificateVerifyMsg{} 285 m.hasSignatureAlgorithm = true 286 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 287 m.signature = randomBytes(rand.Intn(15)+1, rand) 288 return reflect.ValueOf(m) 289 } 290 291 //goland:noinspection GoUnusedParameter 292 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 293 m := &certificateStatusMsg{} 294 m.response = randomBytes(rand.Intn(10)+1, rand) 295 return reflect.ValueOf(m) 296 } 297 298 //goland:noinspection GoUnusedParameter 299 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 300 m := &clientKeyExchangeMsg{} 301 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 302 return reflect.ValueOf(m) 303 } 304 305 //goland:noinspection GoUnusedParameter 306 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 307 m := &finishedMsg{} 308 m.verifyData = randomBytes(12, rand) 309 return reflect.ValueOf(m) 310 } 311 312 //goland:noinspection GoUnusedParameter 313 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 314 m := &newSessionTicketMsg{} 315 m.ticket = randomBytes(rand.Intn(4), rand) 316 return reflect.ValueOf(m) 317 } 318 319 //goland:noinspection GoUnusedParameter 320 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 321 s := &sessionState{} 322 s.vers = uint16(rand.Intn(10000)) 323 s.cipherSuite = uint16(rand.Intn(10000)) 324 s.masterSecret = randomBytes(rand.Intn(100)+1, rand) 325 s.createdAt = uint64(rand.Int63()) 326 for i := 0; i < rand.Intn(20); i++ { 327 s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) 328 } 329 return reflect.ValueOf(s) 330 } 331 332 //goland:noinspection GoUnusedParameter 333 func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 334 s := &sessionStateTLS13{} 335 s.cipherSuite = uint16(rand.Intn(10000)) 336 s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) 337 s.createdAt = uint64(rand.Int63()) 338 for i := 0; i < rand.Intn(2)+1; i++ { 339 s.certificate.Certificate = append( 340 s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 341 } 342 if rand.Intn(10) > 5 { 343 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 344 } 345 if rand.Intn(10) > 5 { 346 for i := 0; i < rand.Intn(2)+1; i++ { 347 s.certificate.SignedCertificateTimestamps = append( 348 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 349 } 350 } 351 return reflect.ValueOf(s) 352 } 353 354 //goland:noinspection GoUnusedParameter 355 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 356 m := &endOfEarlyDataMsg{} 357 return reflect.ValueOf(m) 358 } 359 360 //goland:noinspection GoUnusedParameter 361 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 362 m := &keyUpdateMsg{} 363 m.updateRequested = rand.Intn(10) > 5 364 return reflect.ValueOf(m) 365 } 366 367 //goland:noinspection GoUnusedParameter 368 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 369 m := &newSessionTicketMsgTLS13{} 370 m.lifetime = uint32(rand.Intn(500000)) 371 m.ageAdd = uint32(rand.Intn(500000)) 372 m.nonce = randomBytes(rand.Intn(100), rand) 373 m.label = randomBytes(rand.Intn(1000), rand) 374 if rand.Intn(10) > 5 { 375 m.maxEarlyData = uint32(rand.Intn(500000)) 376 } 377 return reflect.ValueOf(m) 378 } 379 380 //goland:noinspection GoUnusedParameter 381 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 382 m := &certificateRequestMsgTLS13{} 383 if rand.Intn(10) > 5 { 384 m.ocspStapling = true 385 } 386 if rand.Intn(10) > 5 { 387 m.scts = true 388 } 389 if rand.Intn(10) > 5 { 390 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 391 } 392 if rand.Intn(10) > 5 { 393 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 394 } 395 if rand.Intn(10) > 5 { 396 m.certificateAuthorities = make([][]byte, 3) 397 for i := 0; i < 3; i++ { 398 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 399 } 400 } 401 return reflect.ValueOf(m) 402 } 403 404 //goland:noinspection GoUnusedParameter 405 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 406 m := &certificateMsgTLS13{} 407 for i := 0; i < rand.Intn(2)+1; i++ { 408 m.certificate.Certificate = append( 409 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 410 } 411 if rand.Intn(10) > 5 { 412 m.ocspStapling = true 413 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 414 } 415 if rand.Intn(10) > 5 { 416 m.scts = true 417 for i := 0; i < rand.Intn(2)+1; i++ { 418 m.certificate.SignedCertificateTimestamps = append( 419 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 420 } 421 } 422 return reflect.ValueOf(m) 423 } 424 425 func TestRejectEmptySCTList(t *testing.T) { 426 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 427 428 var random [32]byte 429 sct := []byte{0x42, 0x42, 0x42, 0x42} 430 serverHello := serverHelloMsg{ 431 vers: VersionTLS12, 432 random: random[:], 433 scts: [][]byte{sct}, 434 } 435 serverHelloBytes := serverHello.marshal() 436 437 var serverHelloCopy serverHelloMsg 438 if !serverHelloCopy.unmarshal(serverHelloBytes) { 439 t.Fatal("Failed to unmarshal initial message") 440 } 441 442 // Change serverHelloBytes so that the SCT list is empty 443 i := bytes.Index(serverHelloBytes, sct) 444 if i < 0 { 445 t.Fatal("Cannot find SCT in ServerHello") 446 } 447 448 var serverHelloEmptySCT []byte 449 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 450 // Append the extension length and SCT list length for an empty list. 451 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 452 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 453 454 // Update the handshake message length. 455 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 456 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 457 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 458 459 // Update the extensions length 460 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 461 serverHelloEmptySCT[43] = byte(len(serverHelloEmptySCT) - 44) 462 463 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 464 t.Fatal("Unmarshaled ServerHello with empty SCT list") 465 } 466 } 467 468 func TestRejectEmptySCT(t *testing.T) { 469 // Not only must the SCT list be non-empty, but the SCT elements must 470 // not be zero length. 471 472 var random [32]byte 473 serverHello := serverHelloMsg{ 474 vers: VersionTLS12, 475 random: random[:], 476 scts: [][]byte{nil}, 477 } 478 serverHelloBytes := serverHello.marshal() 479 480 var serverHelloCopy serverHelloMsg 481 if serverHelloCopy.unmarshal(serverHelloBytes) { 482 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 483 } 484 }