github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/handshake_client.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"crypto/subtle"
    12  	"encoding/asn1"
    13  	"encoding/binary"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"math/big"
    18  	"net"
    19  	"strconv"
    20  	"time"
    21  
    22  	"github.com/zmap/zcrypto/dsa"
    23  
    24  	"github.com/zmap/zcrypto/x509"
    25  )
    26  
    27  type clientHandshakeState struct {
    28  	c               *Conn
    29  	serverHello     *serverHelloMsg
    30  	hello           *clientHelloMsg
    31  	suite           *cipherSuite
    32  	finishedHash    finishedHash
    33  	masterSecret    []byte
    34  	preMasterSecret []byte
    35  	session         *ClientSessionState
    36  }
    37  
    38  type CacheKeyGenerator interface {
    39  	Key(net.Addr) string
    40  }
    41  
    42  type ClientFingerprintConfiguration struct {
    43  	// Version in the handshake header
    44  	HandshakeVersion uint16
    45  
    46  	// if len == 32, it will specify the client random.
    47  	// Otherwise, the field will be random
    48  	// except the top 4 bytes if InsertTimestamp is true
    49  	ClientRandom    []byte
    50  	InsertTimestamp bool
    51  
    52  	// if RandomSessionID > 0, will overwrite SessionID w/ that many
    53  	// random bytes when a session resumption occurs
    54  	RandomSessionID int
    55  	SessionID       []byte
    56  
    57  	// These fields will appear exactly in order in the ClientHello
    58  	CipherSuites       []uint16
    59  	CompressionMethods []uint8
    60  	Extensions         []ClientExtension
    61  
    62  	// Optional, both must be non-nil, or neither.
    63  	// Custom Session cache implementations allowed
    64  	SessionCache ClientSessionCache
    65  	CacheKey     CacheKeyGenerator
    66  }
    67  
    68  type ClientExtension interface {
    69  	// Produce the bytes on the wire for this extension, type and length included
    70  	Marshal() []byte
    71  
    72  	// Function will return an error if zTLS does not implement the necessary features for this extension
    73  	CheckImplemented() error
    74  
    75  	// Modifies the config to reflect the state of the extension
    76  	WriteToConfig(*Config) error
    77  }
    78  
    79  func (c *ClientFingerprintConfiguration) CheckImplementedExtensions() error {
    80  	for _, ext := range c.Extensions {
    81  		if err := ext.CheckImplemented(); err != nil {
    82  			return err
    83  		}
    84  	}
    85  	return nil
    86  }
    87  
    88  func (c *clientHelloMsg) WriteToConfig(config *Config) error {
    89  	config.NextProtos = c.alpnProtocols
    90  	config.CipherSuites = c.cipherSuites
    91  	config.MaxVersion = c.vers
    92  	config.ClientRandom = c.random
    93  	config.CurvePreferences = c.supportedCurves
    94  	config.HeartbeatEnabled = c.heartbeatEnabled
    95  	config.ExtendedRandom = c.extendedRandomEnabled
    96  	config.ForceSessionTicketExt = c.ticketSupported
    97  	config.ExtendedMasterSecret = c.extendedMasterSecret
    98  	config.SignedCertificateTimestampExt = c.sctEnabled
    99  	return nil
   100  }
   101  
   102  func (c *ClientFingerprintConfiguration) WriteToConfig(config *Config) error {
   103  	config.NextProtos = []string{}
   104  	config.CipherSuites = c.CipherSuites
   105  	config.MaxVersion = c.HandshakeVersion
   106  	config.ClientRandom = c.ClientRandom
   107  	config.CurvePreferences = []CurveID{}
   108  	config.HeartbeatEnabled = false
   109  	config.ExtendedRandom = false
   110  	config.ForceSessionTicketExt = false
   111  	config.ExtendedMasterSecret = false
   112  	config.SignedCertificateTimestampExt = false
   113  	for _, ext := range c.Extensions {
   114  		if err := ext.WriteToConfig(config); err != nil {
   115  			return err
   116  		}
   117  	}
   118  	return nil
   119  }
   120  
   121  func currentTimestamp() ([]byte, error) {
   122  	t := time.Now().Unix()
   123  	buf := new(bytes.Buffer)
   124  	err := binary.Write(buf, binary.BigEndian, t)
   125  	return buf.Bytes(), err
   126  }
   127  
   128  func (c *ClientFingerprintConfiguration) marshal(config *Config) ([]byte, error) {
   129  	if err := c.CheckImplementedExtensions(); err != nil {
   130  		return nil, err
   131  	}
   132  	head := make([]byte, 38)
   133  	head[0] = 1
   134  	head[4] = uint8(c.HandshakeVersion >> 8)
   135  	head[5] = uint8(c.HandshakeVersion)
   136  	if len(c.ClientRandom) == 32 {
   137  		copy(head[6:38], c.ClientRandom[0:32])
   138  	} else {
   139  		start := 6
   140  		if c.InsertTimestamp {
   141  			t, err := currentTimestamp()
   142  			if err != nil {
   143  				return nil, err
   144  			}
   145  			copy(head[start:start+4], t)
   146  			start = start + 4
   147  		}
   148  		_, err := io.ReadFull(config.rand(), head[start:38])
   149  		if err != nil {
   150  			return nil, errors.New("tls: short read from Rand: " + err.Error())
   151  		}
   152  	}
   153  
   154  	if len(c.SessionID) >= 256 {
   155  		return nil, errors.New("tls: SessionID too long")
   156  	}
   157  	sessionID := make([]byte, len(c.SessionID)+1)
   158  	sessionID[0] = uint8(len(c.SessionID))
   159  	if len(c.SessionID) > 0 {
   160  		copy(sessionID[1:], c.SessionID)
   161  	}
   162  
   163  	ciphers := make([]byte, 2+2*len(c.CipherSuites))
   164  	ciphers[0] = uint8(len(c.CipherSuites) >> 7)
   165  	ciphers[1] = uint8(len(c.CipherSuites) << 1)
   166  	for i, suite := range c.CipherSuites {
   167  		if !config.ForceSuites {
   168  			found := false
   169  			for _, impl := range implementedCipherSuites {
   170  				if impl.id == suite {
   171  					found = true
   172  				}
   173  			}
   174  			if !found {
   175  				return nil, errors.New(fmt.Sprintf("tls: unimplemented cipher suite %d", suite))
   176  			}
   177  		}
   178  
   179  		ciphers[2+i*2] = uint8(suite >> 8)
   180  		ciphers[3+i*2] = uint8(suite)
   181  	}
   182  
   183  	if len(c.CompressionMethods) >= 256 {
   184  		return nil, errors.New("tls: Too many compression methods")
   185  	}
   186  	compressions := make([]byte, len(c.CompressionMethods)+1)
   187  	compressions[0] = uint8(len(c.CompressionMethods))
   188  	if len(c.CompressionMethods) > 0 {
   189  		copy(compressions[1:], c.CompressionMethods)
   190  		if c.CompressionMethods[0] != 0 {
   191  			return nil, errors.New(fmt.Sprintf("tls: unimplemented compression method %d", c.CompressionMethods[0]))
   192  		}
   193  		if len(c.CompressionMethods) > 1 {
   194  			return nil, errors.New(fmt.Sprintf("tls: unimplemented compression method %d", c.CompressionMethods[1]))
   195  		}
   196  	} else {
   197  		return nil, errors.New("tls: no compression method")
   198  	}
   199  
   200  	var extensions []byte
   201  	for _, ext := range c.Extensions {
   202  		extensions = append(extensions, ext.Marshal()...)
   203  	}
   204  	if len(extensions) > 0 {
   205  		length := make([]byte, 2)
   206  		length[0] = uint8(len(extensions) >> 8)
   207  		length[1] = uint8(len(extensions))
   208  		extensions = append(length, extensions...)
   209  	}
   210  	helloArray := [][]byte{head, sessionID, ciphers, compressions, extensions}
   211  	hello := []byte{}
   212  	for _, b := range helloArray {
   213  		hello = append(hello, b...)
   214  	}
   215  	lengthOnTheWire := len(hello) - 4
   216  	if lengthOnTheWire >= 1<<24 {
   217  		return nil, errors.New("ClientHello message too long")
   218  	}
   219  	hello[1] = uint8(lengthOnTheWire >> 16)
   220  	hello[2] = uint8(lengthOnTheWire >> 8)
   221  	hello[3] = uint8(lengthOnTheWire)
   222  
   223  	return hello, nil
   224  }
   225  
   226  func (c *Conn) clientHandshake() error {
   227  	if c.config == nil {
   228  		c.config = defaultConfig()
   229  	}
   230  	var hello *clientHelloMsg
   231  	var helloBytes []byte
   232  	var session *ClientSessionState
   233  	var sessionCache ClientSessionCache
   234  	var cacheKey string
   235  
   236  	// first, let's check if a ClientFingerprintConfiguration template was provided by the config
   237  	if c.config.ClientFingerprintConfiguration != nil {
   238  		if err := c.config.ClientFingerprintConfiguration.WriteToConfig(c.config); err != nil {
   239  			return err
   240  		}
   241  		session = nil
   242  		sessionCache = c.config.ClientFingerprintConfiguration.SessionCache
   243  		if sessionCache != nil {
   244  			if c.config.ClientFingerprintConfiguration.CacheKey == nil {
   245  				return errors.New("tls: must specify CacheKey if SessionCache is defined in Config.ClientFingerprintConfiguration")
   246  			}
   247  			cacheKey = c.config.ClientFingerprintConfiguration.CacheKey.Key(c.conn.RemoteAddr())
   248  			candidateSession, ok := sessionCache.Get(cacheKey)
   249  			if ok {
   250  				cipherSuiteOk := false
   251  				for _, id := range c.config.ClientFingerprintConfiguration.CipherSuites {
   252  					if id == candidateSession.cipherSuite {
   253  						cipherSuiteOk = true
   254  						break
   255  					}
   256  				}
   257  				versOk := candidateSession.vers >= c.config.minVersion() &&
   258  					candidateSession.vers <= c.config.ClientFingerprintConfiguration.HandshakeVersion
   259  				if versOk && cipherSuiteOk {
   260  					session = candidateSession
   261  				}
   262  			}
   263  		}
   264  		for i, ext := range c.config.ClientFingerprintConfiguration.Extensions {
   265  			switch casted := ext.(type) {
   266  			case *SessionTicketExtension:
   267  				if casted.Autopopulate {
   268  					if session == nil {
   269  						if !c.config.ForceSessionTicketExt {
   270  							c.config.ClientFingerprintConfiguration.Extensions[i] = &NullExtension{}
   271  						}
   272  					} else {
   273  						c.config.ClientFingerprintConfiguration.Extensions[i] = &SessionTicketExtension{session.sessionTicket, true}
   274  						if c.config.ClientFingerprintConfiguration.RandomSessionID > 0 {
   275  							c.config.ClientFingerprintConfiguration.SessionID = make([]byte, c.config.ClientFingerprintConfiguration.RandomSessionID)
   276  							if _, err := io.ReadFull(c.config.rand(), c.config.ClientFingerprintConfiguration.SessionID); err != nil {
   277  								c.sendAlert(alertInternalError)
   278  								return errors.New("tls: short read from Rand: " + err.Error())
   279  							}
   280  
   281  						}
   282  					}
   283  				}
   284  			}
   285  		}
   286  		var err error
   287  		helloBytes, err = c.config.ClientFingerprintConfiguration.marshal(c.config)
   288  		if err != nil {
   289  			return err
   290  		}
   291  		hello = &clientHelloMsg{}
   292  		if ok := hello.unmarshal(helloBytes); !ok {
   293  			return errors.New("tls: incompatible ClientFingerprintConfiguration")
   294  		}
   295  
   296  		// next, let's check if a ClientHello template was provided by the user
   297  	} else if c.config.ExternalClientHello != nil {
   298  
   299  		hello = new(clientHelloMsg)
   300  
   301  		if !hello.unmarshal(c.config.ExternalClientHello) {
   302  			return errors.New("could not read the ClientHello provided")
   303  		}
   304  		if err := hello.WriteToConfig(c.config); err != nil {
   305  			return err
   306  		}
   307  
   308  		// update the SNI with one name, whether or not the extension was already there
   309  		hello.serverName = c.config.ServerName
   310  
   311  		// then we update the 'raw' value of the message
   312  		hello.raw = nil
   313  		helloBytes = hello.marshal()
   314  
   315  		session = nil
   316  		sessionCache = nil
   317  	} else {
   318  		if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
   319  			return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
   320  		}
   321  
   322  		supportedPoints := []uint8{pointFormatUncompressed}
   323  		if c.config.SupportedPoints != nil {
   324  			supportedPoints = c.config.SupportedPoints
   325  		}
   326  		oscpStapling := true
   327  		if c.config.NoOcspStapling {
   328  			oscpStapling = false
   329  		}
   330  
   331  		compressionMethods := []uint8{compressionNone}
   332  		if c.config.CompressionMethods != nil {
   333  			compressionMethods = c.config.CompressionMethods
   334  		}
   335  
   336  		hello = &clientHelloMsg{
   337  			vers:                 c.config.maxVersion(),
   338  			compressionMethods:   compressionMethods,
   339  			random:               make([]byte, 32),
   340  			ocspStapling:         oscpStapling,
   341  			serverName:           c.config.ServerName,
   342  			supportedCurves:      c.config.curvePreferences(),
   343  			supportedPoints:      supportedPoints,
   344  			nextProtoNeg:         len(c.config.NextProtos) > 0,
   345  			secureRenegotiation:  true,
   346  			alpnProtocols:        c.config.NextProtos,
   347  			extendedMasterSecret: c.config.maxVersion() >= VersionTLS10 && c.config.ExtendedMasterSecret,
   348  		}
   349  
   350  		if c.config.ForceSessionTicketExt {
   351  			hello.ticketSupported = true
   352  		}
   353  		if c.config.SignedCertificateTimestampExt {
   354  			hello.sctEnabled = true
   355  		}
   356  
   357  		if c.config.HeartbeatEnabled && !c.config.ExtendedRandom {
   358  			hello.heartbeatEnabled = true
   359  			hello.heartbeatMode = heartbeatModePeerAllowed
   360  		}
   361  
   362  		possibleCipherSuites := c.config.cipherSuites()
   363  		hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
   364  
   365  		if c.config.ForceSuites {
   366  			hello.cipherSuites = possibleCipherSuites
   367  		} else {
   368  
   369  		NextCipherSuite:
   370  			for _, suiteId := range possibleCipherSuites {
   371  				for _, suite := range implementedCipherSuites {
   372  					if suite.id != suiteId {
   373  						continue
   374  					}
   375  					// Don't advertise TLS 1.2-only cipher suites unless
   376  					// we're attempting TLS 1.2.
   377  					if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
   378  						continue
   379  					}
   380  					hello.cipherSuites = append(hello.cipherSuites, suiteId)
   381  					continue NextCipherSuite
   382  				}
   383  			}
   384  		}
   385  
   386  		if len(c.config.ClientRandom) == 32 {
   387  			copy(hello.random, c.config.ClientRandom)
   388  		} else {
   389  			_, err := io.ReadFull(c.config.rand(), hello.random)
   390  			if err != nil {
   391  				c.sendAlert(alertInternalError)
   392  				return errors.New("tls: short read from Rand: " + err.Error())
   393  			}
   394  		}
   395  
   396  		if c.config.ExtendedRandom {
   397  			hello.extendedRandomEnabled = true
   398  			hello.extendedRandom = make([]byte, 32)
   399  			if _, err := io.ReadFull(c.config.rand(), hello.extendedRandom); err != nil {
   400  				return errors.New("tls: short read from Rand: " + err.Error())
   401  			}
   402  		}
   403  
   404  		if hello.vers >= VersionTLS12 {
   405  			hello.signatureAndHashes = c.config.signatureAndHashesForClient()
   406  		}
   407  
   408  		sessionCache = c.config.ClientSessionCache
   409  		if c.config.SessionTicketsDisabled {
   410  			sessionCache = nil
   411  		}
   412  		if sessionCache != nil {
   413  			hello.ticketSupported = true
   414  
   415  			// Try to resume a previously negotiated TLS session, if
   416  			// available.
   417  			cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
   418  			candidateSession, ok := sessionCache.Get(cacheKey)
   419  			if ok {
   420  				// Check that the ciphersuite/version used for the
   421  				// previous session are still valid.
   422  				cipherSuiteOk := false
   423  				for _, id := range hello.cipherSuites {
   424  					if id == candidateSession.cipherSuite {
   425  						cipherSuiteOk = true
   426  						break
   427  					}
   428  				}
   429  
   430  				versOk := candidateSession.vers >= c.config.minVersion() &&
   431  					candidateSession.vers <= c.config.maxVersion()
   432  				if versOk && cipherSuiteOk {
   433  					session = candidateSession
   434  				}
   435  			}
   436  		}
   437  
   438  		if session != nil {
   439  			hello.sessionTicket = session.sessionTicket
   440  			// A random session ID is used to detect when the
   441  			// server accepted the ticket and is resuming a session
   442  			// (see RFC 5077).
   443  			hello.sessionId = make([]byte, 16)
   444  			if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil {
   445  				c.sendAlert(alertInternalError)
   446  				return errors.New("tls: short read from Rand: " + err.Error())
   447  			}
   448  
   449  		}
   450  
   451  		helloBytes = hello.marshal()
   452  	}
   453  
   454  	c.handshakeLog = new(ServerHandshake)
   455  	c.heartbleedLog = new(Heartbleed)
   456  	c.writeRecord(recordTypeHandshake, helloBytes)
   457  	c.handshakeLog.ClientHello = hello.MakeLog()
   458  
   459  	msg, err := c.readHandshake()
   460  	if err != nil {
   461  		return err
   462  	}
   463  	serverHello, ok := msg.(*serverHelloMsg)
   464  	if !ok {
   465  		c.sendAlert(alertUnexpectedMessage)
   466  		return unexpectedMessageError(serverHello, msg)
   467  	}
   468  	c.handshakeLog.ServerHello = serverHello.MakeLog()
   469  
   470  	if serverHello.heartbeatEnabled {
   471  		c.heartbeat = true
   472  		c.heartbleedLog.HeartbeatEnabled = true
   473  	}
   474  
   475  	vers, ok := c.config.mutualVersion(serverHello.vers)
   476  	if !ok {
   477  		c.sendAlert(alertProtocolVersion)
   478  		return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers)
   479  	}
   480  	c.vers = vers
   481  	c.haveVers = true
   482  
   483  	suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
   484  	cipherImplemented := cipherIDInCipherList(serverHello.cipherSuite, implementedCipherSuites)
   485  	cipherShared := cipherIDInCipherIDList(serverHello.cipherSuite, c.config.cipherSuites())
   486  	if suite == nil {
   487  		// c.sendAlert(alertHandshakeFailure)
   488  		if !cipherShared {
   489  			c.cipherError = ErrNoMutualCipher
   490  		} else if !cipherImplemented {
   491  			c.cipherError = ErrUnimplementedCipher
   492  		}
   493  	}
   494  
   495  	hs := &clientHandshakeState{
   496  		c:            c,
   497  		serverHello:  serverHello,
   498  		hello:        hello,
   499  		suite:        suite,
   500  		finishedHash: newFinishedHash(c.vers, suite),
   501  		session:      session,
   502  	}
   503  
   504  	hs.finishedHash.Write(helloBytes)
   505  	hs.finishedHash.Write(hs.serverHello.marshal())
   506  
   507  	isResume, err := hs.processServerHello()
   508  	if err != nil {
   509  		return err
   510  	}
   511  	if !c.config.DontBufferHandshakes {
   512  		c.buffering = true
   513  		defer c.flush()
   514  	}
   515  	if isResume {
   516  		if c.cipherError != nil {
   517  			c.sendAlert(alertHandshakeFailure)
   518  			return c.cipherError
   519  		}
   520  		if err := hs.establishKeys(); err != nil {
   521  			return err
   522  		}
   523  		if err := hs.readSessionTicket(); err != nil {
   524  			return err
   525  		}
   526  		if err := hs.readFinished(); err != nil {
   527  			return err
   528  		}
   529  		if err := hs.sendFinished(); err != nil {
   530  			return err
   531  		}
   532  		if _, err := c.flush(); err != nil {
   533  			return err
   534  		}
   535  	} else {
   536  		if err := hs.doFullHandshake(); err != nil {
   537  			if err == ErrCertsOnly {
   538  				c.sendAlert(alertCloseNotify)
   539  			}
   540  			return err
   541  		}
   542  		if err := hs.establishKeys(); err != nil {
   543  			return err
   544  		}
   545  		if err := hs.sendFinished(); err != nil {
   546  			return err
   547  		}
   548  		if _, err := c.flush(); err != nil {
   549  			return err
   550  		}
   551  		if err := hs.readSessionTicket(); err != nil {
   552  			return err
   553  		}
   554  		if err := hs.readFinished(); err != nil {
   555  			return err
   556  		}
   557  	}
   558  
   559  	if hs.session == nil {
   560  		c.handshakeLog.SessionTicket = nil
   561  	} else {
   562  		c.handshakeLog.SessionTicket = hs.session.MakeLog()
   563  	}
   564  
   565  	c.handshakeLog.KeyMaterial = hs.MakeLog()
   566  
   567  	if sessionCache != nil && hs.session != nil && session != hs.session {
   568  		sessionCache.Put(cacheKey, hs.session)
   569  	}
   570  
   571  	c.didResume = isResume
   572  	c.handshakeComplete = true
   573  	c.cipherSuite = suite.id
   574  	return nil
   575  }
   576  
   577  func (hs *clientHandshakeState) doFullHandshake() error {
   578  	c := hs.c
   579  
   580  	msg, err := c.readHandshake()
   581  	if err != nil {
   582  		return err
   583  	}
   584  
   585  	var serverCert *x509.Certificate
   586  
   587  	isAnon := hs.suite != nil && (hs.suite.flags&suiteAnon > 0)
   588  
   589  	if !isAnon {
   590  
   591  		certMsg, ok := msg.(*certificateMsg)
   592  		if !ok || len(certMsg.certificates) == 0 {
   593  			c.sendAlert(alertUnexpectedMessage)
   594  			return unexpectedMessageError(certMsg, msg)
   595  		}
   596  		hs.finishedHash.Write(certMsg.marshal())
   597  
   598  		certs := make([]*x509.Certificate, len(certMsg.certificates))
   599  		invalidCert := false
   600  		var invalidCertErr error
   601  		for i, asn1Data := range certMsg.certificates {
   602  			cert, err := x509.ParseCertificate(asn1Data)
   603  			if err != nil {
   604  				invalidCert = true
   605  				invalidCertErr = err
   606  				break
   607  			}
   608  			certs[i] = cert
   609  		}
   610  
   611  		c.handshakeLog.ServerCertificates = certMsg.MakeLog()
   612  
   613  		if c.config.CertsOnly {
   614  			// short circuit!
   615  			err = ErrCertsOnly
   616  			return err
   617  		}
   618  
   619  		if !invalidCert {
   620  			opts := x509.VerifyOptions{
   621  				Roots:         c.config.RootCAs,
   622  				CurrentTime:   c.config.time(),
   623  				DNSName:       c.config.ServerName,
   624  				Intermediates: x509.NewCertPool(),
   625  			}
   626  
   627  			// Always check validity of the certificates
   628  			for _, cert := range certs {
   629  				/*
   630  					if i == 0 {
   631  						continue
   632  					}
   633  				*/
   634  				opts.Intermediates.AddCert(cert)
   635  			}
   636  			var validation *x509.Validation
   637  			c.verifiedChains, validation, err = certs[0].ValidateWithStupidDetail(opts)
   638  			c.handshakeLog.ServerCertificates.addParsed(certs, validation)
   639  
   640  			// If actually verifying and invalid, reject
   641  			if !c.config.InsecureSkipVerify {
   642  				if err != nil {
   643  					c.sendAlert(alertBadCertificate)
   644  					return err
   645  				}
   646  			}
   647  		}
   648  
   649  		if invalidCert {
   650  			c.sendAlert(alertBadCertificate)
   651  			return errors.New("tls: failed to parse certificate from server: " + invalidCertErr.Error())
   652  		}
   653  
   654  		c.peerCertificates = certs
   655  
   656  		if hs.serverHello.ocspStapling {
   657  			msg, err = c.readHandshake()
   658  			if err != nil {
   659  				return err
   660  			}
   661  			cs, ok := msg.(*certificateStatusMsg)
   662  			if !ok {
   663  				c.sendAlert(alertUnexpectedMessage)
   664  				return unexpectedMessageError(cs, msg)
   665  			}
   666  			hs.finishedHash.Write(cs.marshal())
   667  
   668  			if cs.statusType == statusTypeOCSP {
   669  				c.ocspResponse = cs.response
   670  			}
   671  		}
   672  
   673  		serverCert = certs[0]
   674  
   675  		var supportedCertKeyType bool
   676  		switch serverCert.PublicKey.(type) {
   677  		case *rsa.PublicKey, *ecdsa.PublicKey, *x509.AugmentedECDSA:
   678  			supportedCertKeyType = true
   679  			break
   680  		case *dsa.PublicKey:
   681  			if c.config.ClientDSAEnabled {
   682  				supportedCertKeyType = true
   683  			}
   684  		default:
   685  			break
   686  		}
   687  
   688  		if !supportedCertKeyType {
   689  			c.sendAlert(alertUnsupportedCertificate)
   690  			return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", serverCert.PublicKey)
   691  		}
   692  
   693  		msg, err = c.readHandshake()
   694  		if err != nil {
   695  			return err
   696  		}
   697  	}
   698  
   699  	// If we don't support the cipher, quit before we need to read the hs.suite
   700  	// variable
   701  	if c.cipherError != nil {
   702  		return c.cipherError
   703  	}
   704  
   705  	skx, ok := msg.(*serverKeyExchangeMsg)
   706  
   707  	keyAgreement := hs.suite.ka(c.vers)
   708  
   709  	if ok {
   710  		hs.finishedHash.Write(skx.marshal())
   711  
   712  		err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, serverCert, skx)
   713  		c.handshakeLog.ServerKeyExchange = skx.MakeLog(keyAgreement)
   714  		if err != nil {
   715  			c.sendAlert(alertUnexpectedMessage)
   716  			return err
   717  		}
   718  
   719  		msg, err = c.readHandshake()
   720  		if err != nil {
   721  			return err
   722  		}
   723  	}
   724  
   725  	var chainToSend *Certificate
   726  	var certRequested bool
   727  	certReq, ok := msg.(*certificateRequestMsg)
   728  	if ok {
   729  		certRequested = true
   730  
   731  		// RFC 4346 on the certificateAuthorities field:
   732  		// A list of the distinguished names of acceptable certificate
   733  		// authorities. These distinguished names may specify a desired
   734  		// distinguished name for a root CA or for a subordinate CA;
   735  		// thus, this message can be used to describe both known roots
   736  		// and a desired authorization space. If the
   737  		// certificate_authorities list is empty then the client MAY
   738  		// send any certificate of the appropriate
   739  		// ClientCertificateType, unless there is some external
   740  		// arrangement to the contrary.
   741  
   742  		hs.finishedHash.Write(certReq.marshal())
   743  
   744  		var rsaAvail, ecdsaAvail bool
   745  		for _, certType := range certReq.certificateTypes {
   746  			switch certType {
   747  			case certTypeRSASign:
   748  				rsaAvail = true
   749  			case certTypeECDSASign:
   750  				ecdsaAvail = true
   751  			}
   752  		}
   753  
   754  		// We need to search our list of client certs for one
   755  		// where SignatureAlgorithm is RSA and the Issuer is in
   756  		// certReq.certificateAuthorities
   757  	findCert:
   758  		for i, chain := range c.config.Certificates {
   759  			if !rsaAvail && !ecdsaAvail {
   760  				continue
   761  			}
   762  
   763  			for j, cert := range chain.Certificate {
   764  				x509Cert := chain.Leaf
   765  				// parse the certificate if this isn't the leaf
   766  				// node, or if chain.Leaf was nil
   767  				if j != 0 || x509Cert == nil {
   768  					if x509Cert, err = x509.ParseCertificate(cert); err != nil {
   769  						c.sendAlert(alertInternalError)
   770  						return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
   771  					}
   772  				}
   773  
   774  				switch {
   775  				case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
   776  				case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
   777  				default:
   778  					continue findCert
   779  				}
   780  
   781  				if len(certReq.certificateAuthorities) == 0 {
   782  					// they gave us an empty list, so just take the
   783  					// first RSA cert from c.config.Certificates
   784  					chainToSend = &chain
   785  					break findCert
   786  				}
   787  
   788  				for _, ca := range certReq.certificateAuthorities {
   789  					if bytes.Equal(x509Cert.RawIssuer, ca) {
   790  						chainToSend = &chain
   791  						break findCert
   792  					}
   793  				}
   794  			}
   795  		}
   796  
   797  		msg, err = c.readHandshake()
   798  		if err != nil {
   799  			return err
   800  		}
   801  	}
   802  
   803  	shd, ok := msg.(*serverHelloDoneMsg)
   804  	if !ok {
   805  		c.sendAlert(alertUnexpectedMessage)
   806  		return unexpectedMessageError(shd, msg)
   807  	}
   808  	hs.finishedHash.Write(shd.marshal())
   809  
   810  	// If the server requested a certificate then we have to send a
   811  	// Certificate message, even if it's empty because we don't have a
   812  	// certificate to send.
   813  	if certRequested {
   814  		certMsg := new(certificateMsg)
   815  		if chainToSend != nil {
   816  			certMsg.certificates = chainToSend.Certificate
   817  		}
   818  		hs.finishedHash.Write(certMsg.marshal())
   819  		c.writeRecord(recordTypeHandshake, certMsg.marshal())
   820  	}
   821  
   822  	preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, serverCert)
   823  	if err != nil {
   824  		c.sendAlert(alertInternalError)
   825  		return err
   826  	}
   827  
   828  	c.handshakeLog.ClientKeyExchange = ckx.MakeLog(keyAgreement)
   829  
   830  	if ckx != nil {
   831  		hs.finishedHash.Write(ckx.marshal())
   832  		c.writeRecord(recordTypeHandshake, ckx.marshal())
   833  	}
   834  
   835  	if chainToSend != nil {
   836  		var signed []byte
   837  		certVerify := &certificateVerifyMsg{
   838  			hasSignatureAndHash: c.vers >= VersionTLS12,
   839  		}
   840  
   841  		// Determine the hash to sign.
   842  		var signatureType uint8
   843  		switch c.config.Certificates[0].PrivateKey.(type) {
   844  		case *ecdsa.PrivateKey:
   845  			signatureType = signatureECDSA
   846  		case *rsa.PrivateKey:
   847  			signatureType = signatureRSA
   848  		default:
   849  			c.sendAlert(alertInternalError)
   850  			return errors.New("unknown private key type")
   851  		}
   852  		certVerify.signatureAndHash, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.signatureAndHashes, c.config.signatureAndHashesForClient(), signatureType)
   853  		if err != nil {
   854  			c.sendAlert(alertInternalError)
   855  			return err
   856  		}
   857  		digest, hashFunc, err := hs.finishedHash.hashForClientCertificate(certVerify.signatureAndHash, hs.masterSecret)
   858  		if err != nil {
   859  			c.sendAlert(alertInternalError)
   860  			return err
   861  		}
   862  
   863  		switch key := c.config.Certificates[0].PrivateKey.(type) {
   864  		case *ecdsa.PrivateKey:
   865  			var r, s *big.Int
   866  			r, s, err = ecdsa.Sign(c.config.rand(), key, digest)
   867  			if err == nil {
   868  				signed, err = asn1.Marshal(ecdsaSignature{r, s})
   869  			}
   870  		case *rsa.PrivateKey:
   871  			signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest)
   872  		default:
   873  			err = errors.New("unknown private key type")
   874  		}
   875  		if err != nil {
   876  			c.sendAlert(alertInternalError)
   877  			return errors.New("tls: failed to sign handshake with client certificate: " + err.Error())
   878  		}
   879  		certVerify.signature = signed
   880  
   881  		hs.writeClientHash(certVerify.marshal())
   882  		c.writeRecord(recordTypeHandshake, certVerify.marshal())
   883  	}
   884  
   885  	var cr, sr []byte
   886  	if hs.hello.extendedRandomEnabled {
   887  		helloRandomLen := len(hs.hello.random)
   888  		helloExtendedRandomLen := len(hs.hello.extendedRandom)
   889  
   890  		cr = make([]byte, helloRandomLen+helloExtendedRandomLen)
   891  		copy(cr, hs.hello.random)
   892  		copy(cr[helloRandomLen:], hs.hello.extendedRandom)
   893  	}
   894  
   895  	if hs.serverHello.extendedRandomEnabled {
   896  		serverRandomLen := len(hs.serverHello.random)
   897  		serverExtendedRandomLen := len(hs.serverHello.extendedRandom)
   898  
   899  		sr = make([]byte, serverRandomLen+serverExtendedRandomLen)
   900  		copy(sr, hs.serverHello.random)
   901  		copy(sr[serverRandomLen:], hs.serverHello.extendedRandom)
   902  	}
   903  
   904  	hs.preMasterSecret = make([]byte, len(preMasterSecret))
   905  	copy(hs.preMasterSecret, preMasterSecret)
   906  
   907  	if hs.serverHello.extendedMasterSecret && c.vers >= VersionTLS10 {
   908  		hs.masterSecret = extendedMasterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.finishedHash)
   909  		c.extendedMasterSecret = true
   910  	} else {
   911  		hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
   912  	}
   913  
   914  	return nil
   915  }
   916  
   917  func (hs *clientHandshakeState) establishKeys() error {
   918  	c := hs.c
   919  
   920  	clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
   921  	var clientCipher, serverCipher interface{}
   922  	var clientHash, serverHash macFunction
   923  	if hs.suite.cipher != nil {
   924  		clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
   925  		clientHash = hs.suite.mac(c.vers, clientMAC)
   926  		serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
   927  		serverHash = hs.suite.mac(c.vers, serverMAC)
   928  	} else {
   929  		clientCipher = hs.suite.aead(clientKey, clientIV)
   930  		serverCipher = hs.suite.aead(serverKey, serverIV)
   931  	}
   932  
   933  	c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
   934  	c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
   935  	return nil
   936  }
   937  
   938  func (hs *clientHandshakeState) serverResumedSession() bool {
   939  	// If the server responded with the same sessionId then it means the
   940  	// sessionTicket is being used to resume a TLS session.
   941  	return hs.session != nil && hs.hello.sessionId != nil &&
   942  		bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
   943  }
   944  
   945  func (hs *clientHandshakeState) processServerHello() (bool, error) {
   946  	c := hs.c
   947  
   948  	if hs.serverHello.compressionMethod != compressionNone {
   949  		c.sendAlert(alertUnexpectedMessage)
   950  		return false, errors.New("tls: server selected unsupported compression format")
   951  	}
   952  
   953  	clientDidNPN := hs.hello.nextProtoNeg
   954  	clientDidALPN := len(hs.hello.alpnProtocols) > 0
   955  	serverHasNPN := hs.serverHello.nextProtoNeg
   956  	serverHasALPN := len(hs.serverHello.alpnProtocol) > 0
   957  
   958  	if !clientDidNPN && serverHasNPN {
   959  		c.sendAlert(alertHandshakeFailure)
   960  		return false, errors.New("tls: server advertised unrequested NPN extension")
   961  	}
   962  
   963  	if !clientDidALPN && serverHasALPN {
   964  		c.sendAlert(alertHandshakeFailure)
   965  		return false, errors.New("tls: server advertised unrequested ALPN extension")
   966  	}
   967  
   968  	if serverHasNPN && serverHasALPN {
   969  		c.sendAlert(alertHandshakeFailure)
   970  		return false, errors.New("tls: server advertised both NPN and ALPN extensions")
   971  	}
   972  
   973  	if serverHasALPN {
   974  		c.clientProtocol = hs.serverHello.alpnProtocol
   975  		c.clientProtocolFallback = false
   976  	}
   977  
   978  	if hs.serverResumedSession() {
   979  		// Restore masterSecret and peerCerts from previous state
   980  		hs.masterSecret = hs.session.masterSecret
   981  		c.extendedMasterSecret = hs.session.extendedMasterSecret
   982  		c.peerCertificates = hs.session.serverCertificates
   983  		return true, nil
   984  	}
   985  	return false, nil
   986  }
   987  
   988  func (hs *clientHandshakeState) readFinished() error {
   989  	c := hs.c
   990  
   991  	c.readRecord(recordTypeChangeCipherSpec)
   992  	if err := c.in.error(); err != nil {
   993  		return err
   994  	}
   995  
   996  	msg, err := c.readHandshake()
   997  	if err != nil {
   998  		return err
   999  	}
  1000  	serverFinished, ok := msg.(*finishedMsg)
  1001  	if !ok {
  1002  		c.sendAlert(alertUnexpectedMessage)
  1003  		return unexpectedMessageError(serverFinished, msg)
  1004  	}
  1005  	c.handshakeLog.ServerFinished = serverFinished.MakeLog()
  1006  
  1007  	verify := hs.finishedHash.serverSum(hs.masterSecret)
  1008  	if len(verify) != len(serverFinished.verifyData) ||
  1009  		subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
  1010  		c.sendAlert(alertHandshakeFailure)
  1011  		return errors.New("tls: server's Finished message was incorrect")
  1012  	}
  1013  	hs.finishedHash.Write(serverFinished.marshal())
  1014  	return nil
  1015  }
  1016  
  1017  func (hs *clientHandshakeState) readSessionTicket() error {
  1018  	if !hs.serverHello.ticketSupported {
  1019  		return nil
  1020  	}
  1021  
  1022  	c := hs.c
  1023  	msg, err := c.readHandshake()
  1024  	if err != nil {
  1025  		return err
  1026  	}
  1027  	sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
  1028  	if !ok {
  1029  		c.sendAlert(alertUnexpectedMessage)
  1030  		return unexpectedMessageError(sessionTicketMsg, msg)
  1031  	}
  1032  	hs.finishedHash.Write(sessionTicketMsg.marshal())
  1033  
  1034  	hs.session = &ClientSessionState{
  1035  		sessionTicket:      sessionTicketMsg.ticket,
  1036  		vers:               c.vers,
  1037  		cipherSuite:        hs.suite.id,
  1038  		masterSecret:       hs.masterSecret,
  1039  		serverCertificates: c.peerCertificates,
  1040  		lifetimeHint:       sessionTicketMsg.lifetimeHint,
  1041  	}
  1042  
  1043  	return nil
  1044  }
  1045  
  1046  func (hs *clientHandshakeState) sendFinished() error {
  1047  	c := hs.c
  1048  
  1049  	c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
  1050  	if hs.serverHello.nextProtoNeg {
  1051  		nextProto := new(nextProtoMsg)
  1052  		proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
  1053  		nextProto.proto = proto
  1054  		c.clientProtocol = proto
  1055  		c.clientProtocolFallback = fallback
  1056  
  1057  		hs.finishedHash.Write(nextProto.marshal())
  1058  		c.writeRecord(recordTypeHandshake, nextProto.marshal())
  1059  	}
  1060  
  1061  	finished := new(finishedMsg)
  1062  	finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
  1063  	hs.finishedHash.Write(finished.marshal())
  1064  
  1065  	c.handshakeLog.ClientFinished = finished.MakeLog()
  1066  
  1067  	c.writeRecord(recordTypeHandshake, finished.marshal())
  1068  	return nil
  1069  }
  1070  
  1071  func (hs *clientHandshakeState) writeClientHash(msg []byte) {
  1072  	// writeClientHash is called before writeRecord.
  1073  	hs.writeHash(msg, 0)
  1074  }
  1075  
  1076  func (hs *clientHandshakeState) writeServerHash(msg []byte) {
  1077  	// writeServerHash is called after readHandshake.
  1078  	hs.writeHash(msg, 0)
  1079  }
  1080  
  1081  func (hs *clientHandshakeState) writeHash(msg []byte, seqno uint16) {
  1082  	hs.finishedHash.Write(msg)
  1083  }
  1084  
  1085  // clientSessionCacheKey returns a key used to cache sessionTickets that could
  1086  // be used to resume previously negotiated TLS sessions with a server.
  1087  func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
  1088  	if len(config.ServerName) > 0 {
  1089  		return config.ServerName
  1090  	}
  1091  	return serverAddr.String()
  1092  }
  1093  
  1094  // mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol
  1095  // given list of possible protocols and a list of the preference order. The
  1096  // first list must not be empty. It returns the resulting protocol and flag
  1097  // indicating if the fallback case was reached.
  1098  func mutualProtocol(protos, preferenceProtos []string) (string, bool) {
  1099  	for _, s := range preferenceProtos {
  1100  		for _, c := range protos {
  1101  			if s == c {
  1102  				return s, false
  1103  			}
  1104  		}
  1105  	}
  1106  
  1107  	return protos[0], true
  1108  }