github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/handshake_messages.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package gmtls
    17  
    18  import "bytes"
    19  
    20  type clientHelloMsg struct {
    21  	raw                          []byte
    22  	vers                         uint16
    23  	random                       []byte
    24  	sessionId                    []byte
    25  	cipherSuites                 []uint16
    26  	compressionMethods           []uint8
    27  	nextProtoNeg                 bool
    28  	serverName                   string
    29  	ocspStapling                 bool
    30  	scts                         bool
    31  	supportedCurves              []CurveID
    32  	supportedPoints              []uint8
    33  	ticketSupported              bool
    34  	sessionTicket                []uint8
    35  	supportedSignatureAlgorithms []SignatureScheme
    36  	secureRenegotiation          []byte
    37  	secureRenegotiationSupported bool
    38  	alpnProtocols                []string
    39  }
    40  
    41  func (m *clientHelloMsg) equal(i interface{}) bool {
    42  	m1, ok := i.(*clientHelloMsg)
    43  	if !ok {
    44  		return false
    45  	}
    46  
    47  	return bytes.Equal(m.raw, m1.raw) &&
    48  		m.vers == m1.vers &&
    49  		bytes.Equal(m.random, m1.random) &&
    50  		bytes.Equal(m.sessionId, m1.sessionId) &&
    51  		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    52  		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    53  		m.nextProtoNeg == m1.nextProtoNeg &&
    54  		m.serverName == m1.serverName &&
    55  		m.ocspStapling == m1.ocspStapling &&
    56  		m.scts == m1.scts &&
    57  		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
    58  		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    59  		m.ticketSupported == m1.ticketSupported &&
    60  		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    61  		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
    62  		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
    63  		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
    64  		eqStrings(m.alpnProtocols, m1.alpnProtocols)
    65  }
    66  
    67  func (m *clientHelloMsg) marshal() []byte {
    68  	if m.raw != nil {
    69  		return m.raw
    70  	}
    71  
    72  	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    73  	numExtensions := 0
    74  	extensionsLength := 0
    75  	if m.nextProtoNeg {
    76  		numExtensions++
    77  	}
    78  	if m.ocspStapling {
    79  		extensionsLength += 1 + 2 + 2
    80  		numExtensions++
    81  	}
    82  	if len(m.serverName) > 0 {
    83  		extensionsLength += 5 + len(m.serverName)
    84  		numExtensions++
    85  	}
    86  	if len(m.supportedCurves) > 0 {
    87  		extensionsLength += 2 + 2*len(m.supportedCurves)
    88  		numExtensions++
    89  	}
    90  	if len(m.supportedPoints) > 0 {
    91  		extensionsLength += 1 + len(m.supportedPoints)
    92  		numExtensions++
    93  	}
    94  	if m.ticketSupported {
    95  		extensionsLength += len(m.sessionTicket)
    96  		numExtensions++
    97  	}
    98  	if len(m.supportedSignatureAlgorithms) > 0 {
    99  		extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
   100  		numExtensions++
   101  	}
   102  	if m.secureRenegotiationSupported {
   103  		extensionsLength += 1 + len(m.secureRenegotiation)
   104  		numExtensions++
   105  	}
   106  	if len(m.alpnProtocols) > 0 {
   107  		extensionsLength += 2
   108  		for _, s := range m.alpnProtocols {
   109  			if l := len(s); l == 0 || l > 255 {
   110  				panic("invalid ALPN protocol")
   111  			}
   112  			extensionsLength++
   113  			extensionsLength += len(s)
   114  		}
   115  		numExtensions++
   116  	}
   117  	if m.scts {
   118  		numExtensions++
   119  	}
   120  	if numExtensions > 0 {
   121  		extensionsLength += 4 * numExtensions
   122  		length += 2 + extensionsLength
   123  	}
   124  
   125  	x := make([]byte, 4+length)
   126  	x[0] = typeClientHello
   127  	x[1] = uint8(length >> 16)
   128  	x[2] = uint8(length >> 8)
   129  	x[3] = uint8(length)
   130  	x[4] = uint8(m.vers >> 8)
   131  	x[5] = uint8(m.vers)
   132  	copy(x[6:38], m.random)
   133  	x[38] = uint8(len(m.sessionId))
   134  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   135  	y := x[39+len(m.sessionId):]
   136  	y[0] = uint8(len(m.cipherSuites) >> 7)
   137  	y[1] = uint8(len(m.cipherSuites) << 1)
   138  	for i, suite := range m.cipherSuites {
   139  		y[2+i*2] = uint8(suite >> 8)
   140  		y[3+i*2] = uint8(suite)
   141  	}
   142  	z := y[2+len(m.cipherSuites)*2:]
   143  	z[0] = uint8(len(m.compressionMethods))
   144  	copy(z[1:], m.compressionMethods)
   145  
   146  	z = z[1+len(m.compressionMethods):]
   147  	if numExtensions > 0 {
   148  		z[0] = byte(extensionsLength >> 8)
   149  		z[1] = byte(extensionsLength)
   150  		z = z[2:]
   151  	}
   152  	if m.nextProtoNeg {
   153  		z[0] = byte(extensionNextProtoNeg >> 8)
   154  		z[1] = byte(extensionNextProtoNeg & 0xff)
   155  		// The length is always 0
   156  		z = z[4:]
   157  	}
   158  	if len(m.serverName) > 0 {
   159  		z[0] = byte(extensionServerName >> 8)
   160  		z[1] = byte(extensionServerName & 0xff)
   161  		l := len(m.serverName) + 5
   162  		z[2] = byte(l >> 8)
   163  		z[3] = byte(l)
   164  		z = z[4:]
   165  
   166  		// RFC 3546, section 3.1
   167  		//
   168  		// struct {
   169  		//     NameType name_type;
   170  		//     select (name_type) {
   171  		//         case host_name: HostName;
   172  		//     } name;
   173  		// } ServerName;
   174  		//
   175  		// enum {
   176  		//     host_name(0), (255)
   177  		// } NameType;
   178  		//
   179  		// opaque HostName<1..2^16-1>;
   180  		//
   181  		// struct {
   182  		//     ServerName server_name_list<1..2^16-1>
   183  		// } ServerNameList;
   184  
   185  		z[0] = byte((len(m.serverName) + 3) >> 8)
   186  		z[1] = byte(len(m.serverName) + 3)
   187  		z[3] = byte(len(m.serverName) >> 8)
   188  		z[4] = byte(len(m.serverName))
   189  		copy(z[5:], []byte(m.serverName))
   190  		z = z[l:]
   191  	}
   192  	if m.ocspStapling {
   193  		// RFC 4366, section 3.6
   194  		z[0] = byte(extensionStatusRequest >> 8)
   195  		z[1] = byte(extensionStatusRequest)
   196  		z[2] = 0
   197  		z[3] = 5
   198  		z[4] = 1 // OCSP type
   199  		// Two zero valued uint16s for the two lengths.
   200  		z = z[9:]
   201  	}
   202  	if len(m.supportedCurves) > 0 {
   203  		// http://tools.ietf.org/html/rfc4492#section-5.5.1
   204  		z[0] = byte(extensionSupportedCurves >> 8)
   205  		z[1] = byte(extensionSupportedCurves)
   206  		l := 2 + 2*len(m.supportedCurves)
   207  		z[2] = byte(l >> 8)
   208  		z[3] = byte(l)
   209  		l -= 2
   210  		z[4] = byte(l >> 8)
   211  		z[5] = byte(l)
   212  		z = z[6:]
   213  		for _, curve := range m.supportedCurves {
   214  			z[0] = byte(curve >> 8)
   215  			z[1] = byte(curve)
   216  			z = z[2:]
   217  		}
   218  	}
   219  	if len(m.supportedPoints) > 0 {
   220  		// http://tools.ietf.org/html/rfc4492#section-5.5.2
   221  		z[0] = byte(extensionSupportedPoints >> 8)
   222  		z[1] = byte(extensionSupportedPoints)
   223  		l := 1 + len(m.supportedPoints)
   224  		z[2] = byte(l >> 8)
   225  		z[3] = byte(l)
   226  		l--
   227  		z[4] = byte(l)
   228  		z = z[5:]
   229  		for _, pointFormat := range m.supportedPoints {
   230  			z[0] = pointFormat
   231  			z = z[1:]
   232  		}
   233  	}
   234  	if m.ticketSupported {
   235  		// http://tools.ietf.org/html/rfc5077#section-3.2
   236  		z[0] = byte(extensionSessionTicket >> 8)
   237  		z[1] = byte(extensionSessionTicket)
   238  		l := len(m.sessionTicket)
   239  		z[2] = byte(l >> 8)
   240  		z[3] = byte(l)
   241  		z = z[4:]
   242  		copy(z, m.sessionTicket)
   243  		z = z[len(m.sessionTicket):]
   244  	}
   245  	if len(m.supportedSignatureAlgorithms) > 0 {
   246  		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   247  		z[0] = byte(extensionSignatureAlgorithms >> 8)
   248  		z[1] = byte(extensionSignatureAlgorithms)
   249  		l := 2 + 2*len(m.supportedSignatureAlgorithms)
   250  		z[2] = byte(l >> 8)
   251  		z[3] = byte(l)
   252  		z = z[4:]
   253  
   254  		l -= 2
   255  		z[0] = byte(l >> 8)
   256  		z[1] = byte(l)
   257  		z = z[2:]
   258  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
   259  			z[0] = byte(sigAlgo >> 8)
   260  			z[1] = byte(sigAlgo)
   261  			z = z[2:]
   262  		}
   263  	}
   264  	if m.secureRenegotiationSupported {
   265  		z[0] = byte(extensionRenegotiationInfo >> 8)
   266  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   267  		z[2] = 0
   268  		z[3] = byte(len(m.secureRenegotiation) + 1)
   269  		z[4] = byte(len(m.secureRenegotiation))
   270  		z = z[5:]
   271  		copy(z, m.secureRenegotiation)
   272  		z = z[len(m.secureRenegotiation):]
   273  	}
   274  	if len(m.alpnProtocols) > 0 {
   275  		z[0] = byte(extensionALPN >> 8)
   276  		z[1] = byte(extensionALPN & 0xff)
   277  		lengths := z[2:]
   278  		z = z[6:]
   279  
   280  		stringsLength := 0
   281  		for _, s := range m.alpnProtocols {
   282  			l := len(s)
   283  			z[0] = byte(l)
   284  			copy(z[1:], s)
   285  			z = z[1+l:]
   286  			stringsLength += 1 + l
   287  		}
   288  
   289  		lengths[2] = byte(stringsLength >> 8)
   290  		lengths[3] = byte(stringsLength)
   291  		stringsLength += 2
   292  		lengths[0] = byte(stringsLength >> 8)
   293  		lengths[1] = byte(stringsLength)
   294  	}
   295  	if m.scts {
   296  		// https://tools.ietf.org/html/rfc6962#section-3.3.1
   297  		z[0] = byte(extensionSCT >> 8)
   298  		z[1] = byte(extensionSCT)
   299  		// zero uint16 for the zero-length extension_data
   300  		z = z[4:]
   301  	}
   302  
   303  	m.raw = x
   304  
   305  	return x
   306  }
   307  
   308  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   309  	if len(data) < 42 {
   310  		return false
   311  	}
   312  	m.raw = data
   313  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   314  	m.random = data[6:38]
   315  	sessionIdLen := int(data[38])
   316  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   317  		return false
   318  	}
   319  	m.sessionId = data[39 : 39+sessionIdLen]
   320  	data = data[39+sessionIdLen:]
   321  	if len(data) < 2 {
   322  		return false
   323  	}
   324  	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   325  	// they are uint16s, the number must be even.
   326  	cipherSuiteLen := int(data[0])<<8 | int(data[1])
   327  	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   328  		return false
   329  	}
   330  	numCipherSuites := cipherSuiteLen / 2
   331  	m.cipherSuites = make([]uint16, numCipherSuites)
   332  	for i := 0; i < numCipherSuites; i++ {
   333  		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   334  		if m.cipherSuites[i] == scsvRenegotiation {
   335  			m.secureRenegotiationSupported = true
   336  		}
   337  	}
   338  	data = data[2+cipherSuiteLen:]
   339  	if len(data) < 1 {
   340  		return false
   341  	}
   342  	compressionMethodsLen := int(data[0])
   343  	if len(data) < 1+compressionMethodsLen {
   344  		return false
   345  	}
   346  	m.compressionMethods = data[1 : 1+compressionMethodsLen]
   347  
   348  	data = data[1+compressionMethodsLen:]
   349  
   350  	m.nextProtoNeg = false
   351  	m.serverName = ""
   352  	m.ocspStapling = false
   353  	m.ticketSupported = false
   354  	m.sessionTicket = nil
   355  	m.supportedSignatureAlgorithms = nil
   356  	m.alpnProtocols = nil
   357  	m.scts = false
   358  
   359  	if len(data) == 0 {
   360  		// ClientHello is optionally followed by extension data
   361  		return true
   362  	}
   363  	if len(data) < 2 {
   364  		return false
   365  	}
   366  
   367  	extensionsLength := int(data[0])<<8 | int(data[1])
   368  	data = data[2:]
   369  	if extensionsLength != len(data) {
   370  		return false
   371  	}
   372  
   373  	for len(data) != 0 {
   374  		if len(data) < 4 {
   375  			return false
   376  		}
   377  		extension := uint16(data[0])<<8 | uint16(data[1])
   378  		length := int(data[2])<<8 | int(data[3])
   379  		data = data[4:]
   380  		if len(data) < length {
   381  			return false
   382  		}
   383  
   384  		switch extension {
   385  		case extensionServerName:
   386  			d := data[:length]
   387  			if len(d) < 2 {
   388  				return false
   389  			}
   390  			namesLen := int(d[0])<<8 | int(d[1])
   391  			d = d[2:]
   392  			if len(d) != namesLen {
   393  				return false
   394  			}
   395  			for len(d) > 0 {
   396  				if len(d) < 3 {
   397  					return false
   398  				}
   399  				nameType := d[0]
   400  				nameLen := int(d[1])<<8 | int(d[2])
   401  				d = d[3:]
   402  				if len(d) < nameLen {
   403  					return false
   404  				}
   405  				if nameType == 0 {
   406  					m.serverName = string(d[:nameLen])
   407  					break
   408  				}
   409  				d = d[nameLen:]
   410  			}
   411  		case extensionNextProtoNeg:
   412  			if length > 0 {
   413  				return false
   414  			}
   415  			m.nextProtoNeg = true
   416  		case extensionStatusRequest:
   417  			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   418  		case extensionSupportedCurves:
   419  			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   420  			if length < 2 {
   421  				return false
   422  			}
   423  			l := int(data[0])<<8 | int(data[1])
   424  			if l%2 == 1 || length != l+2 {
   425  				return false
   426  			}
   427  			numCurves := l / 2
   428  			m.supportedCurves = make([]CurveID, numCurves)
   429  			d := data[2:]
   430  			for i := 0; i < numCurves; i++ {
   431  				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
   432  				d = d[2:]
   433  			}
   434  		case extensionSupportedPoints:
   435  			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   436  			if length < 1 {
   437  				return false
   438  			}
   439  			l := int(data[0])
   440  			if length != l+1 {
   441  				return false
   442  			}
   443  			m.supportedPoints = make([]uint8, l)
   444  			copy(m.supportedPoints, data[1:])
   445  		case extensionSessionTicket:
   446  			// http://tools.ietf.org/html/rfc5077#section-3.2
   447  			m.ticketSupported = true
   448  			m.sessionTicket = data[:length]
   449  		case extensionSignatureAlgorithms:
   450  			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   451  			if length < 2 || length&1 != 0 {
   452  				return false
   453  			}
   454  			l := int(data[0])<<8 | int(data[1])
   455  			if l != length-2 {
   456  				return false
   457  			}
   458  			n := l / 2
   459  			d := data[2:]
   460  			m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
   461  			for i := range m.supportedSignatureAlgorithms {
   462  				m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
   463  				d = d[2:]
   464  			
   465  			}
   466  		case extensionRenegotiationInfo:
   467  			if length == 0 {
   468  				return false
   469  			}
   470  			d := data[:length]
   471  			l := int(d[0])
   472  			d = d[1:]
   473  			if l != len(d) {
   474  				return false
   475  			}
   476  
   477  			m.secureRenegotiation = d
   478  			m.secureRenegotiationSupported = true
   479  		case extensionALPN:
   480  			if length < 2 {
   481  				return false
   482  			}
   483  			l := int(data[0])<<8 | int(data[1])
   484  			if l != length-2 {
   485  				return false
   486  			}
   487  			d := data[2:length]
   488  			for len(d) != 0 {
   489  				stringLen := int(d[0])
   490  				d = d[1:]
   491  				if stringLen == 0 || stringLen > len(d) {
   492  					return false
   493  				}
   494  				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
   495  				d = d[stringLen:]
   496  			}
   497  		case extensionSCT:
   498  			m.scts = true
   499  			if length != 0 {
   500  				return false
   501  			}
   502  		}
   503  		data = data[length:]
   504  	}
   505  
   506  	return true
   507  }
   508  
   509  type serverHelloMsg struct {
   510  	raw                          []byte
   511  	vers                         uint16
   512  	random                       []byte
   513  	sessionId                    []byte
   514  	cipherSuite                  uint16
   515  	compressionMethod            uint8
   516  	nextProtoNeg                 bool
   517  	nextProtos                   []string
   518  	ocspStapling                 bool
   519  	scts                         [][]byte
   520  	ticketSupported              bool
   521  	secureRenegotiation          []byte
   522  	secureRenegotiationSupported bool
   523  	alpnProtocol                 string
   524  }
   525  
   526  func (m *serverHelloMsg) equal(i interface{}) bool {
   527  	m1, ok := i.(*serverHelloMsg)
   528  	if !ok {
   529  		return false
   530  	}
   531  
   532  	if len(m.scts) != len(m1.scts) {
   533  		return false
   534  	}
   535  	for i, sct := range m.scts {
   536  		if !bytes.Equal(sct, m1.scts[i]) {
   537  			return false
   538  		}
   539  	}
   540  
   541  	return bytes.Equal(m.raw, m1.raw) &&
   542  		m.vers == m1.vers &&
   543  		bytes.Equal(m.random, m1.random) &&
   544  		bytes.Equal(m.sessionId, m1.sessionId) &&
   545  		m.cipherSuite == m1.cipherSuite &&
   546  		m.compressionMethod == m1.compressionMethod &&
   547  		m.nextProtoNeg == m1.nextProtoNeg &&
   548  		eqStrings(m.nextProtos, m1.nextProtos) &&
   549  		m.ocspStapling == m1.ocspStapling &&
   550  		m.ticketSupported == m1.ticketSupported &&
   551  		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
   552  		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
   553  		m.alpnProtocol == m1.alpnProtocol
   554  }
   555  
   556  func (m *serverHelloMsg) marshal() []byte {
   557  	if m.raw != nil {
   558  		return m.raw
   559  	}
   560  
   561  	length := 38 + len(m.sessionId)
   562  	numExtensions := 0
   563  	extensionsLength := 0
   564  
   565  	nextProtoLen := 0
   566  	if m.nextProtoNeg {
   567  		numExtensions++
   568  		for _, v := range m.nextProtos {
   569  			nextProtoLen += len(v)
   570  		}
   571  		nextProtoLen += len(m.nextProtos)
   572  		extensionsLength += nextProtoLen
   573  	}
   574  	if m.ocspStapling {
   575  		numExtensions++
   576  	}
   577  	if m.ticketSupported {
   578  		numExtensions++
   579  	}
   580  	if m.secureRenegotiationSupported {
   581  		extensionsLength += 1 + len(m.secureRenegotiation)
   582  		numExtensions++
   583  	}
   584  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   585  		if alpnLen >= 256 {
   586  			panic("invalid ALPN protocol")
   587  		}
   588  		extensionsLength += 2 + 1 + alpnLen
   589  		numExtensions++
   590  	}
   591  	sctLen := 0
   592  	if len(m.scts) > 0 {
   593  		for _, sct := range m.scts {
   594  			sctLen += len(sct) + 2
   595  		}
   596  		extensionsLength += 2 + sctLen
   597  		numExtensions++
   598  	}
   599  
   600  	if numExtensions > 0 {
   601  		extensionsLength += 4 * numExtensions
   602  		length += 2 + extensionsLength
   603  	}
   604  
   605  	x := make([]byte, 4+length)
   606  	x[0] = typeServerHello
   607  	x[1] = uint8(length >> 16)
   608  	x[2] = uint8(length >> 8)
   609  	x[3] = uint8(length)
   610  	x[4] = uint8(m.vers >> 8)
   611  	x[5] = uint8(m.vers)
   612  	copy(x[6:38], m.random)
   613  	x[38] = uint8(len(m.sessionId))
   614  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   615  	z := x[39+len(m.sessionId):]
   616  	z[0] = uint8(m.cipherSuite >> 8)
   617  	z[1] = uint8(m.cipherSuite)
   618  	z[2] = m.compressionMethod
   619  
   620  	z = z[3:]
   621  	if numExtensions > 0 {
   622  		z[0] = byte(extensionsLength >> 8)
   623  		z[1] = byte(extensionsLength)
   624  		z = z[2:]
   625  	}
   626  	if m.nextProtoNeg {
   627  		z[0] = byte(extensionNextProtoNeg >> 8)
   628  		z[1] = byte(extensionNextProtoNeg & 0xff)
   629  		z[2] = byte(nextProtoLen >> 8)
   630  		z[3] = byte(nextProtoLen)
   631  		z = z[4:]
   632  
   633  		for _, v := range m.nextProtos {
   634  			l := len(v)
   635  			if l > 255 {
   636  				l = 255
   637  			}
   638  			z[0] = byte(l)
   639  			copy(z[1:], []byte(v[0:l]))
   640  			z = z[1+l:]
   641  		}
   642  	}
   643  	if m.ocspStapling {
   644  		z[0] = byte(extensionStatusRequest >> 8)
   645  		z[1] = byte(extensionStatusRequest)
   646  		z = z[4:]
   647  	}
   648  	if m.ticketSupported {
   649  		z[0] = byte(extensionSessionTicket >> 8)
   650  		z[1] = byte(extensionSessionTicket)
   651  		z = z[4:]
   652  	}
   653  	if m.secureRenegotiationSupported {
   654  		z[0] = byte(extensionRenegotiationInfo >> 8)
   655  		z[1] = byte(extensionRenegotiationInfo & 0xff)
   656  		z[2] = 0
   657  		z[3] = byte(len(m.secureRenegotiation) + 1)
   658  		z[4] = byte(len(m.secureRenegotiation))
   659  		z = z[5:]
   660  		copy(z, m.secureRenegotiation)
   661  		z = z[len(m.secureRenegotiation):]
   662  	}
   663  	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   664  		z[0] = byte(extensionALPN >> 8)
   665  		z[1] = byte(extensionALPN & 0xff)
   666  		l := 2 + 1 + alpnLen
   667  		z[2] = byte(l >> 8)
   668  		z[3] = byte(l)
   669  		l -= 2
   670  		z[4] = byte(l >> 8)
   671  		z[5] = byte(l)
   672  		l -= 1
   673  		z[6] = byte(l)
   674  		copy(z[7:], []byte(m.alpnProtocol))
   675  		z = z[7+alpnLen:]
   676  	}
   677  	if sctLen > 0 {
   678  		z[0] = byte(extensionSCT >> 8)
   679  		z[1] = byte(extensionSCT)
   680  		l := sctLen + 2
   681  		z[2] = byte(l >> 8)
   682  		z[3] = byte(l)
   683  		z[4] = byte(sctLen >> 8)
   684  		z[5] = byte(sctLen)
   685  
   686  		z = z[6:]
   687  		for _, sct := range m.scts {
   688  			z[0] = byte(len(sct) >> 8)
   689  			z[1] = byte(len(sct))
   690  			copy(z[2:], sct)
   691  			z = z[len(sct)+2:]
   692  		}
   693  	}
   694  
   695  	m.raw = x
   696  
   697  	return x
   698  }
   699  
   700  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   701  	if len(data) < 42 {
   702  		return false
   703  	}
   704  	m.raw = data
   705  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   706  	m.random = data[6:38]
   707  	sessionIdLen := int(data[38])
   708  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   709  		return false
   710  	}
   711  	m.sessionId = data[39 : 39+sessionIdLen]
   712  	data = data[39+sessionIdLen:]
   713  	if len(data) < 3 {
   714  		return false
   715  	}
   716  	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   717  	m.compressionMethod = data[2]
   718  	data = data[3:]
   719  
   720  	m.nextProtoNeg = false
   721  	m.nextProtos = nil
   722  	m.ocspStapling = false
   723  	m.scts = nil
   724  	m.ticketSupported = false
   725  	m.alpnProtocol = ""
   726  
   727  	if len(data) == 0 {
   728  		// ServerHello is optionally followed by extension data
   729  		return true
   730  	}
   731  	if len(data) < 2 {
   732  		return false
   733  	}
   734  
   735  	extensionsLength := int(data[0])<<8 | int(data[1])
   736  	data = data[2:]
   737  	if len(data) != extensionsLength {
   738  		return false
   739  	}
   740  
   741  	for len(data) != 0 {
   742  		if len(data) < 4 {
   743  			return false
   744  		}
   745  		extension := uint16(data[0])<<8 | uint16(data[1])
   746  		length := int(data[2])<<8 | int(data[3])
   747  		data = data[4:]
   748  		if len(data) < length {
   749  			return false
   750  		}
   751  
   752  		switch extension {
   753  		case extensionNextProtoNeg:
   754  			m.nextProtoNeg = true
   755  			d := data[:length]
   756  			for len(d) > 0 {
   757  				l := int(d[0])
   758  				d = d[1:]
   759  				if l == 0 || l > len(d) {
   760  					return false
   761  				}
   762  				m.nextProtos = append(m.nextProtos, string(d[:l]))
   763  				d = d[l:]
   764  			}
   765  		case extensionStatusRequest:
   766  			if length > 0 {
   767  				return false
   768  			}
   769  			m.ocspStapling = true
   770  		case extensionSessionTicket:
   771  			if length > 0 {
   772  				return false
   773  			}
   774  			m.ticketSupported = true
   775  		case extensionRenegotiationInfo:
   776  			if length == 0 {
   777  				return false
   778  			}
   779  			d := data[:length]
   780  			l := int(d[0])
   781  			d = d[1:]
   782  			if l != len(d) {
   783  				return false
   784  			}
   785  
   786  			m.secureRenegotiation = d
   787  			m.secureRenegotiationSupported = true
   788  		case extensionALPN:
   789  			d := data[:length]
   790  			if len(d) < 3 {
   791  				return false
   792  			}
   793  			l := int(d[0])<<8 | int(d[1])
   794  			if l != len(d)-2 {
   795  				return false
   796  			}
   797  			d = d[2:]
   798  			l = int(d[0])
   799  			if l != len(d)-1 {
   800  				return false
   801  			}
   802  			d = d[1:]
   803  			if len(d) == 0 {
   804  				// ALPN protocols must not be empty.
   805  				return false
   806  			}
   807  			m.alpnProtocol = string(d)
   808  		case extensionSCT:
   809  			d := data[:length]
   810  
   811  			if len(d) < 2 {
   812  				return false
   813  			}
   814  			l := int(d[0])<<8 | int(d[1])
   815  			d = d[2:]
   816  			if len(d) != l || l == 0 {
   817  				return false
   818  			}
   819  
   820  			m.scts = make([][]byte, 0, 3)
   821  			for len(d) != 0 {
   822  				if len(d) < 2 {
   823  					return false
   824  				}
   825  				sctLen := int(d[0])<<8 | int(d[1])
   826  				d = d[2:]
   827  				if sctLen == 0 || len(d) < sctLen {
   828  					return false
   829  				}
   830  				m.scts = append(m.scts, d[:sctLen])
   831  				d = d[sctLen:]
   832  			}
   833  		}
   834  		data = data[length:]
   835  	}
   836  
   837  	return true
   838  }
   839  
   840  type certificateMsg struct {
   841  	raw          []byte
   842  	certificates [][]byte
   843  }
   844  
   845  func (m *certificateMsg) equal(i interface{}) bool {
   846  	m1, ok := i.(*certificateMsg)
   847  	if !ok {
   848  		return false
   849  	}
   850  
   851  	return bytes.Equal(m.raw, m1.raw) &&
   852  		eqByteSlices(m.certificates, m1.certificates)
   853  }
   854  
   855  func (m *certificateMsg) marshal() (x []byte) {
   856  	if m.raw != nil {
   857  		return m.raw
   858  	}
   859  
   860  	var i int
   861  	for _, slice := range m.certificates {
   862  		i += len(slice)
   863  	}
   864  
   865  	length := 3 + 3*len(m.certificates) + i
   866  	x = make([]byte, 4+length)
   867  	x[0] = typeCertificate
   868  	x[1] = uint8(length >> 16)
   869  	x[2] = uint8(length >> 8)
   870  	x[3] = uint8(length)
   871  
   872  	certificateOctets := length - 3
   873  	x[4] = uint8(certificateOctets >> 16)
   874  	x[5] = uint8(certificateOctets >> 8)
   875  	x[6] = uint8(certificateOctets)
   876  
   877  	y := x[7:]
   878  	for _, slice := range m.certificates {
   879  		y[0] = uint8(len(slice) >> 16)
   880  		y[1] = uint8(len(slice) >> 8)
   881  		y[2] = uint8(len(slice))
   882  		copy(y[3:], slice)
   883  		y = y[3+len(slice):]
   884  	}
   885  
   886  	m.raw = x
   887  	return
   888  }
   889  
   890  func (m *certificateMsg) unmarshal(data []byte) bool {
   891  	if len(data) < 7 {
   892  		return false
   893  	}
   894  
   895  	m.raw = data
   896  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
   897  	if uint32(len(data)) != certsLen+7 {
   898  		return false
   899  	}
   900  
   901  	numCerts := 0
   902  	d := data[7:]
   903  	for certsLen > 0 {
   904  		if len(d) < 4 {
   905  			return false
   906  		}
   907  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   908  		if uint32(len(d)) < 3+certLen {
   909  			return false
   910  		}
   911  		d = d[3+certLen:]
   912  		certsLen -= 3 + certLen
   913  		numCerts++
   914  	}
   915  
   916  	m.certificates = make([][]byte, numCerts)
   917  	d = data[7:]
   918  	for i := 0; i < numCerts; i++ {
   919  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   920  		m.certificates[i] = d[3 : 3+certLen]
   921  		d = d[3+certLen:]
   922  	}
   923  
   924  	return true
   925  }
   926  
   927  type serverKeyExchangeMsg struct {
   928  	raw []byte
   929  	key []byte
   930  }
   931  
   932  func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
   933  	m1, ok := i.(*serverKeyExchangeMsg)
   934  	if !ok {
   935  		return false
   936  	}
   937  
   938  	return bytes.Equal(m.raw, m1.raw) &&
   939  		bytes.Equal(m.key, m1.key)
   940  }
   941  
   942  func (m *serverKeyExchangeMsg) marshal() []byte {
   943  	if m.raw != nil {
   944  		return m.raw
   945  	}
   946  	length := len(m.key)
   947  	x := make([]byte, length+4)
   948  	x[0] = typeServerKeyExchange
   949  	x[1] = uint8(length >> 16)
   950  	x[2] = uint8(length >> 8)
   951  	x[3] = uint8(length)
   952  	copy(x[4:], m.key)
   953  
   954  	m.raw = x
   955  	return x
   956  }
   957  
   958  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
   959  	m.raw = data
   960  	if len(data) < 4 {
   961  		return false
   962  	}
   963  	m.key = data[4:]
   964  	return true
   965  }
   966  
   967  type certificateStatusMsg struct {
   968  	raw        []byte
   969  	statusType uint8
   970  	response   []byte
   971  }
   972  
   973  func (m *certificateStatusMsg) equal(i interface{}) bool {
   974  	m1, ok := i.(*certificateStatusMsg)
   975  	if !ok {
   976  		return false
   977  	}
   978  
   979  	return bytes.Equal(m.raw, m1.raw) &&
   980  		m.statusType == m1.statusType &&
   981  		bytes.Equal(m.response, m1.response)
   982  }
   983  
   984  func (m *certificateStatusMsg) marshal() []byte {
   985  	if m.raw != nil {
   986  		return m.raw
   987  	}
   988  
   989  	var x []byte
   990  	if m.statusType == statusTypeOCSP {
   991  		x = make([]byte, 4+4+len(m.response))
   992  		x[0] = typeCertificateStatus
   993  		l := len(m.response) + 4
   994  		x[1] = byte(l >> 16)
   995  		x[2] = byte(l >> 8)
   996  		x[3] = byte(l)
   997  		x[4] = statusTypeOCSP
   998  
   999  		l -= 4
  1000  		x[5] = byte(l >> 16)
  1001  		x[6] = byte(l >> 8)
  1002  		x[7] = byte(l)
  1003  		copy(x[8:], m.response)
  1004  	} else {
  1005  		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
  1006  	}
  1007  
  1008  	m.raw = x
  1009  	return x
  1010  }
  1011  
  1012  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1013  	m.raw = data
  1014  	if len(data) < 5 {
  1015  		return false
  1016  	}
  1017  	m.statusType = data[4]
  1018  
  1019  	m.response = nil
  1020  	if m.statusType == statusTypeOCSP {
  1021  		if len(data) < 8 {
  1022  			return false
  1023  		}
  1024  		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
  1025  		if uint32(len(data)) != 4+4+respLen {
  1026  			return false
  1027  		}
  1028  		m.response = data[8:]
  1029  	}
  1030  	return true
  1031  }
  1032  
  1033  type serverHelloDoneMsg struct{}
  1034  
  1035  func (m *serverHelloDoneMsg) equal(i interface{}) bool {
  1036  	_, ok := i.(*serverHelloDoneMsg)
  1037  	return ok
  1038  }
  1039  
  1040  func (m *serverHelloDoneMsg) marshal() []byte {
  1041  	x := make([]byte, 4)
  1042  	x[0] = typeServerHelloDone
  1043  	return x
  1044  }
  1045  
  1046  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1047  	return len(data) == 4
  1048  }
  1049  
  1050  type clientKeyExchangeMsg struct {
  1051  	raw        []byte
  1052  	ciphertext []byte
  1053  }
  1054  
  1055  func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
  1056  	m1, ok := i.(*clientKeyExchangeMsg)
  1057  	if !ok {
  1058  		return false
  1059  	}
  1060  
  1061  	return bytes.Equal(m.raw, m1.raw) &&
  1062  		bytes.Equal(m.ciphertext, m1.ciphertext)
  1063  }
  1064  
  1065  func (m *clientKeyExchangeMsg) marshal() []byte {
  1066  	if m.raw != nil {
  1067  		return m.raw
  1068  	}
  1069  	length := len(m.ciphertext)
  1070  	x := make([]byte, length+4)
  1071  	x[0] = typeClientKeyExchange
  1072  	x[1] = uint8(length >> 16)
  1073  	x[2] = uint8(length >> 8)
  1074  	x[3] = uint8(length)
  1075  	copy(x[4:], m.ciphertext)
  1076  
  1077  	m.raw = x
  1078  	return x
  1079  }
  1080  
  1081  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1082  	m.raw = data
  1083  	if len(data) < 4 {
  1084  		return false
  1085  	}
  1086  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1087  	if l != len(data)-4 {
  1088  		return false
  1089  	}
  1090  	m.ciphertext = data[4:]
  1091  	return true
  1092  }
  1093  
  1094  type finishedMsg struct {
  1095  	raw        []byte
  1096  	verifyData []byte
  1097  }
  1098  
  1099  func (m *finishedMsg) equal(i interface{}) bool {
  1100  	m1, ok := i.(*finishedMsg)
  1101  	if !ok {
  1102  		return false
  1103  	}
  1104  
  1105  	return bytes.Equal(m.raw, m1.raw) &&
  1106  		bytes.Equal(m.verifyData, m1.verifyData)
  1107  }
  1108  
  1109  func (m *finishedMsg) marshal() (x []byte) {
  1110  	if m.raw != nil {
  1111  		return m.raw
  1112  	}
  1113  
  1114  	x = make([]byte, 4+len(m.verifyData))
  1115  	x[0] = typeFinished
  1116  	x[3] = byte(len(m.verifyData))
  1117  	copy(x[4:], m.verifyData)
  1118  	m.raw = x
  1119  	return
  1120  }
  1121  
  1122  func (m *finishedMsg) unmarshal(data []byte) bool {
  1123  	m.raw = data
  1124  	if len(data) < 4 {
  1125  		return false
  1126  	}
  1127  	m.verifyData = data[4:]
  1128  	return true
  1129  }
  1130  
  1131  type nextProtoMsg struct {
  1132  	raw   []byte
  1133  	proto string
  1134  }
  1135  
  1136  func (m *nextProtoMsg) equal(i interface{}) bool {
  1137  	m1, ok := i.(*nextProtoMsg)
  1138  	if !ok {
  1139  		return false
  1140  	}
  1141  
  1142  	return bytes.Equal(m.raw, m1.raw) &&
  1143  		m.proto == m1.proto
  1144  }
  1145  
  1146  func (m *nextProtoMsg) marshal() []byte {
  1147  	if m.raw != nil {
  1148  		return m.raw
  1149  	}
  1150  	l := len(m.proto)
  1151  	if l > 255 {
  1152  		l = 255
  1153  	}
  1154  
  1155  	padding := 32 - (l+2)%32
  1156  	length := l + padding + 2
  1157  	x := make([]byte, length+4)
  1158  	x[0] = typeNextProtocol
  1159  	x[1] = uint8(length >> 16)
  1160  	x[2] = uint8(length >> 8)
  1161  	x[3] = uint8(length)
  1162  
  1163  	y := x[4:]
  1164  	y[0] = byte(l)
  1165  	copy(y[1:], []byte(m.proto[0:l]))
  1166  	y = y[1+l:]
  1167  	y[0] = byte(padding)
  1168  
  1169  	m.raw = x
  1170  
  1171  	return x
  1172  }
  1173  
  1174  func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1175  	m.raw = data
  1176  
  1177  	if len(data) < 5 {
  1178  		return false
  1179  	}
  1180  	data = data[4:]
  1181  	protoLen := int(data[0])
  1182  	data = data[1:]
  1183  	if len(data) < protoLen {
  1184  		return false
  1185  	}
  1186  	m.proto = string(data[0:protoLen])
  1187  	data = data[protoLen:]
  1188  
  1189  	if len(data) < 1 {
  1190  		return false
  1191  	}
  1192  	paddingLen := int(data[0])
  1193  	data = data[1:]
  1194  	if len(data) != paddingLen {
  1195  		return false
  1196  	}
  1197  
  1198  	return true
  1199  }
  1200  
  1201  type certificateRequestMsg struct {
  1202  	raw []byte
  1203  	// hasSignatureAndHash indicates whether this message includes a list
  1204  	// of signature and hash functions. This change was introduced with TLS
  1205  	// 1.2.
  1206  	hasSignatureAndHash bool
  1207  
  1208  	certificateTypes       []byte
  1209  	supportedSignatureAlgorithms []SignatureScheme
  1210  	certificateAuthorities [][]byte
  1211  }
  1212  
  1213  func (m *certificateRequestMsg) equal(i interface{}) bool {
  1214  	m1, ok := i.(*certificateRequestMsg)
  1215  	if !ok {
  1216  		return false
  1217  	}
  1218  
  1219  	return bytes.Equal(m.raw, m1.raw) &&
  1220  		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1221  		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1222  		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
  1223  }
  1224  
  1225  func (m *certificateRequestMsg) marshal() (x []byte) {
  1226  	if m.raw != nil {
  1227  		return m.raw
  1228  	}
  1229  
  1230  	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1231  	length := 1 + len(m.certificateTypes) + 2
  1232  	casLength := 0
  1233  	for _, ca := range m.certificateAuthorities {
  1234  		casLength += 2 + len(ca)
  1235  	}
  1236  	length += casLength
  1237  
  1238  	if m.hasSignatureAndHash {
  1239  		length += 2 +2*len(m.supportedSignatureAlgorithms)
  1240  	}
  1241  
  1242  	x = make([]byte, 4+length)
  1243  	x[0] = typeCertificateRequest
  1244  	x[1] = uint8(length >> 16)
  1245  	x[2] = uint8(length >> 8)
  1246  	x[3] = uint8(length)
  1247  
  1248  	x[4] = uint8(len(m.certificateTypes))
  1249  
  1250  	copy(x[5:], m.certificateTypes)
  1251  	y := x[5+len(m.certificateTypes):]
  1252  
  1253  	if m.hasSignatureAndHash {
  1254  		n := len(m.supportedSignatureAlgorithms) * 2
  1255  		y[0] = uint8(n >> 8)
  1256  		y[1] = uint8(n)
  1257  		y = y[2:]
  1258  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1259  			y[0] = uint8(sigAlgo >> 8)
  1260  			y[1] = uint8(sigAlgo)
  1261  			y = y[2:]
  1262  		}
  1263  	}
  1264  
  1265  	y[0] = uint8(casLength >> 8)
  1266  	y[1] = uint8(casLength)
  1267  	y = y[2:]
  1268  	for _, ca := range m.certificateAuthorities {
  1269  		y[0] = uint8(len(ca) >> 8)
  1270  		y[1] = uint8(len(ca))
  1271  		y = y[2:]
  1272  		copy(y, ca)
  1273  		y = y[len(ca):]
  1274  	}
  1275  
  1276  	m.raw = x
  1277  	return
  1278  }
  1279  
  1280  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1281  	m.raw = data
  1282  
  1283  	if len(data) < 5 {
  1284  		return false
  1285  	}
  1286  
  1287  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1288  	if uint32(len(data))-4 != length {
  1289  		return false
  1290  	}
  1291  
  1292  	numCertTypes := int(data[4])
  1293  	data = data[5:]
  1294  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1295  		return false
  1296  	}
  1297  
  1298  	m.certificateTypes = make([]byte, numCertTypes)
  1299  	if copy(m.certificateTypes, data) != numCertTypes {
  1300  		return false
  1301  	}
  1302  
  1303  	data = data[numCertTypes:]
  1304  
  1305  	if m.hasSignatureAndHash {
  1306  		if len(data) < 2 {
  1307  			return false
  1308  		}
  1309  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1310  		data = data[2:]
  1311  		if sigAndHashLen&1 != 0 {
  1312  			return false
  1313  		}
  1314  		if len(data) < int(sigAndHashLen) {
  1315  			return false
  1316  		}
  1317  		numSigAlgos := sigAndHashLen / 2
  1318  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1319  		for i := range m.supportedSignatureAlgorithms {
  1320  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1321  			data = data[2:]
  1322  		}
  1323  	}
  1324  
  1325  	if len(data) < 2 {
  1326  		return false
  1327  	}
  1328  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1329  	data = data[2:]
  1330  	if len(data) < int(casLength) {
  1331  		return false
  1332  	}
  1333  	cas := make([]byte, casLength)
  1334  	copy(cas, data)
  1335  	data = data[casLength:]
  1336  
  1337  	m.certificateAuthorities = nil
  1338  	for len(cas) > 0 {
  1339  		if len(cas) < 2 {
  1340  			return false
  1341  		}
  1342  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1343  		cas = cas[2:]
  1344  
  1345  		if len(cas) < int(caLen) {
  1346  			return false
  1347  		}
  1348  
  1349  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1350  		cas = cas[caLen:]
  1351  	}
  1352  
  1353  	return len(data) == 0
  1354  }
  1355  
  1356  type certificateVerifyMsg struct {
  1357  	raw                 []byte
  1358  	hasSignatureAndHash bool
  1359  	signatureAlgorithm    SignatureScheme
  1360  	signature           []byte
  1361  }
  1362  
  1363  func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1364  	m1, ok := i.(*certificateVerifyMsg)
  1365  	if !ok {
  1366  		return false
  1367  	}
  1368  
  1369  	return bytes.Equal(m.raw, m1.raw) &&
  1370  		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1371  		m.signatureAlgorithm == m1.signatureAlgorithm &&
  1372  		bytes.Equal(m.signature, m1.signature)
  1373  }
  1374  
  1375  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1376  	if m.raw != nil {
  1377  		return m.raw
  1378  	}
  1379  
  1380  	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1381  	siglength := len(m.signature)
  1382  	length := 2 + siglength
  1383  	if m.hasSignatureAndHash {
  1384  		length += 2
  1385  	}
  1386  	x = make([]byte, 4+length)
  1387  	x[0] = typeCertificateVerify
  1388  	x[1] = uint8(length >> 16)
  1389  	x[2] = uint8(length >> 8)
  1390  	x[3] = uint8(length)
  1391  	y := x[4:]
  1392  	if m.hasSignatureAndHash {
  1393  		y[0] = uint8(m.signatureAlgorithm >> 8)
  1394  		y[1] = uint8(m.signatureAlgorithm)
  1395  		y = y[2:]
  1396  	}
  1397  	y[0] = uint8(siglength >> 8)
  1398  	y[1] = uint8(siglength)
  1399  	copy(y[2:], m.signature)
  1400  
  1401  	m.raw = x
  1402  
  1403  	return
  1404  }
  1405  
  1406  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1407  	m.raw = data
  1408  
  1409  	if len(data) < 6 {
  1410  		return false
  1411  	}
  1412  
  1413  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1414  	if uint32(len(data))-4 != length {
  1415  		return false
  1416  	}
  1417  
  1418  	data = data[4:]
  1419  	if m.hasSignatureAndHash {
  1420  		m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1421  		data = data[2:]
  1422  	}
  1423  
  1424  	if len(data) < 2 {
  1425  		return false
  1426  	}
  1427  	siglength := int(data[0])<<8 + int(data[1])
  1428  	data = data[2:]
  1429  	if len(data) != siglength {
  1430  		return false
  1431  	}
  1432  
  1433  	m.signature = data
  1434  
  1435  	return true
  1436  }
  1437  
  1438  type newSessionTicketMsg struct {
  1439  	raw    []byte
  1440  	ticket []byte
  1441  }
  1442  
  1443  func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1444  	m1, ok := i.(*newSessionTicketMsg)
  1445  	if !ok {
  1446  		return false
  1447  	}
  1448  
  1449  	return bytes.Equal(m.raw, m1.raw) &&
  1450  		bytes.Equal(m.ticket, m1.ticket)
  1451  }
  1452  
  1453  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1454  	if m.raw != nil {
  1455  		return m.raw
  1456  	}
  1457  
  1458  	// See http://tools.ietf.org/html/rfc5077#section-3.3
  1459  	ticketLen := len(m.ticket)
  1460  	length := 2 + 4 + ticketLen
  1461  	x = make([]byte, 4+length)
  1462  	x[0] = typeNewSessionTicket
  1463  	x[1] = uint8(length >> 16)
  1464  	x[2] = uint8(length >> 8)
  1465  	x[3] = uint8(length)
  1466  	x[8] = uint8(ticketLen >> 8)
  1467  	x[9] = uint8(ticketLen)
  1468  	copy(x[10:], m.ticket)
  1469  
  1470  	m.raw = x
  1471  
  1472  	return
  1473  }
  1474  
  1475  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1476  	m.raw = data
  1477  
  1478  	if len(data) < 10 {
  1479  		return false
  1480  	}
  1481  
  1482  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1483  	if uint32(len(data))-4 != length {
  1484  		return false
  1485  	}
  1486  
  1487  	ticketLen := int(data[8])<<8 + int(data[9])
  1488  	if len(data)-10 != ticketLen {
  1489  		return false
  1490  	}
  1491  
  1492  	m.ticket = data[10:]
  1493  
  1494  	return true
  1495  }
  1496  
  1497  type helloRequestMsg struct {
  1498  }
  1499  
  1500  func (*helloRequestMsg) marshal() []byte {
  1501  	return []byte{typeHelloRequest, 0, 0, 0}
  1502  }
  1503  
  1504  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1505  	return len(data) == 4
  1506  }
  1507  
  1508  func eqUint16s(x, y []uint16) bool {
  1509  	if len(x) != len(y) {
  1510  		return false
  1511  	}
  1512  	for i, v := range x {
  1513  		if y[i] != v {
  1514  			return false
  1515  		}
  1516  	}
  1517  	return true
  1518  }
  1519  
  1520  func eqCurveIDs(x, y []CurveID) bool {
  1521  	if len(x) != len(y) {
  1522  		return false
  1523  	}
  1524  	for i, v := range x {
  1525  		if y[i] != v {
  1526  			return false
  1527  		}
  1528  	}
  1529  	return true
  1530  }
  1531  
  1532  func eqStrings(x, y []string) bool {
  1533  	if len(x) != len(y) {
  1534  		return false
  1535  	}
  1536  	for i, v := range x {
  1537  		if y[i] != v {
  1538  			return false
  1539  		}
  1540  	}
  1541  	return true
  1542  }
  1543  
  1544  func eqByteSlices(x, y [][]byte) bool {
  1545  	if len(x) != len(y) {
  1546  		return false
  1547  	}
  1548  	for i, v := range x {
  1549  		if !bytes.Equal(v, y[i]) {
  1550  			return false
  1551  		}
  1552  	}
  1553  	return true
  1554  }
  1555  
  1556  func  eqSignatureAlgorithms(x, y []SignatureScheme) bool  {
  1557  	if len(x) != len(y) {
  1558  		return false
  1559  	}
  1560  	for i, v := range x {
  1561  		if v != y[i] {
  1562  			return false
  1563  		}
  1564  	}
  1565  	return true
  1566  }