gitlab.com/go-extension/tls@v0.0.0-20240304171319-e6745021905e/ech.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  
    12  	"github.com/cloudflare/circl/hpke"
    13  
    14  	"golang.org/x/crypto/cryptobyte"
    15  )
    16  
    17  const (
    18  	// Constants for TLS operations
    19  	echAcceptConfLabel    = "ech accept confirmation"
    20  	echAcceptConfHRRLabel = "hrr ech accept confirmation"
    21  
    22  	// Constants for HPKE operations
    23  	echHpkeInfoSetup = "tls ech"
    24  
    25  	// When sent in the ClientHello, the first byte of the payload of the ECH
    26  	// extension indicates whether the message is the ClientHelloOuter or
    27  	// ClientHelloInner.
    28  	echClientHelloOuterVariant uint8 = 0
    29  	echClientHelloInnerVariant uint8 = 1
    30  )
    31  
    32  var (
    33  	zeros = [8]byte{}
    34  )
    35  
    36  // echOfferOrGrease is called by the client after generating its ClientHello
    37  // message to decide if it will offer or GREASE ECH. It does neither if ECH is
    38  // disabled. Returns a pair of ClientHello messages, hello and helloInner. If
    39  // offering ECH, these are the ClienthelloOuter and ClientHelloInner
    40  // respectively. Otherwise, hello is the ClientHello and helloInner == nil.
    41  //
    42  // TODO(cjpatton): "[When offering ECH, the client] MUST NOT offer to resume any
    43  // session for TLS 1.2 and below [in ClientHelloInner]."
    44  func (c *Conn) echOfferOrGrease(helloBase *clientHelloMsg) (hello, helloInner *clientHelloMsg, err error) {
    45  	config := c.config
    46  
    47  	if !config.ECHEnabled {
    48  		// Bypass ECH.
    49  		return helloBase, nil, nil
    50  	}
    51  
    52  	// Choose the ECHConfig to use for this connection. If none is available, or
    53  	// if we're not offering TLS 1.3 or above, then GREASE.
    54  	echConfig := config.echSelectConfig()
    55  	if echConfig == nil || config.maxSupportedVersion(roleClient) < VersionTLS13 {
    56  		var err error
    57  
    58  		// Generate a dummy ClientECH.
    59  		helloBase.ech, err = config.echGrease()
    60  		if err != nil {
    61  			return nil, nil, fmt.Errorf("tls: ech: failed to generate grease ECH: %s", err)
    62  		}
    63  
    64  		// GREASE ECH.
    65  		c.ech.offered = false
    66  		c.ech.greased = true
    67  		helloBase.raw = nil
    68  		return helloBase, nil, nil
    69  	}
    70  
    71  	// Store the ECH config parameters that are needed later.
    72  	c.ech.configId = echConfig.configId
    73  	c.ech.maxNameLen = int(echConfig.maxNameLen)
    74  
    75  	// Generate the HPKE context. Store it in case of HRR.
    76  	var enc []byte
    77  	enc, c.ech.sealer, err = echConfig.setupSealer(config.rand())
    78  	if err != nil {
    79  		return nil, nil, fmt.Errorf("tls: ech: %s", err)
    80  	}
    81  
    82  	// ClientHelloInner is constructed from the base ClientHello. The payload of
    83  	// the "encrypted_client_hello" extension is a single 1 byte indicating that
    84  	// this is the ClientHelloInner.
    85  	helloInner = helloBase
    86  	helloInner.ech = []byte{echClientHelloInnerVariant}
    87  
    88  	// Ensure that only TLS 1.3 and above are offered in the inner handshake.
    89  	if v := helloInner.supportedVersions; len(v) == 0 || v[len(v)-1] < VersionTLS13 {
    90  		return nil, nil, errors.New("tls: ech: only TLS 1.3 is allowed in ClientHelloInner")
    91  	}
    92  
    93  	// ClientHelloOuter is constructed by generating a fresh ClientHello and
    94  	// copying "session_id" from ClientHelloInner, setting "server_name" to the
    95  	// client-facing server, and adding the "encrypted_client_hello" extension.
    96  	//
    97  	// In addition, we discard the "key_share" and instead use the one from
    98  	// ClientHelloInner.
    99  	hello, _, err = c.makeClientHello()
   100  	if err != nil {
   101  		return nil, nil, fmt.Errorf("tls: ech: %s", err)
   102  	}
   103  	hello.sessionId = helloBase.sessionId
   104  	hello.serverName = hostnameInSNI(string(echConfig.rawPublicName))
   105  	if err := c.echUpdateClientHelloOuter(hello, helloInner, enc); err != nil {
   106  		return nil, nil, err
   107  	}
   108  
   109  	// Offer ECH.
   110  	c.ech.offered = true
   111  	helloInner.raw = nil
   112  	hello.raw = nil
   113  	return hello, helloInner, nil
   114  }
   115  
   116  // Generates a grease ECH extension using a hard-coded KEM public key.
   117  func (config *Config) echGrease() ([]byte, error) {
   118  	suite := config.ClientECHDummyConfig.cipherSuite()
   119  	kem, kdf, aead := suite.Params()
   120  
   121  	var dummyPublicKey []byte
   122  	switch kem {
   123  	case hpke.KEM_X25519_HKDF_SHA256:
   124  		dummyPublicKey = dummyX25519PublicKey
   125  	default:
   126  		return nil, fmt.Errorf("tls: grease ech: invalid kem %#x", kem)
   127  	}
   128  
   129  	pk, err := kem.Scheme().UnmarshalBinaryPublicKey(dummyPublicKey)
   130  	if err != nil {
   131  		return nil, fmt.Errorf("tls: grease ech: failed to parse dummy public key: %s", err)
   132  	}
   133  
   134  	sender, err := suite.NewSender(pk, nil)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("tls: grease ech: failed to create sender: %s", err)
   137  	}
   138  
   139  	var ech echClientOuter
   140  	ech.handle.suite.KDF = uint16(kdf)
   141  	ech.handle.suite.AEAD = uint16(aead)
   142  	ech.handle.configId = config.ClientECHDummyConfig.configId()
   143  	ech.handle.enc, _, err = sender.Setup(config.rand())
   144  	if err != nil {
   145  		return nil, fmt.Errorf("tls: grease ech: %s", err)
   146  	}
   147  
   148  	ech.payload = make([]byte, aead.CipherLen(uint(config.ClientECHDummyConfig.helloInnerLen())))
   149  	if _, err = io.ReadFull(config.rand(), ech.payload); err != nil {
   150  		return nil, fmt.Errorf("tls: grease ech: %s", err)
   151  	}
   152  
   153  	return ech.marshal(), nil
   154  }
   155  
   156  // echUpdateClientHelloOuter is called by the client to construct the payload of
   157  // the ECH extension in the outer handshake.
   158  func (c *Conn) echUpdateClientHelloOuter(hello, helloInner *clientHelloMsg, enc []byte) error {
   159  	var (
   160  		ech echClientOuter
   161  		err error
   162  	)
   163  
   164  	// Copy all compressed extensions from ClientHelloInner into
   165  	// ClientHelloOuter.
   166  	for _, ext := range echOuterExtensions() {
   167  		echCopyExtensionFromClientHelloInner(hello, helloInner, ext)
   168  	}
   169  
   170  	// Always copy the "key_shares" extension from ClientHelloInner, regardless
   171  	// of whether it gets compressed.
   172  	hello.keyShares = helloInner.keyShares
   173  
   174  	_, kdf, aead := c.ech.sealer.Suite().Params()
   175  	ech.handle.suite.KDF = uint16(kdf)
   176  	ech.handle.suite.AEAD = uint16(aead)
   177  	ech.handle.configId = c.ech.configId
   178  	ech.handle.enc = enc
   179  
   180  	// EncodedClientHelloInner
   181  	helloInner.raw = nil
   182  	helloInnerMarshalled, err := helloInner.marshal()
   183  	if err != nil {
   184  		return fmt.Errorf("tls: ech: failed to marshal helloInner: %w", err)
   185  	}
   186  	encodedHelloInner := echEncodeClientHelloInner(
   187  		helloInnerMarshalled,
   188  		len(helloInner.serverName),
   189  		c.ech.maxNameLen)
   190  	if encodedHelloInner == nil {
   191  		return errors.New("tls: ech: encoding of EncodedClientHelloInner failed")
   192  	}
   193  
   194  	// ClientHelloOuterAAD
   195  	hello.raw = nil
   196  	hello.ech = ech.marshal()
   197  	helloMarshalled, err := hello.marshal()
   198  	if err != nil {
   199  		return fmt.Errorf("tls: ech: failed to marshal hello: %w", err)
   200  	}
   201  	helloOuterAad := echEncodeClientHelloOuterAAD(helloMarshalled,
   202  		aead.CipherLen(uint(len(encodedHelloInner))))
   203  	if helloOuterAad == nil {
   204  		return errors.New("tls: ech: encoding of ClientHelloOuterAAD failed")
   205  	}
   206  
   207  	ech.payload, err = c.ech.sealer.Seal(encodedHelloInner, helloOuterAad)
   208  	if err != nil {
   209  		return fmt.Errorf("tls: ech: seal failed: %s", err)
   210  	}
   211  	ech.raw = nil
   212  	hello.ech = ech.marshal()
   213  
   214  	helloInner.raw = nil
   215  	hello.raw = nil
   216  	return nil
   217  }
   218  
   219  // echAcceptOrReject is called by the client-facing server to determine whether
   220  // ECH was offered by the client, and if so, whether to accept or reject. The
   221  // return value is the ClientHello that will be used for the connection.
   222  //
   223  // This function is called prior to processing the ClientHello. In case of
   224  // HelloRetryRequest, it is also called before processing the second
   225  // ClientHello. This is indicated by the afterHRR flag.
   226  func (c *Conn) echAcceptOrReject(hello *clientHelloMsg, afterHRR bool) (*clientHelloMsg, error) {
   227  	config := c.config
   228  	p := config.ServerECHProvider
   229  
   230  	if !config.echCanAccept() {
   231  		// Bypass ECH.
   232  		return hello, nil
   233  	}
   234  
   235  	if len(hello.ech) > 0 { // The ECH extension is present
   236  		switch hello.ech[0] {
   237  		case echClientHelloInnerVariant: // inner handshake
   238  			if len(hello.ech) > 1 {
   239  				c.sendAlert(alertIllegalParameter)
   240  				return nil, errors.New("ech: inner handshake has non-empty payload")
   241  			}
   242  
   243  			// Continue as the backend server.
   244  			return hello, nil
   245  		case echClientHelloOuterVariant: // outer handshake
   246  		default:
   247  			c.sendAlert(alertIllegalParameter)
   248  			return nil, errors.New("ech: inner handshake has non-empty payload")
   249  		}
   250  	} else {
   251  		if c.ech.offered {
   252  			// This occurs if the server accepted prior to HRR, but the client
   253  			// failed to send the ECH extension in the second ClientHelloOuter. This
   254  			// would cause ClientHelloOuter to be used after ClientHelloInner, which
   255  			// is illegal.
   256  			c.sendAlert(alertMissingExtension)
   257  			return nil, errors.New("ech: hrr: bypass after offer")
   258  		}
   259  
   260  		// Bypass ECH.
   261  		return hello, nil
   262  	}
   263  
   264  	if afterHRR && !c.ech.offered && !c.ech.greased {
   265  		// The client bypassed ECH prior to HRR, but not after. This could
   266  		// cause ClientHelloInner to be used after ClientHelloOuter, which is
   267  		// illegal.
   268  		c.sendAlert(alertIllegalParameter)
   269  		return nil, errors.New("ech: hrr: offer or grease after bypass")
   270  	}
   271  
   272  	// Parse ClientECH.
   273  	ech, err := echUnmarshalClientOuter(hello.ech)
   274  	if err != nil {
   275  		c.sendAlert(alertIllegalParameter)
   276  		return nil, fmt.Errorf("ech: failed to parse extension: %s", err)
   277  	}
   278  
   279  	// Make sure that the HPKE suite and config id don't change across HRR and
   280  	// that the encapsulated key is not present after HRR.
   281  	if afterHRR && c.ech.offered {
   282  		_, kdf, aead := c.ech.opener.Suite().Params()
   283  		if ech.handle.suite.KDF != uint16(kdf) ||
   284  			ech.handle.suite.AEAD != uint16(aead) ||
   285  			ech.handle.configId != c.ech.configId ||
   286  			len(ech.handle.enc) > 0 {
   287  			c.sendAlert(alertIllegalParameter)
   288  			return nil, errors.New("ech: hrr: illegal handle in second hello")
   289  		}
   290  	}
   291  
   292  	// Store the config id in case of HRR.
   293  	c.ech.configId = ech.handle.configId
   294  
   295  	// Ask the ECH provider for the HPKE context.
   296  	if c.ech.opener == nil {
   297  		res := p.GetDecryptionContext(ech.handle.marshal(), extensionECH)
   298  
   299  		// Compute retry configurations, skipping those indicating an
   300  		// unsupported version.
   301  		if len(res.RetryConfigs) > 0 {
   302  			configs, err := UnmarshalECHConfigs(res.RetryConfigs) // skips unrecognized versions
   303  			if err != nil {
   304  				c.sendAlert(alertInternalError)
   305  				return nil, fmt.Errorf("ech: %s", err)
   306  			}
   307  
   308  			if len(configs) > 0 {
   309  				c.ech.retryConfigs, err = echMarshalConfigs(configs)
   310  				if err != nil {
   311  					c.sendAlert(alertInternalError)
   312  					return nil, fmt.Errorf("ech: %s", err)
   313  				}
   314  			}
   315  		}
   316  
   317  		switch res.Status {
   318  		case ECHProviderSuccess:
   319  			c.ech.opener, err = hpke.UnmarshalOpener(res.Context)
   320  			if err != nil {
   321  				c.sendAlert(alertInternalError)
   322  				return nil, fmt.Errorf("ech: %s", err)
   323  			}
   324  		case ECHProviderReject:
   325  			// Reject ECH. We do not know at this point whether the client
   326  			// intended to offer or grease ECH, so we presume grease until the
   327  			// client indicates rejection by sending an "ech_required" alert.
   328  			c.ech.greased = true
   329  			return hello, nil
   330  		case ECHProviderAbort:
   331  			c.sendAlert(alert(res.Alert))
   332  			return nil, fmt.Errorf("ech: provider aborted: %s", res.Error)
   333  		default:
   334  			c.sendAlert(alertInternalError)
   335  			return nil, errors.New("ech: unexpected provider status")
   336  		}
   337  	}
   338  
   339  	// ClientHelloOuterAAD
   340  	helloMarshalled, err := hello.marshal()
   341  	if err != nil {
   342  		return nil, fmt.Errorf("tls: ech: failed to marshal hello: %w", err)
   343  	}
   344  	rawHelloOuterAad := echEncodeClientHelloOuterAAD(helloMarshalled, uint(len(ech.payload)))
   345  	if rawHelloOuterAad == nil {
   346  		// This occurs if the ClientHelloOuter is malformed. This values was
   347  		// already parsed into `hello`, so this should not happen.
   348  		c.sendAlert(alertInternalError)
   349  		return nil, fmt.Errorf("ech: failed to encode ClientHelloOuterAAD")
   350  	}
   351  
   352  	// EncodedClientHelloInner
   353  	rawEncodedHelloInner, err := c.ech.opener.Open(ech.payload, rawHelloOuterAad)
   354  	if err != nil {
   355  		if afterHRR && c.ech.accepted {
   356  			// Don't reject after accept, as this would result in processing the
   357  			// ClientHelloOuter after processing the ClientHelloInner.
   358  			c.sendAlert(alertDecryptError)
   359  			return nil, fmt.Errorf("ech: hrr: reject after accept: %s", err)
   360  		}
   361  
   362  		// Reject ECH. We do not know at this point whether the client
   363  		// intended to offer or grease ECH, so we presume grease until the
   364  		// client indicates rejection by sending an "ech_required" alert.
   365  		c.ech.greased = true
   366  		return hello, nil
   367  	}
   368  
   369  	// ClientHelloInner
   370  	rawHelloInner := echDecodeClientHelloInner(rawEncodedHelloInner, helloMarshalled, hello.sessionId)
   371  	if rawHelloInner == nil {
   372  		c.sendAlert(alertIllegalParameter)
   373  		return nil, fmt.Errorf("ech: failed to decode EncodedClientHelloInner")
   374  	}
   375  	helloInner := new(clientHelloMsg)
   376  	if !helloInner.unmarshal(rawHelloInner) {
   377  		c.sendAlert(alertIllegalParameter)
   378  		return nil, fmt.Errorf("ech: failed to parse ClientHelloInner")
   379  	}
   380  
   381  	// Check for a well-formed ECH extension.
   382  	if len(helloInner.ech) != 1 ||
   383  		helloInner.ech[0] != echClientHelloInnerVariant {
   384  		c.sendAlert(alertIllegalParameter)
   385  		return nil, fmt.Errorf("ech: ClientHelloInner does not have a well-formed ECH extension")
   386  	}
   387  
   388  	// Check that the client did not offer TLS 1.2 or below in the inner
   389  	// handshake.
   390  	helloInnerSupportsTLS12OrBelow := len(helloInner.supportedVersions) == 0
   391  	for _, v := range helloInner.supportedVersions {
   392  		if v < VersionTLS13 {
   393  			helloInnerSupportsTLS12OrBelow = true
   394  		}
   395  	}
   396  	if helloInnerSupportsTLS12OrBelow {
   397  		c.sendAlert(alertIllegalParameter)
   398  		return nil, errors.New("ech: ClientHelloInner offers TLS 1.2 or below")
   399  	}
   400  
   401  	// Accept ECH.
   402  	c.ech.offered = true
   403  	c.ech.accepted = true
   404  	return helloInner, nil
   405  }
   406  
   407  // echClientOuter represents a ClientECH structure, the payload of the client's
   408  // "encrypted_client_hello" extension that appears in the outer handshake.
   409  type echClientOuter struct {
   410  	raw []byte
   411  
   412  	// Parsed from raw
   413  	handle  echContextHandle
   414  	payload []byte
   415  }
   416  
   417  // echUnmarshalClientOuter parses a ClientECH structure. The caller provides the
   418  // ECH version indicated by the client.
   419  func echUnmarshalClientOuter(raw []byte) (*echClientOuter, error) {
   420  	s := cryptobyte.String(raw)
   421  	ech := new(echClientOuter)
   422  	ech.raw = raw
   423  
   424  	// Make sure this is the outer handshake.
   425  	var variant uint8
   426  	if !s.ReadUint8(&variant) {
   427  		return nil, fmt.Errorf("error parsing ClientECH.type")
   428  	}
   429  	if variant != echClientHelloOuterVariant {
   430  		return nil, fmt.Errorf("unexpected ClientECH.type (want outer (0))")
   431  	}
   432  
   433  	// Parse the context handle.
   434  	if !echReadContextHandle(&s, &ech.handle) {
   435  		return nil, fmt.Errorf("error parsing context handle")
   436  	}
   437  	endOfContextHandle := len(raw) - len(s)
   438  	ech.handle.raw = raw[1:endOfContextHandle]
   439  
   440  	// Parse the payload.
   441  	var t cryptobyte.String
   442  	if !s.ReadUint16LengthPrefixed(&t) ||
   443  		!t.ReadBytes(&ech.payload, len(t)) || !s.Empty() {
   444  		return nil, fmt.Errorf("error parsing payload")
   445  	}
   446  
   447  	return ech, nil
   448  }
   449  
   450  func (ech *echClientOuter) marshal() []byte {
   451  	if ech.raw != nil {
   452  		return ech.raw
   453  	}
   454  	var b cryptobyte.Builder
   455  	b.AddUint8(echClientHelloOuterVariant)
   456  	b.AddBytes(ech.handle.marshal())
   457  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   458  		b.AddBytes(ech.payload)
   459  	})
   460  	return b.BytesOrPanic()
   461  }
   462  
   463  // echContextHandle represents the prefix of a ClientECH structure used by
   464  // the server to compute the HPKE context.
   465  type echContextHandle struct {
   466  	raw []byte
   467  
   468  	// Parsed from raw
   469  	suite    hpkeSymmetricCipherSuite
   470  	configId uint8
   471  	enc      []byte
   472  }
   473  
   474  func (handle *echContextHandle) marshal() []byte {
   475  	if handle.raw != nil {
   476  		return handle.raw
   477  	}
   478  	var b cryptobyte.Builder
   479  	b.AddUint16(handle.suite.KDF)
   480  	b.AddUint16(handle.suite.AEAD)
   481  	b.AddUint8(handle.configId)
   482  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   483  		b.AddBytes(handle.enc)
   484  	})
   485  	return b.BytesOrPanic()
   486  }
   487  
   488  func echReadContextHandle(s *cryptobyte.String, handle *echContextHandle) bool {
   489  	var t cryptobyte.String
   490  	if !s.ReadUint16(&handle.suite.KDF) || // cipher_suite.kdf_id
   491  		!s.ReadUint16(&handle.suite.AEAD) || // cipher_suite.aead_id
   492  		!s.ReadUint8(&handle.configId) || // config_id
   493  		!s.ReadUint16LengthPrefixed(&t) || // enc
   494  		!t.ReadBytes(&handle.enc, len(t)) {
   495  		return false
   496  	}
   497  	return true
   498  }
   499  
   500  // echEncodeClientHelloInner interprets innerData as a ClientHelloInner message
   501  // and transforms it into an EncodedClientHelloInner. Returns nil if parsing
   502  // innerData fails.
   503  func echEncodeClientHelloInner(innerData []byte, serverNameLen, maxNameLen int) []byte {
   504  	var (
   505  		errIllegalParameter      = errors.New("illegal parameter")
   506  		outerExtensions          = echOuterExtensions()
   507  		msgType                  uint8
   508  		legacyVersion            uint16
   509  		random                   []byte
   510  		legacySessionId          cryptobyte.String
   511  		cipherSuites             cryptobyte.String
   512  		legacyCompressionMethods cryptobyte.String
   513  		extensions               cryptobyte.String
   514  		s                        cryptobyte.String
   515  		b                        cryptobyte.Builder
   516  	)
   517  
   518  	u := cryptobyte.String(innerData)
   519  	if !u.ReadUint8(&msgType) ||
   520  		!u.ReadUint24LengthPrefixed(&s) || !u.Empty() {
   521  		return nil
   522  	}
   523  
   524  	if !s.ReadUint16(&legacyVersion) ||
   525  		!s.ReadBytes(&random, 32) ||
   526  		!s.ReadUint8LengthPrefixed(&legacySessionId) ||
   527  		!s.ReadUint16LengthPrefixed(&cipherSuites) ||
   528  		!s.ReadUint8LengthPrefixed(&legacyCompressionMethods) {
   529  		return nil
   530  	}
   531  
   532  	if s.Empty() {
   533  		// Extensions field must be present in TLS 1.3.
   534  		return nil
   535  	}
   536  
   537  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   538  		return nil
   539  	}
   540  
   541  	b.AddUint16(legacyVersion)
   542  	b.AddBytes(random)
   543  	b.AddUint8(0) // 0-length legacy_session_id
   544  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   545  		b.AddBytes(cipherSuites)
   546  	})
   547  	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   548  		b.AddBytes(legacyCompressionMethods)
   549  	})
   550  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   551  		for !extensions.Empty() {
   552  			var ext uint16
   553  			var extData cryptobyte.String
   554  			if !extensions.ReadUint16(&ext) ||
   555  				!extensions.ReadUint16LengthPrefixed(&extData) {
   556  				panic(cryptobyte.BuildError{Err: errIllegalParameter})
   557  			}
   558  
   559  			if len(outerExtensions) > 0 && ext == outerExtensions[0] {
   560  				// Replace outer extensions with "outer_extension" extension.
   561  				echAddOuterExtensions(b, outerExtensions)
   562  
   563  				// Consume the remaining outer extensions.
   564  				for _, outerExt := range outerExtensions[1:] {
   565  					if !extensions.ReadUint16(&ext) ||
   566  						!extensions.ReadUint16LengthPrefixed(&extData) {
   567  						panic(cryptobyte.BuildError{Err: errIllegalParameter})
   568  					}
   569  					if ext != outerExt {
   570  						panic("internal error: malformed ClientHelloInner")
   571  					}
   572  				}
   573  
   574  			} else {
   575  				b.AddUint16(ext)
   576  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   577  					b.AddBytes(extData)
   578  				})
   579  			}
   580  		}
   581  	})
   582  
   583  	encodedData, err := b.Bytes()
   584  	if err == errIllegalParameter {
   585  		return nil // Input malformed
   586  	} else if err != nil {
   587  		panic(err) // Host encountered internal error
   588  	}
   589  
   590  	// Add padding.
   591  	paddingLen := 0
   592  	if serverNameLen > 0 {
   593  		// draft-ietf-tls-esni-13, Section 6.1.3:
   594  		//
   595  		// If the ClientHelloInner contained a "server_name" extension with a
   596  		// name of length D, add max(0, L - D) bytes of padding.
   597  		if n := maxNameLen - serverNameLen; n > 0 {
   598  			paddingLen += n
   599  		}
   600  	} else {
   601  		// draft-ietf-tls-esni-13, Section 6.1.3:
   602  		//
   603  		// If the ClientHelloInner did not contain a "server_name" extension
   604  		// (e.g., if the client is connecting to an IP address), add L + 9 bytes
   605  		// of padding.  This is the length of a "server_name" extension with an
   606  		// L-byte name.
   607  		const sniPaddingLen = 9
   608  		paddingLen += sniPaddingLen + maxNameLen
   609  	}
   610  	paddingLen = 31 - ((len(encodedData) + paddingLen - 1) % 32)
   611  	for i := 0; i < paddingLen; i++ {
   612  		encodedData = append(encodedData, 0)
   613  	}
   614  
   615  	return encodedData
   616  }
   617  
   618  func echAddOuterExtensions(b *cryptobyte.Builder, outerExtensions []uint16) {
   619  	b.AddUint16(extensionECHOuterExtensions)
   620  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   621  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   622  			for _, outerExt := range outerExtensions {
   623  				b.AddUint16(outerExt)
   624  			}
   625  		})
   626  	})
   627  }
   628  
   629  // echDecodeClientHelloInner interprets encodedData as an EncodedClientHelloInner
   630  // message and substitutes the "outer_extension" extension with extensions from
   631  // outerData, interpreted as the ClientHelloOuter message. Returns nil if
   632  // parsing encodedData fails.
   633  func echDecodeClientHelloInner(encodedData, outerData, outerSessionId []byte) []byte {
   634  	var (
   635  		errIllegalParameter      = errors.New("illegal parameter")
   636  		legacyVersion            uint16
   637  		random                   []byte
   638  		legacySessionId          cryptobyte.String
   639  		cipherSuites             cryptobyte.String
   640  		legacyCompressionMethods cryptobyte.String
   641  		extensions               cryptobyte.String
   642  		b                        cryptobyte.Builder
   643  	)
   644  
   645  	s := cryptobyte.String(encodedData)
   646  	if !s.ReadUint16(&legacyVersion) ||
   647  		!s.ReadBytes(&random, 32) ||
   648  		!s.ReadUint8LengthPrefixed(&legacySessionId) ||
   649  		!s.ReadUint16LengthPrefixed(&cipherSuites) ||
   650  		!s.ReadUint8LengthPrefixed(&legacyCompressionMethods) {
   651  		return nil
   652  	}
   653  
   654  	if len(legacySessionId) > 0 {
   655  		return nil
   656  	}
   657  
   658  	if s.Empty() {
   659  		// Extensions field must be present in TLS 1.3.
   660  		return nil
   661  	}
   662  
   663  	if !s.ReadUint16LengthPrefixed(&extensions) {
   664  		return nil
   665  	}
   666  
   667  	b.AddUint8(typeClientHello)
   668  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   669  		b.AddUint16(legacyVersion)
   670  		b.AddBytes(random)
   671  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   672  			b.AddBytes(outerSessionId) // ClientHelloOuter.legacy_session_id
   673  		})
   674  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   675  			b.AddBytes(cipherSuites)
   676  		})
   677  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   678  			b.AddBytes(legacyCompressionMethods)
   679  		})
   680  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   681  			var handledOuterExtensions bool
   682  			for !extensions.Empty() {
   683  				var ext uint16
   684  				var extData cryptobyte.String
   685  				if !extensions.ReadUint16(&ext) ||
   686  					!extensions.ReadUint16LengthPrefixed(&extData) {
   687  					panic(cryptobyte.BuildError{Err: errIllegalParameter})
   688  				}
   689  
   690  				if ext == extensionECHOuterExtensions {
   691  					if handledOuterExtensions {
   692  						// It is an error to send any extension more than once in a
   693  						// single message.
   694  						panic(cryptobyte.BuildError{Err: errIllegalParameter})
   695  					}
   696  					handledOuterExtensions = true
   697  
   698  					// Read the referenced outer extensions.
   699  					referencedExts := make([]uint16, 0, 10)
   700  					var outerExtData cryptobyte.String
   701  					if !extData.ReadUint8LengthPrefixed(&outerExtData) ||
   702  						len(outerExtData)%2 != 0 ||
   703  						!extData.Empty() {
   704  						panic(cryptobyte.BuildError{Err: errIllegalParameter})
   705  					}
   706  					for !outerExtData.Empty() {
   707  						if !outerExtData.ReadUint16(&ext) ||
   708  							ext == extensionECH {
   709  							panic(cryptobyte.BuildError{Err: errIllegalParameter})
   710  						}
   711  						referencedExts = append(referencedExts, ext)
   712  					}
   713  
   714  					// Add the outer extensions from the ClientHelloOuter into the
   715  					// ClientHelloInner.
   716  					outerCt := 0
   717  					r := processClientHelloExtensions(outerData, func(ext uint16, extData cryptobyte.String) bool {
   718  						if outerCt < len(referencedExts) && ext == referencedExts[outerCt] {
   719  							outerCt++
   720  							b.AddUint16(ext)
   721  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   722  								b.AddBytes(extData)
   723  							})
   724  						}
   725  						return true
   726  					})
   727  
   728  					// Ensure that all outer extensions have been incorporated
   729  					// exactly once, and in the correct order.
   730  					if !r || outerCt != len(referencedExts) {
   731  						panic(cryptobyte.BuildError{Err: errIllegalParameter})
   732  					}
   733  				} else {
   734  					b.AddUint16(ext)
   735  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   736  						b.AddBytes(extData)
   737  					})
   738  				}
   739  			}
   740  		})
   741  	})
   742  
   743  	innerData, err := b.Bytes()
   744  	if err == errIllegalParameter {
   745  		return nil // Input malformed
   746  	} else if err != nil {
   747  		panic(err) // Host encountered internal error
   748  	}
   749  
   750  	// Read the padding.
   751  	for !s.Empty() {
   752  		var zero uint8
   753  		if !s.ReadUint8(&zero) || zero != 0 {
   754  			return nil
   755  		}
   756  	}
   757  
   758  	return innerData
   759  }
   760  
   761  // echEncodeClientHelloOuterAAD interprets outerData as ClientHelloOuter and
   762  // constructs a ClientHelloOuterAAD. The output doesn't have the 4-byte prefix
   763  // that indicates the handshake message type and its length.
   764  func echEncodeClientHelloOuterAAD(outerData []byte, payloadLen uint) []byte {
   765  	var (
   766  		errIllegalParameter      = errors.New("illegal parameter")
   767  		msgType                  uint8
   768  		legacyVersion            uint16
   769  		random                   []byte
   770  		legacySessionId          cryptobyte.String
   771  		cipherSuites             cryptobyte.String
   772  		legacyCompressionMethods cryptobyte.String
   773  		extensions               cryptobyte.String
   774  		s                        cryptobyte.String
   775  		b                        cryptobyte.Builder
   776  	)
   777  
   778  	u := cryptobyte.String(outerData)
   779  	if !u.ReadUint8(&msgType) ||
   780  		!u.ReadUint24LengthPrefixed(&s) || !u.Empty() {
   781  		return nil
   782  	}
   783  
   784  	if !s.ReadUint16(&legacyVersion) ||
   785  		!s.ReadBytes(&random, 32) ||
   786  		!s.ReadUint8LengthPrefixed(&legacySessionId) ||
   787  		!s.ReadUint16LengthPrefixed(&cipherSuites) ||
   788  		!s.ReadUint8LengthPrefixed(&legacyCompressionMethods) {
   789  		return nil
   790  	}
   791  
   792  	if s.Empty() {
   793  		// Extensions field must be present in TLS 1.3.
   794  		return nil
   795  	}
   796  
   797  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   798  		return nil
   799  	}
   800  
   801  	b.AddUint16(legacyVersion)
   802  	b.AddBytes(random)
   803  	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   804  		b.AddBytes(legacySessionId)
   805  	})
   806  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   807  		b.AddBytes(cipherSuites)
   808  	})
   809  	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   810  		b.AddBytes(legacyCompressionMethods)
   811  	})
   812  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   813  		for !extensions.Empty() {
   814  			var ext uint16
   815  			var extData cryptobyte.String
   816  			if !extensions.ReadUint16(&ext) ||
   817  				!extensions.ReadUint16LengthPrefixed(&extData) {
   818  				panic(cryptobyte.BuildError{Err: errIllegalParameter})
   819  			}
   820  
   821  			// If this is the ECH extension and the payload is the outer variant
   822  			// of ClientECH, then replace the payloadLen 0 bytes.
   823  			if ext == extensionECH {
   824  				ech, err := echUnmarshalClientOuter(extData)
   825  				if err != nil {
   826  					panic(cryptobyte.BuildError{Err: errIllegalParameter})
   827  				}
   828  				ech.payload = make([]byte, payloadLen)
   829  				ech.raw = nil
   830  				extData = ech.marshal()
   831  			}
   832  
   833  			b.AddUint16(ext)
   834  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   835  				b.AddBytes(extData)
   836  			})
   837  		}
   838  	})
   839  
   840  	outerAadData, err := b.Bytes()
   841  	if err == errIllegalParameter {
   842  		return nil // Input malformed
   843  	} else if err != nil {
   844  		panic(err) // Host encountered internal error
   845  	}
   846  
   847  	return outerAadData
   848  }
   849  
   850  // echEncodeAcceptConfHelloRetryRequest interprets data as a ServerHello message
   851  // and replaces the payload of the ECH extension with 8 zero bytes. The output
   852  // includes the 4-byte prefix that indicates the message type and its length.
   853  func echEncodeAcceptConfHelloRetryRequest(data []byte) []byte {
   854  	var (
   855  		errIllegalParameter = errors.New("illegal parameter")
   856  		vers                uint16
   857  		random              []byte
   858  		sessionId           []byte
   859  		cipherSuite         uint16
   860  		compressionMethod   uint8
   861  		s                   cryptobyte.String
   862  		b                   cryptobyte.Builder
   863  	)
   864  
   865  	s = cryptobyte.String(data)
   866  	if !s.Skip(4) || // message type and uint24 length field
   867  		!s.ReadUint16(&vers) || !s.ReadBytes(&random, 32) ||
   868  		!readUint8LengthPrefixed(&s, &sessionId) ||
   869  		!s.ReadUint16(&cipherSuite) ||
   870  		!s.ReadUint8(&compressionMethod) {
   871  		return nil
   872  	}
   873  
   874  	if s.Empty() {
   875  		// ServerHello is optionally followed by extension data
   876  		return nil
   877  	}
   878  
   879  	var extensions cryptobyte.String
   880  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   881  		return nil
   882  	}
   883  
   884  	b.AddUint8(typeServerHello)
   885  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   886  		b.AddUint16(vers)
   887  		b.AddBytes(random)
   888  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   889  			b.AddBytes(sessionId)
   890  		})
   891  		b.AddUint16(cipherSuite)
   892  		b.AddUint8(compressionMethod)
   893  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   894  			for !extensions.Empty() {
   895  				var extension uint16
   896  				var extData cryptobyte.String
   897  				if !extensions.ReadUint16(&extension) ||
   898  					!extensions.ReadUint16LengthPrefixed(&extData) {
   899  					panic(cryptobyte.BuildError{Err: errIllegalParameter})
   900  				}
   901  
   902  				b.AddUint16(extension)
   903  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   904  					if extension == extensionECH {
   905  						b.AddBytes(zeros[:8])
   906  					} else {
   907  						b.AddBytes(extData)
   908  					}
   909  				})
   910  			}
   911  		})
   912  	})
   913  
   914  	encodedData, err := b.Bytes()
   915  	if err == errIllegalParameter {
   916  		return nil // Input malformed
   917  	} else if err != nil {
   918  		panic(err) // Host encountered internal error
   919  	}
   920  
   921  	return encodedData
   922  }
   923  
   924  // processClientHelloExtensions interprets data as a ClientHello and applies a
   925  // function proc to each extension. Returns a bool indicating whether parsing
   926  // succeeded.
   927  func processClientHelloExtensions(data []byte, proc func(ext uint16, extData cryptobyte.String) bool) bool {
   928  	_, extensionsData := splitClientHelloExtensions(data)
   929  	if extensionsData == nil {
   930  		return false
   931  	}
   932  
   933  	s := cryptobyte.String(extensionsData)
   934  	if s.Empty() {
   935  		// Extensions field not present.
   936  		return true
   937  	}
   938  
   939  	var extensions cryptobyte.String
   940  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   941  		return false
   942  	}
   943  
   944  	for !extensions.Empty() {
   945  		var ext uint16
   946  		var extData cryptobyte.String
   947  		if !extensions.ReadUint16(&ext) ||
   948  			!extensions.ReadUint16LengthPrefixed(&extData) {
   949  			return false
   950  		}
   951  		if ok := proc(ext, extData); !ok {
   952  			return false
   953  		}
   954  	}
   955  	return true
   956  }
   957  
   958  // splitClientHelloExtensions interprets data as a ClientHello message and
   959  // returns two strings: the first contains the start of the ClientHello up to
   960  // the start of the extensions; and the second is the length-prefixed
   961  // extensions. Returns (nil, nil) if parsing of data fails.
   962  func splitClientHelloExtensions(data []byte) ([]byte, []byte) {
   963  	s := cryptobyte.String(data)
   964  
   965  	var ignored uint16
   966  	var t cryptobyte.String
   967  	if !s.Skip(4) || // message type and uint24 length field
   968  		!s.ReadUint16(&ignored) || !s.Skip(32) || // vers, random
   969  		!s.ReadUint8LengthPrefixed(&t) { // session_id
   970  		return nil, nil
   971  	}
   972  
   973  	if !s.ReadUint16LengthPrefixed(&t) { // cipher_suites
   974  		return nil, nil
   975  	}
   976  
   977  	if !s.ReadUint8LengthPrefixed(&t) { // compression_methods
   978  		return nil, nil
   979  	}
   980  
   981  	return data[:len(data)-len(s)], s
   982  }
   983  
   984  // TODO(cjpatton): Handle public name as described in draft-ietf-tls-esni-13,
   985  // Section 4.
   986  //
   987  // TODO(cjpatton): Implement ECH config extensions as described in
   988  // draft-ietf-tls-esni-13, Section 4.1.
   989  func (c *Config) echSelectConfig() *ECHConfig {
   990  	for _, echConfig := range c.ClientECHConfigs {
   991  		if _, err := echConfig.selectSuite(); err == nil &&
   992  			echConfig.version == extensionECH {
   993  			return &echConfig
   994  		}
   995  	}
   996  	return nil
   997  }
   998  
   999  func (c *Config) echCanOffer() bool {
  1000  	if c == nil {
  1001  		return false
  1002  	}
  1003  	return c.ECHEnabled &&
  1004  		c.echSelectConfig() != nil &&
  1005  		c.maxSupportedVersion(roleClient) >= VersionTLS13
  1006  }
  1007  
  1008  func (c *Config) echCanAccept() bool {
  1009  	if c == nil {
  1010  		return false
  1011  	}
  1012  	return c.ECHEnabled &&
  1013  		c.ServerECHProvider != nil &&
  1014  		c.maxSupportedVersion(roleServer) >= VersionTLS13
  1015  }
  1016  
  1017  // echOuterExtensions returns the list of extensions of the ClientHelloOuter
  1018  // that will be incorporated into the CleintHelloInner.
  1019  func echOuterExtensions() []uint16 {
  1020  	// NOTE(cjpatton): It would be nice to incorporate more extensions, but
  1021  	// "key_share" is the last extension to appear in the ClientHello before
  1022  	// "pre_shared_key". As a result, the only contiguous sequence of outer
  1023  	// extensions that contains "key_share" is "key_share" itself. Note that
  1024  	// we cannot change the order of extensions in the ClientHello, as the
  1025  	// unit tests expect "key_share" to be the second to last extension.
  1026  	outerExtensions := []uint16{extensionKeyShare}
  1027  	return outerExtensions
  1028  }
  1029  
  1030  func echCopyExtensionFromClientHelloInner(hello, helloInner *clientHelloMsg, ext uint16) {
  1031  	switch ext {
  1032  	case extensionStatusRequest:
  1033  		hello.ocspStapling = helloInner.ocspStapling
  1034  	case extensionSupportedCurves:
  1035  		hello.supportedCurves = helloInner.supportedCurves
  1036  	case extensionSupportedPoints:
  1037  		hello.supportedPoints = helloInner.supportedPoints
  1038  	case extensionKeyShare:
  1039  		hello.keyShares = helloInner.keyShares
  1040  	default:
  1041  		panic(fmt.Errorf("tried to copy unrecognized extension: %04x", ext))
  1042  	}
  1043  }