github.com/goproxy0/go@v0.0.0-20171111080102-49cc0c489d2c/src/crypto/tls/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"strings"
    10  )
    11  
    12  type clientHelloMsg struct {
    13  	raw                          []byte
    14  	rawTruncated                 []byte // for PSK binding
    15  	vers                         uint16
    16  	random                       []byte
    17  	sessionId                    []byte
    18  	cipherSuites                 []uint16
    19  	compressionMethods           []uint8
    20  	nextProtoNeg                 bool
    21  	serverName                   string
    22  	ocspStapling                 bool
    23  	scts                         bool
    24  	supportedCurves              []CurveID
    25  	supportedPoints              []uint8
    26  	ticketSupported              bool
    27  	sessionTicket                []uint8
    28  	supportedSignatureAlgorithms []SignatureScheme
    29  	secureRenegotiation          []byte
    30  	secureRenegotiationSupported bool
    31  	alpnProtocols                []string
    32  	keyShares                    []keyShare
    33  	supportedVersions            []uint16
    34  	psks                         []psk
    35  	pskKeyExchangeModes          []uint8
    36  	earlyData                    bool
    37  }
    38  
    39  func (m *clientHelloMsg) equal(i interface{}) bool {
    40  	m1, ok := i.(*clientHelloMsg)
    41  	if !ok {
    42  		return false
    43  	}
    44  
    45  	return bytes.Equal(m.raw, m1.raw) &&
    46  		m.vers == m1.vers &&
    47  		bytes.Equal(m.random, m1.random) &&
    48  		bytes.Equal(m.sessionId, m1.sessionId) &&
    49  		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    50  		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    51  		m.nextProtoNeg == m1.nextProtoNeg &&
    52  		m.serverName == m1.serverName &&
    53  		m.ocspStapling == m1.ocspStapling &&
    54  		m.scts == m1.scts &&
    55  		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
    56  		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    57  		m.ticketSupported == m1.ticketSupported &&
    58  		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    59  		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
    60  		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
    61  		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
    62  		eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
    63  		eqKeyShares(m.keyShares, m1.keyShares) &&
    64  		eqUint16s(m.supportedVersions, m1.supportedVersions) &&
    65  		m.earlyData == m1.earlyData
    66  }
    67  
    68  func (m *clientHelloMsg) marshal() []byte {
    69  	if m.raw != nil {
    70  		return m.raw
    71  	}
    72  
    73  	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    74  	numExtensions := 0
    75  	extensionsLength := 0
    76  	if m.nextProtoNeg {
    77  		numExtensions++
    78  	}
    79  	if m.ocspStapling {
    80  		extensionsLength += 1 + 2 + 2
    81  		numExtensions++
    82  	}
    83  	if len(m.serverName) > 0 {
    84  		extensionsLength += 5 + len(m.serverName)
    85  		numExtensions++
    86  	}
    87  	if len(m.supportedCurves) > 0 {
    88  		extensionsLength += 2 + 2*len(m.supportedCurves)
    89  		numExtensions++
    90  	}
    91  	if len(m.supportedPoints) > 0 {
    92  		extensionsLength += 1 + len(m.supportedPoints)
    93  		numExtensions++
    94  	}
    95  	if m.ticketSupported {
    96  		extensionsLength += len(m.sessionTicket)
    97  		numExtensions++
    98  	}
    99  	if len(m.supportedSignatureAlgorithms) > 0 {
   100  		extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
   101  		numExtensions++
   102  	}
   103  	if m.secureRenegotiationSupported {
   104  		extensionsLength += 1 + len(m.secureRenegotiation)
   105  		numExtensions++
   106  	}
   107  	if len(m.alpnProtocols) > 0 {
   108  		extensionsLength += 2
   109  		for _, s := range m.alpnProtocols {
   110  			if l := len(s); l == 0 || l > 255 {
   111  				panic("invalid ALPN protocol")
   112  			}
   113  			extensionsLength++
   114  			extensionsLength += len(s)
   115  		}
   116  		numExtensions++
   117  	}
   118  	if m.scts {
   119  		numExtensions++
   120  	}
   121  	if len(m.keyShares) > 0 {
   122  		extensionsLength += 2
   123  		for _, k := range m.keyShares {
   124  			extensionsLength += 4 + len(k.data)
   125  		}
   126  		numExtensions++
   127  	}
   128  	if len(m.supportedVersions) > 0 {
   129  		extensionsLength += 1 + 2*len(m.supportedVersions)
   130  		numExtensions++
   131  	}
   132  	if m.earlyData {
   133  		numExtensions++
   134  	}
   135  	if numExtensions > 0 {
   136  		extensionsLength += 4 * numExtensions
   137  		length += 2 + extensionsLength
   138  	}
   139  
   140  	x := make([]byte, 4+length)
   141  	x[0] = typeClientHello
   142  	x[1] = uint8(length >> 16)
   143  	x[2] = uint8(length >> 8)
   144  	x[3] = uint8(length)
   145  	x[4] = uint8(m.vers >> 8)
   146  	x[5] = uint8(m.vers)
   147  	copy(x[6:38], m.random)
   148  	x[38] = uint8(len(m.sessionId))
   149  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   150  	y := x[39+len(m.sessionId):]
   151  	y[0] = uint8(len(m.cipherSuites) >> 7)
   152  	y[1] = uint8(len(m.cipherSuites) << 1)
   153  	for i, suite := range m.cipherSuites {
   154  		y[2+i*2] = uint8(suite >> 8)
   155  		y[3+i*2] = uint8(suite)
   156  	}
   157  	z := y[2+len(m.cipherSuites)*2:]
   158  	z[0] = uint8(len(m.compressionMethods))
   159  	copy(z[1:], m.compressionMethods)
   160  
   161  	z = z[1+len(m.compressionMethods):]
   162  	if numExtensions > 0 {
   163  		z[0] = byte(extensionsLength >> 8)
   164  		z[1] = byte(extensionsLength)
   165  		z = z[2:]
   166  	}
   167  	if m.nextProtoNeg {
   168  		z[0] = byte(extensionNextProtoNeg >> 8)
   169  		z[1] = byte(extensionNextProtoNeg & 0xff)
   170  		// The length is always 0
   171  		z = z[4:]
   172  	}
   173  	if len(m.serverName) > 0 {
   174  		z[0] = byte(extensionServerName >> 8)
   175  		z[1] = byte(extensionServerName & 0xff)
   176  		l := len(m.serverName) + 5
   177  		z[2] = byte(l >> 8)
   178  		z[3] = byte(l)
   179  		z = z[4:]
   180  
   181  		// RFC 3546, section 3.1
   182  		//
   183  		// struct {
   184  		//     NameType name_type;
   185  		//     select (name_type) {
   186  		//         case host_name: HostName;
   187  		//     } name;
   188  		// } ServerName;
   189  		//
   190  		// enum {
   191  		//     host_name(0), (255)
   192  		// } NameType;
   193  		//
   194  		// opaque HostName<1..2^16-1>;
   195  		//
   196  		// struct {
   197  		//     ServerName server_name_list<1..2^16-1>
   198  		// } ServerNameList;
   199  
   200  		z[0] = byte((len(m.serverName) + 3) >> 8)
   201  		z[1] = byte(len(m.serverName) + 3)
   202  		z[3] = byte(len(m.serverName) >> 8)
   203  		z[4] = byte(len(m.serverName))
   204  		copy(z[5:], []byte(m.serverName))
   205  		z = z[l:]
   206  	}
   207  	if m.ocspStapling {
   208  		// RFC 4366, section 3.6
   209  		z[0] = byte(extensionStatusRequest >> 8)
   210  		z[1] = byte(extensionStatusRequest)
   211  		z[2] = 0
   212  		z[3] = 5
   213  		z[4] = 1 // OCSP type
   214  		// Two zero valued uint16s for the two lengths.
   215  		z = z[9:]
   216  	}
   217  	if len(m.supportedCurves) > 0 {
   218  		// http://tools.ietf.org/html/rfc4492#section-5.5.1
   219  		// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4
   220  		z[0] = byte(extensionSupportedCurves >> 8)
   221  		z[1] = byte(extensionSupportedCurves)
   222  		l := 2 + 2*len(m.supportedCurves)
   223  		z[2] = byte(l >> 8)
   224  		z[3] = byte(l)
   225  		l -= 2
   226  		z[4] = byte(l >> 8)
   227  		z[5] = byte(l)
   228  		z = z[6:]
   229  		for _, curve := range m.supportedCurves {
   230  			z[0] = byte(curve >> 8)
   231  			z[1] = byte(curve)
   232  			z = z[2:]
   233  		}
   234  	}
   235  	if len(m.supportedPoints) > 0 {
   236  		// http://tools.ietf.org/html/rfc4492#section-5.5.2
   237  		z[0] = byte(extensionSupportedPoints >> 8)
   238  		z[1] = byte(extensionSupportedPoints)
   239  		l := 1 + len(m.supportedPoints)
   240  		z[2] = byte(l >> 8)
   241  		z[3] = byte(l)
   242  		l--
   243  		z[4] = byte(l)
   244  		z = z[5:]
   245  		for _, pointFormat := range m.supportedPoints {
   246  			z[0] = pointFormat
   247  			z = z[1:]
   248  		}
   249  	}
   250  	if m.ticketSupported {
   251  		// http://tools.ietf.org/html/rfc5077#section-3.2
   252  		z[0] = byte(extensionSessionTicket >> 8)
   253  		z[1] = byte(extensionSessionTicket)
   254  		l := len(m.sessionTicket)
   255  		z[2] = byte(l >> 8)
   256  		z[3] = byte(l)
   257  		z = z[4:]
   258  		copy(z, m.sessionTicket)
   259  		z = z[len(m.sessionTicket):]
   260  	}
   261  	if len(m.supportedSignatureAlgorithms) > 0 {
   262  		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   263  		// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3
   264  		z[0] = byte(extensionSignatureAlgorithms >> 8)
   265  		z[1] = byte(extensionSignatureAlgorithms)
   266  		l := 2 + 2*len(m.supportedSignatureAlgorithms)
   267  		z[2] = byte(l >> 8)
   268  		z[3] = byte(l)
   269  		z = z[4:]
   270  
   271  		l -= 2
   272  		z[0] = byte(l >> 8)
   273  		z[1] = byte(l)
   274  		z = z[2:]
   275  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
   276  			z[0] = byte(sigAlgo >> 8)
   277  			z[1] = byte(sigAlgo)
   278  			z = z[2:]
   279  		}
   280  	}
   281  	if m.secureRenegotiationSupported {
   282  		z[0] = byte(extensionRenegotiationInfo >> 8)
   283  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   284  		z[2] = 0
   285  		z[3] = byte(len(m.secureRenegotiation) + 1)
   286  		z[4] = byte(len(m.secureRenegotiation))
   287  		z = z[5:]
   288  		copy(z, m.secureRenegotiation)
   289  		z = z[len(m.secureRenegotiation):]
   290  	}
   291  	if len(m.alpnProtocols) > 0 {
   292  		z[0] = byte(extensionALPN >> 8)
   293  		z[1] = byte(extensionALPN & 0xff)
   294  		lengths := z[2:]
   295  		z = z[6:]
   296  
   297  		stringsLength := 0
   298  		for _, s := range m.alpnProtocols {
   299  			l := len(s)
   300  			z[0] = byte(l)
   301  			copy(z[1:], s)
   302  			z = z[1+l:]
   303  			stringsLength += 1 + l
   304  		}
   305  
   306  		lengths[2] = byte(stringsLength >> 8)
   307  		lengths[3] = byte(stringsLength)
   308  		stringsLength += 2
   309  		lengths[0] = byte(stringsLength >> 8)
   310  		lengths[1] = byte(stringsLength)
   311  	}
   312  	if m.scts {
   313  		// https://tools.ietf.org/html/rfc6962#section-3.3.1
   314  		z[0] = byte(extensionSCT >> 8)
   315  		z[1] = byte(extensionSCT)
   316  		// zero uint16 for the zero-length extension_data
   317  		z = z[4:]
   318  	}
   319  	if len(m.keyShares) > 0 {
   320  		// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
   321  		z[0] = byte(extensionKeyShare >> 8)
   322  		z[1] = byte(extensionKeyShare)
   323  		lengths := z[2:]
   324  		z = z[6:]
   325  
   326  		totalLength := 0
   327  		for _, ks := range m.keyShares {
   328  			z[0] = byte(ks.group >> 8)
   329  			z[1] = byte(ks.group)
   330  			z[2] = byte(len(ks.data) >> 8)
   331  			z[3] = byte(len(ks.data))
   332  			copy(z[4:], ks.data)
   333  			z = z[4+len(ks.data):]
   334  			totalLength += 4 + len(ks.data)
   335  		}
   336  
   337  		lengths[2] = byte(totalLength >> 8)
   338  		lengths[3] = byte(totalLength)
   339  		totalLength += 2
   340  		lengths[0] = byte(totalLength >> 8)
   341  		lengths[1] = byte(totalLength)
   342  	}
   343  	if len(m.supportedVersions) > 0 {
   344  		z[0] = byte(extensionSupportedVersions >> 8)
   345  		z[1] = byte(extensionSupportedVersions)
   346  		l := 1 + 2*len(m.supportedVersions)
   347  		z[2] = byte(l >> 8)
   348  		z[3] = byte(l)
   349  		l -= 1
   350  		z[4] = byte(l)
   351  		z = z[5:]
   352  		for _, v := range m.supportedVersions {
   353  			z[0] = byte(v >> 8)
   354  			z[1] = byte(v)
   355  			z = z[2:]
   356  		}
   357  	}
   358  	if m.earlyData {
   359  		z[0] = byte(extensionEarlyData >> 8)
   360  		z[1] = byte(extensionEarlyData)
   361  		z = z[4:]
   362  	}
   363  
   364  	m.raw = x
   365  
   366  	return x
   367  }
   368  
   369  func (m *clientHelloMsg) unmarshal(data []byte) alert {
   370  	if len(data) < 42 {
   371  		return alertDecodeError
   372  	}
   373  	m.raw = data
   374  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   375  	m.random = data[6:38]
   376  	sessionIdLen := int(data[38])
   377  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   378  		return alertDecodeError
   379  	}
   380  	m.sessionId = data[39 : 39+sessionIdLen]
   381  	data = data[39+sessionIdLen:]
   382  	bindersOffset := 39 + sessionIdLen
   383  	if len(data) < 2 {
   384  		return alertDecodeError
   385  	}
   386  	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   387  	// they are uint16s, the number must be even.
   388  	cipherSuiteLen := int(data[0])<<8 | int(data[1])
   389  	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   390  		return alertDecodeError
   391  	}
   392  	numCipherSuites := cipherSuiteLen / 2
   393  	m.cipherSuites = make([]uint16, numCipherSuites)
   394  	for i := 0; i < numCipherSuites; i++ {
   395  		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   396  		if m.cipherSuites[i] == scsvRenegotiation {
   397  			m.secureRenegotiationSupported = true
   398  		}
   399  	}
   400  	data = data[2+cipherSuiteLen:]
   401  	bindersOffset += 2 + cipherSuiteLen
   402  	if len(data) < 1 {
   403  		return alertDecodeError
   404  	}
   405  	compressionMethodsLen := int(data[0])
   406  	if len(data) < 1+compressionMethodsLen {
   407  		return alertDecodeError
   408  	}
   409  	m.compressionMethods = data[1 : 1+compressionMethodsLen]
   410  
   411  	data = data[1+compressionMethodsLen:]
   412  	bindersOffset += 1 + compressionMethodsLen
   413  
   414  	m.nextProtoNeg = false
   415  	m.serverName = ""
   416  	m.ocspStapling = false
   417  	m.ticketSupported = false
   418  	m.sessionTicket = nil
   419  	m.supportedSignatureAlgorithms = nil
   420  	m.alpnProtocols = nil
   421  	m.scts = false
   422  	m.keyShares = nil
   423  	m.supportedVersions = nil
   424  	m.psks = nil
   425  	m.pskKeyExchangeModes = nil
   426  	m.earlyData = false
   427  
   428  	if len(data) == 0 {
   429  		// ClientHello is optionally followed by extension data
   430  		return alertSuccess
   431  	}
   432  	if len(data) < 2 {
   433  		return alertDecodeError
   434  	}
   435  
   436  	extensionsLength := int(data[0])<<8 | int(data[1])
   437  	data = data[2:]
   438  	bindersOffset += 2
   439  	if extensionsLength != len(data) {
   440  		return alertDecodeError
   441  	}
   442  
   443  	for len(data) != 0 {
   444  		if len(data) < 4 {
   445  			return alertDecodeError
   446  		}
   447  		extension := uint16(data[0])<<8 | uint16(data[1])
   448  		length := int(data[2])<<8 | int(data[3])
   449  		data = data[4:]
   450  		bindersOffset += 4
   451  		if len(data) < length {
   452  			return alertDecodeError
   453  		}
   454  
   455  		switch extension {
   456  		case extensionServerName:
   457  			d := data[:length]
   458  			if len(d) < 2 {
   459  				return alertDecodeError
   460  			}
   461  			namesLen := int(d[0])<<8 | int(d[1])
   462  			d = d[2:]
   463  			if len(d) != namesLen {
   464  				return alertDecodeError
   465  			}
   466  			for len(d) > 0 {
   467  				if len(d) < 3 {
   468  					return alertDecodeError
   469  				}
   470  				nameType := d[0]
   471  				nameLen := int(d[1])<<8 | int(d[2])
   472  				d = d[3:]
   473  				if len(d) < nameLen {
   474  					return alertDecodeError
   475  				}
   476  				if nameType == 0 {
   477  					m.serverName = string(d[:nameLen])
   478  					// An SNI value may not include a
   479  					// trailing dot. See
   480  					// https://tools.ietf.org/html/rfc6066#section-3.
   481  					if strings.HasSuffix(m.serverName, ".") {
   482  						// TODO use alertDecodeError?
   483  						return alertUnexpectedMessage
   484  					}
   485  					break
   486  				}
   487  				d = d[nameLen:]
   488  			}
   489  		case extensionNextProtoNeg:
   490  			if length > 0 {
   491  				return alertDecodeError
   492  			}
   493  			m.nextProtoNeg = true
   494  		case extensionStatusRequest:
   495  			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   496  		case extensionSupportedCurves:
   497  			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   498  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4
   499  			if length < 2 {
   500  				return alertDecodeError
   501  			}
   502  			l := int(data[0])<<8 | int(data[1])
   503  			if l%2 == 1 || length != l+2 {
   504  				return alertDecodeError
   505  			}
   506  			numCurves := l / 2
   507  			m.supportedCurves = make([]CurveID, numCurves)
   508  			d := data[2:]
   509  			for i := 0; i < numCurves; i++ {
   510  				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
   511  				d = d[2:]
   512  			}
   513  		case extensionSupportedPoints:
   514  			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   515  			if length < 1 {
   516  				return alertDecodeError
   517  			}
   518  			l := int(data[0])
   519  			if length != l+1 {
   520  				return alertDecodeError
   521  			}
   522  			m.supportedPoints = make([]uint8, l)
   523  			copy(m.supportedPoints, data[1:])
   524  		case extensionSessionTicket:
   525  			// http://tools.ietf.org/html/rfc5077#section-3.2
   526  			m.ticketSupported = true
   527  			m.sessionTicket = data[:length]
   528  		case extensionSignatureAlgorithms:
   529  			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   530  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3
   531  			if length < 2 || length&1 != 0 {
   532  				return alertDecodeError
   533  			}
   534  			l := int(data[0])<<8 | int(data[1])
   535  			if l != length-2 {
   536  				return alertDecodeError
   537  			}
   538  			n := l / 2
   539  			d := data[2:]
   540  			m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
   541  			for i := range m.supportedSignatureAlgorithms {
   542  				m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
   543  				d = d[2:]
   544  			}
   545  		case extensionRenegotiationInfo:
   546  			if length == 0 {
   547  				return alertDecodeError
   548  			}
   549  			d := data[:length]
   550  			l := int(d[0])
   551  			d = d[1:]
   552  			if l != len(d) {
   553  				return alertDecodeError
   554  			}
   555  
   556  			m.secureRenegotiation = d
   557  			m.secureRenegotiationSupported = true
   558  		case extensionALPN:
   559  			if length < 2 {
   560  				return alertDecodeError
   561  			}
   562  			l := int(data[0])<<8 | int(data[1])
   563  			if l != length-2 {
   564  				return alertDecodeError
   565  			}
   566  			d := data[2:length]
   567  			for len(d) != 0 {
   568  				stringLen := int(d[0])
   569  				d = d[1:]
   570  				if stringLen == 0 || stringLen > len(d) {
   571  					return alertDecodeError
   572  				}
   573  				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
   574  				d = d[stringLen:]
   575  			}
   576  		case extensionSCT:
   577  			m.scts = true
   578  			if length != 0 {
   579  				return alertDecodeError
   580  			}
   581  		case extensionKeyShare:
   582  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
   583  			if length < 2 {
   584  				return alertDecodeError
   585  			}
   586  			l := int(data[0])<<8 | int(data[1])
   587  			if l != length-2 {
   588  				return alertDecodeError
   589  			}
   590  			d := data[2:length]
   591  			for len(d) != 0 {
   592  				if len(d) < 4 {
   593  					return alertDecodeError
   594  				}
   595  				dataLen := int(d[2])<<8 | int(d[3])
   596  				if dataLen == 0 || 4+dataLen > len(d) {
   597  					return alertDecodeError
   598  				}
   599  				m.keyShares = append(m.keyShares, keyShare{
   600  					group: CurveID(d[0])<<8 | CurveID(d[1]),
   601  					data:  d[4 : 4+dataLen],
   602  				})
   603  				d = d[4+dataLen:]
   604  			}
   605  		case extensionSupportedVersions:
   606  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.1
   607  			if length < 1 {
   608  				return alertDecodeError
   609  			}
   610  			l := int(data[0])
   611  			if l%2 == 1 || length != l+1 {
   612  				return alertDecodeError
   613  			}
   614  			n := l / 2
   615  			d := data[1:]
   616  			for i := 0; i < n; i++ {
   617  				v := uint16(d[0])<<8 + uint16(d[1])
   618  				m.supportedVersions = append(m.supportedVersions, v)
   619  				d = d[2:]
   620  			}
   621  		case extensionPreSharedKey:
   622  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
   623  			if length < 2 {
   624  				return alertDecodeError
   625  			}
   626  			// Ensure this extension is the last one in the Client Hello
   627  			if len(data) != length {
   628  				return alertIllegalParameter
   629  			}
   630  			li := int(data[0])<<8 | int(data[1])
   631  			if 2+li+2 > length {
   632  				return alertDecodeError
   633  			}
   634  			d := data[2 : 2+li]
   635  			bindersOffset += 2 + li
   636  			for len(d) > 0 {
   637  				if len(d) < 6 {
   638  					return alertDecodeError
   639  				}
   640  				l := int(d[0])<<8 | int(d[1])
   641  				if len(d) < 2+l+4 {
   642  					return alertDecodeError
   643  				}
   644  				m.psks = append(m.psks, psk{
   645  					identity: d[2 : 2+l],
   646  					obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 |
   647  						uint32(d[l+4])<<8 | uint32(d[l+5]),
   648  				})
   649  				d = d[2+l+4:]
   650  			}
   651  			lb := int(data[li+2])<<8 | int(data[li+3])
   652  			d = data[2+li+2:]
   653  			if lb != len(d) || lb == 0 {
   654  				return alertDecodeError
   655  			}
   656  			i := 0
   657  			for len(d) > 0 {
   658  				if i >= len(m.psks) {
   659  					return alertIllegalParameter
   660  				}
   661  				if len(d) < 1 {
   662  					return alertDecodeError
   663  				}
   664  				l := int(d[0])
   665  				if l > len(d)-1 {
   666  					return alertDecodeError
   667  				}
   668  				if i >= len(m.psks) {
   669  					return alertIllegalParameter
   670  				}
   671  				m.psks[i].binder = d[1 : 1+l]
   672  				d = d[1+l:]
   673  				i++
   674  			}
   675  			if i != len(m.psks) {
   676  				return alertIllegalParameter
   677  			}
   678  			m.rawTruncated = m.raw[:bindersOffset]
   679  		case extensionPSKKeyExchangeModes:
   680  			if length < 2 {
   681  				return alertDecodeError
   682  			}
   683  			l := int(data[0])
   684  			if length != l+1 {
   685  				return alertDecodeError
   686  			}
   687  			m.pskKeyExchangeModes = data[1:length]
   688  		case extensionEarlyData:
   689  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
   690  			m.earlyData = true
   691  		}
   692  		data = data[length:]
   693  		bindersOffset += length
   694  	}
   695  
   696  	return alertSuccess
   697  }
   698  
   699  type serverHelloMsg struct {
   700  	raw                          []byte
   701  	vers                         uint16
   702  	random                       []byte
   703  	sessionId                    []byte
   704  	cipherSuite                  uint16
   705  	compressionMethod            uint8
   706  	nextProtoNeg                 bool
   707  	nextProtos                   []string
   708  	ocspStapling                 bool
   709  	scts                         [][]byte
   710  	ticketSupported              bool
   711  	secureRenegotiation          []byte
   712  	secureRenegotiationSupported bool
   713  	alpnProtocol                 string
   714  }
   715  
   716  func (m *serverHelloMsg) equal(i interface{}) bool {
   717  	m1, ok := i.(*serverHelloMsg)
   718  	if !ok {
   719  		return false
   720  	}
   721  
   722  	if len(m.scts) != len(m1.scts) {
   723  		return false
   724  	}
   725  	for i, sct := range m.scts {
   726  		if !bytes.Equal(sct, m1.scts[i]) {
   727  			return false
   728  		}
   729  	}
   730  
   731  	return bytes.Equal(m.raw, m1.raw) &&
   732  		m.vers == m1.vers &&
   733  		bytes.Equal(m.random, m1.random) &&
   734  		bytes.Equal(m.sessionId, m1.sessionId) &&
   735  		m.cipherSuite == m1.cipherSuite &&
   736  		m.compressionMethod == m1.compressionMethod &&
   737  		m.nextProtoNeg == m1.nextProtoNeg &&
   738  		eqStrings(m.nextProtos, m1.nextProtos) &&
   739  		m.ocspStapling == m1.ocspStapling &&
   740  		m.ticketSupported == m1.ticketSupported &&
   741  		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
   742  		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
   743  		m.alpnProtocol == m1.alpnProtocol
   744  }
   745  
   746  func (m *serverHelloMsg) marshal() []byte {
   747  	if m.raw != nil {
   748  		return m.raw
   749  	}
   750  
   751  	length := 38 + len(m.sessionId)
   752  	numExtensions := 0
   753  	extensionsLength := 0
   754  
   755  	nextProtoLen := 0
   756  	if m.nextProtoNeg {
   757  		numExtensions++
   758  		for _, v := range m.nextProtos {
   759  			nextProtoLen += len(v)
   760  		}
   761  		nextProtoLen += len(m.nextProtos)
   762  		extensionsLength += nextProtoLen
   763  	}
   764  	if m.ocspStapling {
   765  		numExtensions++
   766  	}
   767  	if m.ticketSupported {
   768  		numExtensions++
   769  	}
   770  	if m.secureRenegotiationSupported {
   771  		extensionsLength += 1 + len(m.secureRenegotiation)
   772  		numExtensions++
   773  	}
   774  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   775  		if alpnLen >= 256 {
   776  			panic("invalid ALPN protocol")
   777  		}
   778  		extensionsLength += 2 + 1 + alpnLen
   779  		numExtensions++
   780  	}
   781  	sctLen := 0
   782  	if len(m.scts) > 0 {
   783  		for _, sct := range m.scts {
   784  			sctLen += len(sct) + 2
   785  		}
   786  		extensionsLength += 2 + sctLen
   787  		numExtensions++
   788  	}
   789  
   790  	if numExtensions > 0 {
   791  		extensionsLength += 4 * numExtensions
   792  		length += 2 + extensionsLength
   793  	}
   794  
   795  	x := make([]byte, 4+length)
   796  	x[0] = typeServerHello
   797  	x[1] = uint8(length >> 16)
   798  	x[2] = uint8(length >> 8)
   799  	x[3] = uint8(length)
   800  	x[4] = uint8(m.vers >> 8)
   801  	x[5] = uint8(m.vers)
   802  	copy(x[6:38], m.random)
   803  	x[38] = uint8(len(m.sessionId))
   804  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   805  	z := x[39+len(m.sessionId):]
   806  	z[0] = uint8(m.cipherSuite >> 8)
   807  	z[1] = uint8(m.cipherSuite)
   808  	z[2] = m.compressionMethod
   809  
   810  	z = z[3:]
   811  	if numExtensions > 0 {
   812  		z[0] = byte(extensionsLength >> 8)
   813  		z[1] = byte(extensionsLength)
   814  		z = z[2:]
   815  	}
   816  	if m.nextProtoNeg {
   817  		z[0] = byte(extensionNextProtoNeg >> 8)
   818  		z[1] = byte(extensionNextProtoNeg & 0xff)
   819  		z[2] = byte(nextProtoLen >> 8)
   820  		z[3] = byte(nextProtoLen)
   821  		z = z[4:]
   822  
   823  		for _, v := range m.nextProtos {
   824  			l := len(v)
   825  			if l > 255 {
   826  				l = 255
   827  			}
   828  			z[0] = byte(l)
   829  			copy(z[1:], []byte(v[0:l]))
   830  			z = z[1+l:]
   831  		}
   832  	}
   833  	if m.ocspStapling {
   834  		z[0] = byte(extensionStatusRequest >> 8)
   835  		z[1] = byte(extensionStatusRequest)
   836  		z = z[4:]
   837  	}
   838  	if m.ticketSupported {
   839  		z[0] = byte(extensionSessionTicket >> 8)
   840  		z[1] = byte(extensionSessionTicket)
   841  		z = z[4:]
   842  	}
   843  	if m.secureRenegotiationSupported {
   844  		z[0] = byte(extensionRenegotiationInfo >> 8)
   845  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   846  		z[2] = 0
   847  		z[3] = byte(len(m.secureRenegotiation) + 1)
   848  		z[4] = byte(len(m.secureRenegotiation))
   849  		z = z[5:]
   850  		copy(z, m.secureRenegotiation)
   851  		z = z[len(m.secureRenegotiation):]
   852  	}
   853  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   854  		z[0] = byte(extensionALPN >> 8)
   855  		z[1] = byte(extensionALPN & 0xff)
   856  		l := 2 + 1 + alpnLen
   857  		z[2] = byte(l >> 8)
   858  		z[3] = byte(l)
   859  		l -= 2
   860  		z[4] = byte(l >> 8)
   861  		z[5] = byte(l)
   862  		l -= 1
   863  		z[6] = byte(l)
   864  		copy(z[7:], []byte(m.alpnProtocol))
   865  		z = z[7+alpnLen:]
   866  	}
   867  	if sctLen > 0 {
   868  		z[0] = byte(extensionSCT >> 8)
   869  		z[1] = byte(extensionSCT)
   870  		l := sctLen + 2
   871  		z[2] = byte(l >> 8)
   872  		z[3] = byte(l)
   873  		z[4] = byte(sctLen >> 8)
   874  		z[5] = byte(sctLen)
   875  
   876  		z = z[6:]
   877  		for _, sct := range m.scts {
   878  			z[0] = byte(len(sct) >> 8)
   879  			z[1] = byte(len(sct))
   880  			copy(z[2:], sct)
   881  			z = z[len(sct)+2:]
   882  		}
   883  	}
   884  
   885  	m.raw = x
   886  
   887  	return x
   888  }
   889  
   890  func (m *serverHelloMsg) unmarshal(data []byte) alert {
   891  	if len(data) < 42 {
   892  		return alertDecodeError
   893  	}
   894  	m.raw = data
   895  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   896  	m.random = data[6:38]
   897  	sessionIdLen := int(data[38])
   898  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   899  		return alertDecodeError
   900  	}
   901  	m.sessionId = data[39 : 39+sessionIdLen]
   902  	data = data[39+sessionIdLen:]
   903  	if len(data) < 3 {
   904  		return alertDecodeError
   905  	}
   906  	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   907  	m.compressionMethod = data[2]
   908  	data = data[3:]
   909  
   910  	m.nextProtoNeg = false
   911  	m.nextProtos = nil
   912  	m.ocspStapling = false
   913  	m.scts = nil
   914  	m.ticketSupported = false
   915  	m.alpnProtocol = ""
   916  
   917  	if len(data) == 0 {
   918  		// ServerHello is optionally followed by extension data
   919  		return alertSuccess
   920  	}
   921  	if len(data) < 2 {
   922  		return alertDecodeError
   923  	}
   924  
   925  	extensionsLength := int(data[0])<<8 | int(data[1])
   926  	data = data[2:]
   927  	if len(data) != extensionsLength {
   928  		return alertDecodeError
   929  	}
   930  
   931  	for len(data) != 0 {
   932  		if len(data) < 4 {
   933  			return alertDecodeError
   934  		}
   935  		extension := uint16(data[0])<<8 | uint16(data[1])
   936  		length := int(data[2])<<8 | int(data[3])
   937  		data = data[4:]
   938  		if len(data) < length {
   939  			return alertDecodeError
   940  		}
   941  
   942  		switch extension {
   943  		case extensionNextProtoNeg:
   944  			m.nextProtoNeg = true
   945  			d := data[:length]
   946  			for len(d) > 0 {
   947  				l := int(d[0])
   948  				d = d[1:]
   949  				if l == 0 || l > len(d) {
   950  					return alertDecodeError
   951  				}
   952  				m.nextProtos = append(m.nextProtos, string(d[:l]))
   953  				d = d[l:]
   954  			}
   955  		case extensionStatusRequest:
   956  			if length > 0 {
   957  				return alertDecodeError
   958  			}
   959  			m.ocspStapling = true
   960  		case extensionSessionTicket:
   961  			if length > 0 {
   962  				return alertDecodeError
   963  			}
   964  			m.ticketSupported = true
   965  		case extensionRenegotiationInfo:
   966  			if length == 0 {
   967  				return alertDecodeError
   968  			}
   969  			d := data[:length]
   970  			l := int(d[0])
   971  			d = d[1:]
   972  			if l != len(d) {
   973  				return alertDecodeError
   974  			}
   975  
   976  			m.secureRenegotiation = d
   977  			m.secureRenegotiationSupported = true
   978  		case extensionALPN:
   979  			d := data[:length]
   980  			if len(d) < 3 {
   981  				return alertDecodeError
   982  			}
   983  			l := int(d[0])<<8 | int(d[1])
   984  			if l != len(d)-2 {
   985  				return alertDecodeError
   986  			}
   987  			d = d[2:]
   988  			l = int(d[0])
   989  			if l != len(d)-1 {
   990  				return alertDecodeError
   991  			}
   992  			d = d[1:]
   993  			if len(d) == 0 {
   994  				// ALPN protocols must not be empty.
   995  				return alertDecodeError
   996  			}
   997  			m.alpnProtocol = string(d)
   998  		case extensionSCT:
   999  			d := data[:length]
  1000  
  1001  			if len(d) < 2 {
  1002  				return alertDecodeError
  1003  			}
  1004  			l := int(d[0])<<8 | int(d[1])
  1005  			d = d[2:]
  1006  			if len(d) != l || l == 0 {
  1007  				return alertDecodeError
  1008  			}
  1009  
  1010  			m.scts = make([][]byte, 0, 3)
  1011  			for len(d) != 0 {
  1012  				if len(d) < 2 {
  1013  					return alertDecodeError
  1014  				}
  1015  				sctLen := int(d[0])<<8 | int(d[1])
  1016  				d = d[2:]
  1017  				if sctLen == 0 || len(d) < sctLen {
  1018  					return alertDecodeError
  1019  				}
  1020  				m.scts = append(m.scts, d[:sctLen])
  1021  				d = d[sctLen:]
  1022  			}
  1023  		}
  1024  		data = data[length:]
  1025  	}
  1026  
  1027  	return alertSuccess
  1028  }
  1029  
  1030  type serverHelloMsg13 struct {
  1031  	raw         []byte
  1032  	vers        uint16
  1033  	random      []byte
  1034  	cipherSuite uint16
  1035  	keyShare    keyShare
  1036  	psk         bool
  1037  	pskIdentity uint16
  1038  }
  1039  
  1040  func (m *serverHelloMsg13) equal(i interface{}) bool {
  1041  	m1, ok := i.(*serverHelloMsg13)
  1042  	if !ok {
  1043  		return false
  1044  	}
  1045  
  1046  	return bytes.Equal(m.raw, m1.raw) &&
  1047  		m.vers == m1.vers &&
  1048  		bytes.Equal(m.random, m1.random) &&
  1049  		m.cipherSuite == m1.cipherSuite &&
  1050  		m.keyShare.group == m1.keyShare.group &&
  1051  		bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
  1052  		m.psk == m1.psk &&
  1053  		m.pskIdentity == m1.pskIdentity
  1054  }
  1055  
  1056  func (m *serverHelloMsg13) marshal() []byte {
  1057  	if m.raw != nil {
  1058  		return m.raw
  1059  	}
  1060  
  1061  	length := 38
  1062  	if m.keyShare.group != 0 {
  1063  		length += 8 + len(m.keyShare.data)
  1064  	}
  1065  	if m.psk {
  1066  		length += 6
  1067  	}
  1068  
  1069  	x := make([]byte, 4+length)
  1070  	x[0] = typeServerHello
  1071  	x[1] = uint8(length >> 16)
  1072  	x[2] = uint8(length >> 8)
  1073  	x[3] = uint8(length)
  1074  	x[4] = uint8(m.vers >> 8)
  1075  	x[5] = uint8(m.vers)
  1076  	copy(x[6:38], m.random)
  1077  	x[38] = uint8(m.cipherSuite >> 8)
  1078  	x[39] = uint8(m.cipherSuite)
  1079  
  1080  	z := x[42:]
  1081  	x[40] = uint8(len(z) >> 8)
  1082  	x[41] = uint8(len(z))
  1083  
  1084  	if m.psk {
  1085  		z[0] = byte(extensionPreSharedKey >> 8)
  1086  		z[1] = byte(extensionPreSharedKey)
  1087  		z[3] = 2
  1088  		z[4] = byte(m.pskIdentity >> 8)
  1089  		z[5] = byte(m.pskIdentity)
  1090  		z = z[6:]
  1091  	}
  1092  
  1093  	if m.keyShare.group != 0 {
  1094  		z[0] = uint8(extensionKeyShare >> 8)
  1095  		z[1] = uint8(extensionKeyShare)
  1096  		l := 4 + len(m.keyShare.data)
  1097  		z[2] = uint8(l >> 8)
  1098  		z[3] = uint8(l)
  1099  		z[4] = uint8(m.keyShare.group >> 8)
  1100  		z[5] = uint8(m.keyShare.group)
  1101  		l -= 4
  1102  		z[6] = uint8(l >> 8)
  1103  		z[7] = uint8(l)
  1104  		copy(z[8:], m.keyShare.data)
  1105  	}
  1106  
  1107  	m.raw = x
  1108  	return x
  1109  }
  1110  
  1111  func (m *serverHelloMsg13) unmarshal(data []byte) alert {
  1112  	if len(data) < 50 {
  1113  		return alertDecodeError
  1114  	}
  1115  	m.raw = data
  1116  	m.vers = uint16(data[4])<<8 | uint16(data[5])
  1117  	m.random = data[6:38]
  1118  	m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
  1119  	m.psk = false
  1120  	m.pskIdentity = 0
  1121  
  1122  	extensionsLength := int(data[40])<<8 | int(data[41])
  1123  	data = data[42:]
  1124  	if len(data) != extensionsLength {
  1125  		return alertDecodeError
  1126  	}
  1127  
  1128  	for len(data) != 0 {
  1129  		if len(data) < 4 {
  1130  			return alertDecodeError
  1131  		}
  1132  		extension := uint16(data[0])<<8 | uint16(data[1])
  1133  		length := int(data[2])<<8 | int(data[3])
  1134  		data = data[4:]
  1135  		if len(data) < length {
  1136  			return alertDecodeError
  1137  		}
  1138  
  1139  		switch extension {
  1140  		default:
  1141  			return alertDecodeError
  1142  		case extensionPreSharedKey:
  1143  			if length != 2 {
  1144  				return alertDecodeError
  1145  			}
  1146  			m.psk = true
  1147  			m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
  1148  		case extensionKeyShare:
  1149  			if length < 2 {
  1150  				return alertDecodeError
  1151  			}
  1152  			m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1])
  1153  			if length-4 != int(data[2])<<8|int(data[3]) {
  1154  				return alertDecodeError
  1155  			}
  1156  			m.keyShare.data = data[4:length]
  1157  		}
  1158  		data = data[length:]
  1159  	}
  1160  
  1161  	return alertSuccess
  1162  }
  1163  
  1164  type encryptedExtensionsMsg struct {
  1165  	raw          []byte
  1166  	alpnProtocol string
  1167  	earlyData    bool
  1168  }
  1169  
  1170  func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
  1171  	m1, ok := i.(*encryptedExtensionsMsg)
  1172  	if !ok {
  1173  		return false
  1174  	}
  1175  
  1176  	return bytes.Equal(m.raw, m1.raw) &&
  1177  		m.alpnProtocol == m1.alpnProtocol &&
  1178  		m.earlyData == m1.earlyData
  1179  }
  1180  
  1181  func (m *encryptedExtensionsMsg) marshal() []byte {
  1182  	if m.raw != nil {
  1183  		return m.raw
  1184  	}
  1185  
  1186  	length := 2
  1187  
  1188  	if m.earlyData {
  1189  		length += 4
  1190  	}
  1191  	alpnLen := len(m.alpnProtocol)
  1192  	if alpnLen > 0 {
  1193  		if alpnLen >= 256 {
  1194  			panic("invalid ALPN protocol")
  1195  		}
  1196  		length += 2 + 2 + 2 + 1 + alpnLen
  1197  	}
  1198  
  1199  	x := make([]byte, 4+length)
  1200  	x[0] = typeEncryptedExtensions
  1201  	x[1] = uint8(length >> 16)
  1202  	x[2] = uint8(length >> 8)
  1203  	x[3] = uint8(length)
  1204  	length -= 2
  1205  	x[4] = uint8(length >> 8)
  1206  	x[5] = uint8(length)
  1207  
  1208  	z := x[6:]
  1209  	if alpnLen > 0 {
  1210  		z[0] = byte(extensionALPN >> 8)
  1211  		z[1] = byte(extensionALPN)
  1212  		l := 2 + 1 + alpnLen
  1213  		z[2] = byte(l >> 8)
  1214  		z[3] = byte(l)
  1215  		l -= 2
  1216  		z[4] = byte(l >> 8)
  1217  		z[5] = byte(l)
  1218  		l -= 1
  1219  		z[6] = byte(l)
  1220  		copy(z[7:], []byte(m.alpnProtocol))
  1221  		z = z[7+alpnLen:]
  1222  	}
  1223  
  1224  	if m.earlyData {
  1225  		z[0] = byte(extensionEarlyData >> 8)
  1226  		z[1] = byte(extensionEarlyData)
  1227  		z = z[4:]
  1228  	}
  1229  
  1230  	m.raw = x
  1231  	return x
  1232  }
  1233  
  1234  func (m *encryptedExtensionsMsg) unmarshal(data []byte) alert {
  1235  	if len(data) < 6 {
  1236  		return alertDecodeError
  1237  	}
  1238  	m.raw = data
  1239  
  1240  	m.alpnProtocol = ""
  1241  	m.earlyData = false
  1242  
  1243  	extensionsLength := int(data[4])<<8 | int(data[5])
  1244  	data = data[6:]
  1245  	if len(data) != extensionsLength {
  1246  		return alertDecodeError
  1247  	}
  1248  
  1249  	for len(data) != 0 {
  1250  		if len(data) < 4 {
  1251  			return alertDecodeError
  1252  		}
  1253  		extension := uint16(data[0])<<8 | uint16(data[1])
  1254  		length := int(data[2])<<8 | int(data[3])
  1255  		data = data[4:]
  1256  		if len(data) < length {
  1257  			return alertDecodeError
  1258  		}
  1259  
  1260  		switch extension {
  1261  		case extensionALPN:
  1262  			d := data[:length]
  1263  			if len(d) < 3 {
  1264  				return alertDecodeError
  1265  			}
  1266  			l := int(d[0])<<8 | int(d[1])
  1267  			if l != len(d)-2 {
  1268  				return alertDecodeError
  1269  			}
  1270  			d = d[2:]
  1271  			l = int(d[0])
  1272  			if l != len(d)-1 {
  1273  				return alertDecodeError
  1274  			}
  1275  			d = d[1:]
  1276  			if len(d) == 0 {
  1277  				// ALPN protocols must not be empty.
  1278  				return alertDecodeError
  1279  			}
  1280  			m.alpnProtocol = string(d)
  1281  		case extensionEarlyData:
  1282  			// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
  1283  			m.earlyData = true
  1284  		}
  1285  
  1286  		data = data[length:]
  1287  	}
  1288  
  1289  	return alertSuccess
  1290  }
  1291  
  1292  type certificateMsg struct {
  1293  	raw          []byte
  1294  	certificates [][]byte
  1295  }
  1296  
  1297  func (m *certificateMsg) equal(i interface{}) bool {
  1298  	m1, ok := i.(*certificateMsg)
  1299  	if !ok {
  1300  		return false
  1301  	}
  1302  
  1303  	return bytes.Equal(m.raw, m1.raw) &&
  1304  		eqByteSlices(m.certificates, m1.certificates)
  1305  }
  1306  
  1307  func (m *certificateMsg) marshal() (x []byte) {
  1308  	if m.raw != nil {
  1309  		return m.raw
  1310  	}
  1311  
  1312  	var i int
  1313  	for _, slice := range m.certificates {
  1314  		i += len(slice)
  1315  	}
  1316  
  1317  	length := 3 + 3*len(m.certificates) + i
  1318  	x = make([]byte, 4+length)
  1319  	x[0] = typeCertificate
  1320  	x[1] = uint8(length >> 16)
  1321  	x[2] = uint8(length >> 8)
  1322  	x[3] = uint8(length)
  1323  
  1324  	certificateOctets := length - 3
  1325  	x[4] = uint8(certificateOctets >> 16)
  1326  	x[5] = uint8(certificateOctets >> 8)
  1327  	x[6] = uint8(certificateOctets)
  1328  
  1329  	y := x[7:]
  1330  	for _, slice := range m.certificates {
  1331  		y[0] = uint8(len(slice) >> 16)
  1332  		y[1] = uint8(len(slice) >> 8)
  1333  		y[2] = uint8(len(slice))
  1334  		copy(y[3:], slice)
  1335  		y = y[3+len(slice):]
  1336  	}
  1337  
  1338  	m.raw = x
  1339  	return
  1340  }
  1341  
  1342  func (m *certificateMsg) unmarshal(data []byte) alert {
  1343  	if len(data) < 7 {
  1344  		return alertDecodeError
  1345  	}
  1346  
  1347  	m.raw = data
  1348  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1349  	if uint32(len(data)) != certsLen+7 {
  1350  		return alertDecodeError
  1351  	}
  1352  
  1353  	numCerts := 0
  1354  	d := data[7:]
  1355  	for certsLen > 0 {
  1356  		if len(d) < 4 {
  1357  			return alertDecodeError
  1358  		}
  1359  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1360  		if uint32(len(d)) < 3+certLen {
  1361  			return alertDecodeError
  1362  		}
  1363  		d = d[3+certLen:]
  1364  		certsLen -= 3 + certLen
  1365  		numCerts++
  1366  	}
  1367  
  1368  	m.certificates = make([][]byte, numCerts)
  1369  	d = data[7:]
  1370  	for i := 0; i < numCerts; i++ {
  1371  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1372  		m.certificates[i] = d[3 : 3+certLen]
  1373  		d = d[3+certLen:]
  1374  	}
  1375  
  1376  	return alertSuccess
  1377  }
  1378  
  1379  type certificateEntry struct {
  1380  	data       []byte
  1381  	ocspStaple []byte
  1382  	sctList    [][]byte
  1383  }
  1384  
  1385  type certificateMsg13 struct {
  1386  	raw            []byte
  1387  	requestContext []byte
  1388  	certificates   []certificateEntry
  1389  }
  1390  
  1391  func (m *certificateMsg13) equal(i interface{}) bool {
  1392  	m1, ok := i.(*certificateMsg13)
  1393  	if !ok {
  1394  		return false
  1395  	}
  1396  
  1397  	if len(m.certificates) != len(m1.certificates) {
  1398  		return false
  1399  	}
  1400  	for i, _ := range m.certificates {
  1401  		ok := bytes.Equal(m.certificates[i].data, m1.certificates[i].data)
  1402  		ok = ok && bytes.Equal(m.certificates[i].ocspStaple, m1.certificates[i].ocspStaple)
  1403  		ok = ok && eqByteSlices(m.certificates[i].sctList, m1.certificates[i].sctList)
  1404  		if !ok {
  1405  			return false
  1406  		}
  1407  	}
  1408  
  1409  	return bytes.Equal(m.raw, m1.raw) &&
  1410  		bytes.Equal(m.requestContext, m1.requestContext)
  1411  }
  1412  
  1413  func (m *certificateMsg13) marshal() (x []byte) {
  1414  	if m.raw != nil {
  1415  		return m.raw
  1416  	}
  1417  
  1418  	var i int
  1419  	for _, cert := range m.certificates {
  1420  		i += len(cert.data)
  1421  		if len(cert.ocspStaple) != 0 {
  1422  			i += 8 + len(cert.ocspStaple)
  1423  		}
  1424  		if len(cert.sctList) != 0 {
  1425  			i += 6
  1426  			for _, sct := range cert.sctList {
  1427  				i += 2 + len(sct)
  1428  			}
  1429  		}
  1430  	}
  1431  
  1432  	length := 3 + 3*len(m.certificates) + i
  1433  	length += 2 * len(m.certificates) // extensions
  1434  	length += 1 + len(m.requestContext)
  1435  	x = make([]byte, 4+length)
  1436  	x[0] = typeCertificate
  1437  	x[1] = uint8(length >> 16)
  1438  	x[2] = uint8(length >> 8)
  1439  	x[3] = uint8(length)
  1440  
  1441  	z := x[4:]
  1442  
  1443  	z[0] = byte(len(m.requestContext))
  1444  	copy(z[1:], m.requestContext)
  1445  	z = z[1+len(m.requestContext):]
  1446  
  1447  	certificateOctets := len(z) - 3
  1448  	z[0] = uint8(certificateOctets >> 16)
  1449  	z[1] = uint8(certificateOctets >> 8)
  1450  	z[2] = uint8(certificateOctets)
  1451  
  1452  	z = z[3:]
  1453  	for _, cert := range m.certificates {
  1454  		z[0] = uint8(len(cert.data) >> 16)
  1455  		z[1] = uint8(len(cert.data) >> 8)
  1456  		z[2] = uint8(len(cert.data))
  1457  		copy(z[3:], cert.data)
  1458  		z = z[3+len(cert.data):]
  1459  
  1460  		extLenPos := z[:2]
  1461  		z = z[2:]
  1462  
  1463  		extensionLen := 0
  1464  		if len(cert.ocspStaple) != 0 {
  1465  			stapleLen := 4 + len(cert.ocspStaple)
  1466  			z[0] = uint8(extensionStatusRequest >> 8)
  1467  			z[1] = uint8(extensionStatusRequest)
  1468  			z[2] = uint8(stapleLen >> 8)
  1469  			z[3] = uint8(stapleLen)
  1470  
  1471  			stapleLen -= 4
  1472  			z[4] = statusTypeOCSP
  1473  			z[5] = uint8(stapleLen >> 16)
  1474  			z[6] = uint8(stapleLen >> 8)
  1475  			z[7] = uint8(stapleLen)
  1476  			copy(z[8:], cert.ocspStaple)
  1477  			z = z[8+stapleLen:]
  1478  
  1479  			extensionLen += 8 + stapleLen
  1480  		}
  1481  		if len(cert.sctList) != 0 {
  1482  			z[0] = uint8(extensionSCT >> 8)
  1483  			z[1] = uint8(extensionSCT)
  1484  			sctLenPos := z[2:6]
  1485  			z = z[6:]
  1486  			extensionLen += 6
  1487  
  1488  			sctLen := 2
  1489  			for _, sct := range cert.sctList {
  1490  				z[0] = uint8(len(sct) >> 8)
  1491  				z[1] = uint8(len(sct))
  1492  				copy(z[2:], sct)
  1493  				z = z[2+len(sct):]
  1494  
  1495  				extensionLen += 2 + len(sct)
  1496  				sctLen += 2 + len(sct)
  1497  			}
  1498  			sctLenPos[0] = uint8(sctLen >> 8)
  1499  			sctLenPos[1] = uint8(sctLen)
  1500  			sctLen -= 2
  1501  			sctLenPos[2] = uint8(sctLen >> 8)
  1502  			sctLenPos[3] = uint8(sctLen)
  1503  		}
  1504  		extLenPos[0] = uint8(extensionLen >> 8)
  1505  		extLenPos[1] = uint8(extensionLen)
  1506  	}
  1507  
  1508  	m.raw = x
  1509  	return
  1510  }
  1511  
  1512  func (m *certificateMsg13) unmarshal(data []byte) alert {
  1513  	if len(data) < 5 {
  1514  		return alertDecodeError
  1515  	}
  1516  
  1517  	m.raw = data
  1518  
  1519  	ctxLen := data[4]
  1520  	if len(data) < int(ctxLen)+5+3 {
  1521  		return alertDecodeError
  1522  	}
  1523  	m.requestContext = data[5 : 5+ctxLen]
  1524  
  1525  	d := data[5+ctxLen:]
  1526  	certsLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1527  	if uint32(len(d)) != certsLen+3 {
  1528  		return alertDecodeError
  1529  	}
  1530  
  1531  	numCerts := 0
  1532  	d = d[3:]
  1533  	for certsLen > 0 {
  1534  		if len(d) < 4 {
  1535  			return alertDecodeError
  1536  		}
  1537  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1538  		if uint32(len(d)) < 3+certLen {
  1539  			return alertDecodeError
  1540  		}
  1541  		d = d[3+certLen:]
  1542  
  1543  		if len(d) < 2 {
  1544  			return alertDecodeError
  1545  		}
  1546  		extLen := uint16(d[0])<<8 | uint16(d[1])
  1547  		if uint16(len(d)) < 2+extLen {
  1548  			return alertDecodeError
  1549  		}
  1550  		d = d[2+extLen:]
  1551  
  1552  		certsLen -= 3 + certLen + 2 + uint32(extLen)
  1553  		numCerts++
  1554  	}
  1555  
  1556  	m.certificates = make([]certificateEntry, numCerts)
  1557  	d = data[8+ctxLen:]
  1558  	for i := 0; i < numCerts; i++ {
  1559  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1560  		m.certificates[i].data = d[3 : 3+certLen]
  1561  		d = d[3+certLen:]
  1562  
  1563  		extLen := uint16(d[0])<<8 | uint16(d[1])
  1564  		d = d[2:]
  1565  		for extLen > 0 {
  1566  			if extLen < 4 {
  1567  				return alertDecodeError
  1568  			}
  1569  			typ := uint16(d[0])<<8 | uint16(d[1])
  1570  			bodyLen := uint16(d[2])<<8 | uint16(d[3])
  1571  			if extLen < 4+bodyLen {
  1572  				return alertDecodeError
  1573  			}
  1574  			body := d[4 : 4+bodyLen]
  1575  			d = d[4+bodyLen:]
  1576  			extLen -= 4 + bodyLen
  1577  
  1578  			switch typ {
  1579  			case extensionStatusRequest:
  1580  				if len(body) < 4 || body[0] != 0x01 {
  1581  					return alertDecodeError
  1582  				}
  1583  				ocspLen := int(body[1])<<16 | int(body[2])<<8 | int(body[3])
  1584  				if len(body) != 4+ocspLen {
  1585  					return alertDecodeError
  1586  				}
  1587  				m.certificates[i].ocspStaple = body[4:]
  1588  
  1589  			case extensionSCT:
  1590  				if len(body) < 2 {
  1591  					return alertDecodeError
  1592  				}
  1593  				listLen := int(body[0])<<8 | int(body[1])
  1594  				body = body[2:]
  1595  				if len(body) != listLen {
  1596  					return alertDecodeError
  1597  				}
  1598  				for len(body) > 0 {
  1599  					if len(body) < 2 {
  1600  						return alertDecodeError
  1601  					}
  1602  					sctLen := int(body[0])<<8 | int(body[1])
  1603  					if len(body) < 2+sctLen {
  1604  						return alertDecodeError
  1605  					}
  1606  					m.certificates[i].sctList = append(m.certificates[i].sctList, body[2:2+sctLen])
  1607  					body = body[2+sctLen:]
  1608  				}
  1609  			}
  1610  		}
  1611  	}
  1612  
  1613  	return alertSuccess
  1614  }
  1615  
  1616  type serverKeyExchangeMsg struct {
  1617  	raw []byte
  1618  	key []byte
  1619  }
  1620  
  1621  func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
  1622  	m1, ok := i.(*serverKeyExchangeMsg)
  1623  	if !ok {
  1624  		return false
  1625  	}
  1626  
  1627  	return bytes.Equal(m.raw, m1.raw) &&
  1628  		bytes.Equal(m.key, m1.key)
  1629  }
  1630  
  1631  func (m *serverKeyExchangeMsg) marshal() []byte {
  1632  	if m.raw != nil {
  1633  		return m.raw
  1634  	}
  1635  	length := len(m.key)
  1636  	x := make([]byte, length+4)
  1637  	x[0] = typeServerKeyExchange
  1638  	x[1] = uint8(length >> 16)
  1639  	x[2] = uint8(length >> 8)
  1640  	x[3] = uint8(length)
  1641  	copy(x[4:], m.key)
  1642  
  1643  	m.raw = x
  1644  	return x
  1645  }
  1646  
  1647  func (m *serverKeyExchangeMsg) unmarshal(data []byte) alert {
  1648  	m.raw = data
  1649  	if len(data) < 4 {
  1650  		return alertDecodeError
  1651  	}
  1652  	m.key = data[4:]
  1653  	return alertSuccess
  1654  }
  1655  
  1656  type certificateStatusMsg struct {
  1657  	raw        []byte
  1658  	statusType uint8
  1659  	response   []byte
  1660  }
  1661  
  1662  func (m *certificateStatusMsg) equal(i interface{}) bool {
  1663  	m1, ok := i.(*certificateStatusMsg)
  1664  	if !ok {
  1665  		return false
  1666  	}
  1667  
  1668  	return bytes.Equal(m.raw, m1.raw) &&
  1669  		m.statusType == m1.statusType &&
  1670  		bytes.Equal(m.response, m1.response)
  1671  }
  1672  
  1673  func (m *certificateStatusMsg) marshal() []byte {
  1674  	if m.raw != nil {
  1675  		return m.raw
  1676  	}
  1677  
  1678  	var x []byte
  1679  	if m.statusType == statusTypeOCSP {
  1680  		x = make([]byte, 4+4+len(m.response))
  1681  		x[0] = typeCertificateStatus
  1682  		l := len(m.response) + 4
  1683  		x[1] = byte(l >> 16)
  1684  		x[2] = byte(l >> 8)
  1685  		x[3] = byte(l)
  1686  		x[4] = statusTypeOCSP
  1687  
  1688  		l -= 4
  1689  		x[5] = byte(l >> 16)
  1690  		x[6] = byte(l >> 8)
  1691  		x[7] = byte(l)
  1692  		copy(x[8:], m.response)
  1693  	} else {
  1694  		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  1695  	}
  1696  
  1697  	m.raw = x
  1698  	return x
  1699  }
  1700  
  1701  func (m *certificateStatusMsg) unmarshal(data []byte) alert {
  1702  	m.raw = data
  1703  	if len(data) < 5 {
  1704  		return alertDecodeError
  1705  	}
  1706  	m.statusType = data[4]
  1707  
  1708  	m.response = nil
  1709  	if m.statusType == statusTypeOCSP {
  1710  		if len(data) < 8 {
  1711  			return alertDecodeError
  1712  		}
  1713  		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1714  		if uint32(len(data)) != 4+4+respLen {
  1715  			return alertDecodeError
  1716  		}
  1717  		m.response = data[8:]
  1718  	}
  1719  	return alertSuccess
  1720  }
  1721  
  1722  type serverHelloDoneMsg struct{}
  1723  
  1724  func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  1725  	_, ok := i.(*serverHelloDoneMsg)
  1726  	return ok
  1727  }
  1728  
  1729  func (m *serverHelloDoneMsg) marshal() []byte {
  1730  	x := make([]byte, 4)
  1731  	x[0] = typeServerHelloDone
  1732  	return x
  1733  }
  1734  
  1735  func (m *serverHelloDoneMsg) unmarshal(data []byte) alert {
  1736  	if len(data) != 4 {
  1737  		return alertDecodeError
  1738  	}
  1739  	return alertSuccess
  1740  }
  1741  
  1742  type clientKeyExchangeMsg struct {
  1743  	raw        []byte
  1744  	ciphertext []byte
  1745  }
  1746  
  1747  func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  1748  	m1, ok := i.(*clientKeyExchangeMsg)
  1749  	if !ok {
  1750  		return false
  1751  	}
  1752  
  1753  	return bytes.Equal(m.raw, m1.raw) &&
  1754  		bytes.Equal(m.ciphertext, m1.ciphertext)
  1755  }
  1756  
  1757  func (m *clientKeyExchangeMsg) marshal() []byte {
  1758  	if m.raw != nil {
  1759  		return m.raw
  1760  	}
  1761  	length := len(m.ciphertext)
  1762  	x := make([]byte, length+4)
  1763  	x[0] = typeClientKeyExchange
  1764  	x[1] = uint8(length >> 16)
  1765  	x[2] = uint8(length >> 8)
  1766  	x[3] = uint8(length)
  1767  	copy(x[4:], m.ciphertext)
  1768  
  1769  	m.raw = x
  1770  	return x
  1771  }
  1772  
  1773  func (m *clientKeyExchangeMsg) unmarshal(data []byte) alert {
  1774  	m.raw = data
  1775  	if len(data) < 4 {
  1776  		return alertDecodeError
  1777  	}
  1778  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1779  	if l != len(data)-4 {
  1780  		return alertDecodeError
  1781  	}
  1782  	m.ciphertext = data[4:]
  1783  	return alertSuccess
  1784  }
  1785  
  1786  type finishedMsg struct {
  1787  	raw        []byte
  1788  	verifyData []byte
  1789  }
  1790  
  1791  func (m *finishedMsg) equal(i interface{}) bool {
  1792  	m1, ok := i.(*finishedMsg)
  1793  	if !ok {
  1794  		return false
  1795  	}
  1796  
  1797  	return bytes.Equal(m.raw, m1.raw) &&
  1798  		bytes.Equal(m.verifyData, m1.verifyData)
  1799  }
  1800  
  1801  func (m *finishedMsg) marshal() (x []byte) {
  1802  	if m.raw != nil {
  1803  		return m.raw
  1804  	}
  1805  
  1806  	x = make([]byte, 4+len(m.verifyData))
  1807  	x[0] = typeFinished
  1808  	x[3] = byte(len(m.verifyData))
  1809  	copy(x[4:], m.verifyData)
  1810  	m.raw = x
  1811  	return
  1812  }
  1813  
  1814  func (m *finishedMsg) unmarshal(data []byte) alert {
  1815  	m.raw = data
  1816  	if len(data) < 4 {
  1817  		return alertDecodeError
  1818  	}
  1819  	m.verifyData = data[4:]
  1820  	return alertSuccess
  1821  }
  1822  
  1823  type nextProtoMsg struct {
  1824  	raw   []byte
  1825  	proto string
  1826  }
  1827  
  1828  func (m *nextProtoMsg) equal(i interface{}) bool {
  1829  	m1, ok := i.(*nextProtoMsg)
  1830  	if !ok {
  1831  		return false
  1832  	}
  1833  
  1834  	return bytes.Equal(m.raw, m1.raw) &&
  1835  		m.proto == m1.proto
  1836  }
  1837  
  1838  func (m *nextProtoMsg) marshal() []byte {
  1839  	if m.raw != nil {
  1840  		return m.raw
  1841  	}
  1842  	l := len(m.proto)
  1843  	if l > 255 {
  1844  		l = 255
  1845  	}
  1846  
  1847  	padding := 32 - (l+2)%32
  1848  	length := l + padding + 2
  1849  	x := make([]byte, length+4)
  1850  	x[0] = typeNextProtocol
  1851  	x[1] = uint8(length >> 16)
  1852  	x[2] = uint8(length >> 8)
  1853  	x[3] = uint8(length)
  1854  
  1855  	y := x[4:]
  1856  	y[0] = byte(l)
  1857  	copy(y[1:], []byte(m.proto[0:l]))
  1858  	y = y[1+l:]
  1859  	y[0] = byte(padding)
  1860  
  1861  	m.raw = x
  1862  
  1863  	return x
  1864  }
  1865  
  1866  func (m *nextProtoMsg) unmarshal(data []byte) alert {
  1867  	m.raw = data
  1868  
  1869  	if len(data) < 5 {
  1870  		return alertDecodeError
  1871  	}
  1872  	data = data[4:]
  1873  	protoLen := int(data[0])
  1874  	data = data[1:]
  1875  	if len(data) < protoLen {
  1876  		return alertDecodeError
  1877  	}
  1878  	m.proto = string(data[0:protoLen])
  1879  	data = data[protoLen:]
  1880  
  1881  	if len(data) < 1 {
  1882  		return alertDecodeError
  1883  	}
  1884  	paddingLen := int(data[0])
  1885  	data = data[1:]
  1886  	if len(data) != paddingLen {
  1887  		return alertDecodeError
  1888  	}
  1889  
  1890  	return alertSuccess
  1891  }
  1892  
  1893  type certificateRequestMsg struct {
  1894  	raw []byte
  1895  	// hasSignatureAndHash indicates whether this message includes a list
  1896  	// of signature and hash functions. This change was introduced with TLS
  1897  	// 1.2.
  1898  	hasSignatureAndHash bool
  1899  
  1900  	certificateTypes             []byte
  1901  	supportedSignatureAlgorithms []SignatureScheme
  1902  	certificateAuthorities       [][]byte
  1903  }
  1904  
  1905  func (m *certificateRequestMsg) equal(i interface{}) bool {
  1906  	m1, ok := i.(*certificateRequestMsg)
  1907  	if !ok {
  1908  		return false
  1909  	}
  1910  
  1911  	return bytes.Equal(m.raw, m1.raw) &&
  1912  		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1913  		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1914  		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
  1915  }
  1916  
  1917  func (m *certificateRequestMsg) marshal() (x []byte) {
  1918  	if m.raw != nil {
  1919  		return m.raw
  1920  	}
  1921  
  1922  	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1923  	length := 1 + len(m.certificateTypes) + 2
  1924  	casLength := 0
  1925  	for _, ca := range m.certificateAuthorities {
  1926  		casLength += 2 + len(ca)
  1927  	}
  1928  	length += casLength
  1929  
  1930  	if m.hasSignatureAndHash {
  1931  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1932  	}
  1933  
  1934  	x = make([]byte, 4+length)
  1935  	x[0] = typeCertificateRequest
  1936  	x[1] = uint8(length >> 16)
  1937  	x[2] = uint8(length >> 8)
  1938  	x[3] = uint8(length)
  1939  
  1940  	x[4] = uint8(len(m.certificateTypes))
  1941  
  1942  	copy(x[5:], m.certificateTypes)
  1943  	y := x[5+len(m.certificateTypes):]
  1944  
  1945  	if m.hasSignatureAndHash {
  1946  		n := len(m.supportedSignatureAlgorithms) * 2
  1947  		y[0] = uint8(n >> 8)
  1948  		y[1] = uint8(n)
  1949  		y = y[2:]
  1950  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1951  			y[0] = uint8(sigAlgo >> 8)
  1952  			y[1] = uint8(sigAlgo)
  1953  			y = y[2:]
  1954  		}
  1955  	}
  1956  
  1957  	y[0] = uint8(casLength >> 8)
  1958  	y[1] = uint8(casLength)
  1959  	y = y[2:]
  1960  	for _, ca := range m.certificateAuthorities {
  1961  		y[0] = uint8(len(ca) >> 8)
  1962  		y[1] = uint8(len(ca))
  1963  		y = y[2:]
  1964  		copy(y, ca)
  1965  		y = y[len(ca):]
  1966  	}
  1967  
  1968  	m.raw = x
  1969  	return
  1970  }
  1971  
  1972  func (m *certificateRequestMsg) unmarshal(data []byte) alert {
  1973  	m.raw = data
  1974  
  1975  	if len(data) < 5 {
  1976  		return alertDecodeError
  1977  	}
  1978  
  1979  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1980  	if uint32(len(data))-4 != length {
  1981  		return alertDecodeError
  1982  	}
  1983  
  1984  	numCertTypes := int(data[4])
  1985  	data = data[5:]
  1986  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1987  		return alertDecodeError
  1988  	}
  1989  
  1990  	m.certificateTypes = make([]byte, numCertTypes)
  1991  	if copy(m.certificateTypes, data) != numCertTypes {
  1992  		return alertDecodeError
  1993  	}
  1994  
  1995  	data = data[numCertTypes:]
  1996  
  1997  	if m.hasSignatureAndHash {
  1998  		if len(data) < 2 {
  1999  			return alertDecodeError
  2000  		}
  2001  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  2002  		data = data[2:]
  2003  		if sigAndHashLen&1 != 0 {
  2004  			return alertDecodeError
  2005  		}
  2006  		if len(data) < int(sigAndHashLen) {
  2007  			return alertDecodeError
  2008  		}
  2009  		numSigAlgos := sigAndHashLen / 2
  2010  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  2011  		for i := range m.supportedSignatureAlgorithms {
  2012  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  2013  			data = data[2:]
  2014  		}
  2015  	}
  2016  
  2017  	if len(data) < 2 {
  2018  		return alertDecodeError
  2019  	}
  2020  	casLength := uint16(data[0])<<8 | uint16(data[1])
  2021  	data = data[2:]
  2022  	if len(data) < int(casLength) {
  2023  		return alertDecodeError
  2024  	}
  2025  	cas := make([]byte, casLength)
  2026  	copy(cas, data)
  2027  	data = data[casLength:]
  2028  
  2029  	m.certificateAuthorities = nil
  2030  	for len(cas) > 0 {
  2031  		if len(cas) < 2 {
  2032  			return alertDecodeError
  2033  		}
  2034  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  2035  		cas = cas[2:]
  2036  
  2037  		if len(cas) < int(caLen) {
  2038  			return alertDecodeError
  2039  		}
  2040  
  2041  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  2042  		cas = cas[caLen:]
  2043  	}
  2044  
  2045  	if len(data) != 0 {
  2046  		return alertDecodeError
  2047  	}
  2048  	return alertSuccess
  2049  }
  2050  
  2051  type certificateVerifyMsg struct {
  2052  	raw                 []byte
  2053  	hasSignatureAndHash bool
  2054  	signatureAlgorithm  SignatureScheme
  2055  	signature           []byte
  2056  }
  2057  
  2058  func (m *certificateVerifyMsg) equal(i interface{}) bool {
  2059  	m1, ok := i.(*certificateVerifyMsg)
  2060  	if !ok {
  2061  		return false
  2062  	}
  2063  
  2064  	return bytes.Equal(m.raw, m1.raw) &&
  2065  		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  2066  		m.signatureAlgorithm == m1.signatureAlgorithm &&
  2067  		bytes.Equal(m.signature, m1.signature)
  2068  }
  2069  
  2070  func (m *certificateVerifyMsg) marshal() (x []byte) {
  2071  	if m.raw != nil {
  2072  		return m.raw
  2073  	}
  2074  
  2075  	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  2076  	siglength := len(m.signature)
  2077  	length := 2 + siglength
  2078  	if m.hasSignatureAndHash {
  2079  		length += 2
  2080  	}
  2081  	x = make([]byte, 4+length)
  2082  	x[0] = typeCertificateVerify
  2083  	x[1] = uint8(length >> 16)
  2084  	x[2] = uint8(length >> 8)
  2085  	x[3] = uint8(length)
  2086  	y := x[4:]
  2087  	if m.hasSignatureAndHash {
  2088  		y[0] = uint8(m.signatureAlgorithm >> 8)
  2089  		y[1] = uint8(m.signatureAlgorithm)
  2090  		y = y[2:]
  2091  	}
  2092  	y[0] = uint8(siglength >> 8)
  2093  	y[1] = uint8(siglength)
  2094  	copy(y[2:], m.signature)
  2095  
  2096  	m.raw = x
  2097  
  2098  	return
  2099  }
  2100  
  2101  func (m *certificateVerifyMsg) unmarshal(data []byte) alert {
  2102  	m.raw = data
  2103  
  2104  	if len(data) < 6 {
  2105  		return alertDecodeError
  2106  	}
  2107  
  2108  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2109  	if uint32(len(data))-4 != length {
  2110  		return alertDecodeError
  2111  	}
  2112  
  2113  	data = data[4:]
  2114  	if m.hasSignatureAndHash {
  2115  		m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  2116  		data = data[2:]
  2117  	}
  2118  
  2119  	if len(data) < 2 {
  2120  		return alertDecodeError
  2121  	}
  2122  	siglength := int(data[0])<<8 + int(data[1])
  2123  	data = data[2:]
  2124  	if len(data) != siglength {
  2125  		return alertDecodeError
  2126  	}
  2127  
  2128  	m.signature = data
  2129  
  2130  	return alertSuccess
  2131  }
  2132  
  2133  type newSessionTicketMsg struct {
  2134  	raw    []byte
  2135  	ticket []byte
  2136  }
  2137  
  2138  func (m *newSessionTicketMsg) equal(i interface{}) bool {
  2139  	m1, ok := i.(*newSessionTicketMsg)
  2140  	if !ok {
  2141  		return false
  2142  	}
  2143  
  2144  	return bytes.Equal(m.raw, m1.raw) &&
  2145  		bytes.Equal(m.ticket, m1.ticket)
  2146  }
  2147  
  2148  func (m *newSessionTicketMsg) marshal() (x []byte) {
  2149  	if m.raw != nil {
  2150  		return m.raw
  2151  	}
  2152  
  2153  	// See http://tools.ietf.org/html/rfc5077#section-3.3
  2154  	ticketLen := len(m.ticket)
  2155  	length := 2 + 4 + ticketLen
  2156  	x = make([]byte, 4+length)
  2157  	x[0] = typeNewSessionTicket
  2158  	x[1] = uint8(length >> 16)
  2159  	x[2] = uint8(length >> 8)
  2160  	x[3] = uint8(length)
  2161  	x[8] = uint8(ticketLen >> 8)
  2162  	x[9] = uint8(ticketLen)
  2163  	copy(x[10:], m.ticket)
  2164  
  2165  	m.raw = x
  2166  
  2167  	return
  2168  }
  2169  
  2170  func (m *newSessionTicketMsg) unmarshal(data []byte) alert {
  2171  	m.raw = data
  2172  
  2173  	if len(data) < 10 {
  2174  		return alertDecodeError
  2175  	}
  2176  
  2177  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2178  	if uint32(len(data))-4 != length {
  2179  		return alertDecodeError
  2180  	}
  2181  
  2182  	ticketLen := int(data[8])<<8 + int(data[9])
  2183  	if len(data)-10 != ticketLen {
  2184  		return alertDecodeError
  2185  	}
  2186  
  2187  	m.ticket = data[10:]
  2188  
  2189  	return alertSuccess
  2190  }
  2191  
  2192  type newSessionTicketMsg13 struct {
  2193  	raw                []byte
  2194  	lifetime           uint32
  2195  	ageAdd             uint32
  2196  	ticket             []byte
  2197  	withEarlyDataInfo  bool
  2198  	maxEarlyDataLength uint32
  2199  }
  2200  
  2201  func (m *newSessionTicketMsg13) equal(i interface{}) bool {
  2202  	m1, ok := i.(*newSessionTicketMsg13)
  2203  	if !ok {
  2204  		return false
  2205  	}
  2206  
  2207  	return bytes.Equal(m.raw, m1.raw) &&
  2208  		m.lifetime == m1.lifetime &&
  2209  		m.ageAdd == m1.ageAdd &&
  2210  		bytes.Equal(m.ticket, m1.ticket) &&
  2211  		m.withEarlyDataInfo == m1.withEarlyDataInfo &&
  2212  		m.maxEarlyDataLength == m1.maxEarlyDataLength
  2213  }
  2214  
  2215  func (m *newSessionTicketMsg13) marshal() (x []byte) {
  2216  	if m.raw != nil {
  2217  		return m.raw
  2218  	}
  2219  
  2220  	// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6
  2221  	ticketLen := len(m.ticket)
  2222  	length := 12 + ticketLen
  2223  	if m.withEarlyDataInfo {
  2224  		length += 8
  2225  	}
  2226  	x = make([]byte, 4+length)
  2227  	x[0] = typeNewSessionTicket
  2228  	x[1] = uint8(length >> 16)
  2229  	x[2] = uint8(length >> 8)
  2230  	x[3] = uint8(length)
  2231  
  2232  	x[4] = uint8(m.lifetime >> 24)
  2233  	x[5] = uint8(m.lifetime >> 16)
  2234  	x[6] = uint8(m.lifetime >> 8)
  2235  	x[7] = uint8(m.lifetime)
  2236  	x[8] = uint8(m.ageAdd >> 24)
  2237  	x[9] = uint8(m.ageAdd >> 16)
  2238  	x[10] = uint8(m.ageAdd >> 8)
  2239  	x[11] = uint8(m.ageAdd)
  2240  
  2241  	x[12] = uint8(ticketLen >> 8)
  2242  	x[13] = uint8(ticketLen)
  2243  	copy(x[14:], m.ticket)
  2244  
  2245  	if m.withEarlyDataInfo {
  2246  		z := x[14+ticketLen:]
  2247  		z[1] = 8
  2248  		z[2] = uint8(extensionTicketEarlyDataInfo >> 8)
  2249  		z[3] = uint8(extensionTicketEarlyDataInfo)
  2250  		z[5] = 4
  2251  		z[6] = uint8(m.maxEarlyDataLength >> 24)
  2252  		z[7] = uint8(m.maxEarlyDataLength >> 16)
  2253  		z[8] = uint8(m.maxEarlyDataLength >> 8)
  2254  		z[9] = uint8(m.maxEarlyDataLength)
  2255  	}
  2256  
  2257  	m.raw = x
  2258  
  2259  	return
  2260  }
  2261  
  2262  func (m *newSessionTicketMsg13) unmarshal(data []byte) alert {
  2263  	m.raw = data
  2264  	m.maxEarlyDataLength = 0
  2265  	m.withEarlyDataInfo = false
  2266  
  2267  	if len(data) < 16 {
  2268  		return alertDecodeError
  2269  	}
  2270  
  2271  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  2272  	if uint32(len(data))-4 != length {
  2273  		return alertDecodeError
  2274  	}
  2275  
  2276  	m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 |
  2277  		uint32(data[6])<<8 | uint32(data[7])
  2278  	m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 |
  2279  		uint32(data[10])<<8 | uint32(data[11])
  2280  
  2281  	ticketLen := int(data[12])<<8 + int(data[13])
  2282  	if 14+ticketLen > len(data) {
  2283  		return alertDecodeError
  2284  	}
  2285  	m.ticket = data[14 : 14+ticketLen]
  2286  
  2287  	data = data[14+ticketLen:]
  2288  	extLen := int(data[0])<<8 + int(data[1])
  2289  	if extLen != len(data)-2 {
  2290  		return alertDecodeError
  2291  	}
  2292  
  2293  	data = data[2:]
  2294  	for len(data) > 0 {
  2295  		if len(data) < 4 {
  2296  			return alertDecodeError
  2297  		}
  2298  		extType := uint16(data[0])<<8 + uint16(data[1])
  2299  		length := int(data[2])<<8 + int(data[3])
  2300  		data = data[4:]
  2301  
  2302  		switch extType {
  2303  		case extensionTicketEarlyDataInfo:
  2304  			if length != 4 {
  2305  				return alertDecodeError
  2306  			}
  2307  			m.withEarlyDataInfo = true
  2308  			m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 |
  2309  				uint32(data[2])<<8 | uint32(data[3])
  2310  		}
  2311  		data = data[length:]
  2312  	}
  2313  
  2314  	return alertSuccess
  2315  }
  2316  
  2317  type helloRequestMsg struct {
  2318  }
  2319  
  2320  func (*helloRequestMsg) marshal() []byte {
  2321  	return []byte{typeHelloRequest, 0, 0, 0}
  2322  }
  2323  
  2324  func (*helloRequestMsg) unmarshal(data []byte) alert {
  2325  	if len(data) != 4 {
  2326  		return alertDecodeError
  2327  	}
  2328  	return alertSuccess
  2329  }
  2330  
  2331  func eqUint16s(x, y []uint16) bool {
  2332  	if len(x) != len(y) {
  2333  		return false
  2334  	}
  2335  	for i, v := range x {
  2336  		if y[i] != v {
  2337  			return false
  2338  		}
  2339  	}
  2340  	return true
  2341  }
  2342  
  2343  func eqCurveIDs(x, y []CurveID) bool {
  2344  	if len(x) != len(y) {
  2345  		return false
  2346  	}
  2347  	for i, v := range x {
  2348  		if y[i] != v {
  2349  			return false
  2350  		}
  2351  	}
  2352  	return true
  2353  }
  2354  
  2355  func eqStrings(x, y []string) bool {
  2356  	if len(x) != len(y) {
  2357  		return false
  2358  	}
  2359  	for i, v := range x {
  2360  		if y[i] != v {
  2361  			return false
  2362  		}
  2363  	}
  2364  	return true
  2365  }
  2366  
  2367  func eqByteSlices(x, y [][]byte) bool {
  2368  	if len(x) != len(y) {
  2369  		return false
  2370  	}
  2371  	for i, v := range x {
  2372  		if !bytes.Equal(v, y[i]) {
  2373  			return false
  2374  		}
  2375  	}
  2376  	return true
  2377  }
  2378  
  2379  func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
  2380  	if len(x) != len(y) {
  2381  		return false
  2382  	}
  2383  	for i, v := range x {
  2384  		if v != y[i] {
  2385  			return false
  2386  		}
  2387  	}
  2388  	return true
  2389  }
  2390  
  2391  func eqKeyShares(x, y []keyShare) bool {
  2392  	if len(x) != len(y) {
  2393  		return false
  2394  	}
  2395  	for i := range x {
  2396  		if x[i].group != y[i].group {
  2397  			return false
  2398  		}
  2399  		if !bytes.Equal(x[i].data, y[i].data) {
  2400  			return false
  2401  		}
  2402  	}
  2403  	return true
  2404  }