github.com/3andne/restls-client-go@v0.1.6/u_handshake_client.go (about)

     1  // Copyright 2022 uTLS Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"compress/zlib"
    10  	"crypto/ecdh"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  
    15  	"github.com/andybalholm/brotli"
    16  	"github.com/klauspost/compress/zstd"
    17  )
    18  
    19  // This function is called by (*clientHandshakeStateTLS13).readServerCertificate()
    20  // to retrieve the certificate out of a message read by (*Conn).readHandshake()
    21  func (hs *clientHandshakeStateTLS13) utlsReadServerCertificate(msg any) (processedMsg any, err error) {
    22  	for _, ext := range hs.uconn.Extensions {
    23  		switch ext.(type) {
    24  		case *UtlsCompressCertExtension:
    25  			// Included Compressed Certificate extension
    26  			if len(hs.uconn.certCompressionAlgs) > 0 {
    27  				compressedCertMsg, ok := msg.(*utlsCompressedCertificateMsg)
    28  				if ok {
    29  					if err = transcriptMsg(compressedCertMsg, hs.transcript); err != nil {
    30  						return nil, err
    31  					}
    32  					msg, err = hs.decompressCert(*compressedCertMsg)
    33  					if err != nil {
    34  						return nil, fmt.Errorf("tls: failed to decompress certificate message: %w", err)
    35  					} else {
    36  						return msg, nil
    37  					}
    38  				}
    39  			}
    40  		default:
    41  			continue
    42  		}
    43  	}
    44  	return nil, nil
    45  }
    46  
    47  // called by (*clientHandshakeStateTLS13).utlsReadServerCertificate() when UtlsCompressCertExtension is used
    48  func (hs *clientHandshakeStateTLS13) decompressCert(m utlsCompressedCertificateMsg) (*certificateMsgTLS13, error) {
    49  	var (
    50  		decompressed io.Reader
    51  		compressed   = bytes.NewReader(m.compressedCertificateMessage)
    52  		c            = hs.c
    53  	)
    54  
    55  	// Check to see if the peer responded with an algorithm we advertised.
    56  	supportedAlg := false
    57  	for _, alg := range hs.uconn.certCompressionAlgs {
    58  		if m.algorithm == uint16(alg) {
    59  			supportedAlg = true
    60  		}
    61  	}
    62  	if !supportedAlg {
    63  		c.sendAlert(alertBadCertificate)
    64  		return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm)
    65  	}
    66  
    67  	switch CertCompressionAlgo(m.algorithm) {
    68  	case CertCompressionBrotli:
    69  		decompressed = brotli.NewReader(compressed)
    70  
    71  	case CertCompressionZlib:
    72  		rc, err := zlib.NewReader(compressed)
    73  		if err != nil {
    74  			c.sendAlert(alertBadCertificate)
    75  			return nil, fmt.Errorf("failed to open zlib reader: %w", err)
    76  		}
    77  		defer rc.Close()
    78  		decompressed = rc
    79  
    80  	case CertCompressionZstd:
    81  		rc, err := zstd.NewReader(compressed)
    82  		if err != nil {
    83  			c.sendAlert(alertBadCertificate)
    84  			return nil, fmt.Errorf("failed to open zstd reader: %w", err)
    85  		}
    86  		defer rc.Close()
    87  		decompressed = rc
    88  
    89  	default:
    90  		c.sendAlert(alertBadCertificate)
    91  		return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm)
    92  	}
    93  
    94  	rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field
    95  	rawMsg[0] = typeCertificate
    96  	rawMsg[1] = uint8(m.uncompressedLength >> 16)
    97  	rawMsg[2] = uint8(m.uncompressedLength >> 8)
    98  	rawMsg[3] = uint8(m.uncompressedLength)
    99  
   100  	n, err := decompressed.Read(rawMsg[4:])
   101  	if err != nil && !errors.Is(err, io.EOF) {
   102  		c.sendAlert(alertBadCertificate)
   103  		return nil, err
   104  	}
   105  	if n < len(rawMsg)-4 {
   106  		// If, after decompression, the specified length does not match the actual length, the party
   107  		// receiving the invalid message MUST abort the connection with the "bad_certificate" alert.
   108  		// https://datatracker.ietf.org/doc/html/rfc8879#section-4
   109  		c.sendAlert(alertBadCertificate)
   110  		return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength)
   111  	}
   112  	certMsg := new(certificateMsgTLS13)
   113  	if !certMsg.unmarshal(rawMsg) {
   114  		return nil, c.sendAlert(alertUnexpectedMessage)
   115  	}
   116  	return certMsg, nil
   117  }
   118  
   119  // to be called in (*clientHandshakeStateTLS13).handshake(),
   120  // after hs.readServerFinished() and before hs.sendClientCertificate()
   121  func (hs *clientHandshakeStateTLS13) serverFinishedReceived() error {
   122  	if err := hs.sendClientEncryptedExtensions(); err != nil {
   123  		return err
   124  	}
   125  	return nil
   126  }
   127  
   128  func (hs *clientHandshakeStateTLS13) sendClientEncryptedExtensions() error {
   129  	c := hs.c
   130  	clientEncryptedExtensions := new(utlsClientEncryptedExtensionsMsg)
   131  	if c.utls.hasApplicationSettings {
   132  		clientEncryptedExtensions.hasApplicationSettings = true
   133  		clientEncryptedExtensions.applicationSettings = c.utls.localApplicationSettings
   134  		if _, err := c.writeHandshakeRecord(clientEncryptedExtensions, hs.transcript); err != nil {
   135  			return err
   136  		}
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  func (hs *clientHandshakeStateTLS13) utlsReadServerParameters(encryptedExtensions *encryptedExtensionsMsg) error {
   143  	hs.c.utls.hasApplicationSettings = encryptedExtensions.utls.hasApplicationSettings
   144  	hs.c.utls.peerApplicationSettings = encryptedExtensions.utls.applicationSettings
   145  
   146  	if hs.c.utls.hasApplicationSettings {
   147  		if hs.uconn.vers < VersionTLS13 {
   148  			return errors.New("tls: server sent application settings at invalid version")
   149  		}
   150  		if len(hs.uconn.clientProtocol) == 0 {
   151  			return errors.New("tls: server sent application settings without ALPN")
   152  		}
   153  
   154  		// Check if the ALPN selected by the server exists in the client's list.
   155  		if alps, ok := hs.uconn.config.ApplicationSettings[hs.serverHello.alpnProtocol]; ok {
   156  			hs.c.utls.localApplicationSettings = alps
   157  		} else {
   158  			// return errors.New("tls: server selected ALPN doesn't match a client ALPS")
   159  			return nil // ignore if client doesn't have ALPS in use.
   160  			// TODO: is this a issue or not?
   161  		}
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  func (c *Conn) makeClientHelloForApplyPreset() (*clientHelloMsg, *ecdh.PrivateKey, error) {
   168  	config := c.config
   169  
   170  	// [UTLS SECTION START]
   171  	if len(config.ServerName) == 0 && !config.InsecureSkipVerify && len(config.InsecureServerNameToVerify) == 0 {
   172  		return nil, nil, errors.New("tls: at least one of ServerName, InsecureSkipVerify or InsecureServerNameToVerify must be specified in the tls.Config")
   173  	}
   174  	// [UTLS SECTION END]
   175  
   176  	nextProtosLength := 0
   177  	for _, proto := range config.NextProtos {
   178  		if l := len(proto); l == 0 || l > 255 {
   179  			return nil, nil, errors.New("tls: invalid NextProtos value")
   180  		} else {
   181  			nextProtosLength += 1 + l
   182  		}
   183  	}
   184  	if nextProtosLength > 0xffff {
   185  		return nil, nil, errors.New("tls: NextProtos values too large")
   186  	}
   187  
   188  	supportedVersions := config.supportedVersions(roleClient)
   189  	if len(supportedVersions) == 0 {
   190  		return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
   191  	}
   192  
   193  	clientHelloVersion := config.maxSupportedVersion(roleClient)
   194  	// The version at the beginning of the ClientHello was capped at TLS 1.2
   195  	// for compatibility reasons. The supported_versions extension is used
   196  	// to negotiate versions now. See RFC 8446, Section 4.2.1.
   197  	if clientHelloVersion > VersionTLS12 {
   198  		clientHelloVersion = VersionTLS12
   199  	}
   200  
   201  	hello := &clientHelloMsg{
   202  		vers:                         clientHelloVersion,
   203  		compressionMethods:           []uint8{compressionNone},
   204  		random:                       make([]byte, 32),
   205  		extendedMasterSecret:         true,
   206  		ocspStapling:                 true,
   207  		scts:                         true,
   208  		serverName:                   hostnameInSNI(config.ServerName),
   209  		supportedCurves:              config.curvePreferences(),
   210  		supportedPoints:              []uint8{pointFormatUncompressed},
   211  		secureRenegotiationSupported: true,
   212  		alpnProtocols:                config.NextProtos,
   213  		supportedVersions:            supportedVersions,
   214  	}
   215  
   216  	if c.handshakes > 0 {
   217  		hello.secureRenegotiation = c.clientFinished[:]
   218  	}
   219  
   220  	preferenceOrder := cipherSuitesPreferenceOrder
   221  	if !hasAESGCMHardwareSupport {
   222  		preferenceOrder = cipherSuitesPreferenceOrderNoAES
   223  	}
   224  	configCipherSuites := config.cipherSuites()
   225  	hello.cipherSuites = make([]uint16, 0, len(configCipherSuites))
   226  
   227  	for _, suiteId := range preferenceOrder {
   228  		suite := mutualCipherSuite(configCipherSuites, suiteId)
   229  		if suite == nil {
   230  			continue
   231  		}
   232  		// Don't advertise TLS 1.2-only cipher suites unless
   233  		// we're attempting TLS 1.2.
   234  		if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
   235  			continue
   236  		}
   237  		hello.cipherSuites = append(hello.cipherSuites, suiteId)
   238  	}
   239  
   240  	_, err := io.ReadFull(config.rand(), hello.random)
   241  	if err != nil {
   242  		return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
   243  	}
   244  
   245  	// A random session ID is used to detect when the server accepted a ticket
   246  	// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
   247  	// a compatibility measure (see RFC 8446, Section 4.1.2).
   248  	//
   249  	// The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
   250  	if c.quic == nil {
   251  		hello.sessionId = make([]byte, 32)
   252  		if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
   253  			return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
   254  		}
   255  	}
   256  
   257  	if hello.vers >= VersionTLS12 {
   258  		hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   259  	}
   260  	if testingOnlyForceClientHelloSignatureAlgorithms != nil {
   261  		hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
   262  	}
   263  
   264  	var key *ecdh.PrivateKey
   265  	if hello.supportedVersions[0] == VersionTLS13 {
   266  		// Reset the list of ciphers when the client only supports TLS 1.3.
   267  		if len(hello.supportedVersions) == 1 {
   268  			hello.cipherSuites = nil
   269  		}
   270  		if hasAESGCMHardwareSupport {
   271  			hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
   272  		} else {
   273  			hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
   274  		}
   275  
   276  		curveID := config.curvePreferences()[0]
   277  		if _, ok := curveForCurveID(curveID); !ok {
   278  			return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
   279  		}
   280  		key, err = generateECDHEKey(config.rand(), curveID)
   281  		if err != nil {
   282  			return nil, nil, err
   283  		}
   284  		hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
   285  	}
   286  
   287  	// [UTLS] We don't need this, since it is not ready yet
   288  	// if c.quic != nil {
   289  	// 	p, err := c.quicGetTransportParameters()
   290  	// 	if err != nil {
   291  	// 		return nil, nil, err
   292  	// 	}
   293  	// 	if p == nil {
   294  	// 		p = []byte{}
   295  	// 	}
   296  	// 	hello.quicTransportParameters = p
   297  	// }
   298  
   299  	return hello, key, nil
   300  }