github.com/Carcraftz/utls@v0.0.0-20220413235215-6b7c52fd78b6/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"fmt"
     9  	"golang.org/x/crypto/cryptobyte"
    10  	"strings"
    11  )
    12  
    13  // The marshalingFunction type is an adapter to allow the use of ordinary
    14  // functions as cryptobyte.MarshalingValue.
    15  type marshalingFunction func(b *cryptobyte.Builder) error
    16  
    17  func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
    18  	return f(b)
    19  }
    20  
    21  // addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
    22  // the length of the sequence is not the value specified, it produces an error.
    23  func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
    24  	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
    25  		if len(v) != n {
    26  			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
    27  		}
    28  		b.AddBytes(v)
    29  		return nil
    30  	}))
    31  }
    32  
    33  // addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
    34  func addUint64(b *cryptobyte.Builder, v uint64) {
    35  	b.AddUint32(uint32(v >> 32))
    36  	b.AddUint32(uint32(v))
    37  }
    38  
    39  // readUint64 decodes a big-endian, 64-bit value into out and advances over it.
    40  // It reports whether the read was successful.
    41  func readUint64(s *cryptobyte.String, out *uint64) bool {
    42  	var hi, lo uint32
    43  	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
    44  		return false
    45  	}
    46  	*out = uint64(hi)<<32 | uint64(lo)
    47  	return true
    48  }
    49  
    50  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
    51  // []byte instead of a cryptobyte.String.
    52  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    53  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
    54  }
    55  
    56  // readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
    57  // []byte instead of a cryptobyte.String.
    58  func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    59  	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
    60  }
    61  
    62  // readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
    63  // []byte instead of a cryptobyte.String.
    64  func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    65  	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
    66  }
    67  
    68  type clientHelloMsg struct {
    69  	raw                              []byte
    70  	vers                             uint16
    71  	random                           []byte
    72  	sessionId                        []byte
    73  	cipherSuites                     []uint16
    74  	compressionMethods               []uint8
    75  	nextProtoNeg                     bool
    76  	serverName                       string
    77  	ocspStapling                     bool
    78  	supportedCurves                  []CurveID
    79  	supportedPoints                  []uint8
    80  	ticketSupported                  bool
    81  	sessionTicket                    []uint8
    82  	supportedSignatureAlgorithms     []SignatureScheme
    83  	supportedSignatureAlgorithmsCert []SignatureScheme
    84  	secureRenegotiationSupported     bool
    85  	secureRenegotiation              []byte
    86  	alpnProtocols                    []string
    87  	scts                             bool
    88  	ems                              bool // [UTLS] actually implemented due to its prevalence
    89  	supportedVersions                []uint16
    90  	cookie                           []byte
    91  	keyShares                        []keyShare
    92  	earlyData                        bool
    93  	pskModes                         []uint8
    94  	pskIdentities                    []pskIdentity
    95  	pskBinders                       [][]byte
    96  }
    97  
    98  func (m *clientHelloMsg) marshal() []byte {
    99  	if m.raw != nil {
   100  		return m.raw
   101  	}
   102  
   103  	var b cryptobyte.Builder
   104  	b.AddUint8(typeClientHello)
   105  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   106  		b.AddUint16(m.vers)
   107  		addBytesWithLength(b, m.random, 32)
   108  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   109  			b.AddBytes(m.sessionId)
   110  		})
   111  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   112  			for _, suite := range m.cipherSuites {
   113  				b.AddUint16(suite)
   114  			}
   115  		})
   116  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   117  			b.AddBytes(m.compressionMethods)
   118  		})
   119  
   120  		// If extensions aren't present, omit them.
   121  		var extensionsPresent bool
   122  		bWithoutExtensions := *b
   123  
   124  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   125  			if m.nextProtoNeg {
   126  				// draft-agl-tls-nextprotoneg-04
   127  				b.AddUint16(extensionNextProtoNeg)
   128  				b.AddUint16(0) // empty extension_data
   129  			}
   130  			if len(m.serverName) > 0 {
   131  				// RFC 6066, Section 3
   132  				b.AddUint16(extensionServerName)
   133  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   134  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   135  						b.AddUint8(0) // name_type = host_name
   136  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   137  							b.AddBytes([]byte(m.serverName))
   138  						})
   139  					})
   140  				})
   141  			}
   142  			if m.ocspStapling {
   143  				// RFC 4366, Section 3.6
   144  				b.AddUint16(extensionStatusRequest)
   145  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   146  					b.AddUint8(1)  // status_type = ocsp
   147  					b.AddUint16(0) // empty responder_id_list
   148  					b.AddUint16(0) // empty request_extensions
   149  				})
   150  			}
   151  			if len(m.supportedCurves) > 0 {
   152  				// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   153  				b.AddUint16(extensionSupportedCurves)
   154  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   155  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   156  						for _, curve := range m.supportedCurves {
   157  							b.AddUint16(uint16(curve))
   158  						}
   159  					})
   160  				})
   161  			}
   162  			if len(m.supportedPoints) > 0 {
   163  				// RFC 4492, Section 5.1.2
   164  				b.AddUint16(extensionSupportedPoints)
   165  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   166  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   167  						b.AddBytes(m.supportedPoints)
   168  					})
   169  				})
   170  			}
   171  			if m.ticketSupported {
   172  				// RFC 5077, Section 3.2
   173  				b.AddUint16(extensionSessionTicket)
   174  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   175  					b.AddBytes(m.sessionTicket)
   176  				})
   177  			}
   178  			if len(m.supportedSignatureAlgorithms) > 0 {
   179  				// RFC 5246, Section 7.4.1.4.1
   180  				b.AddUint16(extensionSignatureAlgorithms)
   181  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   182  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   183  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
   184  							b.AddUint16(uint16(sigAlgo))
   185  						}
   186  					})
   187  				})
   188  			}
   189  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
   190  				// RFC 8446, Section 4.2.3
   191  				b.AddUint16(extensionSignatureAlgorithmsCert)
   192  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   193  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   194  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
   195  							b.AddUint16(uint16(sigAlgo))
   196  						}
   197  					})
   198  				})
   199  			}
   200  			if m.secureRenegotiationSupported {
   201  				// RFC 5746, Section 3.2
   202  				b.AddUint16(extensionRenegotiationInfo)
   203  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   204  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   205  						b.AddBytes(m.secureRenegotiation)
   206  					})
   207  				})
   208  			}
   209  			if len(m.alpnProtocols) > 0 {
   210  				// RFC 7301, Section 3.1
   211  				b.AddUint16(extensionALPN)
   212  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   213  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   214  						for _, proto := range m.alpnProtocols {
   215  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   216  								b.AddBytes([]byte(proto))
   217  							})
   218  						}
   219  					})
   220  				})
   221  			}
   222  			if m.scts {
   223  				// RFC 6962, Section 3.3.1
   224  				b.AddUint16(extensionSCT)
   225  				b.AddUint16(0) // empty extension_data
   226  			}
   227  			if len(m.supportedVersions) > 0 {
   228  				// RFC 8446, Section 4.2.1
   229  				b.AddUint16(extensionSupportedVersions)
   230  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   231  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   232  						for _, vers := range m.supportedVersions {
   233  							b.AddUint16(vers)
   234  						}
   235  					})
   236  				})
   237  			}
   238  			if len(m.cookie) > 0 {
   239  				// RFC 8446, Section 4.2.2
   240  				b.AddUint16(extensionCookie)
   241  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   242  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   243  						b.AddBytes(m.cookie)
   244  					})
   245  				})
   246  			}
   247  			if len(m.keyShares) > 0 {
   248  				// RFC 8446, Section 4.2.8
   249  				b.AddUint16(extensionKeyShare)
   250  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   251  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   252  						for _, ks := range m.keyShares {
   253  							b.AddUint16(uint16(ks.group))
   254  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   255  								b.AddBytes(ks.data)
   256  							})
   257  						}
   258  					})
   259  				})
   260  			}
   261  			if m.earlyData {
   262  				// RFC 8446, Section 4.2.10
   263  				b.AddUint16(extensionEarlyData)
   264  				b.AddUint16(0) // empty extension_data
   265  			}
   266  			if len(m.pskModes) > 0 {
   267  				// RFC 8446, Section 4.2.9
   268  				b.AddUint16(extensionPSKModes)
   269  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   270  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   271  						b.AddBytes(m.pskModes)
   272  					})
   273  				})
   274  			}
   275  			if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
   276  				// RFC 8446, Section 4.2.11
   277  				b.AddUint16(extensionPreSharedKey)
   278  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   279  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   280  						for _, psk := range m.pskIdentities {
   281  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   282  								b.AddBytes(psk.label)
   283  							})
   284  							b.AddUint32(psk.obfuscatedTicketAge)
   285  						}
   286  					})
   287  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   288  						for _, binder := range m.pskBinders {
   289  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   290  								b.AddBytes(binder)
   291  							})
   292  						}
   293  					})
   294  				})
   295  			}
   296  
   297  			extensionsPresent = len(b.BytesOrPanic()) > 2
   298  		})
   299  
   300  		if !extensionsPresent {
   301  			*b = bWithoutExtensions
   302  		}
   303  	})
   304  
   305  	m.raw = b.BytesOrPanic()
   306  	return m.raw
   307  }
   308  
   309  // marshalWithoutBinders returns the ClientHello through the
   310  // PreSharedKeyExtension.identities field, according to RFC 8446, Section
   311  // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
   312  func (m *clientHelloMsg) marshalWithoutBinders() []byte {
   313  	bindersLen := 2 // uint16 length prefix
   314  	for _, binder := range m.pskBinders {
   315  		bindersLen += 1 // uint8 length prefix
   316  		bindersLen += len(binder)
   317  	}
   318  
   319  	fullMessage := m.marshal()
   320  	return fullMessage[:len(fullMessage)-bindersLen]
   321  }
   322  
   323  // updateBinders updates the m.pskBinders field, if necessary updating the
   324  // cached marshaled representation. The supplied binders must have the same
   325  // length as the current m.pskBinders.
   326  func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
   327  	if len(pskBinders) != len(m.pskBinders) {
   328  		panic("tls: internal error: pskBinders length mismatch")
   329  	}
   330  	for i := range m.pskBinders {
   331  		if len(pskBinders[i]) != len(m.pskBinders[i]) {
   332  			panic("tls: internal error: pskBinders length mismatch")
   333  		}
   334  	}
   335  	m.pskBinders = pskBinders
   336  	if m.raw != nil {
   337  		lenWithoutBinders := len(m.marshalWithoutBinders())
   338  		// TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported.
   339  		b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders])
   340  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   341  			for _, binder := range m.pskBinders {
   342  				b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   343  					b.AddBytes(binder)
   344  				})
   345  			}
   346  		})
   347  		if len(b.BytesOrPanic()) != len(m.raw) {
   348  			panic("tls: internal error: failed to update binders")
   349  		}
   350  	}
   351  }
   352  
   353  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   354  	*m = clientHelloMsg{raw: data}
   355  	s := cryptobyte.String(data)
   356  
   357  	if !s.Skip(4) || // message type and uint24 length field
   358  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   359  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   360  		return false
   361  	}
   362  
   363  	var cipherSuites cryptobyte.String
   364  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   365  		return false
   366  	}
   367  	m.cipherSuites = []uint16{}
   368  	m.secureRenegotiationSupported = false
   369  	for !cipherSuites.Empty() {
   370  		var suite uint16
   371  		if !cipherSuites.ReadUint16(&suite) {
   372  			return false
   373  		}
   374  		if suite == scsvRenegotiation {
   375  			m.secureRenegotiationSupported = true
   376  		}
   377  		m.cipherSuites = append(m.cipherSuites, suite)
   378  	}
   379  
   380  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   381  		return false
   382  	}
   383  
   384  	if s.Empty() {
   385  		// ClientHello is optionally followed by extension data
   386  		return true
   387  	}
   388  
   389  	var extensions cryptobyte.String
   390  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   391  		return false
   392  	}
   393  
   394  	for !extensions.Empty() {
   395  		var extension uint16
   396  		var extData cryptobyte.String
   397  		if !extensions.ReadUint16(&extension) ||
   398  			!extensions.ReadUint16LengthPrefixed(&extData) {
   399  			return false
   400  		}
   401  
   402  		switch extension {
   403  		case extensionServerName:
   404  			// RFC 6066, Section 3
   405  			var nameList cryptobyte.String
   406  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   407  				return false
   408  			}
   409  			for !nameList.Empty() {
   410  				var nameType uint8
   411  				var serverName cryptobyte.String
   412  				if !nameList.ReadUint8(&nameType) ||
   413  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   414  					serverName.Empty() {
   415  					return false
   416  				}
   417  				if nameType != 0 {
   418  					continue
   419  				}
   420  				if len(m.serverName) != 0 {
   421  					// Multiple names of the same name_type are prohibited.
   422  					return false
   423  				}
   424  				m.serverName = string(serverName)
   425  				// An SNI value may not include a trailing dot.
   426  				if strings.HasSuffix(m.serverName, ".") {
   427  					return false
   428  				}
   429  			}
   430  		case extensionNextProtoNeg:
   431  			// draft-agl-tls-nextprotoneg-04
   432  			m.nextProtoNeg = true
   433  		case extensionStatusRequest:
   434  			// RFC 4366, Section 3.6
   435  			var statusType uint8
   436  			var ignored cryptobyte.String
   437  			if !extData.ReadUint8(&statusType) ||
   438  				!extData.ReadUint16LengthPrefixed(&ignored) ||
   439  				!extData.ReadUint16LengthPrefixed(&ignored) {
   440  				return false
   441  			}
   442  			m.ocspStapling = statusType == statusTypeOCSP
   443  		case extensionSupportedCurves:
   444  			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   445  			var curves cryptobyte.String
   446  			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
   447  				return false
   448  			}
   449  			for !curves.Empty() {
   450  				var curve uint16
   451  				if !curves.ReadUint16(&curve) {
   452  					return false
   453  				}
   454  				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
   455  			}
   456  		case extensionSupportedPoints:
   457  			// RFC 4492, Section 5.1.2
   458  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   459  				len(m.supportedPoints) == 0 {
   460  				return false
   461  			}
   462  		case extensionSessionTicket:
   463  			// RFC 5077, Section 3.2
   464  			m.ticketSupported = true
   465  			extData.ReadBytes(&m.sessionTicket, len(extData))
   466  		case extensionSignatureAlgorithms:
   467  			// RFC 5246, Section 7.4.1.4.1
   468  			var sigAndAlgs cryptobyte.String
   469  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   470  				return false
   471  			}
   472  			for !sigAndAlgs.Empty() {
   473  				var sigAndAlg uint16
   474  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   475  					return false
   476  				}
   477  				m.supportedSignatureAlgorithms = append(
   478  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
   479  			}
   480  		case extensionSignatureAlgorithmsCert:
   481  			// RFC 8446, Section 4.2.3
   482  			var sigAndAlgs cryptobyte.String
   483  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   484  				return false
   485  			}
   486  			for !sigAndAlgs.Empty() {
   487  				var sigAndAlg uint16
   488  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   489  					return false
   490  				}
   491  				m.supportedSignatureAlgorithmsCert = append(
   492  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
   493  			}
   494  		case extensionRenegotiationInfo:
   495  			// RFC 5746, Section 3.2
   496  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   497  				return false
   498  			}
   499  			m.secureRenegotiationSupported = true
   500  		case extensionALPN:
   501  			// RFC 7301, Section 3.1
   502  			var protoList cryptobyte.String
   503  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   504  				return false
   505  			}
   506  			for !protoList.Empty() {
   507  				var proto cryptobyte.String
   508  				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
   509  					return false
   510  				}
   511  				m.alpnProtocols = append(m.alpnProtocols, string(proto))
   512  			}
   513  		case extensionSCT:
   514  			// RFC 6962, Section 3.3.1
   515  			m.scts = true
   516  		case extensionSupportedVersions:
   517  			// RFC 8446, Section 4.2.1
   518  			var versList cryptobyte.String
   519  			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
   520  				return false
   521  			}
   522  			for !versList.Empty() {
   523  				var vers uint16
   524  				if !versList.ReadUint16(&vers) {
   525  					return false
   526  				}
   527  				m.supportedVersions = append(m.supportedVersions, vers)
   528  			}
   529  		case extensionCookie:
   530  			// RFC 8446, Section 4.2.2
   531  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   532  				len(m.cookie) == 0 {
   533  				return false
   534  			}
   535  		case extensionKeyShare:
   536  			// RFC 8446, Section 4.2.8
   537  			var clientShares cryptobyte.String
   538  			if !extData.ReadUint16LengthPrefixed(&clientShares) {
   539  				return false
   540  			}
   541  			for !clientShares.Empty() {
   542  				var ks keyShare
   543  				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
   544  					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
   545  					len(ks.data) == 0 {
   546  					return false
   547  				}
   548  				m.keyShares = append(m.keyShares, ks)
   549  			}
   550  		case extensionEarlyData:
   551  			// RFC 8446, Section 4.2.10
   552  			m.earlyData = true
   553  		case extensionPSKModes:
   554  			// RFC 8446, Section 4.2.9
   555  			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
   556  				return false
   557  			}
   558  		case extensionPreSharedKey:
   559  			// RFC 8446, Section 4.2.11
   560  			if !extensions.Empty() {
   561  				return false // pre_shared_key must be the last extension
   562  			}
   563  			var identities cryptobyte.String
   564  			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
   565  				return false
   566  			}
   567  			for !identities.Empty() {
   568  				var psk pskIdentity
   569  				if !readUint16LengthPrefixed(&identities, &psk.label) ||
   570  					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
   571  					len(psk.label) == 0 {
   572  					return false
   573  				}
   574  				m.pskIdentities = append(m.pskIdentities, psk)
   575  			}
   576  			var binders cryptobyte.String
   577  			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
   578  				return false
   579  			}
   580  			for !binders.Empty() {
   581  				var binder []byte
   582  				if !readUint8LengthPrefixed(&binders, &binder) ||
   583  					len(binder) == 0 {
   584  					return false
   585  				}
   586  				m.pskBinders = append(m.pskBinders, binder)
   587  			}
   588  		default:
   589  			// Ignore unknown extensions.
   590  			continue
   591  		}
   592  
   593  		if !extData.Empty() {
   594  			return false
   595  		}
   596  	}
   597  
   598  	return true
   599  }
   600  
   601  type serverHelloMsg struct {
   602  	raw                          []byte
   603  	vers                         uint16
   604  	random                       []byte
   605  	sessionId                    []byte
   606  	cipherSuite                  uint16
   607  	compressionMethod            uint8
   608  	nextProtoNeg                 bool
   609  	nextProtos                   []string
   610  	ocspStapling                 bool
   611  	ticketSupported              bool
   612  	secureRenegotiationSupported bool
   613  	secureRenegotiation          []byte
   614  	alpnProtocol                 string
   615  	ems                          bool
   616  	scts                         [][]byte
   617  	supportedVersion             uint16
   618  	serverShare                  keyShare
   619  	selectedIdentityPresent      bool
   620  	selectedIdentity             uint16
   621  
   622  	// HelloRetryRequest extensions
   623  	cookie        []byte
   624  	selectedGroup CurveID
   625  }
   626  
   627  func (m *serverHelloMsg) marshal() []byte {
   628  	if m.raw != nil {
   629  		return m.raw
   630  	}
   631  
   632  	var b cryptobyte.Builder
   633  	b.AddUint8(typeServerHello)
   634  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   635  		b.AddUint16(m.vers)
   636  		addBytesWithLength(b, m.random, 32)
   637  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   638  			b.AddBytes(m.sessionId)
   639  		})
   640  		b.AddUint16(m.cipherSuite)
   641  		b.AddUint8(m.compressionMethod)
   642  
   643  		// If extensions aren't present, omit them.
   644  		var extensionsPresent bool
   645  		bWithoutExtensions := *b
   646  
   647  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   648  			if m.nextProtoNeg {
   649  				b.AddUint16(extensionNextProtoNeg)
   650  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   651  					for _, proto := range m.nextProtos {
   652  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   653  							b.AddBytes([]byte(proto))
   654  						})
   655  					}
   656  				})
   657  			}
   658  			if m.ocspStapling {
   659  				b.AddUint16(extensionStatusRequest)
   660  				b.AddUint16(0) // empty extension_data
   661  			}
   662  			if m.ticketSupported {
   663  				b.AddUint16(extensionSessionTicket)
   664  				b.AddUint16(0) // empty extension_data
   665  			}
   666  			if m.secureRenegotiationSupported {
   667  				b.AddUint16(extensionRenegotiationInfo)
   668  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   669  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   670  						b.AddBytes(m.secureRenegotiation)
   671  					})
   672  				})
   673  			}
   674  			if len(m.alpnProtocol) > 0 {
   675  				b.AddUint16(extensionALPN)
   676  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   677  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   678  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   679  							b.AddBytes([]byte(m.alpnProtocol))
   680  						})
   681  					})
   682  				})
   683  			}
   684  			if len(m.scts) > 0 {
   685  				b.AddUint16(extensionSCT)
   686  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   687  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   688  						for _, sct := range m.scts {
   689  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   690  								b.AddBytes(sct)
   691  							})
   692  						}
   693  					})
   694  				})
   695  			}
   696  			if m.supportedVersion != 0 {
   697  				b.AddUint16(extensionSupportedVersions)
   698  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   699  					b.AddUint16(m.supportedVersion)
   700  				})
   701  			}
   702  			if m.serverShare.group != 0 {
   703  				b.AddUint16(extensionKeyShare)
   704  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   705  					b.AddUint16(uint16(m.serverShare.group))
   706  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   707  						b.AddBytes(m.serverShare.data)
   708  					})
   709  				})
   710  			}
   711  			if m.selectedIdentityPresent {
   712  				b.AddUint16(extensionPreSharedKey)
   713  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   714  					b.AddUint16(m.selectedIdentity)
   715  				})
   716  			}
   717  
   718  			if len(m.cookie) > 0 {
   719  				b.AddUint16(extensionCookie)
   720  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   721  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   722  						b.AddBytes(m.cookie)
   723  					})
   724  				})
   725  			}
   726  			if m.selectedGroup != 0 {
   727  				b.AddUint16(extensionKeyShare)
   728  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   729  					b.AddUint16(uint16(m.selectedGroup))
   730  				})
   731  			}
   732  
   733  			extensionsPresent = len(b.BytesOrPanic()) > 2
   734  		})
   735  
   736  		if !extensionsPresent {
   737  			*b = bWithoutExtensions
   738  		}
   739  	})
   740  
   741  	m.raw = b.BytesOrPanic()
   742  	return m.raw
   743  }
   744  
   745  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   746  	*m = serverHelloMsg{raw: data}
   747  	s := cryptobyte.String(data)
   748  
   749  	if !s.Skip(4) || // message type and uint24 length field
   750  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   751  		!readUint8LengthPrefixed(&s, &m.sessionId) ||
   752  		!s.ReadUint16(&m.cipherSuite) ||
   753  		!s.ReadUint8(&m.compressionMethod) {
   754  		return false
   755  	}
   756  
   757  	if s.Empty() {
   758  		// ServerHello is optionally followed by extension data
   759  		return true
   760  	}
   761  
   762  	var extensions cryptobyte.String
   763  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   764  		return false
   765  	}
   766  
   767  	for !extensions.Empty() {
   768  		var extension uint16
   769  		var extData cryptobyte.String
   770  		if !extensions.ReadUint16(&extension) ||
   771  			!extensions.ReadUint16LengthPrefixed(&extData) {
   772  			return false
   773  		}
   774  
   775  		switch extension {
   776  		case extensionNextProtoNeg:
   777  			m.nextProtoNeg = true
   778  			for !extData.Empty() {
   779  				var proto cryptobyte.String
   780  				if !extData.ReadUint8LengthPrefixed(&proto) ||
   781  					proto.Empty() {
   782  					return false
   783  				}
   784  				m.nextProtos = append(m.nextProtos, string(proto))
   785  			}
   786  		case extensionStatusRequest:
   787  			m.ocspStapling = true
   788  		case extensionSessionTicket:
   789  			m.ticketSupported = true
   790  		case utlsExtensionExtendedMasterSecret:
   791  			// No sanity check for this extension: pretending not to know it.
   792  			// if length > 0 {
   793  			// 	return false
   794  			// }
   795  			m.ems = true
   796  		case extensionRenegotiationInfo:
   797  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   798  				return false
   799  			}
   800  			m.secureRenegotiationSupported = true
   801  		case extensionALPN:
   802  			var protoList cryptobyte.String
   803  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   804  				return false
   805  			}
   806  			var proto cryptobyte.String
   807  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   808  				proto.Empty() || !protoList.Empty() {
   809  				return false
   810  			}
   811  			m.alpnProtocol = string(proto)
   812  		case extensionSCT:
   813  			var sctList cryptobyte.String
   814  			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
   815  				return false
   816  			}
   817  			for !sctList.Empty() {
   818  				var sct []byte
   819  				if !readUint16LengthPrefixed(&sctList, &sct) ||
   820  					len(sct) == 0 {
   821  					return false
   822  				}
   823  				m.scts = append(m.scts, sct)
   824  			}
   825  		case extensionSupportedVersions:
   826  			if !extData.ReadUint16(&m.supportedVersion) {
   827  				return false
   828  			}
   829  		case extensionCookie:
   830  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   831  				len(m.cookie) == 0 {
   832  				return false
   833  			}
   834  		case extensionKeyShare:
   835  			// This extension has different formats in SH and HRR, accept either
   836  			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
   837  			if len(extData) == 2 {
   838  				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
   839  					return false
   840  				}
   841  			} else {
   842  				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
   843  					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
   844  					return false
   845  				}
   846  			}
   847  		case extensionPreSharedKey:
   848  			m.selectedIdentityPresent = true
   849  			if !extData.ReadUint16(&m.selectedIdentity) {
   850  				return false
   851  			}
   852  		default:
   853  			// Ignore unknown extensions.
   854  			continue
   855  		}
   856  
   857  		if !extData.Empty() {
   858  			return false
   859  		}
   860  	}
   861  
   862  	return true
   863  }
   864  
   865  type encryptedExtensionsMsg struct {
   866  	raw          []byte
   867  	alpnProtocol string
   868  }
   869  
   870  func (m *encryptedExtensionsMsg) marshal() []byte {
   871  	if m.raw != nil {
   872  		return m.raw
   873  	}
   874  
   875  	var b cryptobyte.Builder
   876  	b.AddUint8(typeEncryptedExtensions)
   877  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   878  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   879  			if len(m.alpnProtocol) > 0 {
   880  				b.AddUint16(extensionALPN)
   881  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   882  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   883  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   884  							b.AddBytes([]byte(m.alpnProtocol))
   885  						})
   886  					})
   887  				})
   888  			}
   889  		})
   890  	})
   891  
   892  	m.raw = b.BytesOrPanic()
   893  	return m.raw
   894  }
   895  
   896  func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
   897  	*m = encryptedExtensionsMsg{raw: data}
   898  	s := cryptobyte.String(data)
   899  
   900  	var extensions cryptobyte.String
   901  	if !s.Skip(4) || // message type and uint24 length field
   902  		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   903  		return false
   904  	}
   905  
   906  	for !extensions.Empty() {
   907  		var extension uint16
   908  		var extData cryptobyte.String
   909  		if !extensions.ReadUint16(&extension) ||
   910  			!extensions.ReadUint16LengthPrefixed(&extData) {
   911  			return false
   912  		}
   913  
   914  		switch extension {
   915  		case extensionALPN:
   916  			var protoList cryptobyte.String
   917  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   918  				return false
   919  			}
   920  			var proto cryptobyte.String
   921  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   922  				proto.Empty() || !protoList.Empty() {
   923  				return false
   924  			}
   925  			m.alpnProtocol = string(proto)
   926  		default:
   927  			// Ignore unknown extensions.
   928  			continue
   929  		}
   930  
   931  		if !extData.Empty() {
   932  			return false
   933  		}
   934  	}
   935  
   936  	return true
   937  }
   938  
   939  type endOfEarlyDataMsg struct{}
   940  
   941  func (m *endOfEarlyDataMsg) marshal() []byte {
   942  	x := make([]byte, 4)
   943  	x[0] = typeEndOfEarlyData
   944  	return x
   945  }
   946  
   947  func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
   948  	return len(data) == 4
   949  }
   950  
   951  type keyUpdateMsg struct {
   952  	raw             []byte
   953  	updateRequested bool
   954  }
   955  
   956  func (m *keyUpdateMsg) marshal() []byte {
   957  	if m.raw != nil {
   958  		return m.raw
   959  	}
   960  
   961  	var b cryptobyte.Builder
   962  	b.AddUint8(typeKeyUpdate)
   963  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   964  		if m.updateRequested {
   965  			b.AddUint8(1)
   966  		} else {
   967  			b.AddUint8(0)
   968  		}
   969  	})
   970  
   971  	m.raw = b.BytesOrPanic()
   972  	return m.raw
   973  }
   974  
   975  func (m *keyUpdateMsg) unmarshal(data []byte) bool {
   976  	m.raw = data
   977  	s := cryptobyte.String(data)
   978  
   979  	var updateRequested uint8
   980  	if !s.Skip(4) || // message type and uint24 length field
   981  		!s.ReadUint8(&updateRequested) || !s.Empty() {
   982  		return false
   983  	}
   984  	switch updateRequested {
   985  	case 0:
   986  		m.updateRequested = false
   987  	case 1:
   988  		m.updateRequested = true
   989  	default:
   990  		return false
   991  	}
   992  	return true
   993  }
   994  
   995  type newSessionTicketMsgTLS13 struct {
   996  	raw          []byte
   997  	lifetime     uint32
   998  	ageAdd       uint32
   999  	nonce        []byte
  1000  	label        []byte
  1001  	maxEarlyData uint32
  1002  }
  1003  
  1004  func (m *newSessionTicketMsgTLS13) marshal() []byte {
  1005  	if m.raw != nil {
  1006  		return m.raw
  1007  	}
  1008  
  1009  	var b cryptobyte.Builder
  1010  	b.AddUint8(typeNewSessionTicket)
  1011  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1012  		b.AddUint32(m.lifetime)
  1013  		b.AddUint32(m.ageAdd)
  1014  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
  1015  			b.AddBytes(m.nonce)
  1016  		})
  1017  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1018  			b.AddBytes(m.label)
  1019  		})
  1020  
  1021  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1022  			if m.maxEarlyData > 0 {
  1023  				b.AddUint16(extensionEarlyData)
  1024  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1025  					b.AddUint32(m.maxEarlyData)
  1026  				})
  1027  			}
  1028  		})
  1029  	})
  1030  
  1031  	m.raw = b.BytesOrPanic()
  1032  	return m.raw
  1033  }
  1034  
  1035  func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
  1036  	*m = newSessionTicketMsgTLS13{raw: data}
  1037  	s := cryptobyte.String(data)
  1038  
  1039  	var extensions cryptobyte.String
  1040  	if !s.Skip(4) || // message type and uint24 length field
  1041  		!s.ReadUint32(&m.lifetime) ||
  1042  		!s.ReadUint32(&m.ageAdd) ||
  1043  		!readUint8LengthPrefixed(&s, &m.nonce) ||
  1044  		!readUint16LengthPrefixed(&s, &m.label) ||
  1045  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1046  		!s.Empty() {
  1047  		return false
  1048  	}
  1049  
  1050  	for !extensions.Empty() {
  1051  		var extension uint16
  1052  		var extData cryptobyte.String
  1053  		if !extensions.ReadUint16(&extension) ||
  1054  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1055  			return false
  1056  		}
  1057  
  1058  		switch extension {
  1059  		case extensionEarlyData:
  1060  			if !extData.ReadUint32(&m.maxEarlyData) {
  1061  				return false
  1062  			}
  1063  		default:
  1064  			// Ignore unknown extensions.
  1065  			continue
  1066  		}
  1067  
  1068  		if !extData.Empty() {
  1069  			return false
  1070  		}
  1071  	}
  1072  
  1073  	return true
  1074  }
  1075  
  1076  type certificateRequestMsgTLS13 struct {
  1077  	raw                              []byte
  1078  	ocspStapling                     bool
  1079  	scts                             bool
  1080  	supportedSignatureAlgorithms     []SignatureScheme
  1081  	supportedSignatureAlgorithmsCert []SignatureScheme
  1082  	certificateAuthorities           [][]byte
  1083  }
  1084  
  1085  func (m *certificateRequestMsgTLS13) marshal() []byte {
  1086  	if m.raw != nil {
  1087  		return m.raw
  1088  	}
  1089  
  1090  	var b cryptobyte.Builder
  1091  	b.AddUint8(typeCertificateRequest)
  1092  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1093  		// certificate_request_context (SHALL be zero length unless used for
  1094  		// post-handshake authentication)
  1095  		b.AddUint8(0)
  1096  
  1097  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1098  			if m.ocspStapling {
  1099  				b.AddUint16(extensionStatusRequest)
  1100  				b.AddUint16(0) // empty extension_data
  1101  			}
  1102  			if m.scts {
  1103  				// RFC 8446, Section 4.4.2.1 makes no mention of
  1104  				// signed_certificate_timestamp in CertificateRequest, but
  1105  				// "Extensions in the Certificate message from the client MUST
  1106  				// correspond to extensions in the CertificateRequest message
  1107  				// from the server." and it appears in the table in Section 4.2.
  1108  				b.AddUint16(extensionSCT)
  1109  				b.AddUint16(0) // empty extension_data
  1110  			}
  1111  			if len(m.supportedSignatureAlgorithms) > 0 {
  1112  				b.AddUint16(extensionSignatureAlgorithms)
  1113  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1114  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1115  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1116  							b.AddUint16(uint16(sigAlgo))
  1117  						}
  1118  					})
  1119  				})
  1120  			}
  1121  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
  1122  				b.AddUint16(extensionSignatureAlgorithmsCert)
  1123  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1124  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1125  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
  1126  							b.AddUint16(uint16(sigAlgo))
  1127  						}
  1128  					})
  1129  				})
  1130  			}
  1131  			if len(m.certificateAuthorities) > 0 {
  1132  				b.AddUint16(extensionCertificateAuthorities)
  1133  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1134  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1135  						for _, ca := range m.certificateAuthorities {
  1136  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1137  								b.AddBytes(ca)
  1138  							})
  1139  						}
  1140  					})
  1141  				})
  1142  			}
  1143  		})
  1144  	})
  1145  
  1146  	m.raw = b.BytesOrPanic()
  1147  	return m.raw
  1148  }
  1149  
  1150  func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
  1151  	*m = certificateRequestMsgTLS13{raw: data}
  1152  	s := cryptobyte.String(data)
  1153  
  1154  	var context, extensions cryptobyte.String
  1155  	if !s.Skip(4) || // message type and uint24 length field
  1156  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1157  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1158  		!s.Empty() {
  1159  		return false
  1160  	}
  1161  
  1162  	for !extensions.Empty() {
  1163  		var extension uint16
  1164  		var extData cryptobyte.String
  1165  		if !extensions.ReadUint16(&extension) ||
  1166  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1167  			return false
  1168  		}
  1169  
  1170  		switch extension {
  1171  		case extensionStatusRequest:
  1172  			m.ocspStapling = true
  1173  		case extensionSCT:
  1174  			m.scts = true
  1175  		case extensionSignatureAlgorithms:
  1176  			var sigAndAlgs cryptobyte.String
  1177  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1178  				return false
  1179  			}
  1180  			for !sigAndAlgs.Empty() {
  1181  				var sigAndAlg uint16
  1182  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1183  					return false
  1184  				}
  1185  				m.supportedSignatureAlgorithms = append(
  1186  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
  1187  			}
  1188  		case extensionSignatureAlgorithmsCert:
  1189  			var sigAndAlgs cryptobyte.String
  1190  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1191  				return false
  1192  			}
  1193  			for !sigAndAlgs.Empty() {
  1194  				var sigAndAlg uint16
  1195  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1196  					return false
  1197  				}
  1198  				m.supportedSignatureAlgorithmsCert = append(
  1199  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
  1200  			}
  1201  		case extensionCertificateAuthorities:
  1202  			var auths cryptobyte.String
  1203  			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
  1204  				return false
  1205  			}
  1206  			for !auths.Empty() {
  1207  				var ca []byte
  1208  				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
  1209  					return false
  1210  				}
  1211  				m.certificateAuthorities = append(m.certificateAuthorities, ca)
  1212  			}
  1213  		default:
  1214  			// Ignore unknown extensions.
  1215  			continue
  1216  		}
  1217  
  1218  		if !extData.Empty() {
  1219  			return false
  1220  		}
  1221  	}
  1222  
  1223  	return true
  1224  }
  1225  
  1226  type certificateMsg struct {
  1227  	raw          []byte
  1228  	certificates [][]byte
  1229  }
  1230  
  1231  func (m *certificateMsg) marshal() (x []byte) {
  1232  	if m.raw != nil {
  1233  		return m.raw
  1234  	}
  1235  
  1236  	var i int
  1237  	for _, slice := range m.certificates {
  1238  		i += len(slice)
  1239  	}
  1240  
  1241  	length := 3 + 3*len(m.certificates) + i
  1242  	x = make([]byte, 4+length)
  1243  	x[0] = typeCertificate
  1244  	x[1] = uint8(length >> 16)
  1245  	x[2] = uint8(length >> 8)
  1246  	x[3] = uint8(length)
  1247  
  1248  	certificateOctets := length - 3
  1249  	x[4] = uint8(certificateOctets >> 16)
  1250  	x[5] = uint8(certificateOctets >> 8)
  1251  	x[6] = uint8(certificateOctets)
  1252  
  1253  	y := x[7:]
  1254  	for _, slice := range m.certificates {
  1255  		y[0] = uint8(len(slice) >> 16)
  1256  		y[1] = uint8(len(slice) >> 8)
  1257  		y[2] = uint8(len(slice))
  1258  		copy(y[3:], slice)
  1259  		y = y[3+len(slice):]
  1260  	}
  1261  
  1262  	m.raw = x
  1263  	return
  1264  }
  1265  
  1266  func (m *certificateMsg) unmarshal(data []byte) bool {
  1267  	if len(data) < 7 {
  1268  		return false
  1269  	}
  1270  
  1271  	m.raw = data
  1272  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1273  	if uint32(len(data)) != certsLen+7 {
  1274  		return false
  1275  	}
  1276  
  1277  	numCerts := 0
  1278  	d := data[7:]
  1279  	for certsLen > 0 {
  1280  		if len(d) < 4 {
  1281  			return false
  1282  		}
  1283  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1284  		if uint32(len(d)) < 3+certLen {
  1285  			return false
  1286  		}
  1287  		d = d[3+certLen:]
  1288  		certsLen -= 3 + certLen
  1289  		numCerts++
  1290  	}
  1291  
  1292  	m.certificates = make([][]byte, numCerts)
  1293  	d = data[7:]
  1294  	for i := 0; i < numCerts; i++ {
  1295  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1296  		m.certificates[i] = d[3 : 3+certLen]
  1297  		d = d[3+certLen:]
  1298  	}
  1299  
  1300  	return true
  1301  }
  1302  
  1303  type certificateMsgTLS13 struct {
  1304  	raw          []byte
  1305  	certificate  Certificate
  1306  	ocspStapling bool
  1307  	scts         bool
  1308  }
  1309  
  1310  func (m *certificateMsgTLS13) marshal() []byte {
  1311  	if m.raw != nil {
  1312  		return m.raw
  1313  	}
  1314  
  1315  	var b cryptobyte.Builder
  1316  	b.AddUint8(typeCertificate)
  1317  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1318  		b.AddUint8(0) // certificate_request_context
  1319  
  1320  		certificate := m.certificate
  1321  		if !m.ocspStapling {
  1322  			certificate.OCSPStaple = nil
  1323  		}
  1324  		if !m.scts {
  1325  			certificate.SignedCertificateTimestamps = nil
  1326  		}
  1327  		marshalCertificate(b, certificate)
  1328  	})
  1329  
  1330  	m.raw = b.BytesOrPanic()
  1331  	return m.raw
  1332  }
  1333  
  1334  func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
  1335  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1336  		for i, cert := range certificate.Certificate {
  1337  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1338  				b.AddBytes(cert)
  1339  			})
  1340  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1341  				if i > 0 {
  1342  					// This library only supports OCSP and SCT for leaf certificates.
  1343  					return
  1344  				}
  1345  				if certificate.OCSPStaple != nil {
  1346  					b.AddUint16(extensionStatusRequest)
  1347  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1348  						b.AddUint8(statusTypeOCSP)
  1349  						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1350  							b.AddBytes(certificate.OCSPStaple)
  1351  						})
  1352  					})
  1353  				}
  1354  				if certificate.SignedCertificateTimestamps != nil {
  1355  					b.AddUint16(extensionSCT)
  1356  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1357  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1358  							for _, sct := range certificate.SignedCertificateTimestamps {
  1359  								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1360  									b.AddBytes(sct)
  1361  								})
  1362  							}
  1363  						})
  1364  					})
  1365  				}
  1366  			})
  1367  		}
  1368  	})
  1369  }
  1370  
  1371  func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
  1372  	*m = certificateMsgTLS13{raw: data}
  1373  	s := cryptobyte.String(data)
  1374  
  1375  	var context cryptobyte.String
  1376  	if !s.Skip(4) || // message type and uint24 length field
  1377  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1378  		!unmarshalCertificate(&s, &m.certificate) ||
  1379  		!s.Empty() {
  1380  		return false
  1381  	}
  1382  
  1383  	m.scts = m.certificate.SignedCertificateTimestamps != nil
  1384  	m.ocspStapling = m.certificate.OCSPStaple != nil
  1385  
  1386  	return true
  1387  }
  1388  
  1389  func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
  1390  	var certList cryptobyte.String
  1391  	if !s.ReadUint24LengthPrefixed(&certList) {
  1392  		return false
  1393  	}
  1394  	for !certList.Empty() {
  1395  		var cert []byte
  1396  		var extensions cryptobyte.String
  1397  		if !readUint24LengthPrefixed(&certList, &cert) ||
  1398  			!certList.ReadUint16LengthPrefixed(&extensions) {
  1399  			return false
  1400  		}
  1401  		certificate.Certificate = append(certificate.Certificate, cert)
  1402  		for !extensions.Empty() {
  1403  			var extension uint16
  1404  			var extData cryptobyte.String
  1405  			if !extensions.ReadUint16(&extension) ||
  1406  				!extensions.ReadUint16LengthPrefixed(&extData) {
  1407  				return false
  1408  			}
  1409  			if len(certificate.Certificate) > 1 {
  1410  				// This library only supports OCSP and SCT for leaf certificates.
  1411  				continue
  1412  			}
  1413  
  1414  			switch extension {
  1415  			case extensionStatusRequest:
  1416  				var statusType uint8
  1417  				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1418  					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
  1419  					len(certificate.OCSPStaple) == 0 {
  1420  					return false
  1421  				}
  1422  			case extensionSCT:
  1423  				var sctList cryptobyte.String
  1424  				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
  1425  					return false
  1426  				}
  1427  				for !sctList.Empty() {
  1428  					var sct []byte
  1429  					if !readUint16LengthPrefixed(&sctList, &sct) ||
  1430  						len(sct) == 0 {
  1431  						return false
  1432  					}
  1433  					certificate.SignedCertificateTimestamps = append(
  1434  						certificate.SignedCertificateTimestamps, sct)
  1435  				}
  1436  			default:
  1437  				// Ignore unknown extensions.
  1438  				continue
  1439  			}
  1440  
  1441  			if !extData.Empty() {
  1442  				return false
  1443  			}
  1444  		}
  1445  	}
  1446  	return true
  1447  }
  1448  
  1449  type serverKeyExchangeMsg struct {
  1450  	raw []byte
  1451  	key []byte
  1452  }
  1453  
  1454  func (m *serverKeyExchangeMsg) marshal() []byte {
  1455  	if m.raw != nil {
  1456  		return m.raw
  1457  	}
  1458  	length := len(m.key)
  1459  	x := make([]byte, length+4)
  1460  	x[0] = typeServerKeyExchange
  1461  	x[1] = uint8(length >> 16)
  1462  	x[2] = uint8(length >> 8)
  1463  	x[3] = uint8(length)
  1464  	copy(x[4:], m.key)
  1465  
  1466  	m.raw = x
  1467  	return x
  1468  }
  1469  
  1470  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1471  	m.raw = data
  1472  	if len(data) < 4 {
  1473  		return false
  1474  	}
  1475  	m.key = data[4:]
  1476  	return true
  1477  }
  1478  
  1479  type certificateStatusMsg struct {
  1480  	raw      []byte
  1481  	response []byte
  1482  }
  1483  
  1484  func (m *certificateStatusMsg) marshal() []byte {
  1485  	if m.raw != nil {
  1486  		return m.raw
  1487  	}
  1488  
  1489  	var b cryptobyte.Builder
  1490  	b.AddUint8(typeCertificateStatus)
  1491  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1492  		b.AddUint8(statusTypeOCSP)
  1493  		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1494  			b.AddBytes(m.response)
  1495  		})
  1496  	})
  1497  
  1498  	m.raw = b.BytesOrPanic()
  1499  	return m.raw
  1500  }
  1501  
  1502  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1503  	m.raw = data
  1504  	s := cryptobyte.String(data)
  1505  
  1506  	var statusType uint8
  1507  	if !s.Skip(4) || // message type and uint24 length field
  1508  		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1509  		!readUint24LengthPrefixed(&s, &m.response) ||
  1510  		len(m.response) == 0 || !s.Empty() {
  1511  		return false
  1512  	}
  1513  	return true
  1514  }
  1515  
  1516  type serverHelloDoneMsg struct{}
  1517  
  1518  func (m *serverHelloDoneMsg) marshal() []byte {
  1519  	x := make([]byte, 4)
  1520  	x[0] = typeServerHelloDone
  1521  	return x
  1522  }
  1523  
  1524  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1525  	return len(data) == 4
  1526  }
  1527  
  1528  type clientKeyExchangeMsg struct {
  1529  	raw        []byte
  1530  	ciphertext []byte
  1531  }
  1532  
  1533  func (m *clientKeyExchangeMsg) marshal() []byte {
  1534  	if m.raw != nil {
  1535  		return m.raw
  1536  	}
  1537  	length := len(m.ciphertext)
  1538  	x := make([]byte, length+4)
  1539  	x[0] = typeClientKeyExchange
  1540  	x[1] = uint8(length >> 16)
  1541  	x[2] = uint8(length >> 8)
  1542  	x[3] = uint8(length)
  1543  	copy(x[4:], m.ciphertext)
  1544  
  1545  	m.raw = x
  1546  	return x
  1547  }
  1548  
  1549  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1550  	m.raw = data
  1551  	if len(data) < 4 {
  1552  		return false
  1553  	}
  1554  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1555  	if l != len(data)-4 {
  1556  		return false
  1557  	}
  1558  	m.ciphertext = data[4:]
  1559  	return true
  1560  }
  1561  
  1562  type finishedMsg struct {
  1563  	raw        []byte
  1564  	verifyData []byte
  1565  }
  1566  
  1567  func (m *finishedMsg) marshal() []byte {
  1568  	if m.raw != nil {
  1569  		return m.raw
  1570  	}
  1571  
  1572  	var b cryptobyte.Builder
  1573  	b.AddUint8(typeFinished)
  1574  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1575  		b.AddBytes(m.verifyData)
  1576  	})
  1577  
  1578  	m.raw = b.BytesOrPanic()
  1579  	return m.raw
  1580  }
  1581  
  1582  func (m *finishedMsg) unmarshal(data []byte) bool {
  1583  	m.raw = data
  1584  	s := cryptobyte.String(data)
  1585  	return s.Skip(1) &&
  1586  		readUint24LengthPrefixed(&s, &m.verifyData) &&
  1587  		s.Empty()
  1588  }
  1589  
  1590  type nextProtoMsg struct {
  1591  	raw   []byte
  1592  	proto string
  1593  }
  1594  
  1595  func (m *nextProtoMsg) marshal() []byte {
  1596  	if m.raw != nil {
  1597  		return m.raw
  1598  	}
  1599  	l := len(m.proto)
  1600  	if l > 255 {
  1601  		l = 255
  1602  	}
  1603  
  1604  	padding := 32 - (l+2)%32
  1605  	length := l + padding + 2
  1606  	x := make([]byte, length+4)
  1607  	x[0] = typeNextProtocol
  1608  	x[1] = uint8(length >> 16)
  1609  	x[2] = uint8(length >> 8)
  1610  	x[3] = uint8(length)
  1611  
  1612  	y := x[4:]
  1613  	y[0] = byte(l)
  1614  	copy(y[1:], []byte(m.proto[0:l]))
  1615  	y = y[1+l:]
  1616  	y[0] = byte(padding)
  1617  
  1618  	m.raw = x
  1619  
  1620  	return x
  1621  }
  1622  
  1623  func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1624  	m.raw = data
  1625  
  1626  	if len(data) < 5 {
  1627  		return false
  1628  	}
  1629  	data = data[4:]
  1630  	protoLen := int(data[0])
  1631  	data = data[1:]
  1632  	if len(data) < protoLen {
  1633  		return false
  1634  	}
  1635  	m.proto = string(data[0:protoLen])
  1636  	data = data[protoLen:]
  1637  
  1638  	if len(data) < 1 {
  1639  		return false
  1640  	}
  1641  	paddingLen := int(data[0])
  1642  	data = data[1:]
  1643  	if len(data) != paddingLen {
  1644  		return false
  1645  	}
  1646  
  1647  	return true
  1648  }
  1649  
  1650  type certificateRequestMsg struct {
  1651  	raw []byte
  1652  	// hasSignatureAlgorithm indicates whether this message includes a list of
  1653  	// supported signature algorithms. This change was introduced with TLS 1.2.
  1654  	hasSignatureAlgorithm bool
  1655  
  1656  	certificateTypes             []byte
  1657  	supportedSignatureAlgorithms []SignatureScheme
  1658  	certificateAuthorities       [][]byte
  1659  }
  1660  
  1661  func (m *certificateRequestMsg) marshal() (x []byte) {
  1662  	if m.raw != nil {
  1663  		return m.raw
  1664  	}
  1665  
  1666  	// See RFC 4346, Section 7.4.4.
  1667  	length := 1 + len(m.certificateTypes) + 2
  1668  	casLength := 0
  1669  	for _, ca := range m.certificateAuthorities {
  1670  		casLength += 2 + len(ca)
  1671  	}
  1672  	length += casLength
  1673  
  1674  	if m.hasSignatureAlgorithm {
  1675  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1676  	}
  1677  
  1678  	x = make([]byte, 4+length)
  1679  	x[0] = typeCertificateRequest
  1680  	x[1] = uint8(length >> 16)
  1681  	x[2] = uint8(length >> 8)
  1682  	x[3] = uint8(length)
  1683  
  1684  	x[4] = uint8(len(m.certificateTypes))
  1685  
  1686  	copy(x[5:], m.certificateTypes)
  1687  	y := x[5+len(m.certificateTypes):]
  1688  
  1689  	if m.hasSignatureAlgorithm {
  1690  		n := len(m.supportedSignatureAlgorithms) * 2
  1691  		y[0] = uint8(n >> 8)
  1692  		y[1] = uint8(n)
  1693  		y = y[2:]
  1694  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1695  			y[0] = uint8(sigAlgo >> 8)
  1696  			y[1] = uint8(sigAlgo)
  1697  			y = y[2:]
  1698  		}
  1699  	}
  1700  
  1701  	y[0] = uint8(casLength >> 8)
  1702  	y[1] = uint8(casLength)
  1703  	y = y[2:]
  1704  	for _, ca := range m.certificateAuthorities {
  1705  		y[0] = uint8(len(ca) >> 8)
  1706  		y[1] = uint8(len(ca))
  1707  		y = y[2:]
  1708  		copy(y, ca)
  1709  		y = y[len(ca):]
  1710  	}
  1711  
  1712  	m.raw = x
  1713  	return
  1714  }
  1715  
  1716  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1717  	m.raw = data
  1718  
  1719  	if len(data) < 5 {
  1720  		return false
  1721  	}
  1722  
  1723  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1724  	if uint32(len(data))-4 != length {
  1725  		return false
  1726  	}
  1727  
  1728  	numCertTypes := int(data[4])
  1729  	data = data[5:]
  1730  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1731  		return false
  1732  	}
  1733  
  1734  	m.certificateTypes = make([]byte, numCertTypes)
  1735  	if copy(m.certificateTypes, data) != numCertTypes {
  1736  		return false
  1737  	}
  1738  
  1739  	data = data[numCertTypes:]
  1740  
  1741  	if m.hasSignatureAlgorithm {
  1742  		if len(data) < 2 {
  1743  			return false
  1744  		}
  1745  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1746  		data = data[2:]
  1747  		if sigAndHashLen&1 != 0 {
  1748  			return false
  1749  		}
  1750  		if len(data) < int(sigAndHashLen) {
  1751  			return false
  1752  		}
  1753  		numSigAlgos := sigAndHashLen / 2
  1754  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1755  		for i := range m.supportedSignatureAlgorithms {
  1756  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1757  			data = data[2:]
  1758  		}
  1759  	}
  1760  
  1761  	if len(data) < 2 {
  1762  		return false
  1763  	}
  1764  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1765  	data = data[2:]
  1766  	if len(data) < int(casLength) {
  1767  		return false
  1768  	}
  1769  	cas := make([]byte, casLength)
  1770  	copy(cas, data)
  1771  	data = data[casLength:]
  1772  
  1773  	m.certificateAuthorities = nil
  1774  	for len(cas) > 0 {
  1775  		if len(cas) < 2 {
  1776  			return false
  1777  		}
  1778  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1779  		cas = cas[2:]
  1780  
  1781  		if len(cas) < int(caLen) {
  1782  			return false
  1783  		}
  1784  
  1785  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1786  		cas = cas[caLen:]
  1787  	}
  1788  
  1789  	return len(data) == 0
  1790  }
  1791  
  1792  type certificateVerifyMsg struct {
  1793  	raw                   []byte
  1794  	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
  1795  	signatureAlgorithm    SignatureScheme
  1796  	signature             []byte
  1797  }
  1798  
  1799  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1800  	if m.raw != nil {
  1801  		return m.raw
  1802  	}
  1803  
  1804  	var b cryptobyte.Builder
  1805  	b.AddUint8(typeCertificateVerify)
  1806  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1807  		if m.hasSignatureAlgorithm {
  1808  			b.AddUint16(uint16(m.signatureAlgorithm))
  1809  		}
  1810  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1811  			b.AddBytes(m.signature)
  1812  		})
  1813  	})
  1814  
  1815  	m.raw = b.BytesOrPanic()
  1816  	return m.raw
  1817  }
  1818  
  1819  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1820  	m.raw = data
  1821  	s := cryptobyte.String(data)
  1822  
  1823  	if !s.Skip(4) { // message type and uint24 length field
  1824  		return false
  1825  	}
  1826  	if m.hasSignatureAlgorithm {
  1827  		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
  1828  			return false
  1829  		}
  1830  	}
  1831  	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
  1832  }
  1833  
  1834  type newSessionTicketMsg struct {
  1835  	raw    []byte
  1836  	ticket []byte
  1837  }
  1838  
  1839  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1840  	if m.raw != nil {
  1841  		return m.raw
  1842  	}
  1843  
  1844  	// See RFC 5077, Section 3.3.
  1845  	ticketLen := len(m.ticket)
  1846  	length := 2 + 4 + ticketLen
  1847  	x = make([]byte, 4+length)
  1848  	x[0] = typeNewSessionTicket
  1849  	x[1] = uint8(length >> 16)
  1850  	x[2] = uint8(length >> 8)
  1851  	x[3] = uint8(length)
  1852  	x[8] = uint8(ticketLen >> 8)
  1853  	x[9] = uint8(ticketLen)
  1854  	copy(x[10:], m.ticket)
  1855  
  1856  	m.raw = x
  1857  
  1858  	return
  1859  }
  1860  
  1861  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1862  	m.raw = data
  1863  
  1864  	if len(data) < 10 {
  1865  		return false
  1866  	}
  1867  
  1868  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1869  	if uint32(len(data))-4 != length {
  1870  		return false
  1871  	}
  1872  
  1873  	ticketLen := int(data[8])<<8 + int(data[9])
  1874  	if len(data)-10 != ticketLen {
  1875  		return false
  1876  	}
  1877  
  1878  	m.ticket = data[10:]
  1879  
  1880  	return true
  1881  }
  1882  
  1883  type helloRequestMsg struct {
  1884  }
  1885  
  1886  func (*helloRequestMsg) marshal() []byte {
  1887  	return []byte{typeHelloRequest, 0, 0, 0}
  1888  }
  1889  
  1890  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1891  	return len(data) == 4
  1892  }