github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"reflect"
    10  )
    11  
    12  type clientHelloMsg struct {
    13  	raw                   []byte
    14  	vers                  uint16
    15  	random                []byte
    16  	sessionId             []byte
    17  	cipherSuites          []uint16
    18  	compressionMethods    []uint8
    19  	nextProtoNeg          bool
    20  	serverName            string
    21  	ocspStapling          bool
    22  	scts                  bool
    23  	supportedCurves       []CurveID
    24  	supportedPoints       []uint8
    25  	ticketSupported       bool
    26  	sessionTicket         []uint8
    27  	signatureAndHashes    []SigAndHash
    28  	secureRenegotiation   bool
    29  	heartbeatEnabled      bool
    30  	heartbeatMode         uint8
    31  	extendedRandomEnabled bool
    32  	extendedRandom        []byte
    33  	extendedMasterSecret  bool
    34  	sctEnabled            bool
    35  	alpnProtocols         []string
    36  	unknownExtensions     [][]byte
    37  }
    38  
    39  func (m *clientHelloMsg) equal(i interface{}) bool {
    40  	m1, ok := i.(*clientHelloMsg)
    41  	if !ok {
    42  		return false
    43  	}
    44  
    45  	return bytes.Equal(m.raw, m1.raw) &&
    46  		m.vers == m1.vers &&
    47  		bytes.Equal(m.random, m1.random) &&
    48  		bytes.Equal(m.sessionId, m1.sessionId) &&
    49  		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    50  		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    51  		m.nextProtoNeg == m1.nextProtoNeg &&
    52  		m.serverName == m1.serverName &&
    53  		m.ocspStapling == m1.ocspStapling &&
    54  		m.scts == m1.scts &&
    55  		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
    56  		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    57  		m.ticketSupported == m1.ticketSupported &&
    58  		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    59  		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
    60  		m.secureRenegotiation == m1.secureRenegotiation &&
    61  		m.heartbeatEnabled == m1.heartbeatEnabled &&
    62  		m.heartbeatMode == m1.heartbeatMode &&
    63  		m.extendedRandomEnabled == m1.extendedRandomEnabled &&
    64  		bytes.Equal(m.extendedRandom, m1.extendedRandom) &&
    65  		m.extendedMasterSecret == m1.extendedMasterSecret &&
    66  		eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
    67  		reflect.DeepEqual(m.unknownExtensions, m1.unknownExtensions)
    68  }
    69  
    70  func (m *clientHelloMsg) marshal() []byte {
    71  	if m.raw != nil {
    72  		return m.raw
    73  	}
    74  
    75  	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    76  	numExtensions := 0
    77  	extensionsLength := 0
    78  	if m.nextProtoNeg {
    79  		numExtensions++
    80  	}
    81  	if m.ocspStapling {
    82  		extensionsLength += 1 + 2 + 2
    83  		numExtensions++
    84  	}
    85  	if len(m.serverName) > 0 {
    86  		extensionsLength += 5 + len(m.serverName)
    87  		numExtensions++
    88  	}
    89  	if len(m.supportedCurves) > 0 {
    90  		extensionsLength += 2 + 2*len(m.supportedCurves)
    91  		numExtensions++
    92  	}
    93  	if len(m.supportedPoints) > 0 {
    94  		extensionsLength += 1 + len(m.supportedPoints)
    95  		numExtensions++
    96  	}
    97  	if m.ticketSupported {
    98  		extensionsLength += len(m.sessionTicket)
    99  		numExtensions++
   100  	}
   101  	if len(m.signatureAndHashes) > 0 {
   102  		extensionsLength += 2 + 2*len(m.signatureAndHashes)
   103  		numExtensions++
   104  	}
   105  	if m.secureRenegotiation {
   106  		extensionsLength += 1
   107  		numExtensions++
   108  	}
   109  	if len(m.alpnProtocols) > 0 {
   110  		extensionsLength += 2
   111  		for _, s := range m.alpnProtocols {
   112  			if l := len(s); l == 0 || l > 255 {
   113  				panic("invalid ALPN protocol")
   114  			}
   115  			extensionsLength++
   116  			extensionsLength += len(s)
   117  		}
   118  		numExtensions++
   119  	}
   120  	if m.heartbeatEnabled {
   121  		extensionsLength += 1
   122  		numExtensions++
   123  	}
   124  	if m.extendedRandomEnabled {
   125  		extensionsLength += 2 + len(m.extendedRandom)
   126  		numExtensions++
   127  	}
   128  	if m.extendedMasterSecret {
   129  		numExtensions++
   130  	}
   131  	if m.sctEnabled {
   132  		numExtensions++
   133  	}
   134  	if len(m.unknownExtensions) > 0 {
   135  		// we do not update numExtensions because the extension code and length
   136  		// are already contained at the beginning of every 'ext' below
   137  		for _, ext := range m.unknownExtensions {
   138  			extensionsLength += len(ext)
   139  		}
   140  	}
   141  	if numExtensions > 0 {
   142  		extensionsLength += 4 * numExtensions
   143  		length += 2 + extensionsLength
   144  	}
   145  
   146  	x := make([]byte, 4+length)
   147  	x[0] = typeClientHello
   148  	x[1] = uint8(length >> 16)
   149  	x[2] = uint8(length >> 8)
   150  	x[3] = uint8(length)
   151  	x[4] = uint8(m.vers >> 8)
   152  	x[5] = uint8(m.vers)
   153  	copy(x[6:38], m.random)
   154  	x[38] = uint8(len(m.sessionId))
   155  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   156  	y := x[39+len(m.sessionId):]
   157  	// This is a clever way to store the lower 16 bits of 2*len(m.cipherSuites)
   158  	y[0] = uint8(len(m.cipherSuites) >> 7)
   159  	y[1] = uint8(len(m.cipherSuites) << 1)
   160  	for i, suite := range m.cipherSuites {
   161  		y[2+i*2] = uint8(suite >> 8)
   162  		y[3+i*2] = uint8(suite)
   163  	}
   164  	z := y[2+len(m.cipherSuites)*2:]
   165  	z[0] = uint8(len(m.compressionMethods))
   166  	copy(z[1:], m.compressionMethods)
   167  
   168  	z = z[1+len(m.compressionMethods):]
   169  	if numExtensions > 0 || len(m.unknownExtensions) > 0 {
   170  		z[0] = byte(extensionsLength >> 8)
   171  		z[1] = byte(extensionsLength)
   172  		z = z[2:]
   173  	}
   174  	if m.nextProtoNeg {
   175  		z[0] = byte(extensionNextProtoNeg >> 8)
   176  		z[1] = byte(extensionNextProtoNeg & 0xff)
   177  		// The length is always 0
   178  		z = z[4:]
   179  	}
   180  	if len(m.serverName) > 0 {
   181  		z[0] = byte(extensionServerName >> 8)
   182  		z[1] = byte(extensionServerName & 0xff)
   183  		l := len(m.serverName) + 5
   184  		z[2] = byte(l >> 8)
   185  		z[3] = byte(l)
   186  		z = z[4:]
   187  
   188  		// RFC 3546, section 3.1
   189  		//
   190  		// struct {
   191  		//     NameType name_type;
   192  		//     select (name_type) {
   193  		//         case host_name: HostName;
   194  		//     } name;
   195  		// } ServerName;
   196  		//
   197  		// enum {
   198  		//     host_name(0), (255)
   199  		// } NameType;
   200  		//
   201  		// opaque HostName<1..2^16-1>;
   202  		//
   203  		// struct {
   204  		//     ServerName server_name_list<1..2^16-1>
   205  		// } ServerNameList;
   206  
   207  		z[0] = byte((len(m.serverName) + 3) >> 8)
   208  		z[1] = byte(len(m.serverName) + 3)
   209  		z[3] = byte(len(m.serverName) >> 8)
   210  		z[4] = byte(len(m.serverName))
   211  		copy(z[5:], []byte(m.serverName))
   212  		z = z[l:]
   213  	}
   214  	if m.ocspStapling {
   215  		// RFC 4366, section 3.6
   216  		z[0] = byte(extensionStatusRequest >> 8)
   217  		z[1] = byte(extensionStatusRequest)
   218  		z[2] = 0
   219  		z[3] = 5
   220  		z[4] = 1 // OCSP type
   221  		// Two zero valued uint16s for the two lengths.
   222  		z = z[9:]
   223  	}
   224  	if len(m.supportedCurves) > 0 {
   225  		// http://tools.ietf.org/html/rfc4492#section-5.5.1
   226  		z[0] = byte(extensionSupportedCurves >> 8)
   227  		z[1] = byte(extensionSupportedCurves)
   228  		l := 2 + 2*len(m.supportedCurves)
   229  		z[2] = byte(l >> 8)
   230  		z[3] = byte(l)
   231  		l -= 2
   232  		z[4] = byte(l >> 8)
   233  		z[5] = byte(l)
   234  		z = z[6:]
   235  		for _, curve := range m.supportedCurves {
   236  			z[0] = byte(curve >> 8)
   237  			z[1] = byte(curve)
   238  			z = z[2:]
   239  		}
   240  	}
   241  	if len(m.supportedPoints) > 0 {
   242  		// http://tools.ietf.org/html/rfc4492#section-5.5.2
   243  		z[0] = byte(extensionSupportedPoints >> 8)
   244  		z[1] = byte(extensionSupportedPoints)
   245  		l := 1 + len(m.supportedPoints)
   246  		z[2] = byte(l >> 8)
   247  		z[3] = byte(l)
   248  		l--
   249  		z[4] = byte(l)
   250  		z = z[5:]
   251  		for _, pointFormat := range m.supportedPoints {
   252  			z[0] = byte(pointFormat)
   253  			z = z[1:]
   254  		}
   255  	}
   256  	if m.ticketSupported {
   257  		// http://tools.ietf.org/html/rfc5077#section-3.2
   258  		z[0] = byte(extensionSessionTicket >> 8)
   259  		z[1] = byte(extensionSessionTicket)
   260  		l := len(m.sessionTicket)
   261  		z[2] = byte(l >> 8)
   262  		z[3] = byte(l)
   263  		z = z[4:]
   264  		copy(z, m.sessionTicket)
   265  		z = z[len(m.sessionTicket):]
   266  	}
   267  	if len(m.signatureAndHashes) > 0 {
   268  		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   269  		z[0] = byte(extensionSignatureAlgorithms >> 8)
   270  		z[1] = byte(extensionSignatureAlgorithms)
   271  		l := 2 + 2*len(m.signatureAndHashes)
   272  		z[2] = byte(l >> 8)
   273  		z[3] = byte(l)
   274  		z = z[4:]
   275  
   276  		l -= 2
   277  		z[0] = byte(l >> 8)
   278  		z[1] = byte(l)
   279  		z = z[2:]
   280  		for _, sigAndHash := range m.signatureAndHashes {
   281  			z[0] = sigAndHash.Hash
   282  			z[1] = sigAndHash.Signature
   283  			z = z[2:]
   284  		}
   285  	}
   286  	if m.secureRenegotiation {
   287  		z[0] = byte(extensionRenegotiationInfo >> 8)
   288  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   289  		z[2] = 0
   290  		z[3] = 1
   291  		z = z[5:]
   292  	}
   293  	if len(m.alpnProtocols) > 0 {
   294  		z[0] = byte(extensionALPN >> 8)
   295  		z[1] = byte(extensionALPN & 0xff)
   296  		lengths := z[2:]
   297  		z = z[6:]
   298  
   299  		stringsLength := 0
   300  		for _, s := range m.alpnProtocols {
   301  			l := len(s)
   302  			z[0] = byte(l)
   303  			copy(z[1:], s)
   304  			z = z[1+l:]
   305  			stringsLength += 1 + l
   306  		}
   307  
   308  		lengths[2] = byte(stringsLength >> 8)
   309  		lengths[3] = byte(stringsLength)
   310  		stringsLength += 2
   311  		lengths[0] = byte(stringsLength >> 8)
   312  		lengths[1] = byte(stringsLength)
   313  	}
   314  	if m.heartbeatEnabled {
   315  		z[0] = byte(extensionHeartbeat >> 8)
   316  		z[1] = byte(extensionHeartbeat)
   317  		length := 1
   318  		z[2] = byte(length >> 8)
   319  		z[3] = byte(length)
   320  		z[4] = m.heartbeatMode
   321  		z = z[5:]
   322  	}
   323  	if m.extendedRandomEnabled {
   324  		z[0] = byte(extensionExtendedRandom >> 8)
   325  		z[1] = byte(extensionExtendedRandom)
   326  		exLen := len(m.extendedRandom)
   327  		length := 2 + exLen
   328  		z[2] = byte(length >> 8)
   329  		z[3] = byte(length)
   330  		z[4] = byte(exLen >> 8)
   331  		z[5] = byte(exLen)
   332  		z = z[6:]
   333  		copy(z, m.extendedRandom)
   334  		z = z[exLen:]
   335  	}
   336  	if m.extendedMasterSecret {
   337  		// https://tools.ietf.org/html/draft-ietf-tls-session-hash-01
   338  		z[0] = byte(extensionExtendedMasterSecret >> 8)
   339  		z[1] = byte(extensionExtendedMasterSecret & 0xff)
   340  		z = z[4:]
   341  	}
   342  	if m.sctEnabled {
   343  		// https://tools.ietf.org/html/rfc6962#section-3.3.1
   344  		z[0] = byte(extensionSCT >> 8)
   345  		z[1] = byte(extensionSCT)
   346  		// zero uint16 for the zero-length extension_data
   347  		z = z[4:]
   348  	}
   349  	if len(m.unknownExtensions) > 0 {
   350  		for _, ext := range m.unknownExtensions {
   351  			copy(z, ext)
   352  			z = z[len(ext):]
   353  		}
   354  	}
   355  
   356  	m.raw = x
   357  
   358  	return x
   359  }
   360  
   361  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   362  	if len(data) < 42 {
   363  		return false
   364  	}
   365  	m.raw = data
   366  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   367  	m.random = data[6:38]
   368  	sessionIdLen := int(data[38])
   369  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   370  		return false
   371  	}
   372  	m.sessionId = data[39 : 39+sessionIdLen]
   373  	data = data[39+sessionIdLen:]
   374  	if len(data) < 2 {
   375  		return false
   376  	}
   377  	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   378  	// they are uint16s, the number must be even.
   379  	cipherSuiteLen := int(data[0])<<8 | int(data[1])
   380  	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   381  		return false
   382  	}
   383  	numCipherSuites := cipherSuiteLen / 2
   384  	m.cipherSuites = make([]uint16, numCipherSuites)
   385  	for i := 0; i < numCipherSuites; i++ {
   386  		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   387  		if m.cipherSuites[i] == scsvRenegotiation {
   388  			m.secureRenegotiation = true
   389  		}
   390  	}
   391  	data = data[2+cipherSuiteLen:]
   392  	if len(data) < 1 {
   393  		return false
   394  	}
   395  	compressionMethodsLen := int(data[0])
   396  	if len(data) < 1+compressionMethodsLen {
   397  		return false
   398  	}
   399  	m.compressionMethods = data[1 : 1+compressionMethodsLen]
   400  
   401  	data = data[1+compressionMethodsLen:]
   402  
   403  	m.nextProtoNeg = false
   404  	m.serverName = ""
   405  	m.ocspStapling = false
   406  	m.ticketSupported = false
   407  	m.sessionTicket = nil
   408  	m.signatureAndHashes = nil
   409  	m.heartbeatEnabled = false
   410  	m.extendedMasterSecret = false
   411  	m.alpnProtocols = nil
   412  	m.scts = false
   413  	m.unknownExtensions = [][]byte(nil)
   414  
   415  	if len(data) == 0 {
   416  		// ClientHello is optionally followed by extension data
   417  		return true
   418  	}
   419  	if len(data) < 2 {
   420  		return false
   421  	}
   422  
   423  	extensionsLength := int(data[0])<<8 | int(data[1])
   424  	data = data[2:]
   425  	if extensionsLength != len(data) {
   426  		return false
   427  	}
   428  
   429  	for len(data) != 0 {
   430  		if len(data) < 4 {
   431  			return false
   432  		}
   433  		fullData := data
   434  		extension := uint16(data[0])<<8 | uint16(data[1])
   435  		length := int(data[2])<<8 | int(data[3])
   436  		data = data[4:]
   437  		if len(data) < length {
   438  			return false
   439  		}
   440  
   441  		switch extension {
   442  		case extensionServerName:
   443  			// we keep only the last name for m.serverName
   444  			if length < 2 {
   445  				return false
   446  			}
   447  			numNames := int(data[0])<<8 | int(data[1])
   448  			d := data[2:]
   449  			for i := 0; i < numNames; i++ {
   450  				if len(d) < 3 {
   451  					return false
   452  				}
   453  				nameType := d[0]
   454  				nameLen := int(d[1])<<8 | int(d[2])
   455  				d = d[3:]
   456  				if len(d) < nameLen {
   457  					return false
   458  				}
   459  				if nameType == 0 {
   460  					m.serverName = string(d[0:nameLen])
   461  					break
   462  				}
   463  				d = d[nameLen:]
   464  			}
   465  		case extensionNextProtoNeg:
   466  			if length > 0 {
   467  				return false
   468  			}
   469  			m.nextProtoNeg = true
   470  		case extensionStatusRequest:
   471  			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   472  		case extensionSupportedCurves:
   473  			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   474  			if length < 2 {
   475  				return false
   476  			}
   477  			l := int(data[0])<<8 | int(data[1])
   478  			if l%2 == 1 || length != l+2 {
   479  				return false
   480  			}
   481  			numCurves := l / 2
   482  			m.supportedCurves = make([]CurveID, numCurves)
   483  			d := data[2:]
   484  			for i := 0; i < numCurves; i++ {
   485  				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
   486  				d = d[2:]
   487  			}
   488  		case extensionSupportedPoints:
   489  			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   490  			if length < 1 {
   491  				return false
   492  			}
   493  			l := int(data[0])
   494  			if length != l+1 {
   495  				return false
   496  			}
   497  			m.supportedPoints = make([]uint8, l)
   498  			copy(m.supportedPoints, data[1:])
   499  		case extensionSessionTicket:
   500  			// http://tools.ietf.org/html/rfc5077#section-3.2
   501  			m.ticketSupported = true
   502  			m.sessionTicket = data[:length]
   503  		case extensionSignatureAlgorithms:
   504  			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   505  			if length < 2 || length&1 != 0 {
   506  				return false
   507  			}
   508  			l := int(data[0])<<8 | int(data[1])
   509  			if l != length-2 {
   510  				return false
   511  			}
   512  			n := l / 2
   513  			d := data[2:]
   514  			m.signatureAndHashes = make([]SigAndHash, n)
   515  			for i := range m.signatureAndHashes {
   516  				m.signatureAndHashes[i].Hash = d[0]
   517  				m.signatureAndHashes[i].Signature = d[1]
   518  				d = d[2:]
   519  			}
   520  		case extensionRenegotiationInfo:
   521  			if length != 1 || data[0] != 0 {
   522  				return false
   523  			}
   524  			m.secureRenegotiation = true
   525  		case extensionALPN:
   526  			if length < 2 {
   527  				return false
   528  			}
   529  			l := int(data[0])<<8 | int(data[1])
   530  			if l != length-2 {
   531  				return false
   532  			}
   533  			d := data[2:length]
   534  			for len(d) != 0 {
   535  				stringLen := int(d[0])
   536  				d = d[1:]
   537  				if stringLen == 0 || stringLen > len(d) {
   538  					return false
   539  				}
   540  				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
   541  				d = d[stringLen:]
   542  			}
   543  		case extensionHeartbeat:
   544  			// https://tools.ietf.org/html/rfc6520
   545  			if length != 1 {
   546  				return false
   547  			}
   548  			mode := data[0]
   549  			if mode != heartbeatModePeerAllowed &&
   550  				mode != heartbeatModePeerNotAllowed {
   551  				return false
   552  			}
   553  			m.heartbeatEnabled = true
   554  			m.heartbeatMode = mode
   555  		case extensionExtendedRandom:
   556  			if length < 3 {
   557  				return false
   558  			}
   559  			exLen := int(data[0])<<8 | int(data[1])
   560  			if length != exLen+2 {
   561  				return false
   562  			}
   563  			if exLen > len(data) {
   564  				return false
   565  			}
   566  			m.extendedRandomEnabled = true
   567  			m.extendedRandom = make([]byte, exLen)
   568  			copy(m.extendedRandom, data[2:])
   569  		case extensionExtendedMasterSecret:
   570  			if length != 0 {
   571  				return false
   572  			}
   573  			m.extendedMasterSecret = true
   574  		case extensionSCT:
   575  			m.scts = true
   576  			if length != 0 {
   577  				return false
   578  			}
   579  		default:
   580  			fullExt := append(fullData[:4], data[:length]...)
   581  			m.unknownExtensions = append(m.unknownExtensions, fullExt)
   582  		}
   583  		data = data[length:]
   584  	}
   585  
   586  	return true
   587  }
   588  
   589  type serverHelloMsg struct {
   590  	raw                   []byte
   591  	vers                  uint16
   592  	random                []byte
   593  	sessionId             []byte
   594  	cipherSuite           uint16
   595  	compressionMethod     uint8
   596  	nextProtoNeg          bool
   597  	nextProtos            []string
   598  	ocspStapling          bool
   599  	scts                  [][]byte
   600  	ticketSupported       bool
   601  	secureRenegotiation   bool
   602  	heartbeatEnabled      bool
   603  	heartbeatMode         uint8
   604  	extendedRandomEnabled bool
   605  	extendedRandom        []byte
   606  	extendedMasterSecret  bool
   607  	alpnProtocol          string
   608  	unknownExtensions     [][]byte
   609  }
   610  
   611  func (m *serverHelloMsg) equal(i interface{}) bool {
   612  	m1, ok := i.(*serverHelloMsg)
   613  	if !ok {
   614  		return false
   615  	}
   616  
   617  	return bytes.Equal(m.raw, m1.raw) &&
   618  		m.vers == m1.vers &&
   619  		bytes.Equal(m.random, m1.random) &&
   620  		bytes.Equal(m.sessionId, m1.sessionId) &&
   621  		m.cipherSuite == m1.cipherSuite &&
   622  		m.compressionMethod == m1.compressionMethod &&
   623  		m.nextProtoNeg == m1.nextProtoNeg &&
   624  		eqStrings(m.nextProtos, m1.nextProtos) &&
   625  		m.ocspStapling == m1.ocspStapling &&
   626  		m.ticketSupported == m1.ticketSupported &&
   627  		m.secureRenegotiation == m1.secureRenegotiation &&
   628  		m.extendedMasterSecret == m1.extendedMasterSecret &&
   629  		m.alpnProtocol == m1.alpnProtocol &&
   630  		reflect.DeepEqual(m.unknownExtensions, m1.unknownExtensions)
   631  }
   632  
   633  func (m *serverHelloMsg) marshal() []byte {
   634  	if m.raw != nil {
   635  		return m.raw
   636  	}
   637  
   638  	length := 38 + len(m.sessionId)
   639  	numExtensions := 0
   640  	extensionsLength := 0
   641  
   642  	nextProtoLen := 0
   643  	if m.nextProtoNeg {
   644  		numExtensions++
   645  		for _, v := range m.nextProtos {
   646  			nextProtoLen += len(v)
   647  		}
   648  		nextProtoLen += len(m.nextProtos)
   649  		extensionsLength += nextProtoLen
   650  	}
   651  	if m.ocspStapling {
   652  		numExtensions++
   653  	}
   654  	if m.ticketSupported {
   655  		numExtensions++
   656  	}
   657  	if m.secureRenegotiation {
   658  		extensionsLength += 1
   659  		numExtensions++
   660  	}
   661  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   662  		if alpnLen >= 256 {
   663  			panic("invalid ALPN protocol")
   664  		}
   665  		extensionsLength += 2 + 1 + alpnLen
   666  		numExtensions++
   667  	}
   668  	if m.heartbeatEnabled {
   669  		extensionsLength += 1
   670  		numExtensions++
   671  	}
   672  	if m.extendedRandomEnabled {
   673  		extensionsLength += 2 + len(m.extendedRandom)
   674  		numExtensions++
   675  	}
   676  	if m.extendedMasterSecret {
   677  		numExtensions++
   678  	}
   679  	sctLen := 0
   680  	if len(m.scts) > 0 {
   681  		for _, sct := range m.scts {
   682  			sctLen += len(sct) + 2
   683  		}
   684  		extensionsLength += 2 + sctLen
   685  		numExtensions++
   686  	}
   687  	if len(m.unknownExtensions) > 0 {
   688  		// we do not update numExtensions because the extension code and length
   689  		// are already contained at the beginning of every 'ext' below
   690  		for _, ext := range m.unknownExtensions {
   691  			extensionsLength += len(ext)
   692  		}
   693  	}
   694  	if numExtensions > 0 {
   695  		extensionsLength += 4 * numExtensions
   696  		length += 2 + extensionsLength
   697  	}
   698  
   699  	x := make([]byte, 4+length)
   700  	x[0] = typeServerHello
   701  	x[1] = uint8(length >> 16)
   702  	x[2] = uint8(length >> 8)
   703  	x[3] = uint8(length)
   704  	x[4] = uint8(m.vers >> 8)
   705  	x[5] = uint8(m.vers)
   706  	copy(x[6:38], m.random)
   707  	x[38] = uint8(len(m.sessionId))
   708  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   709  	z := x[39+len(m.sessionId):]
   710  	z[0] = uint8(m.cipherSuite >> 8)
   711  	z[1] = uint8(m.cipherSuite)
   712  	z[2] = uint8(m.compressionMethod)
   713  
   714  	z = z[3:]
   715  	if numExtensions > 0 || len(m.unknownExtensions) > 0 {
   716  		z[0] = byte(extensionsLength >> 8)
   717  		z[1] = byte(extensionsLength)
   718  		z = z[2:]
   719  	}
   720  	if m.nextProtoNeg {
   721  		z[0] = byte(extensionNextProtoNeg >> 8)
   722  		z[1] = byte(extensionNextProtoNeg & 0xff)
   723  		z[2] = byte(nextProtoLen >> 8)
   724  		z[3] = byte(nextProtoLen)
   725  		z = z[4:]
   726  
   727  		for _, v := range m.nextProtos {
   728  			l := len(v)
   729  			if l > 255 {
   730  				l = 255
   731  			}
   732  			z[0] = byte(l)
   733  			copy(z[1:], []byte(v[0:l]))
   734  			z = z[1+l:]
   735  		}
   736  	}
   737  	if m.ocspStapling {
   738  		z[0] = byte(extensionStatusRequest >> 8)
   739  		z[1] = byte(extensionStatusRequest)
   740  		z = z[4:]
   741  	}
   742  	if m.ticketSupported {
   743  		z[0] = byte(extensionSessionTicket >> 8)
   744  		z[1] = byte(extensionSessionTicket)
   745  		z = z[4:]
   746  	}
   747  	if m.secureRenegotiation {
   748  		z[0] = byte(extensionRenegotiationInfo >> 8)
   749  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   750  		z[2] = 0
   751  		z[3] = 1
   752  		z = z[5:]
   753  	}
   754  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   755  		z[0] = byte(extensionALPN >> 8)
   756  		z[1] = byte(extensionALPN & 0xff)
   757  		l := 2 + 1 + alpnLen
   758  		z[2] = byte(l >> 8)
   759  		z[3] = byte(l)
   760  		l -= 2
   761  		z[4] = byte(l >> 8)
   762  		z[5] = byte(l)
   763  		l -= 1
   764  		z[6] = byte(l)
   765  		copy(z[7:], []byte(m.alpnProtocol))
   766  		z = z[7+alpnLen:]
   767  	}
   768  	if m.heartbeatEnabled {
   769  		z[0] = byte(extensionHeartbeat >> 8)
   770  		z[1] = byte(extensionHeartbeat)
   771  		z[2] = byte(1 >> 8)
   772  		z[3] = byte(1)
   773  		z[4] = m.heartbeatMode
   774  		z = z[5:]
   775  	}
   776  	if m.extendedRandomEnabled {
   777  		z[0] = byte(extensionExtendedRandom >> 8)
   778  		z[1] = byte(extensionExtendedRandom)
   779  		exLen := len(m.extendedRandom)
   780  		fullLength := 2 + exLen
   781  		z[2] = byte(fullLength << 8)
   782  		z[3] = byte(fullLength)
   783  		z[4] = byte(exLen << 8)
   784  		z[5] = byte(exLen)
   785  		z = z[6:]
   786  		copy(z, m.extendedRandom)
   787  		z = z[exLen:]
   788  	}
   789  	if m.extendedMasterSecret {
   790  		z[0] = byte(extensionExtendedMasterSecret >> 8)
   791  		z[1] = byte(extensionExtendedMasterSecret & 0xff)
   792  		z = z[4:]
   793  	}
   794  	if sctLen > 0 {
   795  		z[0] = byte(extensionSCT >> 8)
   796  		z[1] = byte(extensionSCT)
   797  		l := sctLen + 2
   798  		z[2] = byte(l >> 8)
   799  		z[3] = byte(l)
   800  		z[4] = byte(sctLen >> 8)
   801  		z[5] = byte(sctLen)
   802  
   803  		z = z[6:]
   804  		for _, sct := range m.scts {
   805  			z[0] = byte(len(sct) >> 8)
   806  			z[1] = byte(len(sct))
   807  			copy(z[2:], sct)
   808  			z = z[len(sct)+2:]
   809  		}
   810  	}
   811  	if len(m.unknownExtensions) > 0 {
   812  		for _, ext := range m.unknownExtensions {
   813  			copy(z, ext)
   814  			z = z[len(ext):]
   815  		}
   816  	}
   817  	m.raw = x
   818  	return x
   819  }
   820  
   821  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   822  	if len(data) < 42 {
   823  		return false
   824  	}
   825  	m.raw = data
   826  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   827  	m.random = data[6:38]
   828  	sessionIdLen := int(data[38])
   829  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   830  		return false
   831  	}
   832  	m.sessionId = data[39 : 39+sessionIdLen]
   833  	data = data[39+sessionIdLen:]
   834  	if len(data) < 3 {
   835  		return false
   836  	}
   837  	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   838  	m.compressionMethod = data[2]
   839  	data = data[3:]
   840  
   841  	m.nextProtoNeg = false
   842  	m.nextProtos = nil
   843  	m.scts = nil
   844  	m.ocspStapling = false
   845  	m.ticketSupported = false
   846  	m.heartbeatEnabled = false
   847  	m.extendedRandomEnabled = false
   848  	m.extendedMasterSecret = false
   849  	m.alpnProtocol = ""
   850  	m.unknownExtensions = [][]byte(nil)
   851  
   852  	if len(data) == 0 {
   853  		// ServerHello is optionally followed by extension data
   854  		return true
   855  	}
   856  	if len(data) < 2 {
   857  		return false
   858  	}
   859  
   860  	extensionsLength := int(data[0])<<8 | int(data[1])
   861  	data = data[2:]
   862  	if len(data) != extensionsLength {
   863  		return false
   864  	}
   865  
   866  	for len(data) != 0 {
   867  		if len(data) < 4 {
   868  			return false
   869  		}
   870  		fullData := data
   871  		extension := uint16(data[0])<<8 | uint16(data[1])
   872  		length := int(data[2])<<8 | int(data[3])
   873  		data = data[4:]
   874  		if len(data) < length {
   875  			return false
   876  		}
   877  
   878  		switch extension {
   879  		case extensionNextProtoNeg:
   880  			m.nextProtoNeg = true
   881  			d := data[:length]
   882  			for len(d) > 0 {
   883  				l := int(d[0])
   884  				d = d[1:]
   885  				if l == 0 || l > len(d) {
   886  					return false
   887  				}
   888  				m.nextProtos = append(m.nextProtos, string(d[:l]))
   889  				d = d[l:]
   890  			}
   891  		case extensionStatusRequest:
   892  			if length > 0 {
   893  				return false
   894  			}
   895  			m.ocspStapling = true
   896  		case extensionSessionTicket:
   897  			if length > 0 {
   898  				return false
   899  			}
   900  			m.ticketSupported = true
   901  		case extensionRenegotiationInfo:
   902  			if length != 1 || data[0] != 0 {
   903  				return false
   904  			}
   905  			m.secureRenegotiation = true
   906  		case extensionALPN:
   907  			d := data[:length]
   908  			if len(d) < 3 {
   909  				return false
   910  			}
   911  			l := int(d[0])<<8 | int(d[1])
   912  			if l != len(d)-2 {
   913  				return false
   914  			}
   915  			d = d[2:]
   916  			l = int(d[0])
   917  			if l != len(d)-1 {
   918  				return false
   919  			}
   920  			d = d[1:]
   921  			if len(d) == 0 {
   922  				// ALPN protocols must not be empty.
   923  				return false
   924  			}
   925  			m.alpnProtocol = string(d)
   926  		case extensionHeartbeat:
   927  			m.heartbeatEnabled = true
   928  			m.heartbeatMode = data[0]
   929  		case extensionExtendedRandom:
   930  			m.extendedRandomEnabled = true
   931  			if length < 3 {
   932  				return false
   933  			}
   934  			exRandLen := int(data[0])<<8 | int(data[1])
   935  			if length != exRandLen+2 {
   936  				return false
   937  			}
   938  			exRand := data[2:]
   939  			if len(exRand) < exRandLen {
   940  				return false
   941  			}
   942  			m.extendedRandom = make([]byte, exRandLen)
   943  			copy(m.extendedRandom, data[2:])
   944  		case extensionExtendedMasterSecret:
   945  			if length != 0 {
   946  				return false
   947  			}
   948  			m.extendedMasterSecret = true
   949  
   950  		case extensionSCT:
   951  			d := data[:length]
   952  			if len(d) < 2 {
   953  				return false
   954  			}
   955  			l := int(d[0])<<8 | int(d[1])
   956  			d = d[2:]
   957  			if len(d) != l || l == 0 {
   958  				return false
   959  			}
   960  			m.scts = make([][]byte, 0, 3)
   961  			for len(d) != 0 {
   962  				if len(d) < 2 {
   963  					return false
   964  				}
   965  				sctLen := int(d[0])<<8 | int(d[1])
   966  				d = d[2:]
   967  				if sctLen == 0 || len(d) < sctLen {
   968  					return false
   969  				}
   970  				m.scts = append(m.scts, d[:sctLen])
   971  				d = d[sctLen:]
   972  			}
   973  		default:
   974  			fullExt := append(fullData[:4], data[:length]...)
   975  			m.unknownExtensions = append(m.unknownExtensions, fullExt)
   976  		}
   977  		data = data[length:]
   978  	}
   979  
   980  	return true
   981  }
   982  
   983  type certificateMsg struct {
   984  	raw          []byte
   985  	certificates [][]byte
   986  }
   987  
   988  func (m *certificateMsg) equal(i interface{}) bool {
   989  	m1, ok := i.(*certificateMsg)
   990  	if !ok {
   991  		return false
   992  	}
   993  
   994  	return bytes.Equal(m.raw, m1.raw) &&
   995  		eqByteSlices(m.certificates, m1.certificates)
   996  }
   997  
   998  func (m *certificateMsg) marshal() (x []byte) {
   999  	if m.raw != nil {
  1000  		return m.raw
  1001  	}
  1002  
  1003  	var i int
  1004  	for _, slice := range m.certificates {
  1005  		i += len(slice)
  1006  	}
  1007  
  1008  	length := 3 + 3*len(m.certificates) + i
  1009  	x = make([]byte, 4+length)
  1010  	x[0] = typeCertificate
  1011  	x[1] = uint8(length >> 16)
  1012  	x[2] = uint8(length >> 8)
  1013  	x[3] = uint8(length)
  1014  
  1015  	certificateOctets := length - 3
  1016  	x[4] = uint8(certificateOctets >> 16)
  1017  	x[5] = uint8(certificateOctets >> 8)
  1018  	x[6] = uint8(certificateOctets)
  1019  
  1020  	y := x[7:]
  1021  	for _, slice := range m.certificates {
  1022  		y[0] = uint8(len(slice) >> 16)
  1023  		y[1] = uint8(len(slice) >> 8)
  1024  		y[2] = uint8(len(slice))
  1025  		copy(y[3:], slice)
  1026  		y = y[3+len(slice):]
  1027  	}
  1028  
  1029  	m.raw = x
  1030  	return
  1031  }
  1032  
  1033  func (m *certificateMsg) unmarshal(data []byte) bool {
  1034  	if len(data) < 7 {
  1035  		return false
  1036  	}
  1037  
  1038  	m.raw = data
  1039  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1040  	if uint32(len(data)) != certsLen+7 {
  1041  		return false
  1042  	}
  1043  
  1044  	numCerts := 0
  1045  	d := data[7:]
  1046  	for certsLen > 0 {
  1047  		if len(d) < 4 {
  1048  			return false
  1049  		}
  1050  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1051  		if uint32(len(d)) < 3+certLen {
  1052  			return false
  1053  		}
  1054  		d = d[3+certLen:]
  1055  		certsLen -= 3 + certLen
  1056  		numCerts++
  1057  	}
  1058  
  1059  	m.certificates = make([][]byte, numCerts)
  1060  	d = data[7:]
  1061  	for i := 0; i < numCerts; i++ {
  1062  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1063  		m.certificates[i] = d[3 : 3+certLen]
  1064  		d = d[3+certLen:]
  1065  	}
  1066  
  1067  	return true
  1068  }
  1069  
  1070  type serverKeyExchangeMsg struct {
  1071  	raw    []byte
  1072  	digest []byte
  1073  	key    []byte
  1074  }
  1075  
  1076  func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
  1077  	m1, ok := i.(*serverKeyExchangeMsg)
  1078  	if !ok {
  1079  		return false
  1080  	}
  1081  
  1082  	return bytes.Equal(m.raw, m1.raw) &&
  1083  		bytes.Equal(m.key, m1.key)
  1084  }
  1085  
  1086  func (m *serverKeyExchangeMsg) marshal() []byte {
  1087  	if m.raw != nil {
  1088  		return m.raw
  1089  	}
  1090  	length := len(m.key)
  1091  	x := make([]byte, length+4)
  1092  	x[0] = typeServerKeyExchange
  1093  	x[1] = uint8(length >> 16)
  1094  	x[2] = uint8(length >> 8)
  1095  	x[3] = uint8(length)
  1096  	copy(x[4:], m.key)
  1097  
  1098  	m.raw = x
  1099  	return x
  1100  }
  1101  
  1102  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1103  	m.raw = data
  1104  	if len(data) < 4 {
  1105  		return false
  1106  	}
  1107  	m.key = data[4:]
  1108  	return true
  1109  }
  1110  
  1111  type certificateStatusMsg struct {
  1112  	raw        []byte
  1113  	statusType uint8
  1114  	response   []byte
  1115  }
  1116  
  1117  func (m *certificateStatusMsg) equal(i interface{}) bool {
  1118  	m1, ok := i.(*certificateStatusMsg)
  1119  	if !ok {
  1120  		return false
  1121  	}
  1122  
  1123  	return bytes.Equal(m.raw, m1.raw) &&
  1124  		m.statusType == m1.statusType &&
  1125  		bytes.Equal(m.response, m1.response)
  1126  }
  1127  
  1128  func (m *certificateStatusMsg) marshal() []byte {
  1129  	if m.raw != nil {
  1130  		return m.raw
  1131  	}
  1132  
  1133  	var x []byte
  1134  	if m.statusType == statusTypeOCSP {
  1135  		x = make([]byte, 4+4+len(m.response))
  1136  		x[0] = typeCertificateStatus
  1137  		l := len(m.response) + 4
  1138  		x[1] = byte(l >> 16)
  1139  		x[2] = byte(l >> 8)
  1140  		x[3] = byte(l)
  1141  		x[4] = statusTypeOCSP
  1142  
  1143  		l -= 4
  1144  		x[5] = byte(l >> 16)
  1145  		x[6] = byte(l >> 8)
  1146  		x[7] = byte(l)
  1147  		copy(x[8:], m.response)
  1148  	} else {
  1149  		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  1150  	}
  1151  
  1152  	m.raw = x
  1153  	return x
  1154  }
  1155  
  1156  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1157  	m.raw = data
  1158  	if len(data) < 5 {
  1159  		return false
  1160  	}
  1161  	m.statusType = data[4]
  1162  
  1163  	m.response = nil
  1164  	if m.statusType == statusTypeOCSP {
  1165  		if len(data) < 8 {
  1166  			return false
  1167  		}
  1168  		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1169  		if uint32(len(data)) != 4+4+respLen {
  1170  			return false
  1171  		}
  1172  		m.response = data[8:]
  1173  	}
  1174  	return true
  1175  }
  1176  
  1177  type serverHelloDoneMsg struct{}
  1178  
  1179  func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  1180  	_, ok := i.(*serverHelloDoneMsg)
  1181  	return ok
  1182  }
  1183  
  1184  func (m *serverHelloDoneMsg) marshal() []byte {
  1185  	x := make([]byte, 4)
  1186  	x[0] = typeServerHelloDone
  1187  	return x
  1188  }
  1189  
  1190  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1191  	return len(data) == 4
  1192  }
  1193  
  1194  type clientKeyExchangeMsg struct {
  1195  	raw        []byte
  1196  	ciphertext []byte
  1197  }
  1198  
  1199  func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  1200  	m1, ok := i.(*clientKeyExchangeMsg)
  1201  	if !ok {
  1202  		return false
  1203  	}
  1204  
  1205  	return bytes.Equal(m.raw, m1.raw) &&
  1206  		bytes.Equal(m.ciphertext, m1.ciphertext)
  1207  }
  1208  
  1209  func (m *clientKeyExchangeMsg) marshal() []byte {
  1210  	if m.raw != nil {
  1211  		return m.raw
  1212  	}
  1213  	length := len(m.ciphertext)
  1214  	x := make([]byte, length+4)
  1215  	x[0] = typeClientKeyExchange
  1216  	x[1] = uint8(length >> 16)
  1217  	x[2] = uint8(length >> 8)
  1218  	x[3] = uint8(length)
  1219  	copy(x[4:], m.ciphertext)
  1220  
  1221  	m.raw = x
  1222  	return x
  1223  }
  1224  
  1225  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1226  	m.raw = data
  1227  	if len(data) < 4 {
  1228  		return false
  1229  	}
  1230  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1231  	if l != len(data)-4 {
  1232  		return false
  1233  	}
  1234  	m.ciphertext = data[4:]
  1235  	return true
  1236  }
  1237  
  1238  type finishedMsg struct {
  1239  	raw        []byte
  1240  	verifyData []byte
  1241  }
  1242  
  1243  func (m *finishedMsg) equal(i interface{}) bool {
  1244  	m1, ok := i.(*finishedMsg)
  1245  	if !ok {
  1246  		return false
  1247  	}
  1248  
  1249  	return bytes.Equal(m.raw, m1.raw) &&
  1250  		bytes.Equal(m.verifyData, m1.verifyData)
  1251  }
  1252  
  1253  func (m *finishedMsg) marshal() (x []byte) {
  1254  	if m.raw != nil {
  1255  		return m.raw
  1256  	}
  1257  
  1258  	x = make([]byte, 4+len(m.verifyData))
  1259  	x[0] = typeFinished
  1260  	x[3] = byte(len(m.verifyData))
  1261  	copy(x[4:], m.verifyData)
  1262  	m.raw = x
  1263  	return
  1264  }
  1265  
  1266  func (m *finishedMsg) unmarshal(data []byte) bool {
  1267  	m.raw = data
  1268  	if len(data) < 4 {
  1269  		return false
  1270  	}
  1271  	m.verifyData = data[4:]
  1272  	return true
  1273  }
  1274  
  1275  type nextProtoMsg struct {
  1276  	raw   []byte
  1277  	proto string
  1278  }
  1279  
  1280  func (m *nextProtoMsg) equal(i interface{}) bool {
  1281  	m1, ok := i.(*nextProtoMsg)
  1282  	if !ok {
  1283  		return false
  1284  	}
  1285  
  1286  	return bytes.Equal(m.raw, m1.raw) &&
  1287  		m.proto == m1.proto
  1288  }
  1289  
  1290  func (m *nextProtoMsg) marshal() []byte {
  1291  	if m.raw != nil {
  1292  		return m.raw
  1293  	}
  1294  	l := len(m.proto)
  1295  	if l > 255 {
  1296  		l = 255
  1297  	}
  1298  
  1299  	padding := 32 - (l+2)%32
  1300  	length := l + padding + 2
  1301  	x := make([]byte, length+4)
  1302  	x[0] = typeNextProtocol
  1303  	x[1] = uint8(length >> 16)
  1304  	x[2] = uint8(length >> 8)
  1305  	x[3] = uint8(length)
  1306  
  1307  	y := x[4:]
  1308  	y[0] = byte(l)
  1309  	copy(y[1:], []byte(m.proto[0:l]))
  1310  	y = y[1+l:]
  1311  	y[0] = byte(padding)
  1312  
  1313  	m.raw = x
  1314  
  1315  	return x
  1316  }
  1317  
  1318  func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1319  	m.raw = data
  1320  
  1321  	if len(data) < 5 {
  1322  		return false
  1323  	}
  1324  	data = data[4:]
  1325  	protoLen := int(data[0])
  1326  	data = data[1:]
  1327  	if len(data) < protoLen {
  1328  		return false
  1329  	}
  1330  	m.proto = string(data[0:protoLen])
  1331  	data = data[protoLen:]
  1332  
  1333  	if len(data) < 1 {
  1334  		return false
  1335  	}
  1336  	paddingLen := int(data[0])
  1337  	data = data[1:]
  1338  	if len(data) != paddingLen {
  1339  		return false
  1340  	}
  1341  
  1342  	return true
  1343  }
  1344  
  1345  type certificateRequestMsg struct {
  1346  	raw []byte
  1347  	// hasSignatureAndHash indicates whether this message includes a list
  1348  	// of signature and hash functions. This change was introduced with TLS
  1349  	// 1.2.
  1350  	hasSignatureAndHash bool
  1351  
  1352  	certificateTypes       []byte
  1353  	signatureAndHashes     []SigAndHash
  1354  	certificateAuthorities [][]byte
  1355  }
  1356  
  1357  func (m *certificateRequestMsg) equal(i interface{}) bool {
  1358  	m1, ok := i.(*certificateRequestMsg)
  1359  	if !ok {
  1360  		return false
  1361  	}
  1362  
  1363  	return bytes.Equal(m.raw, m1.raw) &&
  1364  		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1365  		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1366  		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
  1367  }
  1368  
  1369  func (m *certificateRequestMsg) marshal() (x []byte) {
  1370  	if m.raw != nil {
  1371  		return m.raw
  1372  	}
  1373  
  1374  	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1375  	length := 1 + len(m.certificateTypes) + 2
  1376  	casLength := 0
  1377  	for _, ca := range m.certificateAuthorities {
  1378  		casLength += 2 + len(ca)
  1379  	}
  1380  	length += casLength
  1381  
  1382  	if m.hasSignatureAndHash {
  1383  		length += 2 + 2*len(m.signatureAndHashes)
  1384  	}
  1385  
  1386  	x = make([]byte, 4+length)
  1387  	x[0] = typeCertificateRequest
  1388  	x[1] = uint8(length >> 16)
  1389  	x[2] = uint8(length >> 8)
  1390  	x[3] = uint8(length)
  1391  
  1392  	x[4] = uint8(len(m.certificateTypes))
  1393  
  1394  	copy(x[5:], m.certificateTypes)
  1395  	y := x[5+len(m.certificateTypes):]
  1396  
  1397  	if m.hasSignatureAndHash {
  1398  		n := len(m.signatureAndHashes) * 2
  1399  		y[0] = uint8(n >> 8)
  1400  		y[1] = uint8(n)
  1401  		y = y[2:]
  1402  		for _, sigAndHash := range m.signatureAndHashes {
  1403  			y[0] = sigAndHash.Hash
  1404  			y[1] = sigAndHash.Signature
  1405  			y = y[2:]
  1406  		}
  1407  	}
  1408  
  1409  	y[0] = uint8(casLength >> 8)
  1410  	y[1] = uint8(casLength)
  1411  	y = y[2:]
  1412  	for _, ca := range m.certificateAuthorities {
  1413  		y[0] = uint8(len(ca) >> 8)
  1414  		y[1] = uint8(len(ca))
  1415  		y = y[2:]
  1416  		copy(y, ca)
  1417  		y = y[len(ca):]
  1418  	}
  1419  
  1420  	m.raw = x
  1421  	return
  1422  }
  1423  
  1424  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1425  	m.raw = data
  1426  
  1427  	if len(data) < 5 {
  1428  		return false
  1429  	}
  1430  
  1431  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1432  	if uint32(len(data))-4 != length {
  1433  		return false
  1434  	}
  1435  
  1436  	numCertTypes := int(data[4])
  1437  	data = data[5:]
  1438  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1439  		return false
  1440  	}
  1441  
  1442  	m.certificateTypes = make([]byte, numCertTypes)
  1443  	if copy(m.certificateTypes, data) != numCertTypes {
  1444  		return false
  1445  	}
  1446  
  1447  	data = data[numCertTypes:]
  1448  
  1449  	if m.hasSignatureAndHash {
  1450  		if len(data) < 2 {
  1451  			return false
  1452  		}
  1453  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1454  		data = data[2:]
  1455  		if sigAndHashLen&1 != 0 {
  1456  			return false
  1457  		}
  1458  		if len(data) < int(sigAndHashLen) {
  1459  			return false
  1460  		}
  1461  		numSigAndHash := sigAndHashLen / 2
  1462  		m.signatureAndHashes = make([]SigAndHash, numSigAndHash)
  1463  		for i := range m.signatureAndHashes {
  1464  			m.signatureAndHashes[i].Hash = data[0]
  1465  			m.signatureAndHashes[i].Signature = data[1]
  1466  			data = data[2:]
  1467  		}
  1468  	}
  1469  
  1470  	if len(data) < 2 {
  1471  		return false
  1472  	}
  1473  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1474  	data = data[2:]
  1475  	if len(data) < int(casLength) {
  1476  		return false
  1477  	}
  1478  	cas := make([]byte, casLength)
  1479  	copy(cas, data)
  1480  	data = data[casLength:]
  1481  
  1482  	m.certificateAuthorities = nil
  1483  	for len(cas) > 0 {
  1484  		if len(cas) < 2 {
  1485  			return false
  1486  		}
  1487  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1488  		cas = cas[2:]
  1489  
  1490  		if len(cas) < int(caLen) {
  1491  			return false
  1492  		}
  1493  
  1494  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1495  		cas = cas[caLen:]
  1496  	}
  1497  	if len(data) > 0 {
  1498  		return false
  1499  	}
  1500  
  1501  	return true
  1502  }
  1503  
  1504  type certificateVerifyMsg struct {
  1505  	raw                 []byte
  1506  	hasSignatureAndHash bool
  1507  	signatureAndHash    SigAndHash
  1508  	signature           []byte
  1509  }
  1510  
  1511  func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1512  	m1, ok := i.(*certificateVerifyMsg)
  1513  	if !ok {
  1514  		return false
  1515  	}
  1516  
  1517  	return bytes.Equal(m.raw, m1.raw) &&
  1518  		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1519  		m.signatureAndHash.Hash == m1.signatureAndHash.Hash &&
  1520  		m.signatureAndHash.Signature == m1.signatureAndHash.Signature &&
  1521  		bytes.Equal(m.signature, m1.signature)
  1522  }
  1523  
  1524  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1525  	if m.raw != nil {
  1526  		return m.raw
  1527  	}
  1528  
  1529  	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1530  	siglength := len(m.signature)
  1531  	length := 2 + siglength
  1532  	if m.hasSignatureAndHash {
  1533  		length += 2
  1534  	}
  1535  	x = make([]byte, 4+length)
  1536  	x[0] = typeCertificateVerify
  1537  	x[1] = uint8(length >> 16)
  1538  	x[2] = uint8(length >> 8)
  1539  	x[3] = uint8(length)
  1540  	y := x[4:]
  1541  	if m.hasSignatureAndHash {
  1542  		y[0] = m.signatureAndHash.Hash
  1543  		y[1] = m.signatureAndHash.Signature
  1544  		y = y[2:]
  1545  	}
  1546  	y[0] = uint8(siglength >> 8)
  1547  	y[1] = uint8(siglength)
  1548  	copy(y[2:], m.signature)
  1549  
  1550  	m.raw = x
  1551  
  1552  	return
  1553  }
  1554  
  1555  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1556  	m.raw = data
  1557  
  1558  	if len(data) < 6 {
  1559  		return false
  1560  	}
  1561  
  1562  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1563  	if uint32(len(data))-4 != length {
  1564  		return false
  1565  	}
  1566  
  1567  	data = data[4:]
  1568  	if m.hasSignatureAndHash {
  1569  		m.signatureAndHash.Hash = data[0]
  1570  		m.signatureAndHash.Signature = data[1]
  1571  		data = data[2:]
  1572  	}
  1573  
  1574  	if len(data) < 2 {
  1575  		return false
  1576  	}
  1577  	siglength := int(data[0])<<8 + int(data[1])
  1578  	data = data[2:]
  1579  	if len(data) != siglength {
  1580  		return false
  1581  	}
  1582  
  1583  	m.signature = data
  1584  
  1585  	return true
  1586  }
  1587  
  1588  type newSessionTicketMsg struct {
  1589  	raw          []byte
  1590  	ticket       []byte
  1591  	lifetimeHint uint32
  1592  }
  1593  
  1594  func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1595  	m1, ok := i.(*newSessionTicketMsg)
  1596  	if !ok {
  1597  		return false
  1598  	}
  1599  
  1600  	return bytes.Equal(m.raw, m1.raw) &&
  1601  		bytes.Equal(m.ticket, m1.ticket)
  1602  }
  1603  
  1604  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1605  	if m.raw != nil {
  1606  		return m.raw
  1607  	}
  1608  
  1609  	// See http://tools.ietf.org/html/rfc5077#section-3.3
  1610  	ticketLen := len(m.ticket)
  1611  	length := 2 + 4 + ticketLen
  1612  	x = make([]byte, 4+length)
  1613  	x[0] = typeNewSessionTicket
  1614  	x[1] = uint8(length >> 16)
  1615  	x[2] = uint8(length >> 8)
  1616  	x[3] = uint8(length)
  1617  	x[8] = uint8(ticketLen >> 8)
  1618  	x[9] = uint8(ticketLen)
  1619  	copy(x[10:], m.ticket)
  1620  
  1621  	m.raw = x
  1622  
  1623  	return
  1624  }
  1625  
  1626  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1627  	m.raw = data
  1628  
  1629  	if len(data) < 10 {
  1630  		return false
  1631  	}
  1632  
  1633  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1634  	if uint32(len(data))-4 != length {
  1635  		return false
  1636  	}
  1637  
  1638  	ticketLen := int(data[8])<<8 + int(data[9])
  1639  	if len(data)-10 != ticketLen {
  1640  		return false
  1641  	}
  1642  
  1643  	m.lifetimeHint = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1644  	m.ticket = data[10:]
  1645  
  1646  	return true
  1647  }
  1648  
  1649  type helloRequestMsg struct {
  1650  }
  1651  
  1652  func (*helloRequestMsg) marshal() []byte {
  1653  	return []byte{typeHelloRequest, 0, 0, 0}
  1654  }
  1655  
  1656  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1657  	return len(data) == 4
  1658  }
  1659  
  1660  func eqUint16s(x, y []uint16) bool {
  1661  	if len(x) != len(y) {
  1662  		return false
  1663  	}
  1664  	for i, v := range x {
  1665  		if y[i] != v {
  1666  			return false
  1667  		}
  1668  	}
  1669  	return true
  1670  }
  1671  
  1672  func eqCurveIDs(x, y []CurveID) bool {
  1673  	if len(x) != len(y) {
  1674  		return false
  1675  	}
  1676  	for i, v := range x {
  1677  		if y[i] != v {
  1678  			return false
  1679  		}
  1680  	}
  1681  	return true
  1682  }
  1683  
  1684  func eqStrings(x, y []string) bool {
  1685  	if len(x) != len(y) {
  1686  		return false
  1687  	}
  1688  	for i, v := range x {
  1689  		if y[i] != v {
  1690  			return false
  1691  		}
  1692  	}
  1693  	return true
  1694  }
  1695  
  1696  func eqByteSlices(x, y [][]byte) bool {
  1697  	if len(x) != len(y) {
  1698  		return false
  1699  	}
  1700  	for i, v := range x {
  1701  		if !bytes.Equal(v, y[i]) {
  1702  			return false
  1703  		}
  1704  	}
  1705  	return true
  1706  }
  1707  
  1708  func eqSignatureAndHashes(x, y []SigAndHash) bool {
  1709  	if len(x) != len(y) {
  1710  		return false
  1711  	}
  1712  	for i, v := range x {
  1713  		v2 := y[i]
  1714  		if v.Hash != v2.Hash || v.Signature != v2.Signature {
  1715  			return false
  1716  		}
  1717  	}
  1718  	return true
  1719  }