github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/handshake_client.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "crypto/subtle" 12 "encoding/asn1" 13 "encoding/binary" 14 "errors" 15 "fmt" 16 "io" 17 "math/big" 18 "net" 19 "strconv" 20 "time" 21 22 "github.com/zmap/zcrypto/dsa" 23 24 "github.com/zmap/zcrypto/x509" 25 ) 26 27 type clientHandshakeState struct { 28 c *Conn 29 serverHello *serverHelloMsg 30 hello *clientHelloMsg 31 suite *cipherSuite 32 finishedHash finishedHash 33 masterSecret []byte 34 preMasterSecret []byte 35 session *ClientSessionState 36 } 37 38 type CacheKeyGenerator interface { 39 Key(net.Addr) string 40 } 41 42 type ClientFingerprintConfiguration struct { 43 // Version in the handshake header 44 HandshakeVersion uint16 45 46 // if len == 32, it will specify the client random. 47 // Otherwise, the field will be random 48 // except the top 4 bytes if InsertTimestamp is true 49 ClientRandom []byte 50 InsertTimestamp bool 51 52 // if RandomSessionID > 0, will overwrite SessionID w/ that many 53 // random bytes when a session resumption occurs 54 RandomSessionID int 55 SessionID []byte 56 57 // These fields will appear exactly in order in the ClientHello 58 CipherSuites []uint16 59 CompressionMethods []uint8 60 Extensions []ClientExtension 61 62 // Optional, both must be non-nil, or neither. 63 // Custom Session cache implementations allowed 64 SessionCache ClientSessionCache 65 CacheKey CacheKeyGenerator 66 } 67 68 type ClientExtension interface { 69 // Produce the bytes on the wire for this extension, type and length included 70 Marshal() []byte 71 72 // Function will return an error if zTLS does not implement the necessary features for this extension 73 CheckImplemented() error 74 75 // Modifies the config to reflect the state of the extension 76 WriteToConfig(*Config) error 77 } 78 79 func (c *ClientFingerprintConfiguration) CheckImplementedExtensions() error { 80 for _, ext := range c.Extensions { 81 if err := ext.CheckImplemented(); err != nil { 82 return err 83 } 84 } 85 return nil 86 } 87 88 func (c *clientHelloMsg) WriteToConfig(config *Config) error { 89 config.NextProtos = c.alpnProtocols 90 config.CipherSuites = c.cipherSuites 91 config.MaxVersion = c.vers 92 config.ClientRandom = c.random 93 config.CurvePreferences = c.supportedCurves 94 config.HeartbeatEnabled = c.heartbeatEnabled 95 config.ExtendedRandom = c.extendedRandomEnabled 96 config.ForceSessionTicketExt = c.ticketSupported 97 config.ExtendedMasterSecret = c.extendedMasterSecret 98 config.SignedCertificateTimestampExt = c.sctEnabled 99 return nil 100 } 101 102 func (c *ClientFingerprintConfiguration) WriteToConfig(config *Config) error { 103 config.NextProtos = []string{} 104 config.CipherSuites = c.CipherSuites 105 config.MaxVersion = c.HandshakeVersion 106 config.ClientRandom = c.ClientRandom 107 config.CurvePreferences = []CurveID{} 108 config.HeartbeatEnabled = false 109 config.ExtendedRandom = false 110 config.ForceSessionTicketExt = false 111 config.ExtendedMasterSecret = false 112 config.SignedCertificateTimestampExt = false 113 for _, ext := range c.Extensions { 114 if err := ext.WriteToConfig(config); err != nil { 115 return err 116 } 117 } 118 return nil 119 } 120 121 func currentTimestamp() ([]byte, error) { 122 t := time.Now().Unix() 123 buf := new(bytes.Buffer) 124 err := binary.Write(buf, binary.BigEndian, t) 125 return buf.Bytes(), err 126 } 127 128 func (c *ClientFingerprintConfiguration) marshal(config *Config) ([]byte, error) { 129 if err := c.CheckImplementedExtensions(); err != nil { 130 return nil, err 131 } 132 head := make([]byte, 38) 133 head[0] = 1 134 head[4] = uint8(c.HandshakeVersion >> 8) 135 head[5] = uint8(c.HandshakeVersion) 136 if len(c.ClientRandom) == 32 { 137 copy(head[6:38], c.ClientRandom[0:32]) 138 } else { 139 start := 6 140 if c.InsertTimestamp { 141 t, err := currentTimestamp() 142 if err != nil { 143 return nil, err 144 } 145 copy(head[start:start+4], t) 146 start = start + 4 147 } 148 _, err := io.ReadFull(config.rand(), head[start:38]) 149 if err != nil { 150 return nil, errors.New("tls: short read from Rand: " + err.Error()) 151 } 152 } 153 154 if len(c.SessionID) >= 256 { 155 return nil, errors.New("tls: SessionID too long") 156 } 157 sessionID := make([]byte, len(c.SessionID)+1) 158 sessionID[0] = uint8(len(c.SessionID)) 159 if len(c.SessionID) > 0 { 160 copy(sessionID[1:], c.SessionID) 161 } 162 163 ciphers := make([]byte, 2+2*len(c.CipherSuites)) 164 ciphers[0] = uint8(len(c.CipherSuites) >> 7) 165 ciphers[1] = uint8(len(c.CipherSuites) << 1) 166 for i, suite := range c.CipherSuites { 167 if !config.ForceSuites { 168 found := false 169 for _, impl := range implementedCipherSuites { 170 if impl.id == suite { 171 found = true 172 } 173 } 174 if !found { 175 return nil, errors.New(fmt.Sprintf("tls: unimplemented cipher suite %d", suite)) 176 } 177 } 178 179 ciphers[2+i*2] = uint8(suite >> 8) 180 ciphers[3+i*2] = uint8(suite) 181 } 182 183 if len(c.CompressionMethods) >= 256 { 184 return nil, errors.New("tls: Too many compression methods") 185 } 186 compressions := make([]byte, len(c.CompressionMethods)+1) 187 compressions[0] = uint8(len(c.CompressionMethods)) 188 if len(c.CompressionMethods) > 0 { 189 copy(compressions[1:], c.CompressionMethods) 190 if c.CompressionMethods[0] != 0 { 191 return nil, errors.New(fmt.Sprintf("tls: unimplemented compression method %d", c.CompressionMethods[0])) 192 } 193 if len(c.CompressionMethods) > 1 { 194 return nil, errors.New(fmt.Sprintf("tls: unimplemented compression method %d", c.CompressionMethods[1])) 195 } 196 } else { 197 return nil, errors.New("tls: no compression method") 198 } 199 200 var extensions []byte 201 for _, ext := range c.Extensions { 202 extensions = append(extensions, ext.Marshal()...) 203 } 204 if len(extensions) > 0 { 205 length := make([]byte, 2) 206 length[0] = uint8(len(extensions) >> 8) 207 length[1] = uint8(len(extensions)) 208 extensions = append(length, extensions...) 209 } 210 helloArray := [][]byte{head, sessionID, ciphers, compressions, extensions} 211 hello := []byte{} 212 for _, b := range helloArray { 213 hello = append(hello, b...) 214 } 215 lengthOnTheWire := len(hello) - 4 216 if lengthOnTheWire >= 1<<24 { 217 return nil, errors.New("ClientHello message too long") 218 } 219 hello[1] = uint8(lengthOnTheWire >> 16) 220 hello[2] = uint8(lengthOnTheWire >> 8) 221 hello[3] = uint8(lengthOnTheWire) 222 223 return hello, nil 224 } 225 226 func (c *Conn) clientHandshake() error { 227 if c.config == nil { 228 c.config = defaultConfig() 229 } 230 var hello *clientHelloMsg 231 var helloBytes []byte 232 var session *ClientSessionState 233 var sessionCache ClientSessionCache 234 var cacheKey string 235 236 // first, let's check if a ClientFingerprintConfiguration template was provided by the config 237 if c.config.ClientFingerprintConfiguration != nil { 238 if err := c.config.ClientFingerprintConfiguration.WriteToConfig(c.config); err != nil { 239 return err 240 } 241 session = nil 242 sessionCache = c.config.ClientFingerprintConfiguration.SessionCache 243 if sessionCache != nil { 244 if c.config.ClientFingerprintConfiguration.CacheKey == nil { 245 return errors.New("tls: must specify CacheKey if SessionCache is defined in Config.ClientFingerprintConfiguration") 246 } 247 cacheKey = c.config.ClientFingerprintConfiguration.CacheKey.Key(c.conn.RemoteAddr()) 248 candidateSession, ok := sessionCache.Get(cacheKey) 249 if ok { 250 cipherSuiteOk := false 251 for _, id := range c.config.ClientFingerprintConfiguration.CipherSuites { 252 if id == candidateSession.cipherSuite { 253 cipherSuiteOk = true 254 break 255 } 256 } 257 versOk := candidateSession.vers >= c.config.minVersion() && 258 candidateSession.vers <= c.config.ClientFingerprintConfiguration.HandshakeVersion 259 if versOk && cipherSuiteOk { 260 session = candidateSession 261 } 262 } 263 } 264 for i, ext := range c.config.ClientFingerprintConfiguration.Extensions { 265 switch casted := ext.(type) { 266 case *SessionTicketExtension: 267 if casted.Autopopulate { 268 if session == nil { 269 if !c.config.ForceSessionTicketExt { 270 c.config.ClientFingerprintConfiguration.Extensions[i] = &NullExtension{} 271 } 272 } else { 273 c.config.ClientFingerprintConfiguration.Extensions[i] = &SessionTicketExtension{session.sessionTicket, true} 274 if c.config.ClientFingerprintConfiguration.RandomSessionID > 0 { 275 c.config.ClientFingerprintConfiguration.SessionID = make([]byte, c.config.ClientFingerprintConfiguration.RandomSessionID) 276 if _, err := io.ReadFull(c.config.rand(), c.config.ClientFingerprintConfiguration.SessionID); err != nil { 277 c.sendAlert(alertInternalError) 278 return errors.New("tls: short read from Rand: " + err.Error()) 279 } 280 281 } 282 } 283 } 284 } 285 } 286 var err error 287 helloBytes, err = c.config.ClientFingerprintConfiguration.marshal(c.config) 288 if err != nil { 289 return err 290 } 291 hello = &clientHelloMsg{} 292 if ok := hello.unmarshal(helloBytes); !ok { 293 return errors.New("tls: incompatible ClientFingerprintConfiguration") 294 } 295 296 // next, let's check if a ClientHello template was provided by the user 297 } else if c.config.ExternalClientHello != nil { 298 299 hello = new(clientHelloMsg) 300 301 if !hello.unmarshal(c.config.ExternalClientHello) { 302 return errors.New("could not read the ClientHello provided") 303 } 304 if err := hello.WriteToConfig(c.config); err != nil { 305 return err 306 } 307 308 // update the SNI with one name, whether or not the extension was already there 309 hello.serverName = c.config.ServerName 310 311 // then we update the 'raw' value of the message 312 hello.raw = nil 313 helloBytes = hello.marshal() 314 315 session = nil 316 sessionCache = nil 317 } else { 318 if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { 319 return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") 320 } 321 322 supportedPoints := []uint8{pointFormatUncompressed} 323 if c.config.SupportedPoints != nil { 324 supportedPoints = c.config.SupportedPoints 325 } 326 oscpStapling := true 327 if c.config.NoOcspStapling { 328 oscpStapling = false 329 } 330 331 compressionMethods := []uint8{compressionNone} 332 if c.config.CompressionMethods != nil { 333 compressionMethods = c.config.CompressionMethods 334 } 335 336 hello = &clientHelloMsg{ 337 vers: c.config.maxVersion(), 338 compressionMethods: compressionMethods, 339 random: make([]byte, 32), 340 ocspStapling: oscpStapling, 341 serverName: c.config.ServerName, 342 supportedCurves: c.config.curvePreferences(), 343 supportedPoints: supportedPoints, 344 nextProtoNeg: len(c.config.NextProtos) > 0, 345 secureRenegotiation: true, 346 alpnProtocols: c.config.NextProtos, 347 extendedMasterSecret: c.config.maxVersion() >= VersionTLS10 && c.config.ExtendedMasterSecret, 348 } 349 350 if c.config.ForceSessionTicketExt { 351 hello.ticketSupported = true 352 } 353 if c.config.SignedCertificateTimestampExt { 354 hello.sctEnabled = true 355 } 356 357 if c.config.HeartbeatEnabled && !c.config.ExtendedRandom { 358 hello.heartbeatEnabled = true 359 hello.heartbeatMode = heartbeatModePeerAllowed 360 } 361 362 possibleCipherSuites := c.config.cipherSuites() 363 hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) 364 365 if c.config.ForceSuites { 366 hello.cipherSuites = possibleCipherSuites 367 } else { 368 369 NextCipherSuite: 370 for _, suiteId := range possibleCipherSuites { 371 for _, suite := range implementedCipherSuites { 372 if suite.id != suiteId { 373 continue 374 } 375 // Don't advertise TLS 1.2-only cipher suites unless 376 // we're attempting TLS 1.2. 377 if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { 378 continue 379 } 380 hello.cipherSuites = append(hello.cipherSuites, suiteId) 381 continue NextCipherSuite 382 } 383 } 384 } 385 386 if len(c.config.ClientRandom) == 32 { 387 copy(hello.random, c.config.ClientRandom) 388 } else { 389 _, err := io.ReadFull(c.config.rand(), hello.random) 390 if err != nil { 391 c.sendAlert(alertInternalError) 392 return errors.New("tls: short read from Rand: " + err.Error()) 393 } 394 } 395 396 if c.config.ExtendedRandom { 397 hello.extendedRandomEnabled = true 398 hello.extendedRandom = make([]byte, 32) 399 if _, err := io.ReadFull(c.config.rand(), hello.extendedRandom); err != nil { 400 return errors.New("tls: short read from Rand: " + err.Error()) 401 } 402 } 403 404 if hello.vers >= VersionTLS12 { 405 hello.signatureAndHashes = c.config.signatureAndHashesForClient() 406 } 407 408 sessionCache = c.config.ClientSessionCache 409 if c.config.SessionTicketsDisabled { 410 sessionCache = nil 411 } 412 if sessionCache != nil { 413 hello.ticketSupported = true 414 415 // Try to resume a previously negotiated TLS session, if 416 // available. 417 cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) 418 candidateSession, ok := sessionCache.Get(cacheKey) 419 if ok { 420 // Check that the ciphersuite/version used for the 421 // previous session are still valid. 422 cipherSuiteOk := false 423 for _, id := range hello.cipherSuites { 424 if id == candidateSession.cipherSuite { 425 cipherSuiteOk = true 426 break 427 } 428 } 429 430 versOk := candidateSession.vers >= c.config.minVersion() && 431 candidateSession.vers <= c.config.maxVersion() 432 if versOk && cipherSuiteOk { 433 session = candidateSession 434 } 435 } 436 } 437 438 if session != nil { 439 hello.sessionTicket = session.sessionTicket 440 // A random session ID is used to detect when the 441 // server accepted the ticket and is resuming a session 442 // (see RFC 5077). 443 hello.sessionId = make([]byte, 16) 444 if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { 445 c.sendAlert(alertInternalError) 446 return errors.New("tls: short read from Rand: " + err.Error()) 447 } 448 449 } 450 451 helloBytes = hello.marshal() 452 } 453 454 c.handshakeLog = new(ServerHandshake) 455 c.heartbleedLog = new(Heartbleed) 456 c.writeRecord(recordTypeHandshake, helloBytes) 457 c.handshakeLog.ClientHello = hello.MakeLog() 458 459 msg, err := c.readHandshake() 460 if err != nil { 461 return err 462 } 463 serverHello, ok := msg.(*serverHelloMsg) 464 if !ok { 465 c.sendAlert(alertUnexpectedMessage) 466 return unexpectedMessageError(serverHello, msg) 467 } 468 c.handshakeLog.ServerHello = serverHello.MakeLog() 469 470 if serverHello.heartbeatEnabled { 471 c.heartbeat = true 472 c.heartbleedLog.HeartbeatEnabled = true 473 } 474 475 vers, ok := c.config.mutualVersion(serverHello.vers) 476 if !ok { 477 c.sendAlert(alertProtocolVersion) 478 return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) 479 } 480 c.vers = vers 481 c.haveVers = true 482 483 suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) 484 cipherImplemented := cipherIDInCipherList(serverHello.cipherSuite, implementedCipherSuites) 485 cipherShared := cipherIDInCipherIDList(serverHello.cipherSuite, c.config.cipherSuites()) 486 if suite == nil { 487 // c.sendAlert(alertHandshakeFailure) 488 if !cipherShared { 489 c.cipherError = ErrNoMutualCipher 490 } else if !cipherImplemented { 491 c.cipherError = ErrUnimplementedCipher 492 } 493 } 494 495 hs := &clientHandshakeState{ 496 c: c, 497 serverHello: serverHello, 498 hello: hello, 499 suite: suite, 500 finishedHash: newFinishedHash(c.vers, suite), 501 session: session, 502 } 503 504 hs.finishedHash.Write(helloBytes) 505 hs.finishedHash.Write(hs.serverHello.marshal()) 506 507 isResume, err := hs.processServerHello() 508 if err != nil { 509 return err 510 } 511 if !c.config.DontBufferHandshakes { 512 c.buffering = true 513 defer c.flush() 514 } 515 if isResume { 516 if c.cipherError != nil { 517 c.sendAlert(alertHandshakeFailure) 518 return c.cipherError 519 } 520 if err := hs.establishKeys(); err != nil { 521 return err 522 } 523 if err := hs.readSessionTicket(); err != nil { 524 return err 525 } 526 if err := hs.readFinished(); err != nil { 527 return err 528 } 529 if err := hs.sendFinished(); err != nil { 530 return err 531 } 532 if _, err := c.flush(); err != nil { 533 return err 534 } 535 } else { 536 if err := hs.doFullHandshake(); err != nil { 537 if err == ErrCertsOnly { 538 c.sendAlert(alertCloseNotify) 539 } 540 return err 541 } 542 if err := hs.establishKeys(); err != nil { 543 return err 544 } 545 if err := hs.sendFinished(); err != nil { 546 return err 547 } 548 if _, err := c.flush(); err != nil { 549 return err 550 } 551 if err := hs.readSessionTicket(); err != nil { 552 return err 553 } 554 if err := hs.readFinished(); err != nil { 555 return err 556 } 557 } 558 559 if hs.session == nil { 560 c.handshakeLog.SessionTicket = nil 561 } else { 562 c.handshakeLog.SessionTicket = hs.session.MakeLog() 563 } 564 565 c.handshakeLog.KeyMaterial = hs.MakeLog() 566 567 if sessionCache != nil && hs.session != nil && session != hs.session { 568 sessionCache.Put(cacheKey, hs.session) 569 } 570 571 c.didResume = isResume 572 c.handshakeComplete = true 573 c.cipherSuite = suite.id 574 return nil 575 } 576 577 func (hs *clientHandshakeState) doFullHandshake() error { 578 c := hs.c 579 580 msg, err := c.readHandshake() 581 if err != nil { 582 return err 583 } 584 585 var serverCert *x509.Certificate 586 587 isAnon := hs.suite != nil && (hs.suite.flags&suiteAnon > 0) 588 589 if !isAnon { 590 591 certMsg, ok := msg.(*certificateMsg) 592 if !ok || len(certMsg.certificates) == 0 { 593 c.sendAlert(alertUnexpectedMessage) 594 return unexpectedMessageError(certMsg, msg) 595 } 596 hs.finishedHash.Write(certMsg.marshal()) 597 598 certs := make([]*x509.Certificate, len(certMsg.certificates)) 599 invalidCert := false 600 var invalidCertErr error 601 for i, asn1Data := range certMsg.certificates { 602 cert, err := x509.ParseCertificate(asn1Data) 603 if err != nil { 604 invalidCert = true 605 invalidCertErr = err 606 break 607 } 608 certs[i] = cert 609 } 610 611 c.handshakeLog.ServerCertificates = certMsg.MakeLog() 612 613 if c.config.CertsOnly { 614 // short circuit! 615 err = ErrCertsOnly 616 return err 617 } 618 619 if !invalidCert { 620 opts := x509.VerifyOptions{ 621 Roots: c.config.RootCAs, 622 CurrentTime: c.config.time(), 623 DNSName: c.config.ServerName, 624 Intermediates: x509.NewCertPool(), 625 } 626 627 // Always check validity of the certificates 628 for _, cert := range certs { 629 /* 630 if i == 0 { 631 continue 632 } 633 */ 634 opts.Intermediates.AddCert(cert) 635 } 636 var validation *x509.Validation 637 c.verifiedChains, validation, err = certs[0].ValidateWithStupidDetail(opts) 638 c.handshakeLog.ServerCertificates.addParsed(certs, validation) 639 640 // If actually verifying and invalid, reject 641 if !c.config.InsecureSkipVerify { 642 if err != nil { 643 c.sendAlert(alertBadCertificate) 644 return err 645 } 646 } 647 } 648 649 if invalidCert { 650 c.sendAlert(alertBadCertificate) 651 return errors.New("tls: failed to parse certificate from server: " + invalidCertErr.Error()) 652 } 653 654 c.peerCertificates = certs 655 656 if hs.serverHello.ocspStapling { 657 msg, err = c.readHandshake() 658 if err != nil { 659 return err 660 } 661 cs, ok := msg.(*certificateStatusMsg) 662 if !ok { 663 c.sendAlert(alertUnexpectedMessage) 664 return unexpectedMessageError(cs, msg) 665 } 666 hs.finishedHash.Write(cs.marshal()) 667 668 if cs.statusType == statusTypeOCSP { 669 c.ocspResponse = cs.response 670 } 671 } 672 673 serverCert = certs[0] 674 675 var supportedCertKeyType bool 676 switch serverCert.PublicKey.(type) { 677 case *rsa.PublicKey, *ecdsa.PublicKey, *x509.AugmentedECDSA: 678 supportedCertKeyType = true 679 break 680 case *dsa.PublicKey: 681 if c.config.ClientDSAEnabled { 682 supportedCertKeyType = true 683 } 684 default: 685 break 686 } 687 688 if !supportedCertKeyType { 689 c.sendAlert(alertUnsupportedCertificate) 690 return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", serverCert.PublicKey) 691 } 692 693 msg, err = c.readHandshake() 694 if err != nil { 695 return err 696 } 697 } 698 699 // If we don't support the cipher, quit before we need to read the hs.suite 700 // variable 701 if c.cipherError != nil { 702 return c.cipherError 703 } 704 705 skx, ok := msg.(*serverKeyExchangeMsg) 706 707 keyAgreement := hs.suite.ka(c.vers) 708 709 if ok { 710 hs.finishedHash.Write(skx.marshal()) 711 712 err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, serverCert, skx) 713 c.handshakeLog.ServerKeyExchange = skx.MakeLog(keyAgreement) 714 if err != nil { 715 c.sendAlert(alertUnexpectedMessage) 716 return err 717 } 718 719 msg, err = c.readHandshake() 720 if err != nil { 721 return err 722 } 723 } 724 725 var chainToSend *Certificate 726 var certRequested bool 727 certReq, ok := msg.(*certificateRequestMsg) 728 if ok { 729 certRequested = true 730 731 // RFC 4346 on the certificateAuthorities field: 732 // A list of the distinguished names of acceptable certificate 733 // authorities. These distinguished names may specify a desired 734 // distinguished name for a root CA or for a subordinate CA; 735 // thus, this message can be used to describe both known roots 736 // and a desired authorization space. If the 737 // certificate_authorities list is empty then the client MAY 738 // send any certificate of the appropriate 739 // ClientCertificateType, unless there is some external 740 // arrangement to the contrary. 741 742 hs.finishedHash.Write(certReq.marshal()) 743 744 var rsaAvail, ecdsaAvail bool 745 for _, certType := range certReq.certificateTypes { 746 switch certType { 747 case certTypeRSASign: 748 rsaAvail = true 749 case certTypeECDSASign: 750 ecdsaAvail = true 751 } 752 } 753 754 // We need to search our list of client certs for one 755 // where SignatureAlgorithm is RSA and the Issuer is in 756 // certReq.certificateAuthorities 757 findCert: 758 for i, chain := range c.config.Certificates { 759 if !rsaAvail && !ecdsaAvail { 760 continue 761 } 762 763 for j, cert := range chain.Certificate { 764 x509Cert := chain.Leaf 765 // parse the certificate if this isn't the leaf 766 // node, or if chain.Leaf was nil 767 if j != 0 || x509Cert == nil { 768 if x509Cert, err = x509.ParseCertificate(cert); err != nil { 769 c.sendAlert(alertInternalError) 770 return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error()) 771 } 772 } 773 774 switch { 775 case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA: 776 case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA: 777 default: 778 continue findCert 779 } 780 781 if len(certReq.certificateAuthorities) == 0 { 782 // they gave us an empty list, so just take the 783 // first RSA cert from c.config.Certificates 784 chainToSend = &chain 785 break findCert 786 } 787 788 for _, ca := range certReq.certificateAuthorities { 789 if bytes.Equal(x509Cert.RawIssuer, ca) { 790 chainToSend = &chain 791 break findCert 792 } 793 } 794 } 795 } 796 797 msg, err = c.readHandshake() 798 if err != nil { 799 return err 800 } 801 } 802 803 shd, ok := msg.(*serverHelloDoneMsg) 804 if !ok { 805 c.sendAlert(alertUnexpectedMessage) 806 return unexpectedMessageError(shd, msg) 807 } 808 hs.finishedHash.Write(shd.marshal()) 809 810 // If the server requested a certificate then we have to send a 811 // Certificate message, even if it's empty because we don't have a 812 // certificate to send. 813 if certRequested { 814 certMsg := new(certificateMsg) 815 if chainToSend != nil { 816 certMsg.certificates = chainToSend.Certificate 817 } 818 hs.finishedHash.Write(certMsg.marshal()) 819 c.writeRecord(recordTypeHandshake, certMsg.marshal()) 820 } 821 822 preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, serverCert) 823 if err != nil { 824 c.sendAlert(alertInternalError) 825 return err 826 } 827 828 c.handshakeLog.ClientKeyExchange = ckx.MakeLog(keyAgreement) 829 830 if ckx != nil { 831 hs.finishedHash.Write(ckx.marshal()) 832 c.writeRecord(recordTypeHandshake, ckx.marshal()) 833 } 834 835 if chainToSend != nil { 836 var signed []byte 837 certVerify := &certificateVerifyMsg{ 838 hasSignatureAndHash: c.vers >= VersionTLS12, 839 } 840 841 // Determine the hash to sign. 842 var signatureType uint8 843 switch c.config.Certificates[0].PrivateKey.(type) { 844 case *ecdsa.PrivateKey: 845 signatureType = signatureECDSA 846 case *rsa.PrivateKey: 847 signatureType = signatureRSA 848 default: 849 c.sendAlert(alertInternalError) 850 return errors.New("unknown private key type") 851 } 852 certVerify.signatureAndHash, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.signatureAndHashes, c.config.signatureAndHashesForClient(), signatureType) 853 if err != nil { 854 c.sendAlert(alertInternalError) 855 return err 856 } 857 digest, hashFunc, err := hs.finishedHash.hashForClientCertificate(certVerify.signatureAndHash, hs.masterSecret) 858 if err != nil { 859 c.sendAlert(alertInternalError) 860 return err 861 } 862 863 switch key := c.config.Certificates[0].PrivateKey.(type) { 864 case *ecdsa.PrivateKey: 865 var r, s *big.Int 866 r, s, err = ecdsa.Sign(c.config.rand(), key, digest) 867 if err == nil { 868 signed, err = asn1.Marshal(ecdsaSignature{r, s}) 869 } 870 case *rsa.PrivateKey: 871 signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest) 872 default: 873 err = errors.New("unknown private key type") 874 } 875 if err != nil { 876 c.sendAlert(alertInternalError) 877 return errors.New("tls: failed to sign handshake with client certificate: " + err.Error()) 878 } 879 certVerify.signature = signed 880 881 hs.writeClientHash(certVerify.marshal()) 882 c.writeRecord(recordTypeHandshake, certVerify.marshal()) 883 } 884 885 var cr, sr []byte 886 if hs.hello.extendedRandomEnabled { 887 helloRandomLen := len(hs.hello.random) 888 helloExtendedRandomLen := len(hs.hello.extendedRandom) 889 890 cr = make([]byte, helloRandomLen+helloExtendedRandomLen) 891 copy(cr, hs.hello.random) 892 copy(cr[helloRandomLen:], hs.hello.extendedRandom) 893 } 894 895 if hs.serverHello.extendedRandomEnabled { 896 serverRandomLen := len(hs.serverHello.random) 897 serverExtendedRandomLen := len(hs.serverHello.extendedRandom) 898 899 sr = make([]byte, serverRandomLen+serverExtendedRandomLen) 900 copy(sr, hs.serverHello.random) 901 copy(sr[serverRandomLen:], hs.serverHello.extendedRandom) 902 } 903 904 hs.preMasterSecret = make([]byte, len(preMasterSecret)) 905 copy(hs.preMasterSecret, preMasterSecret) 906 907 if hs.serverHello.extendedMasterSecret && c.vers >= VersionTLS10 { 908 hs.masterSecret = extendedMasterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.finishedHash) 909 c.extendedMasterSecret = true 910 } else { 911 hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) 912 } 913 914 return nil 915 } 916 917 func (hs *clientHandshakeState) establishKeys() error { 918 c := hs.c 919 920 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) 921 var clientCipher, serverCipher interface{} 922 var clientHash, serverHash macFunction 923 if hs.suite.cipher != nil { 924 clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) 925 clientHash = hs.suite.mac(c.vers, clientMAC) 926 serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) 927 serverHash = hs.suite.mac(c.vers, serverMAC) 928 } else { 929 clientCipher = hs.suite.aead(clientKey, clientIV) 930 serverCipher = hs.suite.aead(serverKey, serverIV) 931 } 932 933 c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) 934 c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) 935 return nil 936 } 937 938 func (hs *clientHandshakeState) serverResumedSession() bool { 939 // If the server responded with the same sessionId then it means the 940 // sessionTicket is being used to resume a TLS session. 941 return hs.session != nil && hs.hello.sessionId != nil && 942 bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) 943 } 944 945 func (hs *clientHandshakeState) processServerHello() (bool, error) { 946 c := hs.c 947 948 if hs.serverHello.compressionMethod != compressionNone { 949 c.sendAlert(alertUnexpectedMessage) 950 return false, errors.New("tls: server selected unsupported compression format") 951 } 952 953 clientDidNPN := hs.hello.nextProtoNeg 954 clientDidALPN := len(hs.hello.alpnProtocols) > 0 955 serverHasNPN := hs.serverHello.nextProtoNeg 956 serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 957 958 if !clientDidNPN && serverHasNPN { 959 c.sendAlert(alertHandshakeFailure) 960 return false, errors.New("tls: server advertised unrequested NPN extension") 961 } 962 963 if !clientDidALPN && serverHasALPN { 964 c.sendAlert(alertHandshakeFailure) 965 return false, errors.New("tls: server advertised unrequested ALPN extension") 966 } 967 968 if serverHasNPN && serverHasALPN { 969 c.sendAlert(alertHandshakeFailure) 970 return false, errors.New("tls: server advertised both NPN and ALPN extensions") 971 } 972 973 if serverHasALPN { 974 c.clientProtocol = hs.serverHello.alpnProtocol 975 c.clientProtocolFallback = false 976 } 977 978 if hs.serverResumedSession() { 979 // Restore masterSecret and peerCerts from previous state 980 hs.masterSecret = hs.session.masterSecret 981 c.extendedMasterSecret = hs.session.extendedMasterSecret 982 c.peerCertificates = hs.session.serverCertificates 983 return true, nil 984 } 985 return false, nil 986 } 987 988 func (hs *clientHandshakeState) readFinished() error { 989 c := hs.c 990 991 c.readRecord(recordTypeChangeCipherSpec) 992 if err := c.in.error(); err != nil { 993 return err 994 } 995 996 msg, err := c.readHandshake() 997 if err != nil { 998 return err 999 } 1000 serverFinished, ok := msg.(*finishedMsg) 1001 if !ok { 1002 c.sendAlert(alertUnexpectedMessage) 1003 return unexpectedMessageError(serverFinished, msg) 1004 } 1005 c.handshakeLog.ServerFinished = serverFinished.MakeLog() 1006 1007 verify := hs.finishedHash.serverSum(hs.masterSecret) 1008 if len(verify) != len(serverFinished.verifyData) || 1009 subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { 1010 c.sendAlert(alertHandshakeFailure) 1011 return errors.New("tls: server's Finished message was incorrect") 1012 } 1013 hs.finishedHash.Write(serverFinished.marshal()) 1014 return nil 1015 } 1016 1017 func (hs *clientHandshakeState) readSessionTicket() error { 1018 if !hs.serverHello.ticketSupported { 1019 return nil 1020 } 1021 1022 c := hs.c 1023 msg, err := c.readHandshake() 1024 if err != nil { 1025 return err 1026 } 1027 sessionTicketMsg, ok := msg.(*newSessionTicketMsg) 1028 if !ok { 1029 c.sendAlert(alertUnexpectedMessage) 1030 return unexpectedMessageError(sessionTicketMsg, msg) 1031 } 1032 hs.finishedHash.Write(sessionTicketMsg.marshal()) 1033 1034 hs.session = &ClientSessionState{ 1035 sessionTicket: sessionTicketMsg.ticket, 1036 vers: c.vers, 1037 cipherSuite: hs.suite.id, 1038 masterSecret: hs.masterSecret, 1039 serverCertificates: c.peerCertificates, 1040 lifetimeHint: sessionTicketMsg.lifetimeHint, 1041 } 1042 1043 return nil 1044 } 1045 1046 func (hs *clientHandshakeState) sendFinished() error { 1047 c := hs.c 1048 1049 c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) 1050 if hs.serverHello.nextProtoNeg { 1051 nextProto := new(nextProtoMsg) 1052 proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) 1053 nextProto.proto = proto 1054 c.clientProtocol = proto 1055 c.clientProtocolFallback = fallback 1056 1057 hs.finishedHash.Write(nextProto.marshal()) 1058 c.writeRecord(recordTypeHandshake, nextProto.marshal()) 1059 } 1060 1061 finished := new(finishedMsg) 1062 finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) 1063 hs.finishedHash.Write(finished.marshal()) 1064 1065 c.handshakeLog.ClientFinished = finished.MakeLog() 1066 1067 c.writeRecord(recordTypeHandshake, finished.marshal()) 1068 return nil 1069 } 1070 1071 func (hs *clientHandshakeState) writeClientHash(msg []byte) { 1072 // writeClientHash is called before writeRecord. 1073 hs.writeHash(msg, 0) 1074 } 1075 1076 func (hs *clientHandshakeState) writeServerHash(msg []byte) { 1077 // writeServerHash is called after readHandshake. 1078 hs.writeHash(msg, 0) 1079 } 1080 1081 func (hs *clientHandshakeState) writeHash(msg []byte, seqno uint16) { 1082 hs.finishedHash.Write(msg) 1083 } 1084 1085 // clientSessionCacheKey returns a key used to cache sessionTickets that could 1086 // be used to resume previously negotiated TLS sessions with a server. 1087 func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { 1088 if len(config.ServerName) > 0 { 1089 return config.ServerName 1090 } 1091 return serverAddr.String() 1092 } 1093 1094 // mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol 1095 // given list of possible protocols and a list of the preference order. The 1096 // first list must not be empty. It returns the resulting protocol and flag 1097 // indicating if the fallback case was reached. 1098 func mutualProtocol(protos, preferenceProtos []string) (string, bool) { 1099 for _, s := range preferenceProtos { 1100 for _, c := range protos { 1101 if s == c { 1102 return s, false 1103 } 1104 } 1105 } 1106 1107 return protos[0], true 1108 }