github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/tls/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"strings"
    11  
    12  	"golang.org/x/crypto/cryptobyte"
    13  )
    14  
    15  // The marshalingFunction type is an adapter to allow the use of ordinary
    16  // functions as cryptobyte.MarshalingValue.
    17  type marshalingFunction func(b *cryptobyte.Builder) error
    18  
    19  func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
    20  	return f(b)
    21  }
    22  
    23  // addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
    24  // the length of the sequence is not the value specified, it produces an error.
    25  func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
    26  	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
    27  		if len(v) != n {
    28  			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
    29  		}
    30  		b.AddBytes(v)
    31  		return nil
    32  	}))
    33  }
    34  
    35  // addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
    36  func addUint64(b *cryptobyte.Builder, v uint64) {
    37  	b.AddUint32(uint32(v >> 32))
    38  	b.AddUint32(uint32(v))
    39  }
    40  
    41  // readUint64 decodes a big-endian, 64-bit value into out and advances over it.
    42  // It reports whether the read was successful.
    43  func readUint64(s *cryptobyte.String, out *uint64) bool {
    44  	var hi, lo uint32
    45  	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
    46  		return false
    47  	}
    48  	*out = uint64(hi)<<32 | uint64(lo)
    49  	return true
    50  }
    51  
    52  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
    53  // []byte instead of a cryptobyte.String.
    54  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    55  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
    56  }
    57  
    58  // readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
    59  // []byte instead of a cryptobyte.String.
    60  func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    61  	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
    62  }
    63  
    64  // readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
    65  // []byte instead of a cryptobyte.String.
    66  func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    67  	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
    68  }
    69  
    70  type clientHelloMsg struct {
    71  	raw                              []byte
    72  	vers                             uint16
    73  	random                           []byte
    74  	sessionId                        []byte
    75  	cipherSuites                     []uint16
    76  	compressionMethods               []uint8
    77  	serverName                       string
    78  	ocspStapling                     bool
    79  	supportedCurves                  []CurveID
    80  	supportedPoints                  []uint8
    81  	ticketSupported                  bool
    82  	sessionTicket                    []uint8
    83  	supportedSignatureAlgorithms     []SignatureScheme
    84  	supportedSignatureAlgorithmsCert []SignatureScheme
    85  	secureRenegotiationSupported     bool
    86  	secureRenegotiation              []byte
    87  	alpnProtocols                    []string
    88  	scts                             bool
    89  	supportedVersions                []uint16
    90  	cookie                           []byte
    91  	keyShares                        []keyShare
    92  	earlyData                        bool
    93  	pskModes                         []uint8
    94  	pskIdentities                    []pskIdentity
    95  	pskBinders                       [][]byte
    96  }
    97  
    98  func (m *clientHelloMsg) marshal() ([]byte, error) {
    99  	if m.raw != nil {
   100  		return m.raw, nil
   101  	}
   102  
   103  	var exts cryptobyte.Builder
   104  	if len(m.serverName) > 0 {
   105  		// RFC 6066, Section 3
   106  		exts.AddUint16(extensionServerName)
   107  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   108  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   109  				exts.AddUint8(0) // name_type = host_name
   110  				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   111  					exts.AddBytes([]byte(m.serverName))
   112  				})
   113  			})
   114  		})
   115  	}
   116  	if m.ocspStapling {
   117  		// RFC 4366, Section 3.6
   118  		exts.AddUint16(extensionStatusRequest)
   119  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   120  			exts.AddUint8(1)  // status_type = ocsp
   121  			exts.AddUint16(0) // empty responder_id_list
   122  			exts.AddUint16(0) // empty request_extensions
   123  		})
   124  	}
   125  	if len(m.supportedCurves) > 0 {
   126  		// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   127  		exts.AddUint16(extensionSupportedCurves)
   128  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   129  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   130  				for _, curve := range m.supportedCurves {
   131  					exts.AddUint16(uint16(curve))
   132  				}
   133  			})
   134  		})
   135  	}
   136  	if len(m.supportedPoints) > 0 {
   137  		// RFC 4492, Section 5.1.2
   138  		exts.AddUint16(extensionSupportedPoints)
   139  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   140  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   141  				exts.AddBytes(m.supportedPoints)
   142  			})
   143  		})
   144  	}
   145  	if m.ticketSupported {
   146  		// RFC 5077, Section 3.2
   147  		exts.AddUint16(extensionSessionTicket)
   148  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   149  			exts.AddBytes(m.sessionTicket)
   150  		})
   151  	}
   152  	if len(m.supportedSignatureAlgorithms) > 0 {
   153  		// RFC 5246, Section 7.4.1.4.1
   154  		exts.AddUint16(extensionSignatureAlgorithms)
   155  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   156  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   157  				for _, sigAlgo := range m.supportedSignatureAlgorithms {
   158  					exts.AddUint16(uint16(sigAlgo))
   159  				}
   160  			})
   161  		})
   162  	}
   163  	if len(m.supportedSignatureAlgorithmsCert) > 0 {
   164  		// RFC 8446, Section 4.2.3
   165  		exts.AddUint16(extensionSignatureAlgorithmsCert)
   166  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   167  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   168  				for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
   169  					exts.AddUint16(uint16(sigAlgo))
   170  				}
   171  			})
   172  		})
   173  	}
   174  	if m.secureRenegotiationSupported {
   175  		// RFC 5746, Section 3.2
   176  		exts.AddUint16(extensionRenegotiationInfo)
   177  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   178  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   179  				exts.AddBytes(m.secureRenegotiation)
   180  			})
   181  		})
   182  	}
   183  	if len(m.alpnProtocols) > 0 {
   184  		// RFC 7301, Section 3.1
   185  		exts.AddUint16(extensionALPN)
   186  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   187  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   188  				for _, proto := range m.alpnProtocols {
   189  					exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   190  						exts.AddBytes([]byte(proto))
   191  					})
   192  				}
   193  			})
   194  		})
   195  	}
   196  	if m.scts {
   197  		// RFC 6962, Section 3.3.1
   198  		exts.AddUint16(extensionSCT)
   199  		exts.AddUint16(0) // empty extension_data
   200  	}
   201  	if len(m.supportedVersions) > 0 {
   202  		// RFC 8446, Section 4.2.1
   203  		exts.AddUint16(extensionSupportedVersions)
   204  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   205  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   206  				for _, vers := range m.supportedVersions {
   207  					exts.AddUint16(vers)
   208  				}
   209  			})
   210  		})
   211  	}
   212  	if len(m.cookie) > 0 {
   213  		// RFC 8446, Section 4.2.2
   214  		exts.AddUint16(extensionCookie)
   215  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   216  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   217  				exts.AddBytes(m.cookie)
   218  			})
   219  		})
   220  	}
   221  	if len(m.keyShares) > 0 {
   222  		// RFC 8446, Section 4.2.8
   223  		exts.AddUint16(extensionKeyShare)
   224  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   225  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   226  				for _, ks := range m.keyShares {
   227  					exts.AddUint16(uint16(ks.group))
   228  					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   229  						exts.AddBytes(ks.data)
   230  					})
   231  				}
   232  			})
   233  		})
   234  	}
   235  	if m.earlyData {
   236  		// RFC 8446, Section 4.2.10
   237  		exts.AddUint16(extensionEarlyData)
   238  		exts.AddUint16(0) // empty extension_data
   239  	}
   240  	if len(m.pskModes) > 0 {
   241  		// RFC 8446, Section 4.2.9
   242  		exts.AddUint16(extensionPSKModes)
   243  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   244  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   245  				exts.AddBytes(m.pskModes)
   246  			})
   247  		})
   248  	}
   249  	if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
   250  		// RFC 8446, Section 4.2.11
   251  		exts.AddUint16(extensionPreSharedKey)
   252  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   253  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   254  				for _, psk := range m.pskIdentities {
   255  					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   256  						exts.AddBytes(psk.label)
   257  					})
   258  					exts.AddUint32(psk.obfuscatedTicketAge)
   259  				}
   260  			})
   261  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   262  				for _, binder := range m.pskBinders {
   263  					exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   264  						exts.AddBytes(binder)
   265  					})
   266  				}
   267  			})
   268  		})
   269  	}
   270  	extBytes, err := exts.Bytes()
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  
   275  	var b cryptobyte.Builder
   276  	b.AddUint8(typeClientHello)
   277  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   278  		b.AddUint16(m.vers)
   279  		addBytesWithLength(b, m.random, 32)
   280  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   281  			b.AddBytes(m.sessionId)
   282  		})
   283  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   284  			for _, suite := range m.cipherSuites {
   285  				b.AddUint16(suite)
   286  			}
   287  		})
   288  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   289  			b.AddBytes(m.compressionMethods)
   290  		})
   291  
   292  		if len(extBytes) > 0 {
   293  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   294  				b.AddBytes(extBytes)
   295  			})
   296  		}
   297  	})
   298  
   299  	m.raw, err = b.Bytes()
   300  	return m.raw, err
   301  }
   302  
   303  // marshalWithoutBinders returns the ClientHello through the
   304  // PreSharedKeyExtension.identities field, according to RFC 8446, Section
   305  // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
   306  func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
   307  	bindersLen := 2 // uint16 length prefix
   308  	for _, binder := range m.pskBinders {
   309  		bindersLen += 1 // uint8 length prefix
   310  		bindersLen += len(binder)
   311  	}
   312  
   313  	fullMessage, err := m.marshal()
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	return fullMessage[:len(fullMessage)-bindersLen], nil
   318  }
   319  
   320  // updateBinders updates the m.pskBinders field, if necessary updating the
   321  // cached marshaled representation. The supplied binders must have the same
   322  // length as the current m.pskBinders.
   323  func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
   324  	if len(pskBinders) != len(m.pskBinders) {
   325  		return errors.New("tls: internal error: pskBinders length mismatch")
   326  	}
   327  	for i := range m.pskBinders {
   328  		if len(pskBinders[i]) != len(m.pskBinders[i]) {
   329  			return errors.New("tls: internal error: pskBinders length mismatch")
   330  		}
   331  	}
   332  	m.pskBinders = pskBinders
   333  	if m.raw != nil {
   334  		helloBytes, err := m.marshalWithoutBinders()
   335  		if err != nil {
   336  			return err
   337  		}
   338  		lenWithoutBinders := len(helloBytes)
   339  		b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
   340  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   341  			for _, binder := range m.pskBinders {
   342  				b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   343  					b.AddBytes(binder)
   344  				})
   345  			}
   346  		})
   347  		if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
   348  			return errors.New("tls: internal error: failed to update binders")
   349  		}
   350  	}
   351  
   352  	return nil
   353  }
   354  
   355  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   356  	*m = clientHelloMsg{raw: data}
   357  	s := cryptobyte.String(data)
   358  
   359  	if !s.Skip(4) || // message type and uint24 length field
   360  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   361  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   362  		return false
   363  	}
   364  
   365  	var cipherSuites cryptobyte.String
   366  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   367  		return false
   368  	}
   369  	m.cipherSuites = []uint16{}
   370  	m.secureRenegotiationSupported = false
   371  	for !cipherSuites.Empty() {
   372  		var suite uint16
   373  		if !cipherSuites.ReadUint16(&suite) {
   374  			return false
   375  		}
   376  		if suite == scsvRenegotiation {
   377  			m.secureRenegotiationSupported = true
   378  		}
   379  		m.cipherSuites = append(m.cipherSuites, suite)
   380  	}
   381  
   382  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   383  		return false
   384  	}
   385  
   386  	if s.Empty() {
   387  		// ClientHello is optionally followed by extension data
   388  		return true
   389  	}
   390  
   391  	var extensions cryptobyte.String
   392  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   393  		return false
   394  	}
   395  
   396  	seenExts := make(map[uint16]bool)
   397  	for !extensions.Empty() {
   398  		var extension uint16
   399  		var extData cryptobyte.String
   400  		if !extensions.ReadUint16(&extension) ||
   401  			!extensions.ReadUint16LengthPrefixed(&extData) {
   402  			return false
   403  		}
   404  
   405  		if seenExts[extension] {
   406  			return false
   407  		}
   408  		seenExts[extension] = true
   409  
   410  		switch extension {
   411  		case extensionServerName:
   412  			// RFC 6066, Section 3
   413  			var nameList cryptobyte.String
   414  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   415  				return false
   416  			}
   417  			for !nameList.Empty() {
   418  				var nameType uint8
   419  				var serverName cryptobyte.String
   420  				if !nameList.ReadUint8(&nameType) ||
   421  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   422  					serverName.Empty() {
   423  					return false
   424  				}
   425  				if nameType != 0 {
   426  					continue
   427  				}
   428  				if len(m.serverName) != 0 {
   429  					// Multiple names of the same name_type are prohibited.
   430  					return false
   431  				}
   432  				m.serverName = string(serverName)
   433  				// An SNI value may not include a trailing dot.
   434  				if strings.HasSuffix(m.serverName, ".") {
   435  					return false
   436  				}
   437  			}
   438  		case extensionStatusRequest:
   439  			// RFC 4366, Section 3.6
   440  			var statusType uint8
   441  			var ignored cryptobyte.String
   442  			if !extData.ReadUint8(&statusType) ||
   443  				!extData.ReadUint16LengthPrefixed(&ignored) ||
   444  				!extData.ReadUint16LengthPrefixed(&ignored) {
   445  				return false
   446  			}
   447  			m.ocspStapling = statusType == statusTypeOCSP
   448  		case extensionSupportedCurves:
   449  			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   450  			var curves cryptobyte.String
   451  			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
   452  				return false
   453  			}
   454  			for !curves.Empty() {
   455  				var curve uint16
   456  				if !curves.ReadUint16(&curve) {
   457  					return false
   458  				}
   459  				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
   460  			}
   461  		case extensionSupportedPoints:
   462  			// RFC 4492, Section 5.1.2
   463  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   464  				len(m.supportedPoints) == 0 {
   465  				return false
   466  			}
   467  		case extensionSessionTicket:
   468  			// RFC 5077, Section 3.2
   469  			m.ticketSupported = true
   470  			extData.ReadBytes(&m.sessionTicket, len(extData))
   471  		case extensionSignatureAlgorithms:
   472  			// RFC 5246, Section 7.4.1.4.1
   473  			var sigAndAlgs cryptobyte.String
   474  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   475  				return false
   476  			}
   477  			for !sigAndAlgs.Empty() {
   478  				var sigAndAlg uint16
   479  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   480  					return false
   481  				}
   482  				m.supportedSignatureAlgorithms = append(
   483  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
   484  			}
   485  		case extensionSignatureAlgorithmsCert:
   486  			// RFC 8446, Section 4.2.3
   487  			var sigAndAlgs cryptobyte.String
   488  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   489  				return false
   490  			}
   491  			for !sigAndAlgs.Empty() {
   492  				var sigAndAlg uint16
   493  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   494  					return false
   495  				}
   496  				m.supportedSignatureAlgorithmsCert = append(
   497  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
   498  			}
   499  		case extensionRenegotiationInfo:
   500  			// RFC 5746, Section 3.2
   501  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   502  				return false
   503  			}
   504  			m.secureRenegotiationSupported = true
   505  		case extensionALPN:
   506  			// RFC 7301, Section 3.1
   507  			var protoList cryptobyte.String
   508  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   509  				return false
   510  			}
   511  			for !protoList.Empty() {
   512  				var proto cryptobyte.String
   513  				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
   514  					return false
   515  				}
   516  				m.alpnProtocols = append(m.alpnProtocols, string(proto))
   517  			}
   518  		case extensionSCT:
   519  			// RFC 6962, Section 3.3.1
   520  			m.scts = true
   521  		case extensionSupportedVersions:
   522  			// RFC 8446, Section 4.2.1
   523  			var versList cryptobyte.String
   524  			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
   525  				return false
   526  			}
   527  			for !versList.Empty() {
   528  				var vers uint16
   529  				if !versList.ReadUint16(&vers) {
   530  					return false
   531  				}
   532  				m.supportedVersions = append(m.supportedVersions, vers)
   533  			}
   534  		case extensionCookie:
   535  			// RFC 8446, Section 4.2.2
   536  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   537  				len(m.cookie) == 0 {
   538  				return false
   539  			}
   540  		case extensionKeyShare:
   541  			// RFC 8446, Section 4.2.8
   542  			var clientShares cryptobyte.String
   543  			if !extData.ReadUint16LengthPrefixed(&clientShares) {
   544  				return false
   545  			}
   546  			for !clientShares.Empty() {
   547  				var ks keyShare
   548  				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
   549  					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
   550  					len(ks.data) == 0 {
   551  					return false
   552  				}
   553  				m.keyShares = append(m.keyShares, ks)
   554  			}
   555  		case extensionEarlyData:
   556  			// RFC 8446, Section 4.2.10
   557  			m.earlyData = true
   558  		case extensionPSKModes:
   559  			// RFC 8446, Section 4.2.9
   560  			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
   561  				return false
   562  			}
   563  		case extensionPreSharedKey:
   564  			// RFC 8446, Section 4.2.11
   565  			if !extensions.Empty() {
   566  				return false // pre_shared_key must be the last extension
   567  			}
   568  			var identities cryptobyte.String
   569  			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
   570  				return false
   571  			}
   572  			for !identities.Empty() {
   573  				var psk pskIdentity
   574  				if !readUint16LengthPrefixed(&identities, &psk.label) ||
   575  					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
   576  					len(psk.label) == 0 {
   577  					return false
   578  				}
   579  				m.pskIdentities = append(m.pskIdentities, psk)
   580  			}
   581  			var binders cryptobyte.String
   582  			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
   583  				return false
   584  			}
   585  			for !binders.Empty() {
   586  				var binder []byte
   587  				if !readUint8LengthPrefixed(&binders, &binder) ||
   588  					len(binder) == 0 {
   589  					return false
   590  				}
   591  				m.pskBinders = append(m.pskBinders, binder)
   592  			}
   593  		default:
   594  			// Ignore unknown extensions.
   595  			continue
   596  		}
   597  
   598  		if !extData.Empty() {
   599  			return false
   600  		}
   601  	}
   602  
   603  	return true
   604  }
   605  
   606  type serverHelloMsg struct {
   607  	raw                          []byte
   608  	vers                         uint16
   609  	random                       []byte
   610  	sessionId                    []byte
   611  	cipherSuite                  uint16
   612  	compressionMethod            uint8
   613  	ocspStapling                 bool
   614  	ticketSupported              bool
   615  	secureRenegotiationSupported bool
   616  	secureRenegotiation          []byte
   617  	alpnProtocol                 string
   618  	scts                         [][]byte
   619  	supportedVersion             uint16
   620  	serverShare                  keyShare
   621  	selectedIdentityPresent      bool
   622  	selectedIdentity             uint16
   623  	supportedPoints              []uint8
   624  
   625  	// HelloRetryRequest extensions
   626  	cookie        []byte
   627  	selectedGroup CurveID
   628  }
   629  
   630  func (m *serverHelloMsg) marshal() ([]byte, error) {
   631  	if m.raw != nil {
   632  		return m.raw, nil
   633  	}
   634  
   635  	var exts cryptobyte.Builder
   636  	if m.ocspStapling {
   637  		exts.AddUint16(extensionStatusRequest)
   638  		exts.AddUint16(0) // empty extension_data
   639  	}
   640  	if m.ticketSupported {
   641  		exts.AddUint16(extensionSessionTicket)
   642  		exts.AddUint16(0) // empty extension_data
   643  	}
   644  	if m.secureRenegotiationSupported {
   645  		exts.AddUint16(extensionRenegotiationInfo)
   646  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   647  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   648  				exts.AddBytes(m.secureRenegotiation)
   649  			})
   650  		})
   651  	}
   652  	if len(m.alpnProtocol) > 0 {
   653  		exts.AddUint16(extensionALPN)
   654  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   655  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   656  				exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   657  					exts.AddBytes([]byte(m.alpnProtocol))
   658  				})
   659  			})
   660  		})
   661  	}
   662  	if len(m.scts) > 0 {
   663  		exts.AddUint16(extensionSCT)
   664  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   665  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   666  				for _, sct := range m.scts {
   667  					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   668  						exts.AddBytes(sct)
   669  					})
   670  				}
   671  			})
   672  		})
   673  	}
   674  	if m.supportedVersion != 0 {
   675  		exts.AddUint16(extensionSupportedVersions)
   676  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   677  			exts.AddUint16(m.supportedVersion)
   678  		})
   679  	}
   680  	if m.serverShare.group != 0 {
   681  		exts.AddUint16(extensionKeyShare)
   682  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   683  			exts.AddUint16(uint16(m.serverShare.group))
   684  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   685  				exts.AddBytes(m.serverShare.data)
   686  			})
   687  		})
   688  	}
   689  	if m.selectedIdentityPresent {
   690  		exts.AddUint16(extensionPreSharedKey)
   691  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   692  			exts.AddUint16(m.selectedIdentity)
   693  		})
   694  	}
   695  
   696  	if len(m.cookie) > 0 {
   697  		exts.AddUint16(extensionCookie)
   698  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   699  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   700  				exts.AddBytes(m.cookie)
   701  			})
   702  		})
   703  	}
   704  	if m.selectedGroup != 0 {
   705  		exts.AddUint16(extensionKeyShare)
   706  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   707  			exts.AddUint16(uint16(m.selectedGroup))
   708  		})
   709  	}
   710  	if len(m.supportedPoints) > 0 {
   711  		exts.AddUint16(extensionSupportedPoints)
   712  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   713  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   714  				exts.AddBytes(m.supportedPoints)
   715  			})
   716  		})
   717  	}
   718  
   719  	extBytes, err := exts.Bytes()
   720  	if err != nil {
   721  		return nil, err
   722  	}
   723  
   724  	var b cryptobyte.Builder
   725  	b.AddUint8(typeServerHello)
   726  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   727  		b.AddUint16(m.vers)
   728  		addBytesWithLength(b, m.random, 32)
   729  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   730  			b.AddBytes(m.sessionId)
   731  		})
   732  		b.AddUint16(m.cipherSuite)
   733  		b.AddUint8(m.compressionMethod)
   734  
   735  		if len(extBytes) > 0 {
   736  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   737  				b.AddBytes(extBytes)
   738  			})
   739  		}
   740  	})
   741  
   742  	m.raw, err = b.Bytes()
   743  	return m.raw, err
   744  }
   745  
   746  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   747  	*m = serverHelloMsg{raw: data}
   748  	s := cryptobyte.String(data)
   749  
   750  	if !s.Skip(4) || // message type and uint24 length field
   751  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   752  		!readUint8LengthPrefixed(&s, &m.sessionId) ||
   753  		!s.ReadUint16(&m.cipherSuite) ||
   754  		!s.ReadUint8(&m.compressionMethod) {
   755  		return false
   756  	}
   757  
   758  	if s.Empty() {
   759  		// ServerHello is optionally followed by extension data
   760  		return true
   761  	}
   762  
   763  	var extensions cryptobyte.String
   764  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   765  		return false
   766  	}
   767  
   768  	seenExts := make(map[uint16]bool)
   769  	for !extensions.Empty() {
   770  		var extension uint16
   771  		var extData cryptobyte.String
   772  		if !extensions.ReadUint16(&extension) ||
   773  			!extensions.ReadUint16LengthPrefixed(&extData) {
   774  			return false
   775  		}
   776  
   777  		if seenExts[extension] {
   778  			return false
   779  		}
   780  		seenExts[extension] = true
   781  
   782  		switch extension {
   783  		case extensionStatusRequest:
   784  			m.ocspStapling = true
   785  		case extensionSessionTicket:
   786  			m.ticketSupported = true
   787  		case extensionRenegotiationInfo:
   788  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   789  				return false
   790  			}
   791  			m.secureRenegotiationSupported = true
   792  		case extensionALPN:
   793  			var protoList cryptobyte.String
   794  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   795  				return false
   796  			}
   797  			var proto cryptobyte.String
   798  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   799  				proto.Empty() || !protoList.Empty() {
   800  				return false
   801  			}
   802  			m.alpnProtocol = string(proto)
   803  		case extensionSCT:
   804  			var sctList cryptobyte.String
   805  			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
   806  				return false
   807  			}
   808  			for !sctList.Empty() {
   809  				var sct []byte
   810  				if !readUint16LengthPrefixed(&sctList, &sct) ||
   811  					len(sct) == 0 {
   812  					return false
   813  				}
   814  				m.scts = append(m.scts, sct)
   815  			}
   816  		case extensionSupportedVersions:
   817  			if !extData.ReadUint16(&m.supportedVersion) {
   818  				return false
   819  			}
   820  		case extensionCookie:
   821  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   822  				len(m.cookie) == 0 {
   823  				return false
   824  			}
   825  		case extensionKeyShare:
   826  			// This extension has different formats in SH and HRR, accept either
   827  			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
   828  			if len(extData) == 2 {
   829  				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
   830  					return false
   831  				}
   832  			} else {
   833  				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
   834  					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
   835  					return false
   836  				}
   837  			}
   838  		case extensionPreSharedKey:
   839  			m.selectedIdentityPresent = true
   840  			if !extData.ReadUint16(&m.selectedIdentity) {
   841  				return false
   842  			}
   843  		case extensionSupportedPoints:
   844  			// RFC 4492, Section 5.1.2
   845  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   846  				len(m.supportedPoints) == 0 {
   847  				return false
   848  			}
   849  		default:
   850  			// Ignore unknown extensions.
   851  			continue
   852  		}
   853  
   854  		if !extData.Empty() {
   855  			return false
   856  		}
   857  	}
   858  
   859  	return true
   860  }
   861  
   862  type encryptedExtensionsMsg struct {
   863  	raw          []byte
   864  	alpnProtocol string
   865  }
   866  
   867  func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
   868  	if m.raw != nil {
   869  		return m.raw, nil
   870  	}
   871  
   872  	var b cryptobyte.Builder
   873  	b.AddUint8(typeEncryptedExtensions)
   874  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   875  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   876  			if len(m.alpnProtocol) > 0 {
   877  				b.AddUint16(extensionALPN)
   878  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   879  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   880  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   881  							b.AddBytes([]byte(m.alpnProtocol))
   882  						})
   883  					})
   884  				})
   885  			}
   886  		})
   887  	})
   888  
   889  	var err error
   890  	m.raw, err = b.Bytes()
   891  	return m.raw, err
   892  }
   893  
   894  func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
   895  	*m = encryptedExtensionsMsg{raw: data}
   896  	s := cryptobyte.String(data)
   897  
   898  	var extensions cryptobyte.String
   899  	if !s.Skip(4) || // message type and uint24 length field
   900  		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   901  		return false
   902  	}
   903  
   904  	for !extensions.Empty() {
   905  		var extension uint16
   906  		var extData cryptobyte.String
   907  		if !extensions.ReadUint16(&extension) ||
   908  			!extensions.ReadUint16LengthPrefixed(&extData) {
   909  			return false
   910  		}
   911  
   912  		switch extension {
   913  		case extensionALPN:
   914  			var protoList cryptobyte.String
   915  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   916  				return false
   917  			}
   918  			var proto cryptobyte.String
   919  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   920  				proto.Empty() || !protoList.Empty() {
   921  				return false
   922  			}
   923  			m.alpnProtocol = string(proto)
   924  		default:
   925  			// Ignore unknown extensions.
   926  			continue
   927  		}
   928  
   929  		if !extData.Empty() {
   930  			return false
   931  		}
   932  	}
   933  
   934  	return true
   935  }
   936  
   937  type endOfEarlyDataMsg struct{}
   938  
   939  func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
   940  	x := make([]byte, 4)
   941  	x[0] = typeEndOfEarlyData
   942  	return x, nil
   943  }
   944  
   945  func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
   946  	return len(data) == 4
   947  }
   948  
   949  type keyUpdateMsg struct {
   950  	raw             []byte
   951  	updateRequested bool
   952  }
   953  
   954  func (m *keyUpdateMsg) marshal() ([]byte, error) {
   955  	if m.raw != nil {
   956  		return m.raw, nil
   957  	}
   958  
   959  	var b cryptobyte.Builder
   960  	b.AddUint8(typeKeyUpdate)
   961  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   962  		if m.updateRequested {
   963  			b.AddUint8(1)
   964  		} else {
   965  			b.AddUint8(0)
   966  		}
   967  	})
   968  
   969  	var err error
   970  	m.raw, err = b.Bytes()
   971  	return m.raw, err
   972  }
   973  
   974  func (m *keyUpdateMsg) unmarshal(data []byte) bool {
   975  	m.raw = data
   976  	s := cryptobyte.String(data)
   977  
   978  	var updateRequested uint8
   979  	if !s.Skip(4) || // message type and uint24 length field
   980  		!s.ReadUint8(&updateRequested) || !s.Empty() {
   981  		return false
   982  	}
   983  	switch updateRequested {
   984  	case 0:
   985  		m.updateRequested = false
   986  	case 1:
   987  		m.updateRequested = true
   988  	default:
   989  		return false
   990  	}
   991  	return true
   992  }
   993  
   994  type newSessionTicketMsgTLS13 struct {
   995  	raw          []byte
   996  	lifetime     uint32
   997  	ageAdd       uint32
   998  	nonce        []byte
   999  	label        []byte
  1000  	maxEarlyData uint32
  1001  }
  1002  
  1003  func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
  1004  	if m.raw != nil {
  1005  		return m.raw, nil
  1006  	}
  1007  
  1008  	var b cryptobyte.Builder
  1009  	b.AddUint8(typeNewSessionTicket)
  1010  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1011  		b.AddUint32(m.lifetime)
  1012  		b.AddUint32(m.ageAdd)
  1013  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
  1014  			b.AddBytes(m.nonce)
  1015  		})
  1016  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1017  			b.AddBytes(m.label)
  1018  		})
  1019  
  1020  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1021  			if m.maxEarlyData > 0 {
  1022  				b.AddUint16(extensionEarlyData)
  1023  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1024  					b.AddUint32(m.maxEarlyData)
  1025  				})
  1026  			}
  1027  		})
  1028  	})
  1029  
  1030  	var err error
  1031  	m.raw, err = b.Bytes()
  1032  	return m.raw, err
  1033  }
  1034  
  1035  func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
  1036  	*m = newSessionTicketMsgTLS13{raw: data}
  1037  	s := cryptobyte.String(data)
  1038  
  1039  	var extensions cryptobyte.String
  1040  	if !s.Skip(4) || // message type and uint24 length field
  1041  		!s.ReadUint32(&m.lifetime) ||
  1042  		!s.ReadUint32(&m.ageAdd) ||
  1043  		!readUint8LengthPrefixed(&s, &m.nonce) ||
  1044  		!readUint16LengthPrefixed(&s, &m.label) ||
  1045  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1046  		!s.Empty() {
  1047  		return false
  1048  	}
  1049  
  1050  	for !extensions.Empty() {
  1051  		var extension uint16
  1052  		var extData cryptobyte.String
  1053  		if !extensions.ReadUint16(&extension) ||
  1054  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1055  			return false
  1056  		}
  1057  
  1058  		switch extension {
  1059  		case extensionEarlyData:
  1060  			if !extData.ReadUint32(&m.maxEarlyData) {
  1061  				return false
  1062  			}
  1063  		default:
  1064  			// Ignore unknown extensions.
  1065  			continue
  1066  		}
  1067  
  1068  		if !extData.Empty() {
  1069  			return false
  1070  		}
  1071  	}
  1072  
  1073  	return true
  1074  }
  1075  
  1076  type certificateRequestMsgTLS13 struct {
  1077  	raw                              []byte
  1078  	ocspStapling                     bool
  1079  	scts                             bool
  1080  	supportedSignatureAlgorithms     []SignatureScheme
  1081  	supportedSignatureAlgorithmsCert []SignatureScheme
  1082  	certificateAuthorities           [][]byte
  1083  }
  1084  
  1085  func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
  1086  	if m.raw != nil {
  1087  		return m.raw, nil
  1088  	}
  1089  
  1090  	var b cryptobyte.Builder
  1091  	b.AddUint8(typeCertificateRequest)
  1092  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1093  		// certificate_request_context (SHALL be zero length unless used for
  1094  		// post-handshake authentication)
  1095  		b.AddUint8(0)
  1096  
  1097  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1098  			if m.ocspStapling {
  1099  				b.AddUint16(extensionStatusRequest)
  1100  				b.AddUint16(0) // empty extension_data
  1101  			}
  1102  			if m.scts {
  1103  				// RFC 8446, Section 4.4.2.1 makes no mention of
  1104  				// signed_certificate_timestamp in CertificateRequest, but
  1105  				// "Extensions in the Certificate message from the client MUST
  1106  				// correspond to extensions in the CertificateRequest message
  1107  				// from the server." and it appears in the table in Section 4.2.
  1108  				b.AddUint16(extensionSCT)
  1109  				b.AddUint16(0) // empty extension_data
  1110  			}
  1111  			if len(m.supportedSignatureAlgorithms) > 0 {
  1112  				b.AddUint16(extensionSignatureAlgorithms)
  1113  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1114  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1115  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1116  							b.AddUint16(uint16(sigAlgo))
  1117  						}
  1118  					})
  1119  				})
  1120  			}
  1121  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
  1122  				b.AddUint16(extensionSignatureAlgorithmsCert)
  1123  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1124  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1125  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
  1126  							b.AddUint16(uint16(sigAlgo))
  1127  						}
  1128  					})
  1129  				})
  1130  			}
  1131  			if len(m.certificateAuthorities) > 0 {
  1132  				b.AddUint16(extensionCertificateAuthorities)
  1133  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1134  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1135  						for _, ca := range m.certificateAuthorities {
  1136  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1137  								b.AddBytes(ca)
  1138  							})
  1139  						}
  1140  					})
  1141  				})
  1142  			}
  1143  		})
  1144  	})
  1145  
  1146  	var err error
  1147  	m.raw, err = b.Bytes()
  1148  	return m.raw, err
  1149  }
  1150  
  1151  func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
  1152  	*m = certificateRequestMsgTLS13{raw: data}
  1153  	s := cryptobyte.String(data)
  1154  
  1155  	var context, extensions cryptobyte.String
  1156  	if !s.Skip(4) || // message type and uint24 length field
  1157  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1158  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1159  		!s.Empty() {
  1160  		return false
  1161  	}
  1162  
  1163  	for !extensions.Empty() {
  1164  		var extension uint16
  1165  		var extData cryptobyte.String
  1166  		if !extensions.ReadUint16(&extension) ||
  1167  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1168  			return false
  1169  		}
  1170  
  1171  		switch extension {
  1172  		case extensionStatusRequest:
  1173  			m.ocspStapling = true
  1174  		case extensionSCT:
  1175  			m.scts = true
  1176  		case extensionSignatureAlgorithms:
  1177  			var sigAndAlgs cryptobyte.String
  1178  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1179  				return false
  1180  			}
  1181  			for !sigAndAlgs.Empty() {
  1182  				var sigAndAlg uint16
  1183  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1184  					return false
  1185  				}
  1186  				m.supportedSignatureAlgorithms = append(
  1187  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
  1188  			}
  1189  		case extensionSignatureAlgorithmsCert:
  1190  			var sigAndAlgs cryptobyte.String
  1191  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1192  				return false
  1193  			}
  1194  			for !sigAndAlgs.Empty() {
  1195  				var sigAndAlg uint16
  1196  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1197  					return false
  1198  				}
  1199  				m.supportedSignatureAlgorithmsCert = append(
  1200  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
  1201  			}
  1202  		case extensionCertificateAuthorities:
  1203  			var auths cryptobyte.String
  1204  			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
  1205  				return false
  1206  			}
  1207  			for !auths.Empty() {
  1208  				var ca []byte
  1209  				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
  1210  					return false
  1211  				}
  1212  				m.certificateAuthorities = append(m.certificateAuthorities, ca)
  1213  			}
  1214  		default:
  1215  			// Ignore unknown extensions.
  1216  			continue
  1217  		}
  1218  
  1219  		if !extData.Empty() {
  1220  			return false
  1221  		}
  1222  	}
  1223  
  1224  	return true
  1225  }
  1226  
  1227  type certificateMsg struct {
  1228  	raw          []byte
  1229  	certificates [][]byte
  1230  }
  1231  
  1232  func (m *certificateMsg) marshal() ([]byte, error) {
  1233  	if m.raw != nil {
  1234  		return m.raw, nil
  1235  	}
  1236  
  1237  	var i int
  1238  	for _, slice := range m.certificates {
  1239  		i += len(slice)
  1240  	}
  1241  
  1242  	length := 3 + 3*len(m.certificates) + i
  1243  	x := make([]byte, 4+length)
  1244  	x[0] = typeCertificate
  1245  	x[1] = uint8(length >> 16)
  1246  	x[2] = uint8(length >> 8)
  1247  	x[3] = uint8(length)
  1248  
  1249  	certificateOctets := length - 3
  1250  	x[4] = uint8(certificateOctets >> 16)
  1251  	x[5] = uint8(certificateOctets >> 8)
  1252  	x[6] = uint8(certificateOctets)
  1253  
  1254  	y := x[7:]
  1255  	for _, slice := range m.certificates {
  1256  		y[0] = uint8(len(slice) >> 16)
  1257  		y[1] = uint8(len(slice) >> 8)
  1258  		y[2] = uint8(len(slice))
  1259  		copy(y[3:], slice)
  1260  		y = y[3+len(slice):]
  1261  	}
  1262  
  1263  	m.raw = x
  1264  	return m.raw, nil
  1265  }
  1266  
  1267  func (m *certificateMsg) unmarshal(data []byte) bool {
  1268  	if len(data) < 7 {
  1269  		return false
  1270  	}
  1271  
  1272  	m.raw = data
  1273  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1274  	if uint32(len(data)) != certsLen+7 {
  1275  		return false
  1276  	}
  1277  
  1278  	numCerts := 0
  1279  	d := data[7:]
  1280  	for certsLen > 0 {
  1281  		if len(d) < 4 {
  1282  			return false
  1283  		}
  1284  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1285  		if uint32(len(d)) < 3+certLen {
  1286  			return false
  1287  		}
  1288  		d = d[3+certLen:]
  1289  		certsLen -= 3 + certLen
  1290  		numCerts++
  1291  	}
  1292  
  1293  	m.certificates = make([][]byte, numCerts)
  1294  	d = data[7:]
  1295  	for i := 0; i < numCerts; i++ {
  1296  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1297  		m.certificates[i] = d[3 : 3+certLen]
  1298  		d = d[3+certLen:]
  1299  	}
  1300  
  1301  	return true
  1302  }
  1303  
  1304  type certificateMsgTLS13 struct {
  1305  	raw          []byte
  1306  	certificate  Certificate
  1307  	ocspStapling bool
  1308  	scts         bool
  1309  }
  1310  
  1311  func (m *certificateMsgTLS13) marshal() ([]byte, error) {
  1312  	if m.raw != nil {
  1313  		return m.raw, nil
  1314  	}
  1315  
  1316  	var b cryptobyte.Builder
  1317  	b.AddUint8(typeCertificate)
  1318  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1319  		b.AddUint8(0) // certificate_request_context
  1320  
  1321  		certificate := m.certificate
  1322  		if !m.ocspStapling {
  1323  			certificate.OCSPStaple = nil
  1324  		}
  1325  		if !m.scts {
  1326  			certificate.SignedCertificateTimestamps = nil
  1327  		}
  1328  		marshalCertificate(b, certificate)
  1329  	})
  1330  
  1331  	var err error
  1332  	m.raw, err = b.Bytes()
  1333  	return m.raw, err
  1334  }
  1335  
  1336  func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
  1337  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1338  		for i, cert := range certificate.Certificate {
  1339  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1340  				b.AddBytes(cert)
  1341  			})
  1342  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1343  				if i > 0 {
  1344  					// This library only supports OCSP and SCT for leaf certificates.
  1345  					return
  1346  				}
  1347  				if certificate.OCSPStaple != nil {
  1348  					b.AddUint16(extensionStatusRequest)
  1349  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1350  						b.AddUint8(statusTypeOCSP)
  1351  						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1352  							b.AddBytes(certificate.OCSPStaple)
  1353  						})
  1354  					})
  1355  				}
  1356  				if certificate.SignedCertificateTimestamps != nil {
  1357  					b.AddUint16(extensionSCT)
  1358  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1359  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1360  							for _, sct := range certificate.SignedCertificateTimestamps {
  1361  								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1362  									b.AddBytes(sct)
  1363  								})
  1364  							}
  1365  						})
  1366  					})
  1367  				}
  1368  			})
  1369  		}
  1370  	})
  1371  }
  1372  
  1373  func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
  1374  	*m = certificateMsgTLS13{raw: data}
  1375  	s := cryptobyte.String(data)
  1376  
  1377  	var context cryptobyte.String
  1378  	if !s.Skip(4) || // message type and uint24 length field
  1379  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1380  		!unmarshalCertificate(&s, &m.certificate) ||
  1381  		!s.Empty() {
  1382  		return false
  1383  	}
  1384  
  1385  	m.scts = m.certificate.SignedCertificateTimestamps != nil
  1386  	m.ocspStapling = m.certificate.OCSPStaple != nil
  1387  
  1388  	return true
  1389  }
  1390  
  1391  func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
  1392  	var certList cryptobyte.String
  1393  	if !s.ReadUint24LengthPrefixed(&certList) {
  1394  		return false
  1395  	}
  1396  	for !certList.Empty() {
  1397  		var cert []byte
  1398  		var extensions cryptobyte.String
  1399  		if !readUint24LengthPrefixed(&certList, &cert) ||
  1400  			!certList.ReadUint16LengthPrefixed(&extensions) {
  1401  			return false
  1402  		}
  1403  		certificate.Certificate = append(certificate.Certificate, cert)
  1404  		for !extensions.Empty() {
  1405  			var extension uint16
  1406  			var extData cryptobyte.String
  1407  			if !extensions.ReadUint16(&extension) ||
  1408  				!extensions.ReadUint16LengthPrefixed(&extData) {
  1409  				return false
  1410  			}
  1411  			if len(certificate.Certificate) > 1 {
  1412  				// This library only supports OCSP and SCT for leaf certificates.
  1413  				continue
  1414  			}
  1415  
  1416  			switch extension {
  1417  			case extensionStatusRequest:
  1418  				var statusType uint8
  1419  				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1420  					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
  1421  					len(certificate.OCSPStaple) == 0 {
  1422  					return false
  1423  				}
  1424  			case extensionSCT:
  1425  				var sctList cryptobyte.String
  1426  				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
  1427  					return false
  1428  				}
  1429  				for !sctList.Empty() {
  1430  					var sct []byte
  1431  					if !readUint16LengthPrefixed(&sctList, &sct) ||
  1432  						len(sct) == 0 {
  1433  						return false
  1434  					}
  1435  					certificate.SignedCertificateTimestamps = append(
  1436  						certificate.SignedCertificateTimestamps, sct)
  1437  				}
  1438  			default:
  1439  				// Ignore unknown extensions.
  1440  				continue
  1441  			}
  1442  
  1443  			if !extData.Empty() {
  1444  				return false
  1445  			}
  1446  		}
  1447  	}
  1448  	return true
  1449  }
  1450  
  1451  type serverKeyExchangeMsg struct {
  1452  	raw []byte
  1453  	key []byte
  1454  }
  1455  
  1456  func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
  1457  	if m.raw != nil {
  1458  		return m.raw, nil
  1459  	}
  1460  	length := len(m.key)
  1461  	x := make([]byte, length+4)
  1462  	x[0] = typeServerKeyExchange
  1463  	x[1] = uint8(length >> 16)
  1464  	x[2] = uint8(length >> 8)
  1465  	x[3] = uint8(length)
  1466  	copy(x[4:], m.key)
  1467  
  1468  	m.raw = x
  1469  	return x, nil
  1470  }
  1471  
  1472  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1473  	m.raw = data
  1474  	if len(data) < 4 {
  1475  		return false
  1476  	}
  1477  	m.key = data[4:]
  1478  	return true
  1479  }
  1480  
  1481  type certificateStatusMsg struct {
  1482  	raw      []byte
  1483  	response []byte
  1484  }
  1485  
  1486  func (m *certificateStatusMsg) marshal() ([]byte, error) {
  1487  	if m.raw != nil {
  1488  		return m.raw, nil
  1489  	}
  1490  
  1491  	var b cryptobyte.Builder
  1492  	b.AddUint8(typeCertificateStatus)
  1493  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1494  		b.AddUint8(statusTypeOCSP)
  1495  		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1496  			b.AddBytes(m.response)
  1497  		})
  1498  	})
  1499  
  1500  	var err error
  1501  	m.raw, err = b.Bytes()
  1502  	return m.raw, err
  1503  }
  1504  
  1505  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1506  	m.raw = data
  1507  	s := cryptobyte.String(data)
  1508  
  1509  	var statusType uint8
  1510  	if !s.Skip(4) || // message type and uint24 length field
  1511  		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1512  		!readUint24LengthPrefixed(&s, &m.response) ||
  1513  		len(m.response) == 0 || !s.Empty() {
  1514  		return false
  1515  	}
  1516  	return true
  1517  }
  1518  
  1519  type serverHelloDoneMsg struct{}
  1520  
  1521  func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
  1522  	x := make([]byte, 4)
  1523  	x[0] = typeServerHelloDone
  1524  	return x, nil
  1525  }
  1526  
  1527  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1528  	return len(data) == 4
  1529  }
  1530  
  1531  type clientKeyExchangeMsg struct {
  1532  	raw        []byte
  1533  	ciphertext []byte
  1534  }
  1535  
  1536  func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
  1537  	if m.raw != nil {
  1538  		return m.raw, nil
  1539  	}
  1540  	length := len(m.ciphertext)
  1541  	x := make([]byte, length+4)
  1542  	x[0] = typeClientKeyExchange
  1543  	x[1] = uint8(length >> 16)
  1544  	x[2] = uint8(length >> 8)
  1545  	x[3] = uint8(length)
  1546  	copy(x[4:], m.ciphertext)
  1547  
  1548  	m.raw = x
  1549  	return x, nil
  1550  }
  1551  
  1552  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1553  	m.raw = data
  1554  	if len(data) < 4 {
  1555  		return false
  1556  	}
  1557  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1558  	if l != len(data)-4 {
  1559  		return false
  1560  	}
  1561  	m.ciphertext = data[4:]
  1562  	return true
  1563  }
  1564  
  1565  type finishedMsg struct {
  1566  	raw        []byte
  1567  	verifyData []byte
  1568  }
  1569  
  1570  func (m *finishedMsg) marshal() ([]byte, error) {
  1571  	if m.raw != nil {
  1572  		return m.raw, nil
  1573  	}
  1574  
  1575  	var b cryptobyte.Builder
  1576  	b.AddUint8(typeFinished)
  1577  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1578  		b.AddBytes(m.verifyData)
  1579  	})
  1580  
  1581  	var err error
  1582  	m.raw, err = b.Bytes()
  1583  	return m.raw, err
  1584  }
  1585  
  1586  func (m *finishedMsg) unmarshal(data []byte) bool {
  1587  	m.raw = data
  1588  	s := cryptobyte.String(data)
  1589  	return s.Skip(1) &&
  1590  		readUint24LengthPrefixed(&s, &m.verifyData) &&
  1591  		s.Empty()
  1592  }
  1593  
  1594  type certificateRequestMsg struct {
  1595  	raw []byte
  1596  	// hasSignatureAlgorithm indicates whether this message includes a list of
  1597  	// supported signature algorithms. This change was introduced with TLS 1.2.
  1598  	hasSignatureAlgorithm bool
  1599  
  1600  	certificateTypes             []byte
  1601  	supportedSignatureAlgorithms []SignatureScheme
  1602  	certificateAuthorities       [][]byte
  1603  }
  1604  
  1605  func (m *certificateRequestMsg) marshal() ([]byte, error) {
  1606  	if m.raw != nil {
  1607  		return m.raw, nil
  1608  	}
  1609  
  1610  	// See RFC 4346, Section 7.4.4.
  1611  	length := 1 + len(m.certificateTypes) + 2
  1612  	casLength := 0
  1613  	for _, ca := range m.certificateAuthorities {
  1614  		casLength += 2 + len(ca)
  1615  	}
  1616  	length += casLength
  1617  
  1618  	if m.hasSignatureAlgorithm {
  1619  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1620  	}
  1621  
  1622  	x := make([]byte, 4+length)
  1623  	x[0] = typeCertificateRequest
  1624  	x[1] = uint8(length >> 16)
  1625  	x[2] = uint8(length >> 8)
  1626  	x[3] = uint8(length)
  1627  
  1628  	x[4] = uint8(len(m.certificateTypes))
  1629  
  1630  	copy(x[5:], m.certificateTypes)
  1631  	y := x[5+len(m.certificateTypes):]
  1632  
  1633  	if m.hasSignatureAlgorithm {
  1634  		n := len(m.supportedSignatureAlgorithms) * 2
  1635  		y[0] = uint8(n >> 8)
  1636  		y[1] = uint8(n)
  1637  		y = y[2:]
  1638  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1639  			y[0] = uint8(sigAlgo >> 8)
  1640  			y[1] = uint8(sigAlgo)
  1641  			y = y[2:]
  1642  		}
  1643  	}
  1644  
  1645  	y[0] = uint8(casLength >> 8)
  1646  	y[1] = uint8(casLength)
  1647  	y = y[2:]
  1648  	for _, ca := range m.certificateAuthorities {
  1649  		y[0] = uint8(len(ca) >> 8)
  1650  		y[1] = uint8(len(ca))
  1651  		y = y[2:]
  1652  		copy(y, ca)
  1653  		y = y[len(ca):]
  1654  	}
  1655  
  1656  	m.raw = x
  1657  	return m.raw, nil
  1658  }
  1659  
  1660  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1661  	m.raw = data
  1662  
  1663  	if len(data) < 5 {
  1664  		return false
  1665  	}
  1666  
  1667  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1668  	if uint32(len(data))-4 != length {
  1669  		return false
  1670  	}
  1671  
  1672  	numCertTypes := int(data[4])
  1673  	data = data[5:]
  1674  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1675  		return false
  1676  	}
  1677  
  1678  	m.certificateTypes = make([]byte, numCertTypes)
  1679  	if copy(m.certificateTypes, data) != numCertTypes {
  1680  		return false
  1681  	}
  1682  
  1683  	data = data[numCertTypes:]
  1684  
  1685  	if m.hasSignatureAlgorithm {
  1686  		if len(data) < 2 {
  1687  			return false
  1688  		}
  1689  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1690  		data = data[2:]
  1691  		if sigAndHashLen&1 != 0 {
  1692  			return false
  1693  		}
  1694  		if len(data) < int(sigAndHashLen) {
  1695  			return false
  1696  		}
  1697  		numSigAlgos := sigAndHashLen / 2
  1698  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1699  		for i := range m.supportedSignatureAlgorithms {
  1700  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1701  			data = data[2:]
  1702  		}
  1703  	}
  1704  
  1705  	if len(data) < 2 {
  1706  		return false
  1707  	}
  1708  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1709  	data = data[2:]
  1710  	if len(data) < int(casLength) {
  1711  		return false
  1712  	}
  1713  	cas := make([]byte, casLength)
  1714  	copy(cas, data)
  1715  	data = data[casLength:]
  1716  
  1717  	m.certificateAuthorities = nil
  1718  	for len(cas) > 0 {
  1719  		if len(cas) < 2 {
  1720  			return false
  1721  		}
  1722  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1723  		cas = cas[2:]
  1724  
  1725  		if len(cas) < int(caLen) {
  1726  			return false
  1727  		}
  1728  
  1729  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1730  		cas = cas[caLen:]
  1731  	}
  1732  
  1733  	return len(data) == 0
  1734  }
  1735  
  1736  type certificateVerifyMsg struct {
  1737  	raw                   []byte
  1738  	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
  1739  	signatureAlgorithm    SignatureScheme
  1740  	signature             []byte
  1741  }
  1742  
  1743  func (m *certificateVerifyMsg) marshal() ([]byte, error) {
  1744  	if m.raw != nil {
  1745  		return m.raw, nil
  1746  	}
  1747  
  1748  	var b cryptobyte.Builder
  1749  	b.AddUint8(typeCertificateVerify)
  1750  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1751  		if m.hasSignatureAlgorithm {
  1752  			b.AddUint16(uint16(m.signatureAlgorithm))
  1753  		}
  1754  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1755  			b.AddBytes(m.signature)
  1756  		})
  1757  	})
  1758  
  1759  	var err error
  1760  	m.raw, err = b.Bytes()
  1761  	return m.raw, err
  1762  }
  1763  
  1764  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1765  	m.raw = data
  1766  	s := cryptobyte.String(data)
  1767  
  1768  	if !s.Skip(4) { // message type and uint24 length field
  1769  		return false
  1770  	}
  1771  	if m.hasSignatureAlgorithm {
  1772  		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
  1773  			return false
  1774  		}
  1775  	}
  1776  	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
  1777  }
  1778  
  1779  type newSessionTicketMsg struct {
  1780  	raw    []byte
  1781  	ticket []byte
  1782  }
  1783  
  1784  func (m *newSessionTicketMsg) marshal() ([]byte, error) {
  1785  	if m.raw != nil {
  1786  		return m.raw, nil
  1787  	}
  1788  
  1789  	// See RFC 5077, Section 3.3.
  1790  	ticketLen := len(m.ticket)
  1791  	length := 2 + 4 + ticketLen
  1792  	x := make([]byte, 4+length)
  1793  	x[0] = typeNewSessionTicket
  1794  	x[1] = uint8(length >> 16)
  1795  	x[2] = uint8(length >> 8)
  1796  	x[3] = uint8(length)
  1797  	x[8] = uint8(ticketLen >> 8)
  1798  	x[9] = uint8(ticketLen)
  1799  	copy(x[10:], m.ticket)
  1800  
  1801  	m.raw = x
  1802  
  1803  	return m.raw, nil
  1804  }
  1805  
  1806  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1807  	m.raw = data
  1808  
  1809  	if len(data) < 10 {
  1810  		return false
  1811  	}
  1812  
  1813  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1814  	if uint32(len(data))-4 != length {
  1815  		return false
  1816  	}
  1817  
  1818  	ticketLen := int(data[8])<<8 + int(data[9])
  1819  	if len(data)-10 != ticketLen {
  1820  		return false
  1821  	}
  1822  
  1823  	m.ticket = data[10:]
  1824  
  1825  	return true
  1826  }
  1827  
  1828  type helloRequestMsg struct {
  1829  }
  1830  
  1831  func (*helloRequestMsg) marshal() ([]byte, error) {
  1832  	return []byte{typeHelloRequest, 0, 0, 0}, nil
  1833  }
  1834  
  1835  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1836  	return len(data) == 4
  1837  }
  1838  
  1839  type transcriptHash interface {
  1840  	Write([]byte) (int, error)
  1841  }
  1842  
  1843  // transcriptMsg is a helper used to marshal and hash messages which typically
  1844  // are not written to the wire, and as such aren't hashed during Conn.writeRecord.
  1845  func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
  1846  	data, err := msg.marshal()
  1847  	if err != nil {
  1848  		return err
  1849  	}
  1850  	h.Write(data)
  1851  	return nil
  1852  }