github.com/hikaru7719/go@v0.0.0-20181025140707-c8b2ac68906a/src/crypto/tls/handshake_messages.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "strings" 10 ) 11 12 type clientHelloMsg struct { 13 raw []byte 14 vers uint16 15 random []byte 16 sessionId []byte 17 cipherSuites []uint16 18 compressionMethods []uint8 19 nextProtoNeg bool 20 serverName string 21 ocspStapling bool 22 scts bool 23 supportedCurves []CurveID 24 supportedPoints []uint8 25 ticketSupported bool 26 sessionTicket []uint8 27 supportedSignatureAlgorithms []SignatureScheme 28 secureRenegotiation []byte 29 secureRenegotiationSupported bool 30 alpnProtocols []string 31 } 32 33 func (m *clientHelloMsg) equal(i interface{}) bool { 34 m1, ok := i.(*clientHelloMsg) 35 if !ok { 36 return false 37 } 38 39 return bytes.Equal(m.raw, m1.raw) && 40 m.vers == m1.vers && 41 bytes.Equal(m.random, m1.random) && 42 bytes.Equal(m.sessionId, m1.sessionId) && 43 eqUint16s(m.cipherSuites, m1.cipherSuites) && 44 bytes.Equal(m.compressionMethods, m1.compressionMethods) && 45 m.nextProtoNeg == m1.nextProtoNeg && 46 m.serverName == m1.serverName && 47 m.ocspStapling == m1.ocspStapling && 48 m.scts == m1.scts && 49 eqCurveIDs(m.supportedCurves, m1.supportedCurves) && 50 bytes.Equal(m.supportedPoints, m1.supportedPoints) && 51 m.ticketSupported == m1.ticketSupported && 52 bytes.Equal(m.sessionTicket, m1.sessionTicket) && 53 eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) && 54 m.secureRenegotiationSupported == m1.secureRenegotiationSupported && 55 bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && 56 eqStrings(m.alpnProtocols, m1.alpnProtocols) 57 } 58 59 func (m *clientHelloMsg) marshal() []byte { 60 if m.raw != nil { 61 return m.raw 62 } 63 64 length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) 65 numExtensions := 0 66 extensionsLength := 0 67 if m.nextProtoNeg { 68 numExtensions++ 69 } 70 if m.ocspStapling { 71 extensionsLength += 1 + 2 + 2 72 numExtensions++ 73 } 74 if len(m.serverName) > 0 { 75 extensionsLength += 5 + len(m.serverName) 76 numExtensions++ 77 } 78 if len(m.supportedCurves) > 0 { 79 extensionsLength += 2 + 2*len(m.supportedCurves) 80 numExtensions++ 81 } 82 if len(m.supportedPoints) > 0 { 83 extensionsLength += 1 + len(m.supportedPoints) 84 numExtensions++ 85 } 86 if m.ticketSupported { 87 extensionsLength += len(m.sessionTicket) 88 numExtensions++ 89 } 90 if len(m.supportedSignatureAlgorithms) > 0 { 91 extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms) 92 numExtensions++ 93 } 94 if m.secureRenegotiationSupported { 95 extensionsLength += 1 + len(m.secureRenegotiation) 96 numExtensions++ 97 } 98 if len(m.alpnProtocols) > 0 { 99 extensionsLength += 2 100 for _, s := range m.alpnProtocols { 101 if l := len(s); l == 0 || l > 255 { 102 panic("invalid ALPN protocol") 103 } 104 extensionsLength++ 105 extensionsLength += len(s) 106 } 107 numExtensions++ 108 } 109 if m.scts { 110 numExtensions++ 111 } 112 if numExtensions > 0 { 113 extensionsLength += 4 * numExtensions 114 length += 2 + extensionsLength 115 } 116 117 x := make([]byte, 4+length) 118 x[0] = typeClientHello 119 x[1] = uint8(length >> 16) 120 x[2] = uint8(length >> 8) 121 x[3] = uint8(length) 122 x[4] = uint8(m.vers >> 8) 123 x[5] = uint8(m.vers) 124 copy(x[6:38], m.random) 125 x[38] = uint8(len(m.sessionId)) 126 copy(x[39:39+len(m.sessionId)], m.sessionId) 127 y := x[39+len(m.sessionId):] 128 y[0] = uint8(len(m.cipherSuites) >> 7) 129 y[1] = uint8(len(m.cipherSuites) << 1) 130 for i, suite := range m.cipherSuites { 131 y[2+i*2] = uint8(suite >> 8) 132 y[3+i*2] = uint8(suite) 133 } 134 z := y[2+len(m.cipherSuites)*2:] 135 z[0] = uint8(len(m.compressionMethods)) 136 copy(z[1:], m.compressionMethods) 137 138 z = z[1+len(m.compressionMethods):] 139 if numExtensions > 0 { 140 z[0] = byte(extensionsLength >> 8) 141 z[1] = byte(extensionsLength) 142 z = z[2:] 143 } 144 if m.nextProtoNeg { 145 z[0] = byte(extensionNextProtoNeg >> 8) 146 z[1] = byte(extensionNextProtoNeg & 0xff) 147 // The length is always 0 148 z = z[4:] 149 } 150 if len(m.serverName) > 0 { 151 z[0] = byte(extensionServerName >> 8) 152 z[1] = byte(extensionServerName & 0xff) 153 l := len(m.serverName) + 5 154 z[2] = byte(l >> 8) 155 z[3] = byte(l) 156 z = z[4:] 157 158 // RFC 3546, Section 3.1 159 // 160 // struct { 161 // NameType name_type; 162 // select (name_type) { 163 // case host_name: HostName; 164 // } name; 165 // } ServerName; 166 // 167 // enum { 168 // host_name(0), (255) 169 // } NameType; 170 // 171 // opaque HostName<1..2^16-1>; 172 // 173 // struct { 174 // ServerName server_name_list<1..2^16-1> 175 // } ServerNameList; 176 177 z[0] = byte((len(m.serverName) + 3) >> 8) 178 z[1] = byte(len(m.serverName) + 3) 179 z[3] = byte(len(m.serverName) >> 8) 180 z[4] = byte(len(m.serverName)) 181 copy(z[5:], []byte(m.serverName)) 182 z = z[l:] 183 } 184 if m.ocspStapling { 185 // RFC 4366, Section 3.6 186 z[0] = byte(extensionStatusRequest >> 8) 187 z[1] = byte(extensionStatusRequest) 188 z[2] = 0 189 z[3] = 5 190 z[4] = 1 // OCSP type 191 // Two zero valued uint16s for the two lengths. 192 z = z[9:] 193 } 194 if len(m.supportedCurves) > 0 { 195 // RFC 4492, Section 5.5.1 196 z[0] = byte(extensionSupportedCurves >> 8) 197 z[1] = byte(extensionSupportedCurves) 198 l := 2 + 2*len(m.supportedCurves) 199 z[2] = byte(l >> 8) 200 z[3] = byte(l) 201 l -= 2 202 z[4] = byte(l >> 8) 203 z[5] = byte(l) 204 z = z[6:] 205 for _, curve := range m.supportedCurves { 206 z[0] = byte(curve >> 8) 207 z[1] = byte(curve) 208 z = z[2:] 209 } 210 } 211 if len(m.supportedPoints) > 0 { 212 // RFC 4492, Section 5.5.2 213 z[0] = byte(extensionSupportedPoints >> 8) 214 z[1] = byte(extensionSupportedPoints) 215 l := 1 + len(m.supportedPoints) 216 z[2] = byte(l >> 8) 217 z[3] = byte(l) 218 l-- 219 z[4] = byte(l) 220 z = z[5:] 221 for _, pointFormat := range m.supportedPoints { 222 z[0] = pointFormat 223 z = z[1:] 224 } 225 } 226 if m.ticketSupported { 227 // RFC 5077, Section 3.2 228 z[0] = byte(extensionSessionTicket >> 8) 229 z[1] = byte(extensionSessionTicket) 230 l := len(m.sessionTicket) 231 z[2] = byte(l >> 8) 232 z[3] = byte(l) 233 z = z[4:] 234 copy(z, m.sessionTicket) 235 z = z[len(m.sessionTicket):] 236 } 237 if len(m.supportedSignatureAlgorithms) > 0 { 238 // RFC 5246, Section 7.4.1.4.1 239 z[0] = byte(extensionSignatureAlgorithms >> 8) 240 z[1] = byte(extensionSignatureAlgorithms) 241 l := 2 + 2*len(m.supportedSignatureAlgorithms) 242 z[2] = byte(l >> 8) 243 z[3] = byte(l) 244 z = z[4:] 245 246 l -= 2 247 z[0] = byte(l >> 8) 248 z[1] = byte(l) 249 z = z[2:] 250 for _, sigAlgo := range m.supportedSignatureAlgorithms { 251 z[0] = byte(sigAlgo >> 8) 252 z[1] = byte(sigAlgo) 253 z = z[2:] 254 } 255 } 256 if m.secureRenegotiationSupported { 257 z[0] = byte(extensionRenegotiationInfo >> 8) 258 z[1] = byte(extensionRenegotiationInfo & 0xff) 259 z[2] = 0 260 z[3] = byte(len(m.secureRenegotiation) + 1) 261 z[4] = byte(len(m.secureRenegotiation)) 262 z = z[5:] 263 copy(z, m.secureRenegotiation) 264 z = z[len(m.secureRenegotiation):] 265 } 266 if len(m.alpnProtocols) > 0 { 267 z[0] = byte(extensionALPN >> 8) 268 z[1] = byte(extensionALPN & 0xff) 269 lengths := z[2:] 270 z = z[6:] 271 272 stringsLength := 0 273 for _, s := range m.alpnProtocols { 274 l := len(s) 275 z[0] = byte(l) 276 copy(z[1:], s) 277 z = z[1+l:] 278 stringsLength += 1 + l 279 } 280 281 lengths[2] = byte(stringsLength >> 8) 282 lengths[3] = byte(stringsLength) 283 stringsLength += 2 284 lengths[0] = byte(stringsLength >> 8) 285 lengths[1] = byte(stringsLength) 286 } 287 if m.scts { 288 // RFC 6962, Section 3.3.1 289 z[0] = byte(extensionSCT >> 8) 290 z[1] = byte(extensionSCT) 291 // zero uint16 for the zero-length extension_data 292 z = z[4:] 293 } 294 295 m.raw = x 296 297 return x 298 } 299 300 func (m *clientHelloMsg) unmarshal(data []byte) bool { 301 if len(data) < 42 { 302 return false 303 } 304 m.raw = data 305 m.vers = uint16(data[4])<<8 | uint16(data[5]) 306 m.random = data[6:38] 307 sessionIdLen := int(data[38]) 308 if sessionIdLen > 32 || len(data) < 39+sessionIdLen { 309 return false 310 } 311 m.sessionId = data[39 : 39+sessionIdLen] 312 data = data[39+sessionIdLen:] 313 if len(data) < 2 { 314 return false 315 } 316 // cipherSuiteLen is the number of bytes of cipher suite numbers. Since 317 // they are uint16s, the number must be even. 318 cipherSuiteLen := int(data[0])<<8 | int(data[1]) 319 if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { 320 return false 321 } 322 numCipherSuites := cipherSuiteLen / 2 323 m.cipherSuites = make([]uint16, numCipherSuites) 324 for i := 0; i < numCipherSuites; i++ { 325 m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) 326 if m.cipherSuites[i] == scsvRenegotiation { 327 m.secureRenegotiationSupported = true 328 } 329 } 330 data = data[2+cipherSuiteLen:] 331 if len(data) < 1 { 332 return false 333 } 334 compressionMethodsLen := int(data[0]) 335 if len(data) < 1+compressionMethodsLen { 336 return false 337 } 338 m.compressionMethods = data[1 : 1+compressionMethodsLen] 339 340 data = data[1+compressionMethodsLen:] 341 342 m.nextProtoNeg = false 343 m.serverName = "" 344 m.ocspStapling = false 345 m.ticketSupported = false 346 m.sessionTicket = nil 347 m.supportedSignatureAlgorithms = nil 348 m.alpnProtocols = nil 349 m.scts = false 350 351 if len(data) == 0 { 352 // ClientHello is optionally followed by extension data 353 return true 354 } 355 if len(data) < 2 { 356 return false 357 } 358 359 extensionsLength := int(data[0])<<8 | int(data[1]) 360 data = data[2:] 361 if extensionsLength != len(data) { 362 return false 363 } 364 365 for len(data) != 0 { 366 if len(data) < 4 { 367 return false 368 } 369 extension := uint16(data[0])<<8 | uint16(data[1]) 370 length := int(data[2])<<8 | int(data[3]) 371 data = data[4:] 372 if len(data) < length { 373 return false 374 } 375 376 switch extension { 377 case extensionServerName: 378 d := data[:length] 379 if len(d) < 2 { 380 return false 381 } 382 namesLen := int(d[0])<<8 | int(d[1]) 383 d = d[2:] 384 if len(d) != namesLen { 385 return false 386 } 387 for len(d) > 0 { 388 if len(d) < 3 { 389 return false 390 } 391 nameType := d[0] 392 nameLen := int(d[1])<<8 | int(d[2]) 393 d = d[3:] 394 if len(d) < nameLen { 395 return false 396 } 397 if nameType == 0 { 398 m.serverName = string(d[:nameLen]) 399 // An SNI value may not include a trailing dot. 400 // See RFC 6066, Section 3. 401 if strings.HasSuffix(m.serverName, ".") { 402 return false 403 } 404 break 405 } 406 d = d[nameLen:] 407 } 408 case extensionNextProtoNeg: 409 if length > 0 { 410 return false 411 } 412 m.nextProtoNeg = true 413 case extensionStatusRequest: 414 m.ocspStapling = length > 0 && data[0] == statusTypeOCSP 415 case extensionSupportedCurves: 416 // RFC 4492, Section 5.5.1 417 if length < 2 { 418 return false 419 } 420 l := int(data[0])<<8 | int(data[1]) 421 if l%2 == 1 || length != l+2 { 422 return false 423 } 424 numCurves := l / 2 425 m.supportedCurves = make([]CurveID, numCurves) 426 d := data[2:] 427 for i := 0; i < numCurves; i++ { 428 m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) 429 d = d[2:] 430 } 431 case extensionSupportedPoints: 432 // RFC 4492, Section 5.5.2 433 if length < 1 { 434 return false 435 } 436 l := int(data[0]) 437 if length != l+1 { 438 return false 439 } 440 m.supportedPoints = make([]uint8, l) 441 copy(m.supportedPoints, data[1:]) 442 case extensionSessionTicket: 443 // RFC 5077, Section 3.2 444 m.ticketSupported = true 445 m.sessionTicket = data[:length] 446 case extensionSignatureAlgorithms: 447 // RFC 5246, Section 7.4.1.4.1 448 if length < 2 || length&1 != 0 { 449 return false 450 } 451 l := int(data[0])<<8 | int(data[1]) 452 if l != length-2 { 453 return false 454 } 455 n := l / 2 456 d := data[2:] 457 m.supportedSignatureAlgorithms = make([]SignatureScheme, n) 458 for i := range m.supportedSignatureAlgorithms { 459 m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1]) 460 d = d[2:] 461 } 462 case extensionRenegotiationInfo: 463 if length == 0 { 464 return false 465 } 466 d := data[:length] 467 l := int(d[0]) 468 d = d[1:] 469 if l != len(d) { 470 return false 471 } 472 473 m.secureRenegotiation = d 474 m.secureRenegotiationSupported = true 475 case extensionALPN: 476 if length < 2 { 477 return false 478 } 479 l := int(data[0])<<8 | int(data[1]) 480 if l != length-2 { 481 return false 482 } 483 d := data[2:length] 484 for len(d) != 0 { 485 stringLen := int(d[0]) 486 d = d[1:] 487 if stringLen == 0 || stringLen > len(d) { 488 return false 489 } 490 m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) 491 d = d[stringLen:] 492 } 493 case extensionSCT: 494 m.scts = true 495 if length != 0 { 496 return false 497 } 498 } 499 data = data[length:] 500 } 501 502 return true 503 } 504 505 type serverHelloMsg struct { 506 raw []byte 507 vers uint16 508 random []byte 509 sessionId []byte 510 cipherSuite uint16 511 compressionMethod uint8 512 nextProtoNeg bool 513 nextProtos []string 514 ocspStapling bool 515 scts [][]byte 516 ticketSupported bool 517 secureRenegotiation []byte 518 secureRenegotiationSupported bool 519 alpnProtocol string 520 } 521 522 func (m *serverHelloMsg) equal(i interface{}) bool { 523 m1, ok := i.(*serverHelloMsg) 524 if !ok { 525 return false 526 } 527 528 if len(m.scts) != len(m1.scts) { 529 return false 530 } 531 for i, sct := range m.scts { 532 if !bytes.Equal(sct, m1.scts[i]) { 533 return false 534 } 535 } 536 537 return bytes.Equal(m.raw, m1.raw) && 538 m.vers == m1.vers && 539 bytes.Equal(m.random, m1.random) && 540 bytes.Equal(m.sessionId, m1.sessionId) && 541 m.cipherSuite == m1.cipherSuite && 542 m.compressionMethod == m1.compressionMethod && 543 m.nextProtoNeg == m1.nextProtoNeg && 544 eqStrings(m.nextProtos, m1.nextProtos) && 545 m.ocspStapling == m1.ocspStapling && 546 m.ticketSupported == m1.ticketSupported && 547 m.secureRenegotiationSupported == m1.secureRenegotiationSupported && 548 bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && 549 m.alpnProtocol == m1.alpnProtocol 550 } 551 552 func (m *serverHelloMsg) marshal() []byte { 553 if m.raw != nil { 554 return m.raw 555 } 556 557 length := 38 + len(m.sessionId) 558 numExtensions := 0 559 extensionsLength := 0 560 561 nextProtoLen := 0 562 if m.nextProtoNeg { 563 numExtensions++ 564 for _, v := range m.nextProtos { 565 nextProtoLen += len(v) 566 } 567 nextProtoLen += len(m.nextProtos) 568 extensionsLength += nextProtoLen 569 } 570 if m.ocspStapling { 571 numExtensions++ 572 } 573 if m.ticketSupported { 574 numExtensions++ 575 } 576 if m.secureRenegotiationSupported { 577 extensionsLength += 1 + len(m.secureRenegotiation) 578 numExtensions++ 579 } 580 if alpnLen := len(m.alpnProtocol); alpnLen > 0 { 581 if alpnLen >= 256 { 582 panic("invalid ALPN protocol") 583 } 584 extensionsLength += 2 + 1 + alpnLen 585 numExtensions++ 586 } 587 sctLen := 0 588 if len(m.scts) > 0 { 589 for _, sct := range m.scts { 590 sctLen += len(sct) + 2 591 } 592 extensionsLength += 2 + sctLen 593 numExtensions++ 594 } 595 596 if numExtensions > 0 { 597 extensionsLength += 4 * numExtensions 598 length += 2 + extensionsLength 599 } 600 601 x := make([]byte, 4+length) 602 x[0] = typeServerHello 603 x[1] = uint8(length >> 16) 604 x[2] = uint8(length >> 8) 605 x[3] = uint8(length) 606 x[4] = uint8(m.vers >> 8) 607 x[5] = uint8(m.vers) 608 copy(x[6:38], m.random) 609 x[38] = uint8(len(m.sessionId)) 610 copy(x[39:39+len(m.sessionId)], m.sessionId) 611 z := x[39+len(m.sessionId):] 612 z[0] = uint8(m.cipherSuite >> 8) 613 z[1] = uint8(m.cipherSuite) 614 z[2] = m.compressionMethod 615 616 z = z[3:] 617 if numExtensions > 0 { 618 z[0] = byte(extensionsLength >> 8) 619 z[1] = byte(extensionsLength) 620 z = z[2:] 621 } 622 if m.nextProtoNeg { 623 z[0] = byte(extensionNextProtoNeg >> 8) 624 z[1] = byte(extensionNextProtoNeg & 0xff) 625 z[2] = byte(nextProtoLen >> 8) 626 z[3] = byte(nextProtoLen) 627 z = z[4:] 628 629 for _, v := range m.nextProtos { 630 l := len(v) 631 if l > 255 { 632 l = 255 633 } 634 z[0] = byte(l) 635 copy(z[1:], []byte(v[0:l])) 636 z = z[1+l:] 637 } 638 } 639 if m.ocspStapling { 640 z[0] = byte(extensionStatusRequest >> 8) 641 z[1] = byte(extensionStatusRequest) 642 z = z[4:] 643 } 644 if m.ticketSupported { 645 z[0] = byte(extensionSessionTicket >> 8) 646 z[1] = byte(extensionSessionTicket) 647 z = z[4:] 648 } 649 if m.secureRenegotiationSupported { 650 z[0] = byte(extensionRenegotiationInfo >> 8) 651 z[1] = byte(extensionRenegotiationInfo & 0xff) 652 z[2] = 0 653 z[3] = byte(len(m.secureRenegotiation) + 1) 654 z[4] = byte(len(m.secureRenegotiation)) 655 z = z[5:] 656 copy(z, m.secureRenegotiation) 657 z = z[len(m.secureRenegotiation):] 658 } 659 if alpnLen := len(m.alpnProtocol); alpnLen > 0 { 660 z[0] = byte(extensionALPN >> 8) 661 z[1] = byte(extensionALPN & 0xff) 662 l := 2 + 1 + alpnLen 663 z[2] = byte(l >> 8) 664 z[3] = byte(l) 665 l -= 2 666 z[4] = byte(l >> 8) 667 z[5] = byte(l) 668 l -= 1 669 z[6] = byte(l) 670 copy(z[7:], []byte(m.alpnProtocol)) 671 z = z[7+alpnLen:] 672 } 673 if sctLen > 0 { 674 z[0] = byte(extensionSCT >> 8) 675 z[1] = byte(extensionSCT) 676 l := sctLen + 2 677 z[2] = byte(l >> 8) 678 z[3] = byte(l) 679 z[4] = byte(sctLen >> 8) 680 z[5] = byte(sctLen) 681 682 z = z[6:] 683 for _, sct := range m.scts { 684 z[0] = byte(len(sct) >> 8) 685 z[1] = byte(len(sct)) 686 copy(z[2:], sct) 687 z = z[len(sct)+2:] 688 } 689 } 690 691 m.raw = x 692 693 return x 694 } 695 696 func (m *serverHelloMsg) unmarshal(data []byte) bool { 697 if len(data) < 42 { 698 return false 699 } 700 m.raw = data 701 m.vers = uint16(data[4])<<8 | uint16(data[5]) 702 m.random = data[6:38] 703 sessionIdLen := int(data[38]) 704 if sessionIdLen > 32 || len(data) < 39+sessionIdLen { 705 return false 706 } 707 m.sessionId = data[39 : 39+sessionIdLen] 708 data = data[39+sessionIdLen:] 709 if len(data) < 3 { 710 return false 711 } 712 m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) 713 m.compressionMethod = data[2] 714 data = data[3:] 715 716 m.nextProtoNeg = false 717 m.nextProtos = nil 718 m.ocspStapling = false 719 m.scts = nil 720 m.ticketSupported = false 721 m.alpnProtocol = "" 722 723 if len(data) == 0 { 724 // ServerHello is optionally followed by extension data 725 return true 726 } 727 if len(data) < 2 { 728 return false 729 } 730 731 extensionsLength := int(data[0])<<8 | int(data[1]) 732 data = data[2:] 733 if len(data) != extensionsLength { 734 return false 735 } 736 737 for len(data) != 0 { 738 if len(data) < 4 { 739 return false 740 } 741 extension := uint16(data[0])<<8 | uint16(data[1]) 742 length := int(data[2])<<8 | int(data[3]) 743 data = data[4:] 744 if len(data) < length { 745 return false 746 } 747 748 switch extension { 749 case extensionNextProtoNeg: 750 m.nextProtoNeg = true 751 d := data[:length] 752 for len(d) > 0 { 753 l := int(d[0]) 754 d = d[1:] 755 if l == 0 || l > len(d) { 756 return false 757 } 758 m.nextProtos = append(m.nextProtos, string(d[:l])) 759 d = d[l:] 760 } 761 case extensionStatusRequest: 762 if length > 0 { 763 return false 764 } 765 m.ocspStapling = true 766 case extensionSessionTicket: 767 if length > 0 { 768 return false 769 } 770 m.ticketSupported = true 771 case extensionRenegotiationInfo: 772 if length == 0 { 773 return false 774 } 775 d := data[:length] 776 l := int(d[0]) 777 d = d[1:] 778 if l != len(d) { 779 return false 780 } 781 782 m.secureRenegotiation = d 783 m.secureRenegotiationSupported = true 784 case extensionALPN: 785 d := data[:length] 786 if len(d) < 3 { 787 return false 788 } 789 l := int(d[0])<<8 | int(d[1]) 790 if l != len(d)-2 { 791 return false 792 } 793 d = d[2:] 794 l = int(d[0]) 795 if l != len(d)-1 { 796 return false 797 } 798 d = d[1:] 799 if len(d) == 0 { 800 // ALPN protocols must not be empty. 801 return false 802 } 803 m.alpnProtocol = string(d) 804 case extensionSCT: 805 d := data[:length] 806 807 if len(d) < 2 { 808 return false 809 } 810 l := int(d[0])<<8 | int(d[1]) 811 d = d[2:] 812 if len(d) != l || l == 0 { 813 return false 814 } 815 816 m.scts = make([][]byte, 0, 3) 817 for len(d) != 0 { 818 if len(d) < 2 { 819 return false 820 } 821 sctLen := int(d[0])<<8 | int(d[1]) 822 d = d[2:] 823 if sctLen == 0 || len(d) < sctLen { 824 return false 825 } 826 m.scts = append(m.scts, d[:sctLen]) 827 d = d[sctLen:] 828 } 829 } 830 data = data[length:] 831 } 832 833 return true 834 } 835 836 type certificateMsg struct { 837 raw []byte 838 certificates [][]byte 839 } 840 841 func (m *certificateMsg) equal(i interface{}) bool { 842 m1, ok := i.(*certificateMsg) 843 if !ok { 844 return false 845 } 846 847 return bytes.Equal(m.raw, m1.raw) && 848 eqByteSlices(m.certificates, m1.certificates) 849 } 850 851 func (m *certificateMsg) marshal() (x []byte) { 852 if m.raw != nil { 853 return m.raw 854 } 855 856 var i int 857 for _, slice := range m.certificates { 858 i += len(slice) 859 } 860 861 length := 3 + 3*len(m.certificates) + i 862 x = make([]byte, 4+length) 863 x[0] = typeCertificate 864 x[1] = uint8(length >> 16) 865 x[2] = uint8(length >> 8) 866 x[3] = uint8(length) 867 868 certificateOctets := length - 3 869 x[4] = uint8(certificateOctets >> 16) 870 x[5] = uint8(certificateOctets >> 8) 871 x[6] = uint8(certificateOctets) 872 873 y := x[7:] 874 for _, slice := range m.certificates { 875 y[0] = uint8(len(slice) >> 16) 876 y[1] = uint8(len(slice) >> 8) 877 y[2] = uint8(len(slice)) 878 copy(y[3:], slice) 879 y = y[3+len(slice):] 880 } 881 882 m.raw = x 883 return 884 } 885 886 func (m *certificateMsg) unmarshal(data []byte) bool { 887 if len(data) < 7 { 888 return false 889 } 890 891 m.raw = data 892 certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) 893 if uint32(len(data)) != certsLen+7 { 894 return false 895 } 896 897 numCerts := 0 898 d := data[7:] 899 for certsLen > 0 { 900 if len(d) < 4 { 901 return false 902 } 903 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) 904 if uint32(len(d)) < 3+certLen { 905 return false 906 } 907 d = d[3+certLen:] 908 certsLen -= 3 + certLen 909 numCerts++ 910 } 911 912 m.certificates = make([][]byte, numCerts) 913 d = data[7:] 914 for i := 0; i < numCerts; i++ { 915 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) 916 m.certificates[i] = d[3 : 3+certLen] 917 d = d[3+certLen:] 918 } 919 920 return true 921 } 922 923 type serverKeyExchangeMsg struct { 924 raw []byte 925 key []byte 926 } 927 928 func (m *serverKeyExchangeMsg) equal(i interface{}) bool { 929 m1, ok := i.(*serverKeyExchangeMsg) 930 if !ok { 931 return false 932 } 933 934 return bytes.Equal(m.raw, m1.raw) && 935 bytes.Equal(m.key, m1.key) 936 } 937 938 func (m *serverKeyExchangeMsg) marshal() []byte { 939 if m.raw != nil { 940 return m.raw 941 } 942 length := len(m.key) 943 x := make([]byte, length+4) 944 x[0] = typeServerKeyExchange 945 x[1] = uint8(length >> 16) 946 x[2] = uint8(length >> 8) 947 x[3] = uint8(length) 948 copy(x[4:], m.key) 949 950 m.raw = x 951 return x 952 } 953 954 func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { 955 m.raw = data 956 if len(data) < 4 { 957 return false 958 } 959 m.key = data[4:] 960 return true 961 } 962 963 type certificateStatusMsg struct { 964 raw []byte 965 statusType uint8 966 response []byte 967 } 968 969 func (m *certificateStatusMsg) equal(i interface{}) bool { 970 m1, ok := i.(*certificateStatusMsg) 971 if !ok { 972 return false 973 } 974 975 return bytes.Equal(m.raw, m1.raw) && 976 m.statusType == m1.statusType && 977 bytes.Equal(m.response, m1.response) 978 } 979 980 func (m *certificateStatusMsg) marshal() []byte { 981 if m.raw != nil { 982 return m.raw 983 } 984 985 var x []byte 986 if m.statusType == statusTypeOCSP { 987 x = make([]byte, 4+4+len(m.response)) 988 x[0] = typeCertificateStatus 989 l := len(m.response) + 4 990 x[1] = byte(l >> 16) 991 x[2] = byte(l >> 8) 992 x[3] = byte(l) 993 x[4] = statusTypeOCSP 994 995 l -= 4 996 x[5] = byte(l >> 16) 997 x[6] = byte(l >> 8) 998 x[7] = byte(l) 999 copy(x[8:], m.response) 1000 } else { 1001 x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} 1002 } 1003 1004 m.raw = x 1005 return x 1006 } 1007 1008 func (m *certificateStatusMsg) unmarshal(data []byte) bool { 1009 m.raw = data 1010 if len(data) < 5 { 1011 return false 1012 } 1013 m.statusType = data[4] 1014 1015 m.response = nil 1016 if m.statusType == statusTypeOCSP { 1017 if len(data) < 8 { 1018 return false 1019 } 1020 respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) 1021 if uint32(len(data)) != 4+4+respLen { 1022 return false 1023 } 1024 m.response = data[8:] 1025 } 1026 return true 1027 } 1028 1029 type serverHelloDoneMsg struct{} 1030 1031 func (m *serverHelloDoneMsg) equal(i interface{}) bool { 1032 _, ok := i.(*serverHelloDoneMsg) 1033 return ok 1034 } 1035 1036 func (m *serverHelloDoneMsg) marshal() []byte { 1037 x := make([]byte, 4) 1038 x[0] = typeServerHelloDone 1039 return x 1040 } 1041 1042 func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { 1043 return len(data) == 4 1044 } 1045 1046 type clientKeyExchangeMsg struct { 1047 raw []byte 1048 ciphertext []byte 1049 } 1050 1051 func (m *clientKeyExchangeMsg) equal(i interface{}) bool { 1052 m1, ok := i.(*clientKeyExchangeMsg) 1053 if !ok { 1054 return false 1055 } 1056 1057 return bytes.Equal(m.raw, m1.raw) && 1058 bytes.Equal(m.ciphertext, m1.ciphertext) 1059 } 1060 1061 func (m *clientKeyExchangeMsg) marshal() []byte { 1062 if m.raw != nil { 1063 return m.raw 1064 } 1065 length := len(m.ciphertext) 1066 x := make([]byte, length+4) 1067 x[0] = typeClientKeyExchange 1068 x[1] = uint8(length >> 16) 1069 x[2] = uint8(length >> 8) 1070 x[3] = uint8(length) 1071 copy(x[4:], m.ciphertext) 1072 1073 m.raw = x 1074 return x 1075 } 1076 1077 func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { 1078 m.raw = data 1079 if len(data) < 4 { 1080 return false 1081 } 1082 l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) 1083 if l != len(data)-4 { 1084 return false 1085 } 1086 m.ciphertext = data[4:] 1087 return true 1088 } 1089 1090 type finishedMsg struct { 1091 raw []byte 1092 verifyData []byte 1093 } 1094 1095 func (m *finishedMsg) equal(i interface{}) bool { 1096 m1, ok := i.(*finishedMsg) 1097 if !ok { 1098 return false 1099 } 1100 1101 return bytes.Equal(m.raw, m1.raw) && 1102 bytes.Equal(m.verifyData, m1.verifyData) 1103 } 1104 1105 func (m *finishedMsg) marshal() (x []byte) { 1106 if m.raw != nil { 1107 return m.raw 1108 } 1109 1110 x = make([]byte, 4+len(m.verifyData)) 1111 x[0] = typeFinished 1112 x[3] = byte(len(m.verifyData)) 1113 copy(x[4:], m.verifyData) 1114 m.raw = x 1115 return 1116 } 1117 1118 func (m *finishedMsg) unmarshal(data []byte) bool { 1119 m.raw = data 1120 if len(data) < 4 { 1121 return false 1122 } 1123 m.verifyData = data[4:] 1124 return true 1125 } 1126 1127 type nextProtoMsg struct { 1128 raw []byte 1129 proto string 1130 } 1131 1132 func (m *nextProtoMsg) equal(i interface{}) bool { 1133 m1, ok := i.(*nextProtoMsg) 1134 if !ok { 1135 return false 1136 } 1137 1138 return bytes.Equal(m.raw, m1.raw) && 1139 m.proto == m1.proto 1140 } 1141 1142 func (m *nextProtoMsg) marshal() []byte { 1143 if m.raw != nil { 1144 return m.raw 1145 } 1146 l := len(m.proto) 1147 if l > 255 { 1148 l = 255 1149 } 1150 1151 padding := 32 - (l+2)%32 1152 length := l + padding + 2 1153 x := make([]byte, length+4) 1154 x[0] = typeNextProtocol 1155 x[1] = uint8(length >> 16) 1156 x[2] = uint8(length >> 8) 1157 x[3] = uint8(length) 1158 1159 y := x[4:] 1160 y[0] = byte(l) 1161 copy(y[1:], []byte(m.proto[0:l])) 1162 y = y[1+l:] 1163 y[0] = byte(padding) 1164 1165 m.raw = x 1166 1167 return x 1168 } 1169 1170 func (m *nextProtoMsg) unmarshal(data []byte) bool { 1171 m.raw = data 1172 1173 if len(data) < 5 { 1174 return false 1175 } 1176 data = data[4:] 1177 protoLen := int(data[0]) 1178 data = data[1:] 1179 if len(data) < protoLen { 1180 return false 1181 } 1182 m.proto = string(data[0:protoLen]) 1183 data = data[protoLen:] 1184 1185 if len(data) < 1 { 1186 return false 1187 } 1188 paddingLen := int(data[0]) 1189 data = data[1:] 1190 if len(data) != paddingLen { 1191 return false 1192 } 1193 1194 return true 1195 } 1196 1197 type certificateRequestMsg struct { 1198 raw []byte 1199 // hasSignatureAndHash indicates whether this message includes a list 1200 // of signature and hash functions. This change was introduced with TLS 1201 // 1.2. 1202 hasSignatureAndHash bool 1203 1204 certificateTypes []byte 1205 supportedSignatureAlgorithms []SignatureScheme 1206 certificateAuthorities [][]byte 1207 } 1208 1209 func (m *certificateRequestMsg) equal(i interface{}) bool { 1210 m1, ok := i.(*certificateRequestMsg) 1211 if !ok { 1212 return false 1213 } 1214 1215 return bytes.Equal(m.raw, m1.raw) && 1216 bytes.Equal(m.certificateTypes, m1.certificateTypes) && 1217 eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && 1218 eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) 1219 } 1220 1221 func (m *certificateRequestMsg) marshal() (x []byte) { 1222 if m.raw != nil { 1223 return m.raw 1224 } 1225 1226 // See RFC 4346, Section 7.4.4. 1227 length := 1 + len(m.certificateTypes) + 2 1228 casLength := 0 1229 for _, ca := range m.certificateAuthorities { 1230 casLength += 2 + len(ca) 1231 } 1232 length += casLength 1233 1234 if m.hasSignatureAndHash { 1235 length += 2 + 2*len(m.supportedSignatureAlgorithms) 1236 } 1237 1238 x = make([]byte, 4+length) 1239 x[0] = typeCertificateRequest 1240 x[1] = uint8(length >> 16) 1241 x[2] = uint8(length >> 8) 1242 x[3] = uint8(length) 1243 1244 x[4] = uint8(len(m.certificateTypes)) 1245 1246 copy(x[5:], m.certificateTypes) 1247 y := x[5+len(m.certificateTypes):] 1248 1249 if m.hasSignatureAndHash { 1250 n := len(m.supportedSignatureAlgorithms) * 2 1251 y[0] = uint8(n >> 8) 1252 y[1] = uint8(n) 1253 y = y[2:] 1254 for _, sigAlgo := range m.supportedSignatureAlgorithms { 1255 y[0] = uint8(sigAlgo >> 8) 1256 y[1] = uint8(sigAlgo) 1257 y = y[2:] 1258 } 1259 } 1260 1261 y[0] = uint8(casLength >> 8) 1262 y[1] = uint8(casLength) 1263 y = y[2:] 1264 for _, ca := range m.certificateAuthorities { 1265 y[0] = uint8(len(ca) >> 8) 1266 y[1] = uint8(len(ca)) 1267 y = y[2:] 1268 copy(y, ca) 1269 y = y[len(ca):] 1270 } 1271 1272 m.raw = x 1273 return 1274 } 1275 1276 func (m *certificateRequestMsg) unmarshal(data []byte) bool { 1277 m.raw = data 1278 1279 if len(data) < 5 { 1280 return false 1281 } 1282 1283 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) 1284 if uint32(len(data))-4 != length { 1285 return false 1286 } 1287 1288 numCertTypes := int(data[4]) 1289 data = data[5:] 1290 if numCertTypes == 0 || len(data) <= numCertTypes { 1291 return false 1292 } 1293 1294 m.certificateTypes = make([]byte, numCertTypes) 1295 if copy(m.certificateTypes, data) != numCertTypes { 1296 return false 1297 } 1298 1299 data = data[numCertTypes:] 1300 1301 if m.hasSignatureAndHash { 1302 if len(data) < 2 { 1303 return false 1304 } 1305 sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) 1306 data = data[2:] 1307 if sigAndHashLen&1 != 0 { 1308 return false 1309 } 1310 if len(data) < int(sigAndHashLen) { 1311 return false 1312 } 1313 numSigAlgos := sigAndHashLen / 2 1314 m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) 1315 for i := range m.supportedSignatureAlgorithms { 1316 m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) 1317 data = data[2:] 1318 } 1319 } 1320 1321 if len(data) < 2 { 1322 return false 1323 } 1324 casLength := uint16(data[0])<<8 | uint16(data[1]) 1325 data = data[2:] 1326 if len(data) < int(casLength) { 1327 return false 1328 } 1329 cas := make([]byte, casLength) 1330 copy(cas, data) 1331 data = data[casLength:] 1332 1333 m.certificateAuthorities = nil 1334 for len(cas) > 0 { 1335 if len(cas) < 2 { 1336 return false 1337 } 1338 caLen := uint16(cas[0])<<8 | uint16(cas[1]) 1339 cas = cas[2:] 1340 1341 if len(cas) < int(caLen) { 1342 return false 1343 } 1344 1345 m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) 1346 cas = cas[caLen:] 1347 } 1348 1349 return len(data) == 0 1350 } 1351 1352 type certificateVerifyMsg struct { 1353 raw []byte 1354 hasSignatureAndHash bool 1355 signatureAlgorithm SignatureScheme 1356 signature []byte 1357 } 1358 1359 func (m *certificateVerifyMsg) equal(i interface{}) bool { 1360 m1, ok := i.(*certificateVerifyMsg) 1361 if !ok { 1362 return false 1363 } 1364 1365 return bytes.Equal(m.raw, m1.raw) && 1366 m.hasSignatureAndHash == m1.hasSignatureAndHash && 1367 m.signatureAlgorithm == m1.signatureAlgorithm && 1368 bytes.Equal(m.signature, m1.signature) 1369 } 1370 1371 func (m *certificateVerifyMsg) marshal() (x []byte) { 1372 if m.raw != nil { 1373 return m.raw 1374 } 1375 1376 // See RFC 4346, Section 7.4.8. 1377 siglength := len(m.signature) 1378 length := 2 + siglength 1379 if m.hasSignatureAndHash { 1380 length += 2 1381 } 1382 x = make([]byte, 4+length) 1383 x[0] = typeCertificateVerify 1384 x[1] = uint8(length >> 16) 1385 x[2] = uint8(length >> 8) 1386 x[3] = uint8(length) 1387 y := x[4:] 1388 if m.hasSignatureAndHash { 1389 y[0] = uint8(m.signatureAlgorithm >> 8) 1390 y[1] = uint8(m.signatureAlgorithm) 1391 y = y[2:] 1392 } 1393 y[0] = uint8(siglength >> 8) 1394 y[1] = uint8(siglength) 1395 copy(y[2:], m.signature) 1396 1397 m.raw = x 1398 1399 return 1400 } 1401 1402 func (m *certificateVerifyMsg) unmarshal(data []byte) bool { 1403 m.raw = data 1404 1405 if len(data) < 6 { 1406 return false 1407 } 1408 1409 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) 1410 if uint32(len(data))-4 != length { 1411 return false 1412 } 1413 1414 data = data[4:] 1415 if m.hasSignatureAndHash { 1416 m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) 1417 data = data[2:] 1418 } 1419 1420 if len(data) < 2 { 1421 return false 1422 } 1423 siglength := int(data[0])<<8 + int(data[1]) 1424 data = data[2:] 1425 if len(data) != siglength { 1426 return false 1427 } 1428 1429 m.signature = data 1430 1431 return true 1432 } 1433 1434 type newSessionTicketMsg struct { 1435 raw []byte 1436 ticket []byte 1437 } 1438 1439 func (m *newSessionTicketMsg) equal(i interface{}) bool { 1440 m1, ok := i.(*newSessionTicketMsg) 1441 if !ok { 1442 return false 1443 } 1444 1445 return bytes.Equal(m.raw, m1.raw) && 1446 bytes.Equal(m.ticket, m1.ticket) 1447 } 1448 1449 func (m *newSessionTicketMsg) marshal() (x []byte) { 1450 if m.raw != nil { 1451 return m.raw 1452 } 1453 1454 // See RFC 5077, Section 3.3. 1455 ticketLen := len(m.ticket) 1456 length := 2 + 4 + ticketLen 1457 x = make([]byte, 4+length) 1458 x[0] = typeNewSessionTicket 1459 x[1] = uint8(length >> 16) 1460 x[2] = uint8(length >> 8) 1461 x[3] = uint8(length) 1462 x[8] = uint8(ticketLen >> 8) 1463 x[9] = uint8(ticketLen) 1464 copy(x[10:], m.ticket) 1465 1466 m.raw = x 1467 1468 return 1469 } 1470 1471 func (m *newSessionTicketMsg) unmarshal(data []byte) bool { 1472 m.raw = data 1473 1474 if len(data) < 10 { 1475 return false 1476 } 1477 1478 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) 1479 if uint32(len(data))-4 != length { 1480 return false 1481 } 1482 1483 ticketLen := int(data[8])<<8 + int(data[9]) 1484 if len(data)-10 != ticketLen { 1485 return false 1486 } 1487 1488 m.ticket = data[10:] 1489 1490 return true 1491 } 1492 1493 type helloRequestMsg struct { 1494 } 1495 1496 func (*helloRequestMsg) marshal() []byte { 1497 return []byte{typeHelloRequest, 0, 0, 0} 1498 } 1499 1500 func (*helloRequestMsg) unmarshal(data []byte) bool { 1501 return len(data) == 4 1502 } 1503 1504 func eqUint16s(x, y []uint16) bool { 1505 if len(x) != len(y) { 1506 return false 1507 } 1508 for i, v := range x { 1509 if y[i] != v { 1510 return false 1511 } 1512 } 1513 return true 1514 } 1515 1516 func eqCurveIDs(x, y []CurveID) bool { 1517 if len(x) != len(y) { 1518 return false 1519 } 1520 for i, v := range x { 1521 if y[i] != v { 1522 return false 1523 } 1524 } 1525 return true 1526 } 1527 1528 func eqStrings(x, y []string) bool { 1529 if len(x) != len(y) { 1530 return false 1531 } 1532 for i, v := range x { 1533 if y[i] != v { 1534 return false 1535 } 1536 } 1537 return true 1538 } 1539 1540 func eqByteSlices(x, y [][]byte) bool { 1541 if len(x) != len(y) { 1542 return false 1543 } 1544 for i, v := range x { 1545 if !bytes.Equal(v, y[i]) { 1546 return false 1547 } 1548 } 1549 return true 1550 } 1551 1552 func eqSignatureAlgorithms(x, y []SignatureScheme) bool { 1553 if len(x) != len(y) { 1554 return false 1555 } 1556 for i, v := range x { 1557 if v != y[i] { 1558 return false 1559 } 1560 } 1561 return true 1562 }