github.com/hikaru7719/go@v0.0.0-20181025140707-c8b2ac68906a/src/crypto/tls/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"strings"
    10  )
    11  
    12  type clientHelloMsg struct {
    13  	raw                          []byte
    14  	vers                         uint16
    15  	random                       []byte
    16  	sessionId                    []byte
    17  	cipherSuites                 []uint16
    18  	compressionMethods           []uint8
    19  	nextProtoNeg                 bool
    20  	serverName                   string
    21  	ocspStapling                 bool
    22  	scts                         bool
    23  	supportedCurves              []CurveID
    24  	supportedPoints              []uint8
    25  	ticketSupported              bool
    26  	sessionTicket                []uint8
    27  	supportedSignatureAlgorithms []SignatureScheme
    28  	secureRenegotiation          []byte
    29  	secureRenegotiationSupported bool
    30  	alpnProtocols                []string
    31  }
    32  
    33  func (m *clientHelloMsg) equal(i interface{}) bool {
    34  	m1, ok := i.(*clientHelloMsg)
    35  	if !ok {
    36  		return false
    37  	}
    38  
    39  	return bytes.Equal(m.raw, m1.raw) &&
    40  		m.vers == m1.vers &&
    41  		bytes.Equal(m.random, m1.random) &&
    42  		bytes.Equal(m.sessionId, m1.sessionId) &&
    43  		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    44  		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    45  		m.nextProtoNeg == m1.nextProtoNeg &&
    46  		m.serverName == m1.serverName &&
    47  		m.ocspStapling == m1.ocspStapling &&
    48  		m.scts == m1.scts &&
    49  		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
    50  		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    51  		m.ticketSupported == m1.ticketSupported &&
    52  		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    53  		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
    54  		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
    55  		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
    56  		eqStrings(m.alpnProtocols, m1.alpnProtocols)
    57  }
    58  
    59  func (m *clientHelloMsg) marshal() []byte {
    60  	if m.raw != nil {
    61  		return m.raw
    62  	}
    63  
    64  	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    65  	numExtensions := 0
    66  	extensionsLength := 0
    67  	if m.nextProtoNeg {
    68  		numExtensions++
    69  	}
    70  	if m.ocspStapling {
    71  		extensionsLength += 1 + 2 + 2
    72  		numExtensions++
    73  	}
    74  	if len(m.serverName) > 0 {
    75  		extensionsLength += 5 + len(m.serverName)
    76  		numExtensions++
    77  	}
    78  	if len(m.supportedCurves) > 0 {
    79  		extensionsLength += 2 + 2*len(m.supportedCurves)
    80  		numExtensions++
    81  	}
    82  	if len(m.supportedPoints) > 0 {
    83  		extensionsLength += 1 + len(m.supportedPoints)
    84  		numExtensions++
    85  	}
    86  	if m.ticketSupported {
    87  		extensionsLength += len(m.sessionTicket)
    88  		numExtensions++
    89  	}
    90  	if len(m.supportedSignatureAlgorithms) > 0 {
    91  		extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
    92  		numExtensions++
    93  	}
    94  	if m.secureRenegotiationSupported {
    95  		extensionsLength += 1 + len(m.secureRenegotiation)
    96  		numExtensions++
    97  	}
    98  	if len(m.alpnProtocols) > 0 {
    99  		extensionsLength += 2
   100  		for _, s := range m.alpnProtocols {
   101  			if l := len(s); l == 0 || l > 255 {
   102  				panic("invalid ALPN protocol")
   103  			}
   104  			extensionsLength++
   105  			extensionsLength += len(s)
   106  		}
   107  		numExtensions++
   108  	}
   109  	if m.scts {
   110  		numExtensions++
   111  	}
   112  	if numExtensions > 0 {
   113  		extensionsLength += 4 * numExtensions
   114  		length += 2 + extensionsLength
   115  	}
   116  
   117  	x := make([]byte, 4+length)
   118  	x[0] = typeClientHello
   119  	x[1] = uint8(length >> 16)
   120  	x[2] = uint8(length >> 8)
   121  	x[3] = uint8(length)
   122  	x[4] = uint8(m.vers >> 8)
   123  	x[5] = uint8(m.vers)
   124  	copy(x[6:38], m.random)
   125  	x[38] = uint8(len(m.sessionId))
   126  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   127  	y := x[39+len(m.sessionId):]
   128  	y[0] = uint8(len(m.cipherSuites) >> 7)
   129  	y[1] = uint8(len(m.cipherSuites) << 1)
   130  	for i, suite := range m.cipherSuites {
   131  		y[2+i*2] = uint8(suite >> 8)
   132  		y[3+i*2] = uint8(suite)
   133  	}
   134  	z := y[2+len(m.cipherSuites)*2:]
   135  	z[0] = uint8(len(m.compressionMethods))
   136  	copy(z[1:], m.compressionMethods)
   137  
   138  	z = z[1+len(m.compressionMethods):]
   139  	if numExtensions > 0 {
   140  		z[0] = byte(extensionsLength >> 8)
   141  		z[1] = byte(extensionsLength)
   142  		z = z[2:]
   143  	}
   144  	if m.nextProtoNeg {
   145  		z[0] = byte(extensionNextProtoNeg >> 8)
   146  		z[1] = byte(extensionNextProtoNeg & 0xff)
   147  		// The length is always 0
   148  		z = z[4:]
   149  	}
   150  	if len(m.serverName) > 0 {
   151  		z[0] = byte(extensionServerName >> 8)
   152  		z[1] = byte(extensionServerName & 0xff)
   153  		l := len(m.serverName) + 5
   154  		z[2] = byte(l >> 8)
   155  		z[3] = byte(l)
   156  		z = z[4:]
   157  
   158  		// RFC 3546, Section 3.1
   159  		//
   160  		// struct {
   161  		//     NameType name_type;
   162  		//     select (name_type) {
   163  		//         case host_name: HostName;
   164  		//     } name;
   165  		// } ServerName;
   166  		//
   167  		// enum {
   168  		//     host_name(0), (255)
   169  		// } NameType;
   170  		//
   171  		// opaque HostName<1..2^16-1>;
   172  		//
   173  		// struct {
   174  		//     ServerName server_name_list<1..2^16-1>
   175  		// } ServerNameList;
   176  
   177  		z[0] = byte((len(m.serverName) + 3) >> 8)
   178  		z[1] = byte(len(m.serverName) + 3)
   179  		z[3] = byte(len(m.serverName) >> 8)
   180  		z[4] = byte(len(m.serverName))
   181  		copy(z[5:], []byte(m.serverName))
   182  		z = z[l:]
   183  	}
   184  	if m.ocspStapling {
   185  		// RFC 4366, Section 3.6
   186  		z[0] = byte(extensionStatusRequest >> 8)
   187  		z[1] = byte(extensionStatusRequest)
   188  		z[2] = 0
   189  		z[3] = 5
   190  		z[4] = 1 // OCSP type
   191  		// Two zero valued uint16s for the two lengths.
   192  		z = z[9:]
   193  	}
   194  	if len(m.supportedCurves) > 0 {
   195  		// RFC 4492, Section 5.5.1
   196  		z[0] = byte(extensionSupportedCurves >> 8)
   197  		z[1] = byte(extensionSupportedCurves)
   198  		l := 2 + 2*len(m.supportedCurves)
   199  		z[2] = byte(l >> 8)
   200  		z[3] = byte(l)
   201  		l -= 2
   202  		z[4] = byte(l >> 8)
   203  		z[5] = byte(l)
   204  		z = z[6:]
   205  		for _, curve := range m.supportedCurves {
   206  			z[0] = byte(curve >> 8)
   207  			z[1] = byte(curve)
   208  			z = z[2:]
   209  		}
   210  	}
   211  	if len(m.supportedPoints) > 0 {
   212  		// RFC 4492, Section 5.5.2
   213  		z[0] = byte(extensionSupportedPoints >> 8)
   214  		z[1] = byte(extensionSupportedPoints)
   215  		l := 1 + len(m.supportedPoints)
   216  		z[2] = byte(l >> 8)
   217  		z[3] = byte(l)
   218  		l--
   219  		z[4] = byte(l)
   220  		z = z[5:]
   221  		for _, pointFormat := range m.supportedPoints {
   222  			z[0] = pointFormat
   223  			z = z[1:]
   224  		}
   225  	}
   226  	if m.ticketSupported {
   227  		// RFC 5077, Section 3.2
   228  		z[0] = byte(extensionSessionTicket >> 8)
   229  		z[1] = byte(extensionSessionTicket)
   230  		l := len(m.sessionTicket)
   231  		z[2] = byte(l >> 8)
   232  		z[3] = byte(l)
   233  		z = z[4:]
   234  		copy(z, m.sessionTicket)
   235  		z = z[len(m.sessionTicket):]
   236  	}
   237  	if len(m.supportedSignatureAlgorithms) > 0 {
   238  		// RFC 5246, Section 7.4.1.4.1
   239  		z[0] = byte(extensionSignatureAlgorithms >> 8)
   240  		z[1] = byte(extensionSignatureAlgorithms)
   241  		l := 2 + 2*len(m.supportedSignatureAlgorithms)
   242  		z[2] = byte(l >> 8)
   243  		z[3] = byte(l)
   244  		z = z[4:]
   245  
   246  		l -= 2
   247  		z[0] = byte(l >> 8)
   248  		z[1] = byte(l)
   249  		z = z[2:]
   250  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
   251  			z[0] = byte(sigAlgo >> 8)
   252  			z[1] = byte(sigAlgo)
   253  			z = z[2:]
   254  		}
   255  	}
   256  	if m.secureRenegotiationSupported {
   257  		z[0] = byte(extensionRenegotiationInfo >> 8)
   258  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   259  		z[2] = 0
   260  		z[3] = byte(len(m.secureRenegotiation) + 1)
   261  		z[4] = byte(len(m.secureRenegotiation))
   262  		z = z[5:]
   263  		copy(z, m.secureRenegotiation)
   264  		z = z[len(m.secureRenegotiation):]
   265  	}
   266  	if len(m.alpnProtocols) > 0 {
   267  		z[0] = byte(extensionALPN >> 8)
   268  		z[1] = byte(extensionALPN & 0xff)
   269  		lengths := z[2:]
   270  		z = z[6:]
   271  
   272  		stringsLength := 0
   273  		for _, s := range m.alpnProtocols {
   274  			l := len(s)
   275  			z[0] = byte(l)
   276  			copy(z[1:], s)
   277  			z = z[1+l:]
   278  			stringsLength += 1 + l
   279  		}
   280  
   281  		lengths[2] = byte(stringsLength >> 8)
   282  		lengths[3] = byte(stringsLength)
   283  		stringsLength += 2
   284  		lengths[0] = byte(stringsLength >> 8)
   285  		lengths[1] = byte(stringsLength)
   286  	}
   287  	if m.scts {
   288  		// RFC 6962, Section 3.3.1
   289  		z[0] = byte(extensionSCT >> 8)
   290  		z[1] = byte(extensionSCT)
   291  		// zero uint16 for the zero-length extension_data
   292  		z = z[4:]
   293  	}
   294  
   295  	m.raw = x
   296  
   297  	return x
   298  }
   299  
   300  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   301  	if len(data) < 42 {
   302  		return false
   303  	}
   304  	m.raw = data
   305  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   306  	m.random = data[6:38]
   307  	sessionIdLen := int(data[38])
   308  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   309  		return false
   310  	}
   311  	m.sessionId = data[39 : 39+sessionIdLen]
   312  	data = data[39+sessionIdLen:]
   313  	if len(data) < 2 {
   314  		return false
   315  	}
   316  	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   317  	// they are uint16s, the number must be even.
   318  	cipherSuiteLen := int(data[0])<<8 | int(data[1])
   319  	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   320  		return false
   321  	}
   322  	numCipherSuites := cipherSuiteLen / 2
   323  	m.cipherSuites = make([]uint16, numCipherSuites)
   324  	for i := 0; i < numCipherSuites; i++ {
   325  		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   326  		if m.cipherSuites[i] == scsvRenegotiation {
   327  			m.secureRenegotiationSupported = true
   328  		}
   329  	}
   330  	data = data[2+cipherSuiteLen:]
   331  	if len(data) < 1 {
   332  		return false
   333  	}
   334  	compressionMethodsLen := int(data[0])
   335  	if len(data) < 1+compressionMethodsLen {
   336  		return false
   337  	}
   338  	m.compressionMethods = data[1 : 1+compressionMethodsLen]
   339  
   340  	data = data[1+compressionMethodsLen:]
   341  
   342  	m.nextProtoNeg = false
   343  	m.serverName = ""
   344  	m.ocspStapling = false
   345  	m.ticketSupported = false
   346  	m.sessionTicket = nil
   347  	m.supportedSignatureAlgorithms = nil
   348  	m.alpnProtocols = nil
   349  	m.scts = false
   350  
   351  	if len(data) == 0 {
   352  		// ClientHello is optionally followed by extension data
   353  		return true
   354  	}
   355  	if len(data) < 2 {
   356  		return false
   357  	}
   358  
   359  	extensionsLength := int(data[0])<<8 | int(data[1])
   360  	data = data[2:]
   361  	if extensionsLength != len(data) {
   362  		return false
   363  	}
   364  
   365  	for len(data) != 0 {
   366  		if len(data) < 4 {
   367  			return false
   368  		}
   369  		extension := uint16(data[0])<<8 | uint16(data[1])
   370  		length := int(data[2])<<8 | int(data[3])
   371  		data = data[4:]
   372  		if len(data) < length {
   373  			return false
   374  		}
   375  
   376  		switch extension {
   377  		case extensionServerName:
   378  			d := data[:length]
   379  			if len(d) < 2 {
   380  				return false
   381  			}
   382  			namesLen := int(d[0])<<8 | int(d[1])
   383  			d = d[2:]
   384  			if len(d) != namesLen {
   385  				return false
   386  			}
   387  			for len(d) > 0 {
   388  				if len(d) < 3 {
   389  					return false
   390  				}
   391  				nameType := d[0]
   392  				nameLen := int(d[1])<<8 | int(d[2])
   393  				d = d[3:]
   394  				if len(d) < nameLen {
   395  					return false
   396  				}
   397  				if nameType == 0 {
   398  					m.serverName = string(d[:nameLen])
   399  					// An SNI value may not include a trailing dot.
   400  					// See RFC 6066, Section 3.
   401  					if strings.HasSuffix(m.serverName, ".") {
   402  						return false
   403  					}
   404  					break
   405  				}
   406  				d = d[nameLen:]
   407  			}
   408  		case extensionNextProtoNeg:
   409  			if length > 0 {
   410  				return false
   411  			}
   412  			m.nextProtoNeg = true
   413  		case extensionStatusRequest:
   414  			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   415  		case extensionSupportedCurves:
   416  			// RFC 4492, Section 5.5.1
   417  			if length < 2 {
   418  				return false
   419  			}
   420  			l := int(data[0])<<8 | int(data[1])
   421  			if l%2 == 1 || length != l+2 {
   422  				return false
   423  			}
   424  			numCurves := l / 2
   425  			m.supportedCurves = make([]CurveID, numCurves)
   426  			d := data[2:]
   427  			for i := 0; i < numCurves; i++ {
   428  				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
   429  				d = d[2:]
   430  			}
   431  		case extensionSupportedPoints:
   432  			// RFC 4492, Section 5.5.2
   433  			if length < 1 {
   434  				return false
   435  			}
   436  			l := int(data[0])
   437  			if length != l+1 {
   438  				return false
   439  			}
   440  			m.supportedPoints = make([]uint8, l)
   441  			copy(m.supportedPoints, data[1:])
   442  		case extensionSessionTicket:
   443  			// RFC 5077, Section 3.2
   444  			m.ticketSupported = true
   445  			m.sessionTicket = data[:length]
   446  		case extensionSignatureAlgorithms:
   447  			// RFC 5246, Section 7.4.1.4.1
   448  			if length < 2 || length&1 != 0 {
   449  				return false
   450  			}
   451  			l := int(data[0])<<8 | int(data[1])
   452  			if l != length-2 {
   453  				return false
   454  			}
   455  			n := l / 2
   456  			d := data[2:]
   457  			m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
   458  			for i := range m.supportedSignatureAlgorithms {
   459  				m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
   460  				d = d[2:]
   461  			}
   462  		case extensionRenegotiationInfo:
   463  			if length == 0 {
   464  				return false
   465  			}
   466  			d := data[:length]
   467  			l := int(d[0])
   468  			d = d[1:]
   469  			if l != len(d) {
   470  				return false
   471  			}
   472  
   473  			m.secureRenegotiation = d
   474  			m.secureRenegotiationSupported = true
   475  		case extensionALPN:
   476  			if length < 2 {
   477  				return false
   478  			}
   479  			l := int(data[0])<<8 | int(data[1])
   480  			if l != length-2 {
   481  				return false
   482  			}
   483  			d := data[2:length]
   484  			for len(d) != 0 {
   485  				stringLen := int(d[0])
   486  				d = d[1:]
   487  				if stringLen == 0 || stringLen > len(d) {
   488  					return false
   489  				}
   490  				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
   491  				d = d[stringLen:]
   492  			}
   493  		case extensionSCT:
   494  			m.scts = true
   495  			if length != 0 {
   496  				return false
   497  			}
   498  		}
   499  		data = data[length:]
   500  	}
   501  
   502  	return true
   503  }
   504  
   505  type serverHelloMsg struct {
   506  	raw                          []byte
   507  	vers                         uint16
   508  	random                       []byte
   509  	sessionId                    []byte
   510  	cipherSuite                  uint16
   511  	compressionMethod            uint8
   512  	nextProtoNeg                 bool
   513  	nextProtos                   []string
   514  	ocspStapling                 bool
   515  	scts                         [][]byte
   516  	ticketSupported              bool
   517  	secureRenegotiation          []byte
   518  	secureRenegotiationSupported bool
   519  	alpnProtocol                 string
   520  }
   521  
   522  func (m *serverHelloMsg) equal(i interface{}) bool {
   523  	m1, ok := i.(*serverHelloMsg)
   524  	if !ok {
   525  		return false
   526  	}
   527  
   528  	if len(m.scts) != len(m1.scts) {
   529  		return false
   530  	}
   531  	for i, sct := range m.scts {
   532  		if !bytes.Equal(sct, m1.scts[i]) {
   533  			return false
   534  		}
   535  	}
   536  
   537  	return bytes.Equal(m.raw, m1.raw) &&
   538  		m.vers == m1.vers &&
   539  		bytes.Equal(m.random, m1.random) &&
   540  		bytes.Equal(m.sessionId, m1.sessionId) &&
   541  		m.cipherSuite == m1.cipherSuite &&
   542  		m.compressionMethod == m1.compressionMethod &&
   543  		m.nextProtoNeg == m1.nextProtoNeg &&
   544  		eqStrings(m.nextProtos, m1.nextProtos) &&
   545  		m.ocspStapling == m1.ocspStapling &&
   546  		m.ticketSupported == m1.ticketSupported &&
   547  		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
   548  		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
   549  		m.alpnProtocol == m1.alpnProtocol
   550  }
   551  
   552  func (m *serverHelloMsg) marshal() []byte {
   553  	if m.raw != nil {
   554  		return m.raw
   555  	}
   556  
   557  	length := 38 + len(m.sessionId)
   558  	numExtensions := 0
   559  	extensionsLength := 0
   560  
   561  	nextProtoLen := 0
   562  	if m.nextProtoNeg {
   563  		numExtensions++
   564  		for _, v := range m.nextProtos {
   565  			nextProtoLen += len(v)
   566  		}
   567  		nextProtoLen += len(m.nextProtos)
   568  		extensionsLength += nextProtoLen
   569  	}
   570  	if m.ocspStapling {
   571  		numExtensions++
   572  	}
   573  	if m.ticketSupported {
   574  		numExtensions++
   575  	}
   576  	if m.secureRenegotiationSupported {
   577  		extensionsLength += 1 + len(m.secureRenegotiation)
   578  		numExtensions++
   579  	}
   580  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   581  		if alpnLen >= 256 {
   582  			panic("invalid ALPN protocol")
   583  		}
   584  		extensionsLength += 2 + 1 + alpnLen
   585  		numExtensions++
   586  	}
   587  	sctLen := 0
   588  	if len(m.scts) > 0 {
   589  		for _, sct := range m.scts {
   590  			sctLen += len(sct) + 2
   591  		}
   592  		extensionsLength += 2 + sctLen
   593  		numExtensions++
   594  	}
   595  
   596  	if numExtensions > 0 {
   597  		extensionsLength += 4 * numExtensions
   598  		length += 2 + extensionsLength
   599  	}
   600  
   601  	x := make([]byte, 4+length)
   602  	x[0] = typeServerHello
   603  	x[1] = uint8(length >> 16)
   604  	x[2] = uint8(length >> 8)
   605  	x[3] = uint8(length)
   606  	x[4] = uint8(m.vers >> 8)
   607  	x[5] = uint8(m.vers)
   608  	copy(x[6:38], m.random)
   609  	x[38] = uint8(len(m.sessionId))
   610  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   611  	z := x[39+len(m.sessionId):]
   612  	z[0] = uint8(m.cipherSuite >> 8)
   613  	z[1] = uint8(m.cipherSuite)
   614  	z[2] = m.compressionMethod
   615  
   616  	z = z[3:]
   617  	if numExtensions > 0 {
   618  		z[0] = byte(extensionsLength >> 8)
   619  		z[1] = byte(extensionsLength)
   620  		z = z[2:]
   621  	}
   622  	if m.nextProtoNeg {
   623  		z[0] = byte(extensionNextProtoNeg >> 8)
   624  		z[1] = byte(extensionNextProtoNeg & 0xff)
   625  		z[2] = byte(nextProtoLen >> 8)
   626  		z[3] = byte(nextProtoLen)
   627  		z = z[4:]
   628  
   629  		for _, v := range m.nextProtos {
   630  			l := len(v)
   631  			if l > 255 {
   632  				l = 255
   633  			}
   634  			z[0] = byte(l)
   635  			copy(z[1:], []byte(v[0:l]))
   636  			z = z[1+l:]
   637  		}
   638  	}
   639  	if m.ocspStapling {
   640  		z[0] = byte(extensionStatusRequest >> 8)
   641  		z[1] = byte(extensionStatusRequest)
   642  		z = z[4:]
   643  	}
   644  	if m.ticketSupported {
   645  		z[0] = byte(extensionSessionTicket >> 8)
   646  		z[1] = byte(extensionSessionTicket)
   647  		z = z[4:]
   648  	}
   649  	if m.secureRenegotiationSupported {
   650  		z[0] = byte(extensionRenegotiationInfo >> 8)
   651  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   652  		z[2] = 0
   653  		z[3] = byte(len(m.secureRenegotiation) + 1)
   654  		z[4] = byte(len(m.secureRenegotiation))
   655  		z = z[5:]
   656  		copy(z, m.secureRenegotiation)
   657  		z = z[len(m.secureRenegotiation):]
   658  	}
   659  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   660  		z[0] = byte(extensionALPN >> 8)
   661  		z[1] = byte(extensionALPN & 0xff)
   662  		l := 2 + 1 + alpnLen
   663  		z[2] = byte(l >> 8)
   664  		z[3] = byte(l)
   665  		l -= 2
   666  		z[4] = byte(l >> 8)
   667  		z[5] = byte(l)
   668  		l -= 1
   669  		z[6] = byte(l)
   670  		copy(z[7:], []byte(m.alpnProtocol))
   671  		z = z[7+alpnLen:]
   672  	}
   673  	if sctLen > 0 {
   674  		z[0] = byte(extensionSCT >> 8)
   675  		z[1] = byte(extensionSCT)
   676  		l := sctLen + 2
   677  		z[2] = byte(l >> 8)
   678  		z[3] = byte(l)
   679  		z[4] = byte(sctLen >> 8)
   680  		z[5] = byte(sctLen)
   681  
   682  		z = z[6:]
   683  		for _, sct := range m.scts {
   684  			z[0] = byte(len(sct) >> 8)
   685  			z[1] = byte(len(sct))
   686  			copy(z[2:], sct)
   687  			z = z[len(sct)+2:]
   688  		}
   689  	}
   690  
   691  	m.raw = x
   692  
   693  	return x
   694  }
   695  
   696  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   697  	if len(data) < 42 {
   698  		return false
   699  	}
   700  	m.raw = data
   701  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   702  	m.random = data[6:38]
   703  	sessionIdLen := int(data[38])
   704  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   705  		return false
   706  	}
   707  	m.sessionId = data[39 : 39+sessionIdLen]
   708  	data = data[39+sessionIdLen:]
   709  	if len(data) < 3 {
   710  		return false
   711  	}
   712  	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   713  	m.compressionMethod = data[2]
   714  	data = data[3:]
   715  
   716  	m.nextProtoNeg = false
   717  	m.nextProtos = nil
   718  	m.ocspStapling = false
   719  	m.scts = nil
   720  	m.ticketSupported = false
   721  	m.alpnProtocol = ""
   722  
   723  	if len(data) == 0 {
   724  		// ServerHello is optionally followed by extension data
   725  		return true
   726  	}
   727  	if len(data) < 2 {
   728  		return false
   729  	}
   730  
   731  	extensionsLength := int(data[0])<<8 | int(data[1])
   732  	data = data[2:]
   733  	if len(data) != extensionsLength {
   734  		return false
   735  	}
   736  
   737  	for len(data) != 0 {
   738  		if len(data) < 4 {
   739  			return false
   740  		}
   741  		extension := uint16(data[0])<<8 | uint16(data[1])
   742  		length := int(data[2])<<8 | int(data[3])
   743  		data = data[4:]
   744  		if len(data) < length {
   745  			return false
   746  		}
   747  
   748  		switch extension {
   749  		case extensionNextProtoNeg:
   750  			m.nextProtoNeg = true
   751  			d := data[:length]
   752  			for len(d) > 0 {
   753  				l := int(d[0])
   754  				d = d[1:]
   755  				if l == 0 || l > len(d) {
   756  					return false
   757  				}
   758  				m.nextProtos = append(m.nextProtos, string(d[:l]))
   759  				d = d[l:]
   760  			}
   761  		case extensionStatusRequest:
   762  			if length > 0 {
   763  				return false
   764  			}
   765  			m.ocspStapling = true
   766  		case extensionSessionTicket:
   767  			if length > 0 {
   768  				return false
   769  			}
   770  			m.ticketSupported = true
   771  		case extensionRenegotiationInfo:
   772  			if length == 0 {
   773  				return false
   774  			}
   775  			d := data[:length]
   776  			l := int(d[0])
   777  			d = d[1:]
   778  			if l != len(d) {
   779  				return false
   780  			}
   781  
   782  			m.secureRenegotiation = d
   783  			m.secureRenegotiationSupported = true
   784  		case extensionALPN:
   785  			d := data[:length]
   786  			if len(d) < 3 {
   787  				return false
   788  			}
   789  			l := int(d[0])<<8 | int(d[1])
   790  			if l != len(d)-2 {
   791  				return false
   792  			}
   793  			d = d[2:]
   794  			l = int(d[0])
   795  			if l != len(d)-1 {
   796  				return false
   797  			}
   798  			d = d[1:]
   799  			if len(d) == 0 {
   800  				// ALPN protocols must not be empty.
   801  				return false
   802  			}
   803  			m.alpnProtocol = string(d)
   804  		case extensionSCT:
   805  			d := data[:length]
   806  
   807  			if len(d) < 2 {
   808  				return false
   809  			}
   810  			l := int(d[0])<<8 | int(d[1])
   811  			d = d[2:]
   812  			if len(d) != l || l == 0 {
   813  				return false
   814  			}
   815  
   816  			m.scts = make([][]byte, 0, 3)
   817  			for len(d) != 0 {
   818  				if len(d) < 2 {
   819  					return false
   820  				}
   821  				sctLen := int(d[0])<<8 | int(d[1])
   822  				d = d[2:]
   823  				if sctLen == 0 || len(d) < sctLen {
   824  					return false
   825  				}
   826  				m.scts = append(m.scts, d[:sctLen])
   827  				d = d[sctLen:]
   828  			}
   829  		}
   830  		data = data[length:]
   831  	}
   832  
   833  	return true
   834  }
   835  
   836  type certificateMsg struct {
   837  	raw          []byte
   838  	certificates [][]byte
   839  }
   840  
   841  func (m *certificateMsg) equal(i interface{}) bool {
   842  	m1, ok := i.(*certificateMsg)
   843  	if !ok {
   844  		return false
   845  	}
   846  
   847  	return bytes.Equal(m.raw, m1.raw) &&
   848  		eqByteSlices(m.certificates, m1.certificates)
   849  }
   850  
   851  func (m *certificateMsg) marshal() (x []byte) {
   852  	if m.raw != nil {
   853  		return m.raw
   854  	}
   855  
   856  	var i int
   857  	for _, slice := range m.certificates {
   858  		i += len(slice)
   859  	}
   860  
   861  	length := 3 + 3*len(m.certificates) + i
   862  	x = make([]byte, 4+length)
   863  	x[0] = typeCertificate
   864  	x[1] = uint8(length >> 16)
   865  	x[2] = uint8(length >> 8)
   866  	x[3] = uint8(length)
   867  
   868  	certificateOctets := length - 3
   869  	x[4] = uint8(certificateOctets >> 16)
   870  	x[5] = uint8(certificateOctets >> 8)
   871  	x[6] = uint8(certificateOctets)
   872  
   873  	y := x[7:]
   874  	for _, slice := range m.certificates {
   875  		y[0] = uint8(len(slice) >> 16)
   876  		y[1] = uint8(len(slice) >> 8)
   877  		y[2] = uint8(len(slice))
   878  		copy(y[3:], slice)
   879  		y = y[3+len(slice):]
   880  	}
   881  
   882  	m.raw = x
   883  	return
   884  }
   885  
   886  func (m *certificateMsg) unmarshal(data []byte) bool {
   887  	if len(data) < 7 {
   888  		return false
   889  	}
   890  
   891  	m.raw = data
   892  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
   893  	if uint32(len(data)) != certsLen+7 {
   894  		return false
   895  	}
   896  
   897  	numCerts := 0
   898  	d := data[7:]
   899  	for certsLen > 0 {
   900  		if len(d) < 4 {
   901  			return false
   902  		}
   903  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   904  		if uint32(len(d)) < 3+certLen {
   905  			return false
   906  		}
   907  		d = d[3+certLen:]
   908  		certsLen -= 3 + certLen
   909  		numCerts++
   910  	}
   911  
   912  	m.certificates = make([][]byte, numCerts)
   913  	d = data[7:]
   914  	for i := 0; i < numCerts; i++ {
   915  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   916  		m.certificates[i] = d[3 : 3+certLen]
   917  		d = d[3+certLen:]
   918  	}
   919  
   920  	return true
   921  }
   922  
   923  type serverKeyExchangeMsg struct {
   924  	raw []byte
   925  	key []byte
   926  }
   927  
   928  func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
   929  	m1, ok := i.(*serverKeyExchangeMsg)
   930  	if !ok {
   931  		return false
   932  	}
   933  
   934  	return bytes.Equal(m.raw, m1.raw) &&
   935  		bytes.Equal(m.key, m1.key)
   936  }
   937  
   938  func (m *serverKeyExchangeMsg) marshal() []byte {
   939  	if m.raw != nil {
   940  		return m.raw
   941  	}
   942  	length := len(m.key)
   943  	x := make([]byte, length+4)
   944  	x[0] = typeServerKeyExchange
   945  	x[1] = uint8(length >> 16)
   946  	x[2] = uint8(length >> 8)
   947  	x[3] = uint8(length)
   948  	copy(x[4:], m.key)
   949  
   950  	m.raw = x
   951  	return x
   952  }
   953  
   954  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
   955  	m.raw = data
   956  	if len(data) < 4 {
   957  		return false
   958  	}
   959  	m.key = data[4:]
   960  	return true
   961  }
   962  
   963  type certificateStatusMsg struct {
   964  	raw        []byte
   965  	statusType uint8
   966  	response   []byte
   967  }
   968  
   969  func (m *certificateStatusMsg) equal(i interface{}) bool {
   970  	m1, ok := i.(*certificateStatusMsg)
   971  	if !ok {
   972  		return false
   973  	}
   974  
   975  	return bytes.Equal(m.raw, m1.raw) &&
   976  		m.statusType == m1.statusType &&
   977  		bytes.Equal(m.response, m1.response)
   978  }
   979  
   980  func (m *certificateStatusMsg) marshal() []byte {
   981  	if m.raw != nil {
   982  		return m.raw
   983  	}
   984  
   985  	var x []byte
   986  	if m.statusType == statusTypeOCSP {
   987  		x = make([]byte, 4+4+len(m.response))
   988  		x[0] = typeCertificateStatus
   989  		l := len(m.response) + 4
   990  		x[1] = byte(l >> 16)
   991  		x[2] = byte(l >> 8)
   992  		x[3] = byte(l)
   993  		x[4] = statusTypeOCSP
   994  
   995  		l -= 4
   996  		x[5] = byte(l >> 16)
   997  		x[6] = byte(l >> 8)
   998  		x[7] = byte(l)
   999  		copy(x[8:], m.response)
  1000  	} else {
  1001  		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  1002  	}
  1003  
  1004  	m.raw = x
  1005  	return x
  1006  }
  1007  
  1008  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1009  	m.raw = data
  1010  	if len(data) < 5 {
  1011  		return false
  1012  	}
  1013  	m.statusType = data[4]
  1014  
  1015  	m.response = nil
  1016  	if m.statusType == statusTypeOCSP {
  1017  		if len(data) < 8 {
  1018  			return false
  1019  		}
  1020  		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1021  		if uint32(len(data)) != 4+4+respLen {
  1022  			return false
  1023  		}
  1024  		m.response = data[8:]
  1025  	}
  1026  	return true
  1027  }
  1028  
  1029  type serverHelloDoneMsg struct{}
  1030  
  1031  func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  1032  	_, ok := i.(*serverHelloDoneMsg)
  1033  	return ok
  1034  }
  1035  
  1036  func (m *serverHelloDoneMsg) marshal() []byte {
  1037  	x := make([]byte, 4)
  1038  	x[0] = typeServerHelloDone
  1039  	return x
  1040  }
  1041  
  1042  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1043  	return len(data) == 4
  1044  }
  1045  
  1046  type clientKeyExchangeMsg struct {
  1047  	raw        []byte
  1048  	ciphertext []byte
  1049  }
  1050  
  1051  func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  1052  	m1, ok := i.(*clientKeyExchangeMsg)
  1053  	if !ok {
  1054  		return false
  1055  	}
  1056  
  1057  	return bytes.Equal(m.raw, m1.raw) &&
  1058  		bytes.Equal(m.ciphertext, m1.ciphertext)
  1059  }
  1060  
  1061  func (m *clientKeyExchangeMsg) marshal() []byte {
  1062  	if m.raw != nil {
  1063  		return m.raw
  1064  	}
  1065  	length := len(m.ciphertext)
  1066  	x := make([]byte, length+4)
  1067  	x[0] = typeClientKeyExchange
  1068  	x[1] = uint8(length >> 16)
  1069  	x[2] = uint8(length >> 8)
  1070  	x[3] = uint8(length)
  1071  	copy(x[4:], m.ciphertext)
  1072  
  1073  	m.raw = x
  1074  	return x
  1075  }
  1076  
  1077  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1078  	m.raw = data
  1079  	if len(data) < 4 {
  1080  		return false
  1081  	}
  1082  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1083  	if l != len(data)-4 {
  1084  		return false
  1085  	}
  1086  	m.ciphertext = data[4:]
  1087  	return true
  1088  }
  1089  
  1090  type finishedMsg struct {
  1091  	raw        []byte
  1092  	verifyData []byte
  1093  }
  1094  
  1095  func (m *finishedMsg) equal(i interface{}) bool {
  1096  	m1, ok := i.(*finishedMsg)
  1097  	if !ok {
  1098  		return false
  1099  	}
  1100  
  1101  	return bytes.Equal(m.raw, m1.raw) &&
  1102  		bytes.Equal(m.verifyData, m1.verifyData)
  1103  }
  1104  
  1105  func (m *finishedMsg) marshal() (x []byte) {
  1106  	if m.raw != nil {
  1107  		return m.raw
  1108  	}
  1109  
  1110  	x = make([]byte, 4+len(m.verifyData))
  1111  	x[0] = typeFinished
  1112  	x[3] = byte(len(m.verifyData))
  1113  	copy(x[4:], m.verifyData)
  1114  	m.raw = x
  1115  	return
  1116  }
  1117  
  1118  func (m *finishedMsg) unmarshal(data []byte) bool {
  1119  	m.raw = data
  1120  	if len(data) < 4 {
  1121  		return false
  1122  	}
  1123  	m.verifyData = data[4:]
  1124  	return true
  1125  }
  1126  
  1127  type nextProtoMsg struct {
  1128  	raw   []byte
  1129  	proto string
  1130  }
  1131  
  1132  func (m *nextProtoMsg) equal(i interface{}) bool {
  1133  	m1, ok := i.(*nextProtoMsg)
  1134  	if !ok {
  1135  		return false
  1136  	}
  1137  
  1138  	return bytes.Equal(m.raw, m1.raw) &&
  1139  		m.proto == m1.proto
  1140  }
  1141  
  1142  func (m *nextProtoMsg) marshal() []byte {
  1143  	if m.raw != nil {
  1144  		return m.raw
  1145  	}
  1146  	l := len(m.proto)
  1147  	if l > 255 {
  1148  		l = 255
  1149  	}
  1150  
  1151  	padding := 32 - (l+2)%32
  1152  	length := l + padding + 2
  1153  	x := make([]byte, length+4)
  1154  	x[0] = typeNextProtocol
  1155  	x[1] = uint8(length >> 16)
  1156  	x[2] = uint8(length >> 8)
  1157  	x[3] = uint8(length)
  1158  
  1159  	y := x[4:]
  1160  	y[0] = byte(l)
  1161  	copy(y[1:], []byte(m.proto[0:l]))
  1162  	y = y[1+l:]
  1163  	y[0] = byte(padding)
  1164  
  1165  	m.raw = x
  1166  
  1167  	return x
  1168  }
  1169  
  1170  func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1171  	m.raw = data
  1172  
  1173  	if len(data) < 5 {
  1174  		return false
  1175  	}
  1176  	data = data[4:]
  1177  	protoLen := int(data[0])
  1178  	data = data[1:]
  1179  	if len(data) < protoLen {
  1180  		return false
  1181  	}
  1182  	m.proto = string(data[0:protoLen])
  1183  	data = data[protoLen:]
  1184  
  1185  	if len(data) < 1 {
  1186  		return false
  1187  	}
  1188  	paddingLen := int(data[0])
  1189  	data = data[1:]
  1190  	if len(data) != paddingLen {
  1191  		return false
  1192  	}
  1193  
  1194  	return true
  1195  }
  1196  
  1197  type certificateRequestMsg struct {
  1198  	raw []byte
  1199  	// hasSignatureAndHash indicates whether this message includes a list
  1200  	// of signature and hash functions. This change was introduced with TLS
  1201  	// 1.2.
  1202  	hasSignatureAndHash bool
  1203  
  1204  	certificateTypes             []byte
  1205  	supportedSignatureAlgorithms []SignatureScheme
  1206  	certificateAuthorities       [][]byte
  1207  }
  1208  
  1209  func (m *certificateRequestMsg) equal(i interface{}) bool {
  1210  	m1, ok := i.(*certificateRequestMsg)
  1211  	if !ok {
  1212  		return false
  1213  	}
  1214  
  1215  	return bytes.Equal(m.raw, m1.raw) &&
  1216  		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1217  		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1218  		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
  1219  }
  1220  
  1221  func (m *certificateRequestMsg) marshal() (x []byte) {
  1222  	if m.raw != nil {
  1223  		return m.raw
  1224  	}
  1225  
  1226  	// See RFC 4346, Section 7.4.4.
  1227  	length := 1 + len(m.certificateTypes) + 2
  1228  	casLength := 0
  1229  	for _, ca := range m.certificateAuthorities {
  1230  		casLength += 2 + len(ca)
  1231  	}
  1232  	length += casLength
  1233  
  1234  	if m.hasSignatureAndHash {
  1235  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1236  	}
  1237  
  1238  	x = make([]byte, 4+length)
  1239  	x[0] = typeCertificateRequest
  1240  	x[1] = uint8(length >> 16)
  1241  	x[2] = uint8(length >> 8)
  1242  	x[3] = uint8(length)
  1243  
  1244  	x[4] = uint8(len(m.certificateTypes))
  1245  
  1246  	copy(x[5:], m.certificateTypes)
  1247  	y := x[5+len(m.certificateTypes):]
  1248  
  1249  	if m.hasSignatureAndHash {
  1250  		n := len(m.supportedSignatureAlgorithms) * 2
  1251  		y[0] = uint8(n >> 8)
  1252  		y[1] = uint8(n)
  1253  		y = y[2:]
  1254  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1255  			y[0] = uint8(sigAlgo >> 8)
  1256  			y[1] = uint8(sigAlgo)
  1257  			y = y[2:]
  1258  		}
  1259  	}
  1260  
  1261  	y[0] = uint8(casLength >> 8)
  1262  	y[1] = uint8(casLength)
  1263  	y = y[2:]
  1264  	for _, ca := range m.certificateAuthorities {
  1265  		y[0] = uint8(len(ca) >> 8)
  1266  		y[1] = uint8(len(ca))
  1267  		y = y[2:]
  1268  		copy(y, ca)
  1269  		y = y[len(ca):]
  1270  	}
  1271  
  1272  	m.raw = x
  1273  	return
  1274  }
  1275  
  1276  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1277  	m.raw = data
  1278  
  1279  	if len(data) < 5 {
  1280  		return false
  1281  	}
  1282  
  1283  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1284  	if uint32(len(data))-4 != length {
  1285  		return false
  1286  	}
  1287  
  1288  	numCertTypes := int(data[4])
  1289  	data = data[5:]
  1290  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1291  		return false
  1292  	}
  1293  
  1294  	m.certificateTypes = make([]byte, numCertTypes)
  1295  	if copy(m.certificateTypes, data) != numCertTypes {
  1296  		return false
  1297  	}
  1298  
  1299  	data = data[numCertTypes:]
  1300  
  1301  	if m.hasSignatureAndHash {
  1302  		if len(data) < 2 {
  1303  			return false
  1304  		}
  1305  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1306  		data = data[2:]
  1307  		if sigAndHashLen&1 != 0 {
  1308  			return false
  1309  		}
  1310  		if len(data) < int(sigAndHashLen) {
  1311  			return false
  1312  		}
  1313  		numSigAlgos := sigAndHashLen / 2
  1314  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1315  		for i := range m.supportedSignatureAlgorithms {
  1316  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1317  			data = data[2:]
  1318  		}
  1319  	}
  1320  
  1321  	if len(data) < 2 {
  1322  		return false
  1323  	}
  1324  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1325  	data = data[2:]
  1326  	if len(data) < int(casLength) {
  1327  		return false
  1328  	}
  1329  	cas := make([]byte, casLength)
  1330  	copy(cas, data)
  1331  	data = data[casLength:]
  1332  
  1333  	m.certificateAuthorities = nil
  1334  	for len(cas) > 0 {
  1335  		if len(cas) < 2 {
  1336  			return false
  1337  		}
  1338  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1339  		cas = cas[2:]
  1340  
  1341  		if len(cas) < int(caLen) {
  1342  			return false
  1343  		}
  1344  
  1345  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1346  		cas = cas[caLen:]
  1347  	}
  1348  
  1349  	return len(data) == 0
  1350  }
  1351  
  1352  type certificateVerifyMsg struct {
  1353  	raw                 []byte
  1354  	hasSignatureAndHash bool
  1355  	signatureAlgorithm  SignatureScheme
  1356  	signature           []byte
  1357  }
  1358  
  1359  func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1360  	m1, ok := i.(*certificateVerifyMsg)
  1361  	if !ok {
  1362  		return false
  1363  	}
  1364  
  1365  	return bytes.Equal(m.raw, m1.raw) &&
  1366  		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1367  		m.signatureAlgorithm == m1.signatureAlgorithm &&
  1368  		bytes.Equal(m.signature, m1.signature)
  1369  }
  1370  
  1371  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1372  	if m.raw != nil {
  1373  		return m.raw
  1374  	}
  1375  
  1376  	// See RFC 4346, Section 7.4.8.
  1377  	siglength := len(m.signature)
  1378  	length := 2 + siglength
  1379  	if m.hasSignatureAndHash {
  1380  		length += 2
  1381  	}
  1382  	x = make([]byte, 4+length)
  1383  	x[0] = typeCertificateVerify
  1384  	x[1] = uint8(length >> 16)
  1385  	x[2] = uint8(length >> 8)
  1386  	x[3] = uint8(length)
  1387  	y := x[4:]
  1388  	if m.hasSignatureAndHash {
  1389  		y[0] = uint8(m.signatureAlgorithm >> 8)
  1390  		y[1] = uint8(m.signatureAlgorithm)
  1391  		y = y[2:]
  1392  	}
  1393  	y[0] = uint8(siglength >> 8)
  1394  	y[1] = uint8(siglength)
  1395  	copy(y[2:], m.signature)
  1396  
  1397  	m.raw = x
  1398  
  1399  	return
  1400  }
  1401  
  1402  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1403  	m.raw = data
  1404  
  1405  	if len(data) < 6 {
  1406  		return false
  1407  	}
  1408  
  1409  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1410  	if uint32(len(data))-4 != length {
  1411  		return false
  1412  	}
  1413  
  1414  	data = data[4:]
  1415  	if m.hasSignatureAndHash {
  1416  		m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1417  		data = data[2:]
  1418  	}
  1419  
  1420  	if len(data) < 2 {
  1421  		return false
  1422  	}
  1423  	siglength := int(data[0])<<8 + int(data[1])
  1424  	data = data[2:]
  1425  	if len(data) != siglength {
  1426  		return false
  1427  	}
  1428  
  1429  	m.signature = data
  1430  
  1431  	return true
  1432  }
  1433  
  1434  type newSessionTicketMsg struct {
  1435  	raw    []byte
  1436  	ticket []byte
  1437  }
  1438  
  1439  func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1440  	m1, ok := i.(*newSessionTicketMsg)
  1441  	if !ok {
  1442  		return false
  1443  	}
  1444  
  1445  	return bytes.Equal(m.raw, m1.raw) &&
  1446  		bytes.Equal(m.ticket, m1.ticket)
  1447  }
  1448  
  1449  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1450  	if m.raw != nil {
  1451  		return m.raw
  1452  	}
  1453  
  1454  	// See RFC 5077, Section 3.3.
  1455  	ticketLen := len(m.ticket)
  1456  	length := 2 + 4 + ticketLen
  1457  	x = make([]byte, 4+length)
  1458  	x[0] = typeNewSessionTicket
  1459  	x[1] = uint8(length >> 16)
  1460  	x[2] = uint8(length >> 8)
  1461  	x[3] = uint8(length)
  1462  	x[8] = uint8(ticketLen >> 8)
  1463  	x[9] = uint8(ticketLen)
  1464  	copy(x[10:], m.ticket)
  1465  
  1466  	m.raw = x
  1467  
  1468  	return
  1469  }
  1470  
  1471  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1472  	m.raw = data
  1473  
  1474  	if len(data) < 10 {
  1475  		return false
  1476  	}
  1477  
  1478  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1479  	if uint32(len(data))-4 != length {
  1480  		return false
  1481  	}
  1482  
  1483  	ticketLen := int(data[8])<<8 + int(data[9])
  1484  	if len(data)-10 != ticketLen {
  1485  		return false
  1486  	}
  1487  
  1488  	m.ticket = data[10:]
  1489  
  1490  	return true
  1491  }
  1492  
  1493  type helloRequestMsg struct {
  1494  }
  1495  
  1496  func (*helloRequestMsg) marshal() []byte {
  1497  	return []byte{typeHelloRequest, 0, 0, 0}
  1498  }
  1499  
  1500  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1501  	return len(data) == 4
  1502  }
  1503  
  1504  func eqUint16s(x, y []uint16) bool {
  1505  	if len(x) != len(y) {
  1506  		return false
  1507  	}
  1508  	for i, v := range x {
  1509  		if y[i] != v {
  1510  			return false
  1511  		}
  1512  	}
  1513  	return true
  1514  }
  1515  
  1516  func eqCurveIDs(x, y []CurveID) bool {
  1517  	if len(x) != len(y) {
  1518  		return false
  1519  	}
  1520  	for i, v := range x {
  1521  		if y[i] != v {
  1522  			return false
  1523  		}
  1524  	}
  1525  	return true
  1526  }
  1527  
  1528  func eqStrings(x, y []string) bool {
  1529  	if len(x) != len(y) {
  1530  		return false
  1531  	}
  1532  	for i, v := range x {
  1533  		if y[i] != v {
  1534  			return false
  1535  		}
  1536  	}
  1537  	return true
  1538  }
  1539  
  1540  func eqByteSlices(x, y [][]byte) bool {
  1541  	if len(x) != len(y) {
  1542  		return false
  1543  	}
  1544  	for i, v := range x {
  1545  		if !bytes.Equal(v, y[i]) {
  1546  			return false
  1547  		}
  1548  	}
  1549  	return true
  1550  }
  1551  
  1552  func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
  1553  	if len(x) != len(y) {
  1554  		return false
  1555  	}
  1556  	for i, v := range x {
  1557  		if v != y[i] {
  1558  			return false
  1559  		}
  1560  	}
  1561  	return true
  1562  }