gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/handshake_messages_test.go (about) 1 // Copyright (c) 2022 zhaochun 2 // core-gm is licensed under Mulan PSL v2. 3 // You can use this software according to the terms and conditions of the Mulan PSL v2. 4 // You may obtain a copy of Mulan PSL v2 at: 5 // http://license.coscl.org.cn/MulanPSL2 6 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 7 // See the Mulan PSL v2 for more details. 8 9 /* 10 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 11 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 12 */ 13 14 package gmtls 15 16 import ( 17 "bytes" 18 "math/rand" 19 "reflect" 20 "strings" 21 "testing" 22 "testing/quick" 23 "time" 24 ) 25 26 var tests = []interface{}{ 27 &clientHelloMsg{}, 28 &serverHelloMsg{}, 29 &finishedMsg{}, 30 31 &certificateMsg{}, 32 &certificateRequestMsg{}, 33 &certificateVerifyMsg{ 34 hasSignatureAlgorithm: true, 35 }, 36 &certificateStatusMsg{}, 37 &clientKeyExchangeMsg{}, 38 &newSessionTicketMsg{}, 39 &sessionState{}, 40 &sessionStateTLS13{}, 41 &encryptedExtensionsMsg{}, 42 &endOfEarlyDataMsg{}, 43 &keyUpdateMsg{}, 44 &newSessionTicketMsgTLS13{}, 45 &certificateRequestMsgTLS13{}, 46 &certificateMsgTLS13{}, 47 } 48 49 func TestMarshalUnmarshal(t *testing.T) { 50 randNew := rand.New(rand.NewSource(time.Now().UnixNano())) 51 52 for i, iface := range tests { 53 ty := reflect.ValueOf(iface).Type() 54 55 n := 100 56 if testing.Short() { 57 n = 5 58 } 59 for j := 0; j < n; j++ { 60 v, ok := quick.Value(ty, randNew) 61 if !ok { 62 t.Errorf("#%d: failed to create value", i) 63 break 64 } 65 66 m1 := v.Interface().(handshakeMessage) 67 marshaled := m1.marshal() 68 m2 := iface.(handshakeMessage) 69 if !m2.unmarshal(marshaled) { 70 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 71 break 72 } 73 m2.marshal() // to fill any marshal cache in the message 74 75 if !reflect.DeepEqual(m1, m2) { 76 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 77 break 78 } 79 80 if i >= 3 { 81 // The first three message types (ClientHello, 82 // ServerHello and Finished) are allowed to 83 // have parsable prefixes because the extension 84 // data is optional and the length of the 85 // Finished varies across versions. 86 for j := 0; j < len(marshaled); j++ { 87 if m2.unmarshal(marshaled[0:j]) { 88 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 89 break 90 } 91 } 92 } 93 } 94 } 95 } 96 97 func TestFuzz(t *testing.T) { 98 randNew := rand.New(rand.NewSource(0)) 99 for _, iface := range tests { 100 m := iface.(handshakeMessage) 101 102 for j := 0; j < 1000; j++ { 103 length := randNew.Intn(100) 104 ranBytes := randomBytes(length, randNew) 105 // This just looks for crashes due to bounds errors etc. 106 m.unmarshal(ranBytes) 107 } 108 } 109 } 110 111 func randomBytes(n int, rand *rand.Rand) []byte { 112 r := make([]byte, n) 113 if _, err := rand.Read(r); err != nil { 114 panic("rand.Read failed: " + err.Error()) 115 } 116 return r 117 } 118 119 func randomString(n int, rand *rand.Rand) string { 120 b := randomBytes(n, rand) 121 return string(b) 122 } 123 124 //goland:noinspection GoUnusedParameter 125 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 126 m := &clientHelloMsg{} 127 m.vers = uint16(rand.Intn(65536)) 128 m.random = randomBytes(32, rand) 129 m.sessionId = randomBytes(rand.Intn(32), rand) 130 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 131 for i := 0; i < len(m.cipherSuites); i++ { 132 cs := uint16(rand.Int31()) 133 if cs == scsvRenegotiation { 134 cs += 1 135 } 136 m.cipherSuites[i] = cs 137 } 138 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 139 if rand.Intn(10) > 5 { 140 m.serverName = randomString(rand.Intn(255), rand) 141 for strings.HasSuffix(m.serverName, ".") { 142 m.serverName = m.serverName[:len(m.serverName)-1] 143 } 144 } 145 m.ocspStapling = rand.Intn(10) > 5 146 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 147 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 148 for i := range m.supportedCurves { 149 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 150 } 151 if rand.Intn(10) > 5 { 152 m.ticketSupported = true 153 if rand.Intn(10) > 5 { 154 m.sessionTicket = randomBytes(rand.Intn(300), rand) 155 } else { 156 m.sessionTicket = make([]byte, 0) 157 } 158 } 159 if rand.Intn(10) > 5 { 160 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 161 } 162 if rand.Intn(10) > 5 { 163 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 164 } 165 for i := 0; i < rand.Intn(5); i++ { 166 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 167 } 168 if rand.Intn(10) > 5 { 169 m.scts = true 170 } 171 if rand.Intn(10) > 5 { 172 m.secureRenegotiationSupported = true 173 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 174 } 175 for i := 0; i < rand.Intn(5); i++ { 176 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 177 } 178 if rand.Intn(10) > 5 { 179 m.cookie = randomBytes(rand.Intn(500)+1, rand) 180 } 181 for i := 0; i < rand.Intn(5); i++ { 182 var ks keyShare 183 ks.group = CurveID(rand.Intn(30000) + 1) 184 ks.data = randomBytes(rand.Intn(200)+1, rand) 185 m.keyShares = append(m.keyShares, ks) 186 } 187 switch rand.Intn(3) { 188 case 1: 189 m.pskModes = []uint8{pskModeDHE} 190 case 2: 191 m.pskModes = []uint8{pskModeDHE, pskModePlain} 192 } 193 for i := 0; i < rand.Intn(5); i++ { 194 var psk pskIdentity 195 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 196 psk.label = randomBytes(rand.Intn(500)+1, rand) 197 m.pskIdentities = append(m.pskIdentities, psk) 198 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 199 } 200 if rand.Intn(10) > 5 { 201 m.earlyData = true 202 } 203 204 return reflect.ValueOf(m) 205 } 206 207 //goland:noinspection GoUnusedParameter 208 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 209 m := &serverHelloMsg{} 210 m.vers = uint16(rand.Intn(65536)) 211 m.random = randomBytes(32, rand) 212 m.sessionId = randomBytes(rand.Intn(32), rand) 213 m.cipherSuite = uint16(rand.Int31()) 214 m.compressionMethod = uint8(rand.Intn(256)) 215 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 216 217 if rand.Intn(10) > 5 { 218 m.ocspStapling = true 219 } 220 if rand.Intn(10) > 5 { 221 m.ticketSupported = true 222 } 223 if rand.Intn(10) > 5 { 224 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 225 } 226 227 for i := 0; i < rand.Intn(4); i++ { 228 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 229 } 230 231 if rand.Intn(10) > 5 { 232 m.secureRenegotiationSupported = true 233 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 234 } 235 if rand.Intn(10) > 5 { 236 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 237 } 238 if rand.Intn(10) > 5 { 239 m.cookie = randomBytes(rand.Intn(500)+1, rand) 240 } 241 if rand.Intn(10) > 5 { 242 for i := 0; i < rand.Intn(5); i++ { 243 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 244 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 245 } 246 } else if rand.Intn(10) > 5 { 247 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 248 } 249 if rand.Intn(10) > 5 { 250 m.selectedIdentityPresent = true 251 m.selectedIdentity = uint16(rand.Intn(0xffff)) 252 } 253 254 return reflect.ValueOf(m) 255 } 256 257 //goland:noinspection GoUnusedParameter 258 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 259 m := &encryptedExtensionsMsg{} 260 261 if rand.Intn(10) > 5 { 262 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 263 } 264 265 return reflect.ValueOf(m) 266 } 267 268 //goland:noinspection GoUnusedParameter 269 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 270 m := &certificateMsg{} 271 numCerts := rand.Intn(20) 272 m.certificates = make([][]byte, numCerts) 273 for i := 0; i < numCerts; i++ { 274 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 275 } 276 return reflect.ValueOf(m) 277 } 278 279 //goland:noinspection GoUnusedParameter 280 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 281 m := &certificateRequestMsg{} 282 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 283 for i := 0; i < rand.Intn(100); i++ { 284 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 285 } 286 return reflect.ValueOf(m) 287 } 288 289 //goland:noinspection GoUnusedParameter 290 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 291 m := &certificateVerifyMsg{} 292 m.hasSignatureAlgorithm = true 293 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 294 m.signature = randomBytes(rand.Intn(15)+1, rand) 295 return reflect.ValueOf(m) 296 } 297 298 //goland:noinspection GoUnusedParameter 299 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 300 m := &certificateStatusMsg{} 301 m.response = randomBytes(rand.Intn(10)+1, rand) 302 return reflect.ValueOf(m) 303 } 304 305 //goland:noinspection GoUnusedParameter 306 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 307 m := &clientKeyExchangeMsg{} 308 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 309 return reflect.ValueOf(m) 310 } 311 312 //goland:noinspection GoUnusedParameter 313 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 314 m := &finishedMsg{} 315 m.verifyData = randomBytes(12, rand) 316 return reflect.ValueOf(m) 317 } 318 319 //goland:noinspection GoUnusedParameter 320 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 321 m := &newSessionTicketMsg{} 322 m.ticket = randomBytes(rand.Intn(4), rand) 323 return reflect.ValueOf(m) 324 } 325 326 //goland:noinspection GoUnusedParameter 327 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 328 s := &sessionState{} 329 s.vers = uint16(rand.Intn(10000)) 330 s.cipherSuite = uint16(rand.Intn(10000)) 331 s.masterSecret = randomBytes(rand.Intn(100)+1, rand) 332 s.createdAt = uint64(rand.Int63()) 333 for i := 0; i < rand.Intn(20); i++ { 334 s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) 335 } 336 return reflect.ValueOf(s) 337 } 338 339 //goland:noinspection GoUnusedParameter 340 func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 341 s := &sessionStateTLS13{} 342 s.cipherSuite = uint16(rand.Intn(10000)) 343 s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) 344 s.createdAt = uint64(rand.Int63()) 345 for i := 0; i < rand.Intn(2)+1; i++ { 346 s.certificate.Certificate = append( 347 s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 348 } 349 if rand.Intn(10) > 5 { 350 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 351 } 352 if rand.Intn(10) > 5 { 353 for i := 0; i < rand.Intn(2)+1; i++ { 354 s.certificate.SignedCertificateTimestamps = append( 355 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 356 } 357 } 358 return reflect.ValueOf(s) 359 } 360 361 //goland:noinspection GoUnusedParameter 362 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 363 m := &endOfEarlyDataMsg{} 364 return reflect.ValueOf(m) 365 } 366 367 //goland:noinspection GoUnusedParameter 368 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 369 m := &keyUpdateMsg{} 370 m.updateRequested = rand.Intn(10) > 5 371 return reflect.ValueOf(m) 372 } 373 374 //goland:noinspection GoUnusedParameter 375 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 376 m := &newSessionTicketMsgTLS13{} 377 m.lifetime = uint32(rand.Intn(500000)) 378 m.ageAdd = uint32(rand.Intn(500000)) 379 m.nonce = randomBytes(rand.Intn(100), rand) 380 m.label = randomBytes(rand.Intn(1000), rand) 381 if rand.Intn(10) > 5 { 382 m.maxEarlyData = uint32(rand.Intn(500000)) 383 } 384 return reflect.ValueOf(m) 385 } 386 387 //goland:noinspection GoUnusedParameter 388 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 389 m := &certificateRequestMsgTLS13{} 390 if rand.Intn(10) > 5 { 391 m.ocspStapling = true 392 } 393 if rand.Intn(10) > 5 { 394 m.scts = true 395 } 396 if rand.Intn(10) > 5 { 397 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 398 } 399 if rand.Intn(10) > 5 { 400 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 401 } 402 if rand.Intn(10) > 5 { 403 m.certificateAuthorities = make([][]byte, 3) 404 for i := 0; i < 3; i++ { 405 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 406 } 407 } 408 return reflect.ValueOf(m) 409 } 410 411 //goland:noinspection GoUnusedParameter 412 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 413 m := &certificateMsgTLS13{} 414 for i := 0; i < rand.Intn(2)+1; i++ { 415 m.certificate.Certificate = append( 416 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 417 } 418 if rand.Intn(10) > 5 { 419 m.ocspStapling = true 420 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 421 } 422 if rand.Intn(10) > 5 { 423 m.scts = true 424 for i := 0; i < rand.Intn(2)+1; i++ { 425 m.certificate.SignedCertificateTimestamps = append( 426 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 427 } 428 } 429 return reflect.ValueOf(m) 430 } 431 432 func TestRejectEmptySCTList(t *testing.T) { 433 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 434 435 var random [32]byte 436 sct := []byte{0x42, 0x42, 0x42, 0x42} 437 serverHello := serverHelloMsg{ 438 vers: VersionTLS12, 439 random: random[:], 440 scts: [][]byte{sct}, 441 } 442 serverHelloBytes := serverHello.marshal() 443 444 var serverHelloCopy serverHelloMsg 445 if !serverHelloCopy.unmarshal(serverHelloBytes) { 446 t.Fatal("Failed to unmarshal initial message") 447 } 448 449 // Change serverHelloBytes so that the SCT list is empty 450 i := bytes.Index(serverHelloBytes, sct) 451 if i < 0 { 452 t.Fatal("Cannot find SCT in ServerHello") 453 } 454 455 var serverHelloEmptySCT []byte 456 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 457 // Append the extension length and SCT list length for an empty list. 458 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 459 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 460 461 // Update the handshake message length. 462 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 463 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 464 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 465 466 // Update the extensions length 467 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 468 serverHelloEmptySCT[43] = byte(len(serverHelloEmptySCT) - 44) 469 470 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 471 t.Fatal("Unmarshaled ServerHello with empty SCT list") 472 } 473 } 474 475 func TestRejectEmptySCT(t *testing.T) { 476 // Not only must the SCT list be non-empty, but the SCT elements must 477 // not be zero length. 478 479 var random [32]byte 480 serverHello := serverHelloMsg{ 481 vers: VersionTLS12, 482 random: random[:], 483 scts: [][]byte{nil}, 484 } 485 serverHelloBytes := serverHello.marshal() 486 487 var serverHelloCopy serverHelloMsg 488 if serverHelloCopy.unmarshal(serverHelloBytes) { 489 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 490 } 491 }