github.com/FISCO-BCOS/crypto@v0.0.0-20200202032121-bd8ab0b5d4f1/tls/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"fmt"
     9  	"golang.org/x/crypto/cryptobyte"
    10  	"strings"
    11  )
    12  
    13  // The marshalingFunction type is an adapter to allow the use of ordinary
    14  // functions as cryptobyte.MarshalingValue.
    15  type marshalingFunction func(b *cryptobyte.Builder) error
    16  
    17  func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
    18  	return f(b)
    19  }
    20  
    21  // addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
    22  // the length of the sequence is not the value specified, it produces an error.
    23  func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
    24  	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
    25  		if len(v) != n {
    26  			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
    27  		}
    28  		b.AddBytes(v)
    29  		return nil
    30  	}))
    31  }
    32  
    33  // addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
    34  func addUint64(b *cryptobyte.Builder, v uint64) {
    35  	b.AddUint32(uint32(v >> 32))
    36  	b.AddUint32(uint32(v))
    37  }
    38  
    39  // readUint64 decodes a big-endian, 64-bit value into out and advances over it.
    40  // It reports whether the read was successful.
    41  func readUint64(s *cryptobyte.String, out *uint64) bool {
    42  	var hi, lo uint32
    43  	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
    44  		return false
    45  	}
    46  	*out = uint64(hi)<<32 | uint64(lo)
    47  	return true
    48  }
    49  
    50  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
    51  // []byte instead of a cryptobyte.String.
    52  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    53  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
    54  }
    55  
    56  // readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
    57  // []byte instead of a cryptobyte.String.
    58  func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    59  	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
    60  }
    61  
    62  // readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
    63  // []byte instead of a cryptobyte.String.
    64  func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    65  	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
    66  }
    67  
    68  type clientHelloMsg struct {
    69  	raw                              []byte
    70  	vers                             uint16
    71  	random                           []byte
    72  	sessionId                        []byte
    73  	cipherSuites                     []uint16
    74  	compressionMethods               []uint8
    75  	nextProtoNeg                     bool
    76  	serverName                       string
    77  	ocspStapling                     bool
    78  	supportedCurves                  []CurveID
    79  	supportedPoints                  []uint8
    80  	ticketSupported                  bool
    81  	sessionTicket                    []uint8
    82  	supportedSignatureAlgorithms     []SignatureScheme
    83  	supportedSignatureAlgorithmsCert []SignatureScheme
    84  	secureRenegotiationSupported     bool
    85  	secureRenegotiation              []byte
    86  	alpnProtocols                    []string
    87  	scts                             bool
    88  	supportedVersions                []uint16
    89  	cookie                           []byte
    90  	keyShares                        []keyShare
    91  	earlyData                        bool
    92  	pskModes                         []uint8
    93  	pskIdentities                    []pskIdentity
    94  	pskBinders                       [][]byte
    95  }
    96  
    97  func (m *clientHelloMsg) marshal() []byte {
    98  	if m.raw != nil {
    99  		return m.raw
   100  	}
   101  
   102  	var b cryptobyte.Builder
   103  	b.AddUint8(typeClientHello)
   104  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   105  		b.AddUint16(m.vers)
   106  		addBytesWithLength(b, m.random, 32)
   107  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   108  			b.AddBytes(m.sessionId)
   109  		})
   110  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   111  			for _, suite := range m.cipherSuites {
   112  				b.AddUint16(suite)
   113  			}
   114  		})
   115  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   116  			b.AddBytes(m.compressionMethods)
   117  		})
   118  
   119  		// If extensions aren't present, omit them.
   120  		var extensionsPresent bool
   121  		bWithoutExtensions := *b
   122  
   123  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   124  			if m.nextProtoNeg {
   125  				// draft-agl-tls-nextprotoneg-04
   126  				b.AddUint16(extensionNextProtoNeg)
   127  				b.AddUint16(0) // empty extension_data
   128  			}
   129  			if len(m.serverName) > 0 {
   130  				// RFC 6066, Section 3
   131  				b.AddUint16(extensionServerName)
   132  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   133  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   134  						b.AddUint8(0) // name_type = host_name
   135  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   136  							b.AddBytes([]byte(m.serverName))
   137  						})
   138  					})
   139  				})
   140  			}
   141  			if m.ocspStapling {
   142  				// RFC 4366, Section 3.6
   143  				b.AddUint16(extensionStatusRequest)
   144  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   145  					b.AddUint8(1)  // status_type = ocsp
   146  					b.AddUint16(0) // empty responder_id_list
   147  					b.AddUint16(0) // empty request_extensions
   148  				})
   149  			}
   150  			if len(m.supportedCurves) > 0 {
   151  				// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   152  				b.AddUint16(extensionSupportedCurves)
   153  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   154  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   155  						for _, curve := range m.supportedCurves {
   156  							b.AddUint16(uint16(curve))
   157  						}
   158  					})
   159  				})
   160  			}
   161  			if len(m.supportedPoints) > 0 {
   162  				// RFC 4492, Section 5.1.2
   163  				b.AddUint16(extensionSupportedPoints)
   164  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   165  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   166  						b.AddBytes(m.supportedPoints)
   167  					})
   168  				})
   169  			}
   170  			if m.ticketSupported {
   171  				// RFC 5077, Section 3.2
   172  				b.AddUint16(extensionSessionTicket)
   173  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   174  					b.AddBytes(m.sessionTicket)
   175  				})
   176  			}
   177  			if len(m.supportedSignatureAlgorithms) > 0 {
   178  				// RFC 5246, Section 7.4.1.4.1
   179  				b.AddUint16(extensionSignatureAlgorithms)
   180  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   181  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   182  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
   183  							b.AddUint16(uint16(sigAlgo))
   184  						}
   185  					})
   186  				})
   187  			}
   188  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
   189  				// RFC 8446, Section 4.2.3
   190  				b.AddUint16(extensionSignatureAlgorithmsCert)
   191  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   192  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   193  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
   194  							b.AddUint16(uint16(sigAlgo))
   195  						}
   196  					})
   197  				})
   198  			}
   199  			if m.secureRenegotiationSupported {
   200  				// RFC 5746, Section 3.2
   201  				b.AddUint16(extensionRenegotiationInfo)
   202  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   203  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   204  						b.AddBytes(m.secureRenegotiation)
   205  					})
   206  				})
   207  			}
   208  			if len(m.alpnProtocols) > 0 {
   209  				// RFC 7301, Section 3.1
   210  				b.AddUint16(extensionALPN)
   211  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   212  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   213  						for _, proto := range m.alpnProtocols {
   214  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   215  								b.AddBytes([]byte(proto))
   216  							})
   217  						}
   218  					})
   219  				})
   220  			}
   221  			if m.scts {
   222  				// RFC 6962, Section 3.3.1
   223  				b.AddUint16(extensionSCT)
   224  				b.AddUint16(0) // empty extension_data
   225  			}
   226  			if len(m.supportedVersions) > 0 {
   227  				// RFC 8446, Section 4.2.1
   228  				b.AddUint16(extensionSupportedVersions)
   229  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   230  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   231  						for _, vers := range m.supportedVersions {
   232  							b.AddUint16(vers)
   233  						}
   234  					})
   235  				})
   236  			}
   237  			if len(m.cookie) > 0 {
   238  				// RFC 8446, Section 4.2.2
   239  				b.AddUint16(extensionCookie)
   240  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   241  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   242  						b.AddBytes(m.cookie)
   243  					})
   244  				})
   245  			}
   246  			if len(m.keyShares) > 0 {
   247  				// RFC 8446, Section 4.2.8
   248  				b.AddUint16(extensionKeyShare)
   249  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   250  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   251  						for _, ks := range m.keyShares {
   252  							b.AddUint16(uint16(ks.group))
   253  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   254  								b.AddBytes(ks.data)
   255  							})
   256  						}
   257  					})
   258  				})
   259  			}
   260  			if m.earlyData {
   261  				// RFC 8446, Section 4.2.10
   262  				b.AddUint16(extensionEarlyData)
   263  				b.AddUint16(0) // empty extension_data
   264  			}
   265  			if len(m.pskModes) > 0 {
   266  				// RFC 8446, Section 4.2.9
   267  				b.AddUint16(extensionPSKModes)
   268  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   269  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   270  						b.AddBytes(m.pskModes)
   271  					})
   272  				})
   273  			}
   274  			if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
   275  				// RFC 8446, Section 4.2.11
   276  				b.AddUint16(extensionPreSharedKey)
   277  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   278  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   279  						for _, psk := range m.pskIdentities {
   280  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   281  								b.AddBytes(psk.label)
   282  							})
   283  							b.AddUint32(psk.obfuscatedTicketAge)
   284  						}
   285  					})
   286  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   287  						for _, binder := range m.pskBinders {
   288  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   289  								b.AddBytes(binder)
   290  							})
   291  						}
   292  					})
   293  				})
   294  			}
   295  
   296  			extensionsPresent = len(b.BytesOrPanic()) > 2
   297  		})
   298  
   299  		if !extensionsPresent {
   300  			*b = bWithoutExtensions
   301  		}
   302  	})
   303  
   304  	m.raw = b.BytesOrPanic()
   305  	return m.raw
   306  }
   307  
   308  // marshalWithoutBinders returns the ClientHello through the
   309  // PreSharedKeyExtension.identities field, according to RFC 8446, Section
   310  // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
   311  func (m *clientHelloMsg) marshalWithoutBinders() []byte {
   312  	bindersLen := 2 // uint16 length prefix
   313  	for _, binder := range m.pskBinders {
   314  		bindersLen += 1 // uint8 length prefix
   315  		bindersLen += len(binder)
   316  	}
   317  
   318  	fullMessage := m.marshal()
   319  	return fullMessage[:len(fullMessage)-bindersLen]
   320  }
   321  
   322  // updateBinders updates the m.pskBinders field, if necessary updating the
   323  // cached marshaled representation. The supplied binders must have the same
   324  // length as the current m.pskBinders.
   325  func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
   326  	if len(pskBinders) != len(m.pskBinders) {
   327  		panic("tls: internal error: pskBinders length mismatch")
   328  	}
   329  	for i := range m.pskBinders {
   330  		if len(pskBinders[i]) != len(m.pskBinders[i]) {
   331  			panic("tls: internal error: pskBinders length mismatch")
   332  		}
   333  	}
   334  	m.pskBinders = pskBinders
   335  	if m.raw != nil {
   336  		lenWithoutBinders := len(m.marshalWithoutBinders())
   337  		// TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported.
   338  		b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders])
   339  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   340  			for _, binder := range m.pskBinders {
   341  				b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   342  					b.AddBytes(binder)
   343  				})
   344  			}
   345  		})
   346  		if len(b.BytesOrPanic()) != len(m.raw) {
   347  			panic("tls: internal error: failed to update binders")
   348  		}
   349  	}
   350  }
   351  
   352  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   353  	*m = clientHelloMsg{raw: data}
   354  	s := cryptobyte.String(data)
   355  
   356  	if !s.Skip(4) || // message type and uint24 length field
   357  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   358  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   359  		return false
   360  	}
   361  
   362  	var cipherSuites cryptobyte.String
   363  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   364  		return false
   365  	}
   366  	m.cipherSuites = []uint16{}
   367  	m.secureRenegotiationSupported = false
   368  	for !cipherSuites.Empty() {
   369  		var suite uint16
   370  		if !cipherSuites.ReadUint16(&suite) {
   371  			return false
   372  		}
   373  		if suite == scsvRenegotiation {
   374  			m.secureRenegotiationSupported = true
   375  		}
   376  		m.cipherSuites = append(m.cipherSuites, suite)
   377  	}
   378  
   379  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   380  		return false
   381  	}
   382  
   383  	if s.Empty() {
   384  		// ClientHello is optionally followed by extension data
   385  		return true
   386  	}
   387  
   388  	var extensions cryptobyte.String
   389  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   390  		return false
   391  	}
   392  
   393  	for !extensions.Empty() {
   394  		var extension uint16
   395  		var extData cryptobyte.String
   396  		if !extensions.ReadUint16(&extension) ||
   397  			!extensions.ReadUint16LengthPrefixed(&extData) {
   398  			return false
   399  		}
   400  
   401  		switch extension {
   402  		case extensionServerName:
   403  			// RFC 6066, Section 3
   404  			var nameList cryptobyte.String
   405  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   406  				return false
   407  			}
   408  			for !nameList.Empty() {
   409  				var nameType uint8
   410  				var serverName cryptobyte.String
   411  				if !nameList.ReadUint8(&nameType) ||
   412  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   413  					serverName.Empty() {
   414  					return false
   415  				}
   416  				if nameType != 0 {
   417  					continue
   418  				}
   419  				if len(m.serverName) != 0 {
   420  					// Multiple names of the same name_type are prohibited.
   421  					return false
   422  				}
   423  				m.serverName = string(serverName)
   424  				// An SNI value may not include a trailing dot.
   425  				if strings.HasSuffix(m.serverName, ".") {
   426  					return false
   427  				}
   428  			}
   429  		case extensionNextProtoNeg:
   430  			// draft-agl-tls-nextprotoneg-04
   431  			m.nextProtoNeg = true
   432  		case extensionStatusRequest:
   433  			// RFC 4366, Section 3.6
   434  			var statusType uint8
   435  			var ignored cryptobyte.String
   436  			if !extData.ReadUint8(&statusType) ||
   437  				!extData.ReadUint16LengthPrefixed(&ignored) ||
   438  				!extData.ReadUint16LengthPrefixed(&ignored) {
   439  				return false
   440  			}
   441  			m.ocspStapling = statusType == statusTypeOCSP
   442  		case extensionSupportedCurves:
   443  			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   444  			var curves cryptobyte.String
   445  			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
   446  				return false
   447  			}
   448  			for !curves.Empty() {
   449  				var curve uint16
   450  				if !curves.ReadUint16(&curve) {
   451  					return false
   452  				}
   453  				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
   454  			}
   455  		case extensionSupportedPoints:
   456  			// RFC 4492, Section 5.1.2
   457  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   458  				len(m.supportedPoints) == 0 {
   459  				return false
   460  			}
   461  		case extensionSessionTicket:
   462  			// RFC 5077, Section 3.2
   463  			m.ticketSupported = true
   464  			extData.ReadBytes(&m.sessionTicket, len(extData))
   465  		case extensionSignatureAlgorithms:
   466  			// RFC 5246, Section 7.4.1.4.1
   467  			var sigAndAlgs cryptobyte.String
   468  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   469  				return false
   470  			}
   471  			for !sigAndAlgs.Empty() {
   472  				var sigAndAlg uint16
   473  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   474  					return false
   475  				}
   476  				m.supportedSignatureAlgorithms = append(
   477  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
   478  			}
   479  		case extensionSignatureAlgorithmsCert:
   480  			// RFC 8446, Section 4.2.3
   481  			var sigAndAlgs cryptobyte.String
   482  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   483  				return false
   484  			}
   485  			for !sigAndAlgs.Empty() {
   486  				var sigAndAlg uint16
   487  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   488  					return false
   489  				}
   490  				m.supportedSignatureAlgorithmsCert = append(
   491  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
   492  			}
   493  		case extensionRenegotiationInfo:
   494  			// RFC 5746, Section 3.2
   495  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   496  				return false
   497  			}
   498  			m.secureRenegotiationSupported = true
   499  		case extensionALPN:
   500  			// RFC 7301, Section 3.1
   501  			var protoList cryptobyte.String
   502  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   503  				return false
   504  			}
   505  			for !protoList.Empty() {
   506  				var proto cryptobyte.String
   507  				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
   508  					return false
   509  				}
   510  				m.alpnProtocols = append(m.alpnProtocols, string(proto))
   511  			}
   512  		case extensionSCT:
   513  			// RFC 6962, Section 3.3.1
   514  			m.scts = true
   515  		case extensionSupportedVersions:
   516  			// RFC 8446, Section 4.2.1
   517  			var versList cryptobyte.String
   518  			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
   519  				return false
   520  			}
   521  			for !versList.Empty() {
   522  				var vers uint16
   523  				if !versList.ReadUint16(&vers) {
   524  					return false
   525  				}
   526  				m.supportedVersions = append(m.supportedVersions, vers)
   527  			}
   528  		case extensionCookie:
   529  			// RFC 8446, Section 4.2.2
   530  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   531  				len(m.cookie) == 0 {
   532  				return false
   533  			}
   534  		case extensionKeyShare:
   535  			// RFC 8446, Section 4.2.8
   536  			var clientShares cryptobyte.String
   537  			if !extData.ReadUint16LengthPrefixed(&clientShares) {
   538  				return false
   539  			}
   540  			for !clientShares.Empty() {
   541  				var ks keyShare
   542  				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
   543  					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
   544  					len(ks.data) == 0 {
   545  					return false
   546  				}
   547  				m.keyShares = append(m.keyShares, ks)
   548  			}
   549  		case extensionEarlyData:
   550  			// RFC 8446, Section 4.2.10
   551  			m.earlyData = true
   552  		case extensionPSKModes:
   553  			// RFC 8446, Section 4.2.9
   554  			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
   555  				return false
   556  			}
   557  		case extensionPreSharedKey:
   558  			// RFC 8446, Section 4.2.11
   559  			if !extensions.Empty() {
   560  				return false // pre_shared_key must be the last extension
   561  			}
   562  			var identities cryptobyte.String
   563  			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
   564  				return false
   565  			}
   566  			for !identities.Empty() {
   567  				var psk pskIdentity
   568  				if !readUint16LengthPrefixed(&identities, &psk.label) ||
   569  					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
   570  					len(psk.label) == 0 {
   571  					return false
   572  				}
   573  				m.pskIdentities = append(m.pskIdentities, psk)
   574  			}
   575  			var binders cryptobyte.String
   576  			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
   577  				return false
   578  			}
   579  			for !binders.Empty() {
   580  				var binder []byte
   581  				if !readUint8LengthPrefixed(&binders, &binder) ||
   582  					len(binder) == 0 {
   583  					return false
   584  				}
   585  				m.pskBinders = append(m.pskBinders, binder)
   586  			}
   587  		default:
   588  			// Ignore unknown extensions.
   589  			continue
   590  		}
   591  
   592  		if !extData.Empty() {
   593  			return false
   594  		}
   595  	}
   596  
   597  	return true
   598  }
   599  
   600  type serverHelloMsg struct {
   601  	raw                          []byte
   602  	vers                         uint16
   603  	random                       []byte
   604  	sessionId                    []byte
   605  	cipherSuite                  uint16
   606  	compressionMethod            uint8
   607  	nextProtoNeg                 bool
   608  	nextProtos                   []string
   609  	ocspStapling                 bool
   610  	ticketSupported              bool
   611  	secureRenegotiationSupported bool
   612  	secureRenegotiation          []byte
   613  	alpnProtocol                 string
   614  	scts                         [][]byte
   615  	supportedVersion             uint16
   616  	serverShare                  keyShare
   617  	selectedIdentityPresent      bool
   618  	selectedIdentity             uint16
   619  
   620  	// HelloRetryRequest extensions
   621  	cookie        []byte
   622  	selectedGroup CurveID
   623  }
   624  
   625  func (m *serverHelloMsg) marshal() []byte {
   626  	if m.raw != nil {
   627  		return m.raw
   628  	}
   629  
   630  	var b cryptobyte.Builder
   631  	b.AddUint8(typeServerHello)
   632  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   633  		b.AddUint16(m.vers)
   634  		addBytesWithLength(b, m.random, 32)
   635  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   636  			b.AddBytes(m.sessionId)
   637  		})
   638  		b.AddUint16(m.cipherSuite)
   639  		b.AddUint8(m.compressionMethod)
   640  
   641  		// If extensions aren't present, omit them.
   642  		var extensionsPresent bool
   643  		bWithoutExtensions := *b
   644  
   645  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   646  			if m.nextProtoNeg {
   647  				b.AddUint16(extensionNextProtoNeg)
   648  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   649  					for _, proto := range m.nextProtos {
   650  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   651  							b.AddBytes([]byte(proto))
   652  						})
   653  					}
   654  				})
   655  			}
   656  			if m.ocspStapling {
   657  				b.AddUint16(extensionStatusRequest)
   658  				b.AddUint16(0) // empty extension_data
   659  			}
   660  			if m.ticketSupported {
   661  				b.AddUint16(extensionSessionTicket)
   662  				b.AddUint16(0) // empty extension_data
   663  			}
   664  			if m.secureRenegotiationSupported {
   665  				b.AddUint16(extensionRenegotiationInfo)
   666  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   667  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   668  						b.AddBytes(m.secureRenegotiation)
   669  					})
   670  				})
   671  			}
   672  			if len(m.alpnProtocol) > 0 {
   673  				b.AddUint16(extensionALPN)
   674  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   675  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   676  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   677  							b.AddBytes([]byte(m.alpnProtocol))
   678  						})
   679  					})
   680  				})
   681  			}
   682  			if len(m.scts) > 0 {
   683  				b.AddUint16(extensionSCT)
   684  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   685  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   686  						for _, sct := range m.scts {
   687  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   688  								b.AddBytes(sct)
   689  							})
   690  						}
   691  					})
   692  				})
   693  			}
   694  			if m.supportedVersion != 0 {
   695  				b.AddUint16(extensionSupportedVersions)
   696  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   697  					b.AddUint16(m.supportedVersion)
   698  				})
   699  			}
   700  			if m.serverShare.group != 0 {
   701  				b.AddUint16(extensionKeyShare)
   702  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   703  					b.AddUint16(uint16(m.serverShare.group))
   704  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   705  						b.AddBytes(m.serverShare.data)
   706  					})
   707  				})
   708  			}
   709  			if m.selectedIdentityPresent {
   710  				b.AddUint16(extensionPreSharedKey)
   711  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   712  					b.AddUint16(m.selectedIdentity)
   713  				})
   714  			}
   715  
   716  			if len(m.cookie) > 0 {
   717  				b.AddUint16(extensionCookie)
   718  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   719  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   720  						b.AddBytes(m.cookie)
   721  					})
   722  				})
   723  			}
   724  			if m.selectedGroup != 0 {
   725  				b.AddUint16(extensionKeyShare)
   726  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   727  					b.AddUint16(uint16(m.selectedGroup))
   728  				})
   729  			}
   730  
   731  			extensionsPresent = len(b.BytesOrPanic()) > 2
   732  		})
   733  
   734  		if !extensionsPresent {
   735  			*b = bWithoutExtensions
   736  		}
   737  	})
   738  
   739  	m.raw = b.BytesOrPanic()
   740  	return m.raw
   741  }
   742  
   743  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   744  	*m = serverHelloMsg{raw: data}
   745  	s := cryptobyte.String(data)
   746  
   747  	if !s.Skip(4) || // message type and uint24 length field
   748  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   749  		!readUint8LengthPrefixed(&s, &m.sessionId) ||
   750  		!s.ReadUint16(&m.cipherSuite) ||
   751  		!s.ReadUint8(&m.compressionMethod) {
   752  		return false
   753  	}
   754  
   755  	if s.Empty() {
   756  		// ServerHello is optionally followed by extension data
   757  		return true
   758  	}
   759  
   760  	var extensions cryptobyte.String
   761  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   762  		return false
   763  	}
   764  
   765  	for !extensions.Empty() {
   766  		var extension uint16
   767  		var extData cryptobyte.String
   768  		if !extensions.ReadUint16(&extension) ||
   769  			!extensions.ReadUint16LengthPrefixed(&extData) {
   770  			return false
   771  		}
   772  
   773  		switch extension {
   774  		case extensionNextProtoNeg:
   775  			m.nextProtoNeg = true
   776  			for !extData.Empty() {
   777  				var proto cryptobyte.String
   778  				if !extData.ReadUint8LengthPrefixed(&proto) ||
   779  					proto.Empty() {
   780  					return false
   781  				}
   782  				m.nextProtos = append(m.nextProtos, string(proto))
   783  			}
   784  		case extensionStatusRequest:
   785  			m.ocspStapling = true
   786  		case extensionSessionTicket:
   787  			m.ticketSupported = true
   788  		case extensionRenegotiationInfo:
   789  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   790  				return false
   791  			}
   792  			m.secureRenegotiationSupported = true
   793  		case extensionALPN:
   794  			var protoList cryptobyte.String
   795  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   796  				return false
   797  			}
   798  			var proto cryptobyte.String
   799  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   800  				proto.Empty() || !protoList.Empty() {
   801  				return false
   802  			}
   803  			m.alpnProtocol = string(proto)
   804  		case extensionSCT:
   805  			var sctList cryptobyte.String
   806  			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
   807  				return false
   808  			}
   809  			for !sctList.Empty() {
   810  				var sct []byte
   811  				if !readUint16LengthPrefixed(&sctList, &sct) ||
   812  					len(sct) == 0 {
   813  					return false
   814  				}
   815  				m.scts = append(m.scts, sct)
   816  			}
   817  		case extensionSupportedVersions:
   818  			if !extData.ReadUint16(&m.supportedVersion) {
   819  				return false
   820  			}
   821  		case extensionCookie:
   822  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   823  				len(m.cookie) == 0 {
   824  				return false
   825  			}
   826  		case extensionKeyShare:
   827  			// This extension has different formats in SH and HRR, accept either
   828  			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
   829  			if len(extData) == 2 {
   830  				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
   831  					return false
   832  				}
   833  			} else {
   834  				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
   835  					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
   836  					return false
   837  				}
   838  			}
   839  		case extensionPreSharedKey:
   840  			m.selectedIdentityPresent = true
   841  			if !extData.ReadUint16(&m.selectedIdentity) {
   842  				return false
   843  			}
   844  		default:
   845  			// Ignore unknown extensions.
   846  			continue
   847  		}
   848  
   849  		if !extData.Empty() {
   850  			return false
   851  		}
   852  	}
   853  
   854  	return true
   855  }
   856  
   857  type encryptedExtensionsMsg struct {
   858  	raw          []byte
   859  	alpnProtocol string
   860  }
   861  
   862  func (m *encryptedExtensionsMsg) marshal() []byte {
   863  	if m.raw != nil {
   864  		return m.raw
   865  	}
   866  
   867  	var b cryptobyte.Builder
   868  	b.AddUint8(typeEncryptedExtensions)
   869  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   870  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   871  			if len(m.alpnProtocol) > 0 {
   872  				b.AddUint16(extensionALPN)
   873  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   874  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   875  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   876  							b.AddBytes([]byte(m.alpnProtocol))
   877  						})
   878  					})
   879  				})
   880  			}
   881  		})
   882  	})
   883  
   884  	m.raw = b.BytesOrPanic()
   885  	return m.raw
   886  }
   887  
   888  func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
   889  	*m = encryptedExtensionsMsg{raw: data}
   890  	s := cryptobyte.String(data)
   891  
   892  	var extensions cryptobyte.String
   893  	if !s.Skip(4) || // message type and uint24 length field
   894  		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   895  		return false
   896  	}
   897  
   898  	for !extensions.Empty() {
   899  		var extension uint16
   900  		var extData cryptobyte.String
   901  		if !extensions.ReadUint16(&extension) ||
   902  			!extensions.ReadUint16LengthPrefixed(&extData) {
   903  			return false
   904  		}
   905  
   906  		switch extension {
   907  		case extensionALPN:
   908  			var protoList cryptobyte.String
   909  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   910  				return false
   911  			}
   912  			var proto cryptobyte.String
   913  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   914  				proto.Empty() || !protoList.Empty() {
   915  				return false
   916  			}
   917  			m.alpnProtocol = string(proto)
   918  		default:
   919  			// Ignore unknown extensions.
   920  			continue
   921  		}
   922  
   923  		if !extData.Empty() {
   924  			return false
   925  		}
   926  	}
   927  
   928  	return true
   929  }
   930  
   931  type endOfEarlyDataMsg struct{}
   932  
   933  func (m *endOfEarlyDataMsg) marshal() []byte {
   934  	x := make([]byte, 4)
   935  	x[0] = typeEndOfEarlyData
   936  	return x
   937  }
   938  
   939  func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
   940  	return len(data) == 4
   941  }
   942  
   943  type keyUpdateMsg struct {
   944  	raw             []byte
   945  	updateRequested bool
   946  }
   947  
   948  func (m *keyUpdateMsg) marshal() []byte {
   949  	if m.raw != nil {
   950  		return m.raw
   951  	}
   952  
   953  	var b cryptobyte.Builder
   954  	b.AddUint8(typeKeyUpdate)
   955  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   956  		if m.updateRequested {
   957  			b.AddUint8(1)
   958  		} else {
   959  			b.AddUint8(0)
   960  		}
   961  	})
   962  
   963  	m.raw = b.BytesOrPanic()
   964  	return m.raw
   965  }
   966  
   967  func (m *keyUpdateMsg) unmarshal(data []byte) bool {
   968  	m.raw = data
   969  	s := cryptobyte.String(data)
   970  
   971  	var updateRequested uint8
   972  	if !s.Skip(4) || // message type and uint24 length field
   973  		!s.ReadUint8(&updateRequested) || !s.Empty() {
   974  		return false
   975  	}
   976  	switch updateRequested {
   977  	case 0:
   978  		m.updateRequested = false
   979  	case 1:
   980  		m.updateRequested = true
   981  	default:
   982  		return false
   983  	}
   984  	return true
   985  }
   986  
   987  type newSessionTicketMsgTLS13 struct {
   988  	raw          []byte
   989  	lifetime     uint32
   990  	ageAdd       uint32
   991  	nonce        []byte
   992  	label        []byte
   993  	maxEarlyData uint32
   994  }
   995  
   996  func (m *newSessionTicketMsgTLS13) marshal() []byte {
   997  	if m.raw != nil {
   998  		return m.raw
   999  	}
  1000  
  1001  	var b cryptobyte.Builder
  1002  	b.AddUint8(typeNewSessionTicket)
  1003  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1004  		b.AddUint32(m.lifetime)
  1005  		b.AddUint32(m.ageAdd)
  1006  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
  1007  			b.AddBytes(m.nonce)
  1008  		})
  1009  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1010  			b.AddBytes(m.label)
  1011  		})
  1012  
  1013  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1014  			if m.maxEarlyData > 0 {
  1015  				b.AddUint16(extensionEarlyData)
  1016  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1017  					b.AddUint32(m.maxEarlyData)
  1018  				})
  1019  			}
  1020  		})
  1021  	})
  1022  
  1023  	m.raw = b.BytesOrPanic()
  1024  	return m.raw
  1025  }
  1026  
  1027  func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
  1028  	*m = newSessionTicketMsgTLS13{raw: data}
  1029  	s := cryptobyte.String(data)
  1030  
  1031  	var extensions cryptobyte.String
  1032  	if !s.Skip(4) || // message type and uint24 length field
  1033  		!s.ReadUint32(&m.lifetime) ||
  1034  		!s.ReadUint32(&m.ageAdd) ||
  1035  		!readUint8LengthPrefixed(&s, &m.nonce) ||
  1036  		!readUint16LengthPrefixed(&s, &m.label) ||
  1037  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1038  		!s.Empty() {
  1039  		return false
  1040  	}
  1041  
  1042  	for !extensions.Empty() {
  1043  		var extension uint16
  1044  		var extData cryptobyte.String
  1045  		if !extensions.ReadUint16(&extension) ||
  1046  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1047  			return false
  1048  		}
  1049  
  1050  		switch extension {
  1051  		case extensionEarlyData:
  1052  			if !extData.ReadUint32(&m.maxEarlyData) {
  1053  				return false
  1054  			}
  1055  		default:
  1056  			// Ignore unknown extensions.
  1057  			continue
  1058  		}
  1059  
  1060  		if !extData.Empty() {
  1061  			return false
  1062  		}
  1063  	}
  1064  
  1065  	return true
  1066  }
  1067  
  1068  type certificateRequestMsgTLS13 struct {
  1069  	raw                              []byte
  1070  	ocspStapling                     bool
  1071  	scts                             bool
  1072  	supportedSignatureAlgorithms     []SignatureScheme
  1073  	supportedSignatureAlgorithmsCert []SignatureScheme
  1074  	certificateAuthorities           [][]byte
  1075  }
  1076  
  1077  func (m *certificateRequestMsgTLS13) marshal() []byte {
  1078  	if m.raw != nil {
  1079  		return m.raw
  1080  	}
  1081  
  1082  	var b cryptobyte.Builder
  1083  	b.AddUint8(typeCertificateRequest)
  1084  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1085  		// certificate_request_context (SHALL be zero length unless used for
  1086  		// post-handshake authentication)
  1087  		b.AddUint8(0)
  1088  
  1089  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1090  			if m.ocspStapling {
  1091  				b.AddUint16(extensionStatusRequest)
  1092  				b.AddUint16(0) // empty extension_data
  1093  			}
  1094  			if m.scts {
  1095  				// RFC 8446, Section 4.4.2.1 makes no mention of
  1096  				// signed_certificate_timestamp in CertificateRequest, but
  1097  				// "Extensions in the Certificate message from the client MUST
  1098  				// correspond to extensions in the CertificateRequest message
  1099  				// from the server." and it appears in the table in Section 4.2.
  1100  				b.AddUint16(extensionSCT)
  1101  				b.AddUint16(0) // empty extension_data
  1102  			}
  1103  			if len(m.supportedSignatureAlgorithms) > 0 {
  1104  				b.AddUint16(extensionSignatureAlgorithms)
  1105  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1106  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1107  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1108  							b.AddUint16(uint16(sigAlgo))
  1109  						}
  1110  					})
  1111  				})
  1112  			}
  1113  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
  1114  				b.AddUint16(extensionSignatureAlgorithmsCert)
  1115  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1116  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1117  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
  1118  							b.AddUint16(uint16(sigAlgo))
  1119  						}
  1120  					})
  1121  				})
  1122  			}
  1123  			if len(m.certificateAuthorities) > 0 {
  1124  				b.AddUint16(extensionCertificateAuthorities)
  1125  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1126  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1127  						for _, ca := range m.certificateAuthorities {
  1128  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1129  								b.AddBytes(ca)
  1130  							})
  1131  						}
  1132  					})
  1133  				})
  1134  			}
  1135  		})
  1136  	})
  1137  
  1138  	m.raw = b.BytesOrPanic()
  1139  	return m.raw
  1140  }
  1141  
  1142  func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
  1143  	*m = certificateRequestMsgTLS13{raw: data}
  1144  	s := cryptobyte.String(data)
  1145  
  1146  	var context, extensions cryptobyte.String
  1147  	if !s.Skip(4) || // message type and uint24 length field
  1148  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1149  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1150  		!s.Empty() {
  1151  		return false
  1152  	}
  1153  
  1154  	for !extensions.Empty() {
  1155  		var extension uint16
  1156  		var extData cryptobyte.String
  1157  		if !extensions.ReadUint16(&extension) ||
  1158  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1159  			return false
  1160  		}
  1161  
  1162  		switch extension {
  1163  		case extensionStatusRequest:
  1164  			m.ocspStapling = true
  1165  		case extensionSCT:
  1166  			m.scts = true
  1167  		case extensionSignatureAlgorithms:
  1168  			var sigAndAlgs cryptobyte.String
  1169  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1170  				return false
  1171  			}
  1172  			for !sigAndAlgs.Empty() {
  1173  				var sigAndAlg uint16
  1174  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1175  					return false
  1176  				}
  1177  				m.supportedSignatureAlgorithms = append(
  1178  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
  1179  			}
  1180  		case extensionSignatureAlgorithmsCert:
  1181  			var sigAndAlgs cryptobyte.String
  1182  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1183  				return false
  1184  			}
  1185  			for !sigAndAlgs.Empty() {
  1186  				var sigAndAlg uint16
  1187  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1188  					return false
  1189  				}
  1190  				m.supportedSignatureAlgorithmsCert = append(
  1191  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
  1192  			}
  1193  		case extensionCertificateAuthorities:
  1194  			var auths cryptobyte.String
  1195  			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
  1196  				return false
  1197  			}
  1198  			for !auths.Empty() {
  1199  				var ca []byte
  1200  				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
  1201  					return false
  1202  				}
  1203  				m.certificateAuthorities = append(m.certificateAuthorities, ca)
  1204  			}
  1205  		default:
  1206  			// Ignore unknown extensions.
  1207  			continue
  1208  		}
  1209  
  1210  		if !extData.Empty() {
  1211  			return false
  1212  		}
  1213  	}
  1214  
  1215  	return true
  1216  }
  1217  
  1218  type certificateMsg struct {
  1219  	raw          []byte
  1220  	certificates [][]byte
  1221  }
  1222  
  1223  func (m *certificateMsg) marshal() (x []byte) {
  1224  	if m.raw != nil {
  1225  		return m.raw
  1226  	}
  1227  
  1228  	var i int
  1229  	for _, slice := range m.certificates {
  1230  		i += len(slice)
  1231  	}
  1232  
  1233  	length := 3 + 3*len(m.certificates) + i
  1234  	x = make([]byte, 4+length)
  1235  	x[0] = typeCertificate
  1236  	x[1] = uint8(length >> 16)
  1237  	x[2] = uint8(length >> 8)
  1238  	x[3] = uint8(length)
  1239  
  1240  	certificateOctets := length - 3
  1241  	x[4] = uint8(certificateOctets >> 16)
  1242  	x[5] = uint8(certificateOctets >> 8)
  1243  	x[6] = uint8(certificateOctets)
  1244  
  1245  	y := x[7:]
  1246  	for _, slice := range m.certificates {
  1247  		y[0] = uint8(len(slice) >> 16)
  1248  		y[1] = uint8(len(slice) >> 8)
  1249  		y[2] = uint8(len(slice))
  1250  		copy(y[3:], slice)
  1251  		y = y[3+len(slice):]
  1252  	}
  1253  
  1254  	m.raw = x
  1255  	return
  1256  }
  1257  
  1258  func (m *certificateMsg) unmarshal(data []byte) bool {
  1259  	if len(data) < 7 {
  1260  		return false
  1261  	}
  1262  
  1263  	m.raw = data
  1264  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1265  	if uint32(len(data)) != certsLen+7 {
  1266  		return false
  1267  	}
  1268  
  1269  	numCerts := 0
  1270  	d := data[7:]
  1271  	for certsLen > 0 {
  1272  		if len(d) < 4 {
  1273  			return false
  1274  		}
  1275  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1276  		if uint32(len(d)) < 3+certLen {
  1277  			return false
  1278  		}
  1279  		d = d[3+certLen:]
  1280  		certsLen -= 3 + certLen
  1281  		numCerts++
  1282  	}
  1283  
  1284  	m.certificates = make([][]byte, numCerts)
  1285  	d = data[7:]
  1286  	for i := 0; i < numCerts; i++ {
  1287  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1288  		m.certificates[i] = d[3 : 3+certLen]
  1289  		d = d[3+certLen:]
  1290  	}
  1291  
  1292  	return true
  1293  }
  1294  
  1295  type certificateMsgTLS13 struct {
  1296  	raw          []byte
  1297  	certificate  Certificate
  1298  	ocspStapling bool
  1299  	scts         bool
  1300  }
  1301  
  1302  func (m *certificateMsgTLS13) marshal() []byte {
  1303  	if m.raw != nil {
  1304  		return m.raw
  1305  	}
  1306  
  1307  	var b cryptobyte.Builder
  1308  	b.AddUint8(typeCertificate)
  1309  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1310  		b.AddUint8(0) // certificate_request_context
  1311  
  1312  		certificate := m.certificate
  1313  		if !m.ocspStapling {
  1314  			certificate.OCSPStaple = nil
  1315  		}
  1316  		if !m.scts {
  1317  			certificate.SignedCertificateTimestamps = nil
  1318  		}
  1319  		marshalCertificate(b, certificate)
  1320  	})
  1321  
  1322  	m.raw = b.BytesOrPanic()
  1323  	return m.raw
  1324  }
  1325  
  1326  func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
  1327  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1328  		for i, cert := range certificate.Certificate {
  1329  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1330  				b.AddBytes(cert)
  1331  			})
  1332  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1333  				if i > 0 {
  1334  					// This library only supports OCSP and SCT for leaf certificates.
  1335  					return
  1336  				}
  1337  				if certificate.OCSPStaple != nil {
  1338  					b.AddUint16(extensionStatusRequest)
  1339  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1340  						b.AddUint8(statusTypeOCSP)
  1341  						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1342  							b.AddBytes(certificate.OCSPStaple)
  1343  						})
  1344  					})
  1345  				}
  1346  				if certificate.SignedCertificateTimestamps != nil {
  1347  					b.AddUint16(extensionSCT)
  1348  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1349  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1350  							for _, sct := range certificate.SignedCertificateTimestamps {
  1351  								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1352  									b.AddBytes(sct)
  1353  								})
  1354  							}
  1355  						})
  1356  					})
  1357  				}
  1358  			})
  1359  		}
  1360  	})
  1361  }
  1362  
  1363  func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
  1364  	*m = certificateMsgTLS13{raw: data}
  1365  	s := cryptobyte.String(data)
  1366  
  1367  	var context cryptobyte.String
  1368  	if !s.Skip(4) || // message type and uint24 length field
  1369  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1370  		!unmarshalCertificate(&s, &m.certificate) ||
  1371  		!s.Empty() {
  1372  		return false
  1373  	}
  1374  
  1375  	m.scts = m.certificate.SignedCertificateTimestamps != nil
  1376  	m.ocspStapling = m.certificate.OCSPStaple != nil
  1377  
  1378  	return true
  1379  }
  1380  
  1381  func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
  1382  	var certList cryptobyte.String
  1383  	if !s.ReadUint24LengthPrefixed(&certList) {
  1384  		return false
  1385  	}
  1386  	for !certList.Empty() {
  1387  		var cert []byte
  1388  		var extensions cryptobyte.String
  1389  		if !readUint24LengthPrefixed(&certList, &cert) ||
  1390  			!certList.ReadUint16LengthPrefixed(&extensions) {
  1391  			return false
  1392  		}
  1393  		certificate.Certificate = append(certificate.Certificate, cert)
  1394  		for !extensions.Empty() {
  1395  			var extension uint16
  1396  			var extData cryptobyte.String
  1397  			if !extensions.ReadUint16(&extension) ||
  1398  				!extensions.ReadUint16LengthPrefixed(&extData) {
  1399  				return false
  1400  			}
  1401  			if len(certificate.Certificate) > 1 {
  1402  				// This library only supports OCSP and SCT for leaf certificates.
  1403  				continue
  1404  			}
  1405  
  1406  			switch extension {
  1407  			case extensionStatusRequest:
  1408  				var statusType uint8
  1409  				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1410  					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
  1411  					len(certificate.OCSPStaple) == 0 {
  1412  					return false
  1413  				}
  1414  			case extensionSCT:
  1415  				var sctList cryptobyte.String
  1416  				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
  1417  					return false
  1418  				}
  1419  				for !sctList.Empty() {
  1420  					var sct []byte
  1421  					if !readUint16LengthPrefixed(&sctList, &sct) ||
  1422  						len(sct) == 0 {
  1423  						return false
  1424  					}
  1425  					certificate.SignedCertificateTimestamps = append(
  1426  						certificate.SignedCertificateTimestamps, sct)
  1427  				}
  1428  			default:
  1429  				// Ignore unknown extensions.
  1430  				continue
  1431  			}
  1432  
  1433  			if !extData.Empty() {
  1434  				return false
  1435  			}
  1436  		}
  1437  	}
  1438  	return true
  1439  }
  1440  
  1441  type serverKeyExchangeMsg struct {
  1442  	raw []byte
  1443  	key []byte
  1444  }
  1445  
  1446  func (m *serverKeyExchangeMsg) marshal() []byte {
  1447  	if m.raw != nil {
  1448  		return m.raw
  1449  	}
  1450  	length := len(m.key)
  1451  	x := make([]byte, length+4)
  1452  	x[0] = typeServerKeyExchange
  1453  	x[1] = uint8(length >> 16)
  1454  	x[2] = uint8(length >> 8)
  1455  	x[3] = uint8(length)
  1456  	copy(x[4:], m.key)
  1457  
  1458  	m.raw = x
  1459  	return x
  1460  }
  1461  
  1462  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1463  	m.raw = data
  1464  	if len(data) < 4 {
  1465  		return false
  1466  	}
  1467  	m.key = data[4:]
  1468  	return true
  1469  }
  1470  
  1471  type certificateStatusMsg struct {
  1472  	raw      []byte
  1473  	response []byte
  1474  }
  1475  
  1476  func (m *certificateStatusMsg) marshal() []byte {
  1477  	if m.raw != nil {
  1478  		return m.raw
  1479  	}
  1480  
  1481  	var b cryptobyte.Builder
  1482  	b.AddUint8(typeCertificateStatus)
  1483  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1484  		b.AddUint8(statusTypeOCSP)
  1485  		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1486  			b.AddBytes(m.response)
  1487  		})
  1488  	})
  1489  
  1490  	m.raw = b.BytesOrPanic()
  1491  	return m.raw
  1492  }
  1493  
  1494  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1495  	m.raw = data
  1496  	s := cryptobyte.String(data)
  1497  
  1498  	var statusType uint8
  1499  	if !s.Skip(4) || // message type and uint24 length field
  1500  		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1501  		!readUint24LengthPrefixed(&s, &m.response) ||
  1502  		len(m.response) == 0 || !s.Empty() {
  1503  		return false
  1504  	}
  1505  	return true
  1506  }
  1507  
  1508  type serverHelloDoneMsg struct{}
  1509  
  1510  func (m *serverHelloDoneMsg) marshal() []byte {
  1511  	x := make([]byte, 4)
  1512  	x[0] = typeServerHelloDone
  1513  	return x
  1514  }
  1515  
  1516  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1517  	return len(data) == 4
  1518  }
  1519  
  1520  type clientKeyExchangeMsg struct {
  1521  	raw        []byte
  1522  	ciphertext []byte
  1523  }
  1524  
  1525  func (m *clientKeyExchangeMsg) marshal() []byte {
  1526  	if m.raw != nil {
  1527  		return m.raw
  1528  	}
  1529  	length := len(m.ciphertext)
  1530  	x := make([]byte, length+4)
  1531  	x[0] = typeClientKeyExchange
  1532  	x[1] = uint8(length >> 16)
  1533  	x[2] = uint8(length >> 8)
  1534  	x[3] = uint8(length)
  1535  	copy(x[4:], m.ciphertext)
  1536  
  1537  	m.raw = x
  1538  	return x
  1539  }
  1540  
  1541  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1542  	m.raw = data
  1543  	if len(data) < 4 {
  1544  		return false
  1545  	}
  1546  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1547  	if l != len(data)-4 {
  1548  		return false
  1549  	}
  1550  	m.ciphertext = data[4:]
  1551  	return true
  1552  }
  1553  
  1554  type finishedMsg struct {
  1555  	raw        []byte
  1556  	verifyData []byte
  1557  }
  1558  
  1559  func (m *finishedMsg) marshal() []byte {
  1560  	if m.raw != nil {
  1561  		return m.raw
  1562  	}
  1563  
  1564  	var b cryptobyte.Builder
  1565  	b.AddUint8(typeFinished)
  1566  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1567  		b.AddBytes(m.verifyData)
  1568  	})
  1569  
  1570  	m.raw = b.BytesOrPanic()
  1571  	return m.raw
  1572  }
  1573  
  1574  func (m *finishedMsg) unmarshal(data []byte) bool {
  1575  	m.raw = data
  1576  	s := cryptobyte.String(data)
  1577  	return s.Skip(1) &&
  1578  		readUint24LengthPrefixed(&s, &m.verifyData) &&
  1579  		s.Empty()
  1580  }
  1581  
  1582  type nextProtoMsg struct {
  1583  	raw   []byte
  1584  	proto string
  1585  }
  1586  
  1587  func (m *nextProtoMsg) marshal() []byte {
  1588  	if m.raw != nil {
  1589  		return m.raw
  1590  	}
  1591  	l := len(m.proto)
  1592  	if l > 255 {
  1593  		l = 255
  1594  	}
  1595  
  1596  	padding := 32 - (l+2)%32
  1597  	length := l + padding + 2
  1598  	x := make([]byte, length+4)
  1599  	x[0] = typeNextProtocol
  1600  	x[1] = uint8(length >> 16)
  1601  	x[2] = uint8(length >> 8)
  1602  	x[3] = uint8(length)
  1603  
  1604  	y := x[4:]
  1605  	y[0] = byte(l)
  1606  	copy(y[1:], []byte(m.proto[0:l]))
  1607  	y = y[1+l:]
  1608  	y[0] = byte(padding)
  1609  
  1610  	m.raw = x
  1611  
  1612  	return x
  1613  }
  1614  
  1615  func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1616  	m.raw = data
  1617  
  1618  	if len(data) < 5 {
  1619  		return false
  1620  	}
  1621  	data = data[4:]
  1622  	protoLen := int(data[0])
  1623  	data = data[1:]
  1624  	if len(data) < protoLen {
  1625  		return false
  1626  	}
  1627  	m.proto = string(data[0:protoLen])
  1628  	data = data[protoLen:]
  1629  
  1630  	if len(data) < 1 {
  1631  		return false
  1632  	}
  1633  	paddingLen := int(data[0])
  1634  	data = data[1:]
  1635  	if len(data) != paddingLen {
  1636  		return false
  1637  	}
  1638  
  1639  	return true
  1640  }
  1641  
  1642  type certificateRequestMsg struct {
  1643  	raw []byte
  1644  	// hasSignatureAlgorithm indicates whether this message includes a list of
  1645  	// supported signature algorithms. This change was introduced with TLS 1.2.
  1646  	hasSignatureAlgorithm bool
  1647  
  1648  	certificateTypes             []byte
  1649  	supportedSignatureAlgorithms []SignatureScheme
  1650  	certificateAuthorities       [][]byte
  1651  }
  1652  
  1653  func (m *certificateRequestMsg) marshal() (x []byte) {
  1654  	if m.raw != nil {
  1655  		return m.raw
  1656  	}
  1657  
  1658  	// See RFC 4346, Section 7.4.4.
  1659  	length := 1 + len(m.certificateTypes) + 2
  1660  	casLength := 0
  1661  	for _, ca := range m.certificateAuthorities {
  1662  		casLength += 2 + len(ca)
  1663  	}
  1664  	length += casLength
  1665  
  1666  	if m.hasSignatureAlgorithm {
  1667  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1668  	}
  1669  
  1670  	x = make([]byte, 4+length)
  1671  	x[0] = typeCertificateRequest
  1672  	x[1] = uint8(length >> 16)
  1673  	x[2] = uint8(length >> 8)
  1674  	x[3] = uint8(length)
  1675  
  1676  	x[4] = uint8(len(m.certificateTypes))
  1677  
  1678  	copy(x[5:], m.certificateTypes)
  1679  	y := x[5+len(m.certificateTypes):]
  1680  
  1681  	if m.hasSignatureAlgorithm {
  1682  		n := len(m.supportedSignatureAlgorithms) * 2
  1683  		y[0] = uint8(n >> 8)
  1684  		y[1] = uint8(n)
  1685  		y = y[2:]
  1686  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1687  			y[0] = uint8(sigAlgo >> 8)
  1688  			y[1] = uint8(sigAlgo)
  1689  			y = y[2:]
  1690  		}
  1691  	}
  1692  
  1693  	y[0] = uint8(casLength >> 8)
  1694  	y[1] = uint8(casLength)
  1695  	y = y[2:]
  1696  	for _, ca := range m.certificateAuthorities {
  1697  		y[0] = uint8(len(ca) >> 8)
  1698  		y[1] = uint8(len(ca))
  1699  		y = y[2:]
  1700  		copy(y, ca)
  1701  		y = y[len(ca):]
  1702  	}
  1703  
  1704  	m.raw = x
  1705  	return
  1706  }
  1707  
  1708  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1709  	m.raw = data
  1710  
  1711  	if len(data) < 5 {
  1712  		return false
  1713  	}
  1714  
  1715  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1716  	if uint32(len(data))-4 != length {
  1717  		return false
  1718  	}
  1719  
  1720  	numCertTypes := int(data[4])
  1721  	data = data[5:]
  1722  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1723  		return false
  1724  	}
  1725  
  1726  	m.certificateTypes = make([]byte, numCertTypes)
  1727  	if copy(m.certificateTypes, data) != numCertTypes {
  1728  		return false
  1729  	}
  1730  
  1731  	data = data[numCertTypes:]
  1732  
  1733  	if m.hasSignatureAlgorithm {
  1734  		if len(data) < 2 {
  1735  			return false
  1736  		}
  1737  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1738  		data = data[2:]
  1739  		if sigAndHashLen&1 != 0 {
  1740  			return false
  1741  		}
  1742  		if len(data) < int(sigAndHashLen) {
  1743  			return false
  1744  		}
  1745  		numSigAlgos := sigAndHashLen / 2
  1746  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1747  		for i := range m.supportedSignatureAlgorithms {
  1748  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1749  			data = data[2:]
  1750  		}
  1751  	}
  1752  
  1753  	if len(data) < 2 {
  1754  		return false
  1755  	}
  1756  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1757  	data = data[2:]
  1758  	if len(data) < int(casLength) {
  1759  		return false
  1760  	}
  1761  	cas := make([]byte, casLength)
  1762  	copy(cas, data)
  1763  	data = data[casLength:]
  1764  
  1765  	m.certificateAuthorities = nil
  1766  	for len(cas) > 0 {
  1767  		if len(cas) < 2 {
  1768  			return false
  1769  		}
  1770  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1771  		cas = cas[2:]
  1772  
  1773  		if len(cas) < int(caLen) {
  1774  			return false
  1775  		}
  1776  
  1777  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1778  		cas = cas[caLen:]
  1779  	}
  1780  
  1781  	return len(data) == 0
  1782  }
  1783  
  1784  type certificateVerifyMsg struct {
  1785  	raw                   []byte
  1786  	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
  1787  	signatureAlgorithm    SignatureScheme
  1788  	signature             []byte
  1789  }
  1790  
  1791  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1792  	if m.raw != nil {
  1793  		return m.raw
  1794  	}
  1795  
  1796  	var b cryptobyte.Builder
  1797  	b.AddUint8(typeCertificateVerify)
  1798  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1799  		if m.hasSignatureAlgorithm {
  1800  			b.AddUint16(uint16(m.signatureAlgorithm))
  1801  		}
  1802  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1803  			b.AddBytes(m.signature)
  1804  		})
  1805  	})
  1806  
  1807  	m.raw = b.BytesOrPanic()
  1808  	return m.raw
  1809  }
  1810  
  1811  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1812  	m.raw = data
  1813  	s := cryptobyte.String(data)
  1814  
  1815  	if !s.Skip(4) { // message type and uint24 length field
  1816  		return false
  1817  	}
  1818  	if m.hasSignatureAlgorithm {
  1819  		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
  1820  			return false
  1821  		}
  1822  	}
  1823  	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
  1824  }
  1825  
  1826  type newSessionTicketMsg struct {
  1827  	raw    []byte
  1828  	ticket []byte
  1829  }
  1830  
  1831  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1832  	if m.raw != nil {
  1833  		return m.raw
  1834  	}
  1835  
  1836  	// See RFC 5077, Section 3.3.
  1837  	ticketLen := len(m.ticket)
  1838  	length := 2 + 4 + ticketLen
  1839  	x = make([]byte, 4+length)
  1840  	x[0] = typeNewSessionTicket
  1841  	x[1] = uint8(length >> 16)
  1842  	x[2] = uint8(length >> 8)
  1843  	x[3] = uint8(length)
  1844  	x[8] = uint8(ticketLen >> 8)
  1845  	x[9] = uint8(ticketLen)
  1846  	copy(x[10:], m.ticket)
  1847  
  1848  	m.raw = x
  1849  
  1850  	return
  1851  }
  1852  
  1853  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1854  	m.raw = data
  1855  
  1856  	if len(data) < 10 {
  1857  		return false
  1858  	}
  1859  
  1860  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1861  	if uint32(len(data))-4 != length {
  1862  		return false
  1863  	}
  1864  
  1865  	ticketLen := int(data[8])<<8 + int(data[9])
  1866  	if len(data)-10 != ticketLen {
  1867  		return false
  1868  	}
  1869  
  1870  	m.ticket = data[10:]
  1871  
  1872  	return true
  1873  }
  1874  
  1875  type helloRequestMsg struct {
  1876  }
  1877  
  1878  func (*helloRequestMsg) marshal() []byte {
  1879  	return []byte{typeHelloRequest, 0, 0, 0}
  1880  }
  1881  
  1882  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1883  	return len(data) == 4
  1884  }