github.com/rohankumardubey/syslog-redirector-golang@v0.0.0-20140320174030-4859f03d829a/src/pkg/crypto/tls/handshake_messages.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import "bytes"
     8  
     9  type clientHelloMsg struct {
    10  	raw                []byte
    11  	vers               uint16
    12  	random             []byte
    13  	sessionId          []byte
    14  	cipherSuites       []uint16
    15  	compressionMethods []uint8
    16  	nextProtoNeg       bool
    17  	serverName         string
    18  	ocspStapling       bool
    19  	supportedCurves    []uint16
    20  	supportedPoints    []uint8
    21  	ticketSupported    bool
    22  	sessionTicket      []uint8
    23  	signatureAndHashes []signatureAndHash
    24  }
    25  
    26  func (m *clientHelloMsg) equal(i interface{}) bool {
    27  	m1, ok := i.(*clientHelloMsg)
    28  	if !ok {
    29  		return false
    30  	}
    31  
    32  	return bytes.Equal(m.raw, m1.raw) &&
    33  		m.vers == m1.vers &&
    34  		bytes.Equal(m.random, m1.random) &&
    35  		bytes.Equal(m.sessionId, m1.sessionId) &&
    36  		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    37  		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    38  		m.nextProtoNeg == m1.nextProtoNeg &&
    39  		m.serverName == m1.serverName &&
    40  		m.ocspStapling == m1.ocspStapling &&
    41  		eqUint16s(m.supportedCurves, m1.supportedCurves) &&
    42  		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    43  		m.ticketSupported == m1.ticketSupported &&
    44  		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    45  		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
    46  }
    47  
    48  func (m *clientHelloMsg) marshal() []byte {
    49  	if m.raw != nil {
    50  		return m.raw
    51  	}
    52  
    53  	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    54  	numExtensions := 0
    55  	extensionsLength := 0
    56  	if m.nextProtoNeg {
    57  		numExtensions++
    58  	}
    59  	if m.ocspStapling {
    60  		extensionsLength += 1 + 2 + 2
    61  		numExtensions++
    62  	}
    63  	if len(m.serverName) > 0 {
    64  		extensionsLength += 5 + len(m.serverName)
    65  		numExtensions++
    66  	}
    67  	if len(m.supportedCurves) > 0 {
    68  		extensionsLength += 2 + 2*len(m.supportedCurves)
    69  		numExtensions++
    70  	}
    71  	if len(m.supportedPoints) > 0 {
    72  		extensionsLength += 1 + len(m.supportedPoints)
    73  		numExtensions++
    74  	}
    75  	if m.ticketSupported {
    76  		extensionsLength += len(m.sessionTicket)
    77  		numExtensions++
    78  	}
    79  	if len(m.signatureAndHashes) > 0 {
    80  		extensionsLength += 2 + 2*len(m.signatureAndHashes)
    81  		numExtensions++
    82  	}
    83  	if numExtensions > 0 {
    84  		extensionsLength += 4 * numExtensions
    85  		length += 2 + extensionsLength
    86  	}
    87  
    88  	x := make([]byte, 4+length)
    89  	x[0] = typeClientHello
    90  	x[1] = uint8(length >> 16)
    91  	x[2] = uint8(length >> 8)
    92  	x[3] = uint8(length)
    93  	x[4] = uint8(m.vers >> 8)
    94  	x[5] = uint8(m.vers)
    95  	copy(x[6:38], m.random)
    96  	x[38] = uint8(len(m.sessionId))
    97  	copy(x[39:39+len(m.sessionId)], m.sessionId)
    98  	y := x[39+len(m.sessionId):]
    99  	y[0] = uint8(len(m.cipherSuites) >> 7)
   100  	y[1] = uint8(len(m.cipherSuites) << 1)
   101  	for i, suite := range m.cipherSuites {
   102  		y[2+i*2] = uint8(suite >> 8)
   103  		y[3+i*2] = uint8(suite)
   104  	}
   105  	z := y[2+len(m.cipherSuites)*2:]
   106  	z[0] = uint8(len(m.compressionMethods))
   107  	copy(z[1:], m.compressionMethods)
   108  
   109  	z = z[1+len(m.compressionMethods):]
   110  	if numExtensions > 0 {
   111  		z[0] = byte(extensionsLength >> 8)
   112  		z[1] = byte(extensionsLength)
   113  		z = z[2:]
   114  	}
   115  	if m.nextProtoNeg {
   116  		z[0] = byte(extensionNextProtoNeg >> 8)
   117  		z[1] = byte(extensionNextProtoNeg)
   118  		// The length is always 0
   119  		z = z[4:]
   120  	}
   121  	if len(m.serverName) > 0 {
   122  		z[0] = byte(extensionServerName >> 8)
   123  		z[1] = byte(extensionServerName)
   124  		l := len(m.serverName) + 5
   125  		z[2] = byte(l >> 8)
   126  		z[3] = byte(l)
   127  		z = z[4:]
   128  
   129  		// RFC 3546, section 3.1
   130  		//
   131  		// struct {
   132  		//     NameType name_type;
   133  		//     select (name_type) {
   134  		//         case host_name: HostName;
   135  		//     } name;
   136  		// } ServerName;
   137  		//
   138  		// enum {
   139  		//     host_name(0), (255)
   140  		// } NameType;
   141  		//
   142  		// opaque HostName<1..2^16-1>;
   143  		//
   144  		// struct {
   145  		//     ServerName server_name_list<1..2^16-1>
   146  		// } ServerNameList;
   147  
   148  		z[0] = byte((len(m.serverName) + 3) >> 8)
   149  		z[1] = byte(len(m.serverName) + 3)
   150  		z[3] = byte(len(m.serverName) >> 8)
   151  		z[4] = byte(len(m.serverName))
   152  		copy(z[5:], []byte(m.serverName))
   153  		z = z[l:]
   154  	}
   155  	if m.ocspStapling {
   156  		// RFC 4366, section 3.6
   157  		z[0] = byte(extensionStatusRequest >> 8)
   158  		z[1] = byte(extensionStatusRequest)
   159  		z[2] = 0
   160  		z[3] = 5
   161  		z[4] = 1 // OCSP type
   162  		// Two zero valued uint16s for the two lengths.
   163  		z = z[9:]
   164  	}
   165  	if len(m.supportedCurves) > 0 {
   166  		// http://tools.ietf.org/html/rfc4492#section-5.5.1
   167  		z[0] = byte(extensionSupportedCurves >> 8)
   168  		z[1] = byte(extensionSupportedCurves)
   169  		l := 2 + 2*len(m.supportedCurves)
   170  		z[2] = byte(l >> 8)
   171  		z[3] = byte(l)
   172  		l -= 2
   173  		z[4] = byte(l >> 8)
   174  		z[5] = byte(l)
   175  		z = z[6:]
   176  		for _, curve := range m.supportedCurves {
   177  			z[0] = byte(curve >> 8)
   178  			z[1] = byte(curve)
   179  			z = z[2:]
   180  		}
   181  	}
   182  	if len(m.supportedPoints) > 0 {
   183  		// http://tools.ietf.org/html/rfc4492#section-5.5.2
   184  		z[0] = byte(extensionSupportedPoints >> 8)
   185  		z[1] = byte(extensionSupportedPoints)
   186  		l := 1 + len(m.supportedPoints)
   187  		z[2] = byte(l >> 8)
   188  		z[3] = byte(l)
   189  		l--
   190  		z[4] = byte(l)
   191  		z = z[5:]
   192  		for _, pointFormat := range m.supportedPoints {
   193  			z[0] = byte(pointFormat)
   194  			z = z[1:]
   195  		}
   196  	}
   197  	if m.ticketSupported {
   198  		// http://tools.ietf.org/html/rfc5077#section-3.2
   199  		z[0] = byte(extensionSessionTicket >> 8)
   200  		z[1] = byte(extensionSessionTicket)
   201  		l := len(m.sessionTicket)
   202  		z[2] = byte(l >> 8)
   203  		z[3] = byte(l)
   204  		z = z[4:]
   205  		copy(z, m.sessionTicket)
   206  		z = z[len(m.sessionTicket):]
   207  	}
   208  	if len(m.signatureAndHashes) > 0 {
   209  		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   210  		z[0] = byte(extensionSignatureAlgorithms >> 8)
   211  		z[1] = byte(extensionSignatureAlgorithms)
   212  		l := 2 + 2*len(m.signatureAndHashes)
   213  		z[2] = byte(l >> 8)
   214  		z[3] = byte(l)
   215  		z = z[4:]
   216  
   217  		l -= 2
   218  		z[0] = byte(l >> 8)
   219  		z[1] = byte(l)
   220  		z = z[2:]
   221  		for _, sigAndHash := range m.signatureAndHashes {
   222  			z[0] = sigAndHash.hash
   223  			z[1] = sigAndHash.signature
   224  			z = z[2:]
   225  		}
   226  	}
   227  
   228  	m.raw = x
   229  
   230  	return x
   231  }
   232  
   233  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   234  	if len(data) < 42 {
   235  		return false
   236  	}
   237  	m.raw = data
   238  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   239  	m.random = data[6:38]
   240  	sessionIdLen := int(data[38])
   241  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   242  		return false
   243  	}
   244  	m.sessionId = data[39 : 39+sessionIdLen]
   245  	data = data[39+sessionIdLen:]
   246  	if len(data) < 2 {
   247  		return false
   248  	}
   249  	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   250  	// they are uint16s, the number must be even.
   251  	cipherSuiteLen := int(data[0])<<8 | int(data[1])
   252  	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   253  		return false
   254  	}
   255  	numCipherSuites := cipherSuiteLen / 2
   256  	m.cipherSuites = make([]uint16, numCipherSuites)
   257  	for i := 0; i < numCipherSuites; i++ {
   258  		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   259  	}
   260  	data = data[2+cipherSuiteLen:]
   261  	if len(data) < 1 {
   262  		return false
   263  	}
   264  	compressionMethodsLen := int(data[0])
   265  	if len(data) < 1+compressionMethodsLen {
   266  		return false
   267  	}
   268  	m.compressionMethods = data[1 : 1+compressionMethodsLen]
   269  
   270  	data = data[1+compressionMethodsLen:]
   271  
   272  	m.nextProtoNeg = false
   273  	m.serverName = ""
   274  	m.ocspStapling = false
   275  	m.ticketSupported = false
   276  	m.sessionTicket = nil
   277  	m.signatureAndHashes = nil
   278  
   279  	if len(data) == 0 {
   280  		// ClientHello is optionally followed by extension data
   281  		return true
   282  	}
   283  	if len(data) < 2 {
   284  		return false
   285  	}
   286  
   287  	extensionsLength := int(data[0])<<8 | int(data[1])
   288  	data = data[2:]
   289  	if extensionsLength != len(data) {
   290  		return false
   291  	}
   292  
   293  	for len(data) != 0 {
   294  		if len(data) < 4 {
   295  			return false
   296  		}
   297  		extension := uint16(data[0])<<8 | uint16(data[1])
   298  		length := int(data[2])<<8 | int(data[3])
   299  		data = data[4:]
   300  		if len(data) < length {
   301  			return false
   302  		}
   303  
   304  		switch extension {
   305  		case extensionServerName:
   306  			if length < 2 {
   307  				return false
   308  			}
   309  			numNames := int(data[0])<<8 | int(data[1])
   310  			d := data[2:]
   311  			for i := 0; i < numNames; i++ {
   312  				if len(d) < 3 {
   313  					return false
   314  				}
   315  				nameType := d[0]
   316  				nameLen := int(d[1])<<8 | int(d[2])
   317  				d = d[3:]
   318  				if len(d) < nameLen {
   319  					return false
   320  				}
   321  				if nameType == 0 {
   322  					m.serverName = string(d[0:nameLen])
   323  					break
   324  				}
   325  				d = d[nameLen:]
   326  			}
   327  		case extensionNextProtoNeg:
   328  			if length > 0 {
   329  				return false
   330  			}
   331  			m.nextProtoNeg = true
   332  		case extensionStatusRequest:
   333  			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   334  		case extensionSupportedCurves:
   335  			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   336  			if length < 2 {
   337  				return false
   338  			}
   339  			l := int(data[0])<<8 | int(data[1])
   340  			if l%2 == 1 || length != l+2 {
   341  				return false
   342  			}
   343  			numCurves := l / 2
   344  			m.supportedCurves = make([]uint16, numCurves)
   345  			d := data[2:]
   346  			for i := 0; i < numCurves; i++ {
   347  				m.supportedCurves[i] = uint16(d[0])<<8 | uint16(d[1])
   348  				d = d[2:]
   349  			}
   350  		case extensionSupportedPoints:
   351  			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   352  			if length < 1 {
   353  				return false
   354  			}
   355  			l := int(data[0])
   356  			if length != l+1 {
   357  				return false
   358  			}
   359  			m.supportedPoints = make([]uint8, l)
   360  			copy(m.supportedPoints, data[1:])
   361  		case extensionSessionTicket:
   362  			// http://tools.ietf.org/html/rfc5077#section-3.2
   363  			m.ticketSupported = true
   364  			m.sessionTicket = data[:length]
   365  		case extensionSignatureAlgorithms:
   366  			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   367  			if length < 2 || length&1 != 0 {
   368  				return false
   369  			}
   370  			l := int(data[0])<<8 | int(data[1])
   371  			if l != length-2 {
   372  				return false
   373  			}
   374  			n := l / 2
   375  			d := data[2:]
   376  			m.signatureAndHashes = make([]signatureAndHash, n)
   377  			for i := range m.signatureAndHashes {
   378  				m.signatureAndHashes[i].hash = d[0]
   379  				m.signatureAndHashes[i].signature = d[1]
   380  				d = d[2:]
   381  			}
   382  		}
   383  		data = data[length:]
   384  	}
   385  
   386  	return true
   387  }
   388  
   389  type serverHelloMsg struct {
   390  	raw               []byte
   391  	vers              uint16
   392  	random            []byte
   393  	sessionId         []byte
   394  	cipherSuite       uint16
   395  	compressionMethod uint8
   396  	nextProtoNeg      bool
   397  	nextProtos        []string
   398  	ocspStapling      bool
   399  	ticketSupported   bool
   400  }
   401  
   402  func (m *serverHelloMsg) equal(i interface{}) bool {
   403  	m1, ok := i.(*serverHelloMsg)
   404  	if !ok {
   405  		return false
   406  	}
   407  
   408  	return bytes.Equal(m.raw, m1.raw) &&
   409  		m.vers == m1.vers &&
   410  		bytes.Equal(m.random, m1.random) &&
   411  		bytes.Equal(m.sessionId, m1.sessionId) &&
   412  		m.cipherSuite == m1.cipherSuite &&
   413  		m.compressionMethod == m1.compressionMethod &&
   414  		m.nextProtoNeg == m1.nextProtoNeg &&
   415  		eqStrings(m.nextProtos, m1.nextProtos) &&
   416  		m.ocspStapling == m1.ocspStapling &&
   417  		m.ticketSupported == m1.ticketSupported
   418  }
   419  
   420  func (m *serverHelloMsg) marshal() []byte {
   421  	if m.raw != nil {
   422  		return m.raw
   423  	}
   424  
   425  	length := 38 + len(m.sessionId)
   426  	numExtensions := 0
   427  	extensionsLength := 0
   428  
   429  	nextProtoLen := 0
   430  	if m.nextProtoNeg {
   431  		numExtensions++
   432  		for _, v := range m.nextProtos {
   433  			nextProtoLen += len(v)
   434  		}
   435  		nextProtoLen += len(m.nextProtos)
   436  		extensionsLength += nextProtoLen
   437  	}
   438  	if m.ocspStapling {
   439  		numExtensions++
   440  	}
   441  	if m.ticketSupported {
   442  		numExtensions++
   443  	}
   444  	if numExtensions > 0 {
   445  		extensionsLength += 4 * numExtensions
   446  		length += 2 + extensionsLength
   447  	}
   448  
   449  	x := make([]byte, 4+length)
   450  	x[0] = typeServerHello
   451  	x[1] = uint8(length >> 16)
   452  	x[2] = uint8(length >> 8)
   453  	x[3] = uint8(length)
   454  	x[4] = uint8(m.vers >> 8)
   455  	x[5] = uint8(m.vers)
   456  	copy(x[6:38], m.random)
   457  	x[38] = uint8(len(m.sessionId))
   458  	copy(x[39:39+len(m.sessionId)], m.sessionId)
   459  	z := x[39+len(m.sessionId):]
   460  	z[0] = uint8(m.cipherSuite >> 8)
   461  	z[1] = uint8(m.cipherSuite)
   462  	z[2] = uint8(m.compressionMethod)
   463  
   464  	z = z[3:]
   465  	if numExtensions > 0 {
   466  		z[0] = byte(extensionsLength >> 8)
   467  		z[1] = byte(extensionsLength)
   468  		z = z[2:]
   469  	}
   470  	if m.nextProtoNeg {
   471  		z[0] = byte(extensionNextProtoNeg >> 8)
   472  		z[1] = byte(extensionNextProtoNeg)
   473  		z[2] = byte(nextProtoLen >> 8)
   474  		z[3] = byte(nextProtoLen)
   475  		z = z[4:]
   476  
   477  		for _, v := range m.nextProtos {
   478  			l := len(v)
   479  			if l > 255 {
   480  				l = 255
   481  			}
   482  			z[0] = byte(l)
   483  			copy(z[1:], []byte(v[0:l]))
   484  			z = z[1+l:]
   485  		}
   486  	}
   487  	if m.ocspStapling {
   488  		z[0] = byte(extensionStatusRequest >> 8)
   489  		z[1] = byte(extensionStatusRequest)
   490  		z = z[4:]
   491  	}
   492  	if m.ticketSupported {
   493  		z[0] = byte(extensionSessionTicket >> 8)
   494  		z[1] = byte(extensionSessionTicket)
   495  		z = z[4:]
   496  	}
   497  
   498  	m.raw = x
   499  
   500  	return x
   501  }
   502  
   503  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   504  	if len(data) < 42 {
   505  		return false
   506  	}
   507  	m.raw = data
   508  	m.vers = uint16(data[4])<<8 | uint16(data[5])
   509  	m.random = data[6:38]
   510  	sessionIdLen := int(data[38])
   511  	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   512  		return false
   513  	}
   514  	m.sessionId = data[39 : 39+sessionIdLen]
   515  	data = data[39+sessionIdLen:]
   516  	if len(data) < 3 {
   517  		return false
   518  	}
   519  	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   520  	m.compressionMethod = data[2]
   521  	data = data[3:]
   522  
   523  	m.nextProtoNeg = false
   524  	m.nextProtos = nil
   525  	m.ocspStapling = false
   526  	m.ticketSupported = false
   527  
   528  	if len(data) == 0 {
   529  		// ServerHello is optionally followed by extension data
   530  		return true
   531  	}
   532  	if len(data) < 2 {
   533  		return false
   534  	}
   535  
   536  	extensionsLength := int(data[0])<<8 | int(data[1])
   537  	data = data[2:]
   538  	if len(data) != extensionsLength {
   539  		return false
   540  	}
   541  
   542  	for len(data) != 0 {
   543  		if len(data) < 4 {
   544  			return false
   545  		}
   546  		extension := uint16(data[0])<<8 | uint16(data[1])
   547  		length := int(data[2])<<8 | int(data[3])
   548  		data = data[4:]
   549  		if len(data) < length {
   550  			return false
   551  		}
   552  
   553  		switch extension {
   554  		case extensionNextProtoNeg:
   555  			m.nextProtoNeg = true
   556  			d := data[:length]
   557  			for len(d) > 0 {
   558  				l := int(d[0])
   559  				d = d[1:]
   560  				if l == 0 || l > len(d) {
   561  					return false
   562  				}
   563  				m.nextProtos = append(m.nextProtos, string(d[:l]))
   564  				d = d[l:]
   565  			}
   566  		case extensionStatusRequest:
   567  			if length > 0 {
   568  				return false
   569  			}
   570  			m.ocspStapling = true
   571  		case extensionSessionTicket:
   572  			if length > 0 {
   573  				return false
   574  			}
   575  			m.ticketSupported = true
   576  		}
   577  		data = data[length:]
   578  	}
   579  
   580  	return true
   581  }
   582  
   583  type certificateMsg struct {
   584  	raw          []byte
   585  	certificates [][]byte
   586  }
   587  
   588  func (m *certificateMsg) equal(i interface{}) bool {
   589  	m1, ok := i.(*certificateMsg)
   590  	if !ok {
   591  		return false
   592  	}
   593  
   594  	return bytes.Equal(m.raw, m1.raw) &&
   595  		eqByteSlices(m.certificates, m1.certificates)
   596  }
   597  
   598  func (m *certificateMsg) marshal() (x []byte) {
   599  	if m.raw != nil {
   600  		return m.raw
   601  	}
   602  
   603  	var i int
   604  	for _, slice := range m.certificates {
   605  		i += len(slice)
   606  	}
   607  
   608  	length := 3 + 3*len(m.certificates) + i
   609  	x = make([]byte, 4+length)
   610  	x[0] = typeCertificate
   611  	x[1] = uint8(length >> 16)
   612  	x[2] = uint8(length >> 8)
   613  	x[3] = uint8(length)
   614  
   615  	certificateOctets := length - 3
   616  	x[4] = uint8(certificateOctets >> 16)
   617  	x[5] = uint8(certificateOctets >> 8)
   618  	x[6] = uint8(certificateOctets)
   619  
   620  	y := x[7:]
   621  	for _, slice := range m.certificates {
   622  		y[0] = uint8(len(slice) >> 16)
   623  		y[1] = uint8(len(slice) >> 8)
   624  		y[2] = uint8(len(slice))
   625  		copy(y[3:], slice)
   626  		y = y[3+len(slice):]
   627  	}
   628  
   629  	m.raw = x
   630  	return
   631  }
   632  
   633  func (m *certificateMsg) unmarshal(data []byte) bool {
   634  	if len(data) < 7 {
   635  		return false
   636  	}
   637  
   638  	m.raw = data
   639  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
   640  	if uint32(len(data)) != certsLen+7 {
   641  		return false
   642  	}
   643  
   644  	numCerts := 0
   645  	d := data[7:]
   646  	for certsLen > 0 {
   647  		if len(d) < 4 {
   648  			return false
   649  		}
   650  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   651  		if uint32(len(d)) < 3+certLen {
   652  			return false
   653  		}
   654  		d = d[3+certLen:]
   655  		certsLen -= 3 + certLen
   656  		numCerts++
   657  	}
   658  
   659  	m.certificates = make([][]byte, numCerts)
   660  	d = data[7:]
   661  	for i := 0; i < numCerts; i++ {
   662  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   663  		m.certificates[i] = d[3 : 3+certLen]
   664  		d = d[3+certLen:]
   665  	}
   666  
   667  	return true
   668  }
   669  
   670  type serverKeyExchangeMsg struct {
   671  	raw []byte
   672  	key []byte
   673  }
   674  
   675  func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
   676  	m1, ok := i.(*serverKeyExchangeMsg)
   677  	if !ok {
   678  		return false
   679  	}
   680  
   681  	return bytes.Equal(m.raw, m1.raw) &&
   682  		bytes.Equal(m.key, m1.key)
   683  }
   684  
   685  func (m *serverKeyExchangeMsg) marshal() []byte {
   686  	if m.raw != nil {
   687  		return m.raw
   688  	}
   689  	length := len(m.key)
   690  	x := make([]byte, length+4)
   691  	x[0] = typeServerKeyExchange
   692  	x[1] = uint8(length >> 16)
   693  	x[2] = uint8(length >> 8)
   694  	x[3] = uint8(length)
   695  	copy(x[4:], m.key)
   696  
   697  	m.raw = x
   698  	return x
   699  }
   700  
   701  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
   702  	m.raw = data
   703  	if len(data) < 4 {
   704  		return false
   705  	}
   706  	m.key = data[4:]
   707  	return true
   708  }
   709  
   710  type certificateStatusMsg struct {
   711  	raw        []byte
   712  	statusType uint8
   713  	response   []byte
   714  }
   715  
   716  func (m *certificateStatusMsg) equal(i interface{}) bool {
   717  	m1, ok := i.(*certificateStatusMsg)
   718  	if !ok {
   719  		return false
   720  	}
   721  
   722  	return bytes.Equal(m.raw, m1.raw) &&
   723  		m.statusType == m1.statusType &&
   724  		bytes.Equal(m.response, m1.response)
   725  }
   726  
   727  func (m *certificateStatusMsg) marshal() []byte {
   728  	if m.raw != nil {
   729  		return m.raw
   730  	}
   731  
   732  	var x []byte
   733  	if m.statusType == statusTypeOCSP {
   734  		x = make([]byte, 4+4+len(m.response))
   735  		x[0] = typeCertificateStatus
   736  		l := len(m.response) + 4
   737  		x[1] = byte(l >> 16)
   738  		x[2] = byte(l >> 8)
   739  		x[3] = byte(l)
   740  		x[4] = statusTypeOCSP
   741  
   742  		l -= 4
   743  		x[5] = byte(l >> 16)
   744  		x[6] = byte(l >> 8)
   745  		x[7] = byte(l)
   746  		copy(x[8:], m.response)
   747  	} else {
   748  		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
   749  	}
   750  
   751  	m.raw = x
   752  	return x
   753  }
   754  
   755  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
   756  	m.raw = data
   757  	if len(data) < 5 {
   758  		return false
   759  	}
   760  	m.statusType = data[4]
   761  
   762  	m.response = nil
   763  	if m.statusType == statusTypeOCSP {
   764  		if len(data) < 8 {
   765  			return false
   766  		}
   767  		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
   768  		if uint32(len(data)) != 4+4+respLen {
   769  			return false
   770  		}
   771  		m.response = data[8:]
   772  	}
   773  	return true
   774  }
   775  
   776  type serverHelloDoneMsg struct{}
   777  
   778  func (m *serverHelloDoneMsg) equal(i interface{}) bool {
   779  	_, ok := i.(*serverHelloDoneMsg)
   780  	return ok
   781  }
   782  
   783  func (m *serverHelloDoneMsg) marshal() []byte {
   784  	x := make([]byte, 4)
   785  	x[0] = typeServerHelloDone
   786  	return x
   787  }
   788  
   789  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
   790  	return len(data) == 4
   791  }
   792  
   793  type clientKeyExchangeMsg struct {
   794  	raw        []byte
   795  	ciphertext []byte
   796  }
   797  
   798  func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
   799  	m1, ok := i.(*clientKeyExchangeMsg)
   800  	if !ok {
   801  		return false
   802  	}
   803  
   804  	return bytes.Equal(m.raw, m1.raw) &&
   805  		bytes.Equal(m.ciphertext, m1.ciphertext)
   806  }
   807  
   808  func (m *clientKeyExchangeMsg) marshal() []byte {
   809  	if m.raw != nil {
   810  		return m.raw
   811  	}
   812  	length := len(m.ciphertext)
   813  	x := make([]byte, length+4)
   814  	x[0] = typeClientKeyExchange
   815  	x[1] = uint8(length >> 16)
   816  	x[2] = uint8(length >> 8)
   817  	x[3] = uint8(length)
   818  	copy(x[4:], m.ciphertext)
   819  
   820  	m.raw = x
   821  	return x
   822  }
   823  
   824  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
   825  	m.raw = data
   826  	if len(data) < 4 {
   827  		return false
   828  	}
   829  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   830  	if l != len(data)-4 {
   831  		return false
   832  	}
   833  	m.ciphertext = data[4:]
   834  	return true
   835  }
   836  
   837  type finishedMsg struct {
   838  	raw        []byte
   839  	verifyData []byte
   840  }
   841  
   842  func (m *finishedMsg) equal(i interface{}) bool {
   843  	m1, ok := i.(*finishedMsg)
   844  	if !ok {
   845  		return false
   846  	}
   847  
   848  	return bytes.Equal(m.raw, m1.raw) &&
   849  		bytes.Equal(m.verifyData, m1.verifyData)
   850  }
   851  
   852  func (m *finishedMsg) marshal() (x []byte) {
   853  	if m.raw != nil {
   854  		return m.raw
   855  	}
   856  
   857  	x = make([]byte, 4+len(m.verifyData))
   858  	x[0] = typeFinished
   859  	x[3] = byte(len(m.verifyData))
   860  	copy(x[4:], m.verifyData)
   861  	m.raw = x
   862  	return
   863  }
   864  
   865  func (m *finishedMsg) unmarshal(data []byte) bool {
   866  	m.raw = data
   867  	if len(data) < 4 {
   868  		return false
   869  	}
   870  	m.verifyData = data[4:]
   871  	return true
   872  }
   873  
   874  type nextProtoMsg struct {
   875  	raw   []byte
   876  	proto string
   877  }
   878  
   879  func (m *nextProtoMsg) equal(i interface{}) bool {
   880  	m1, ok := i.(*nextProtoMsg)
   881  	if !ok {
   882  		return false
   883  	}
   884  
   885  	return bytes.Equal(m.raw, m1.raw) &&
   886  		m.proto == m1.proto
   887  }
   888  
   889  func (m *nextProtoMsg) marshal() []byte {
   890  	if m.raw != nil {
   891  		return m.raw
   892  	}
   893  	l := len(m.proto)
   894  	if l > 255 {
   895  		l = 255
   896  	}
   897  
   898  	padding := 32 - (l+2)%32
   899  	length := l + padding + 2
   900  	x := make([]byte, length+4)
   901  	x[0] = typeNextProtocol
   902  	x[1] = uint8(length >> 16)
   903  	x[2] = uint8(length >> 8)
   904  	x[3] = uint8(length)
   905  
   906  	y := x[4:]
   907  	y[0] = byte(l)
   908  	copy(y[1:], []byte(m.proto[0:l]))
   909  	y = y[1+l:]
   910  	y[0] = byte(padding)
   911  
   912  	m.raw = x
   913  
   914  	return x
   915  }
   916  
   917  func (m *nextProtoMsg) unmarshal(data []byte) bool {
   918  	m.raw = data
   919  
   920  	if len(data) < 5 {
   921  		return false
   922  	}
   923  	data = data[4:]
   924  	protoLen := int(data[0])
   925  	data = data[1:]
   926  	if len(data) < protoLen {
   927  		return false
   928  	}
   929  	m.proto = string(data[0:protoLen])
   930  	data = data[protoLen:]
   931  
   932  	if len(data) < 1 {
   933  		return false
   934  	}
   935  	paddingLen := int(data[0])
   936  	data = data[1:]
   937  	if len(data) != paddingLen {
   938  		return false
   939  	}
   940  
   941  	return true
   942  }
   943  
   944  type certificateRequestMsg struct {
   945  	raw []byte
   946  	// hasSignatureAndHash indicates whether this message includes a list
   947  	// of signature and hash functions. This change was introduced with TLS
   948  	// 1.2.
   949  	hasSignatureAndHash bool
   950  
   951  	certificateTypes       []byte
   952  	signatureAndHashes     []signatureAndHash
   953  	certificateAuthorities [][]byte
   954  }
   955  
   956  func (m *certificateRequestMsg) equal(i interface{}) bool {
   957  	m1, ok := i.(*certificateRequestMsg)
   958  	if !ok {
   959  		return false
   960  	}
   961  
   962  	return bytes.Equal(m.raw, m1.raw) &&
   963  		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
   964  		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
   965  		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
   966  }
   967  
   968  func (m *certificateRequestMsg) marshal() (x []byte) {
   969  	if m.raw != nil {
   970  		return m.raw
   971  	}
   972  
   973  	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
   974  	length := 1 + len(m.certificateTypes) + 2
   975  	casLength := 0
   976  	for _, ca := range m.certificateAuthorities {
   977  		casLength += 2 + len(ca)
   978  	}
   979  	length += casLength
   980  
   981  	if m.hasSignatureAndHash {
   982  		length += 2 + 2*len(m.signatureAndHashes)
   983  	}
   984  
   985  	x = make([]byte, 4+length)
   986  	x[0] = typeCertificateRequest
   987  	x[1] = uint8(length >> 16)
   988  	x[2] = uint8(length >> 8)
   989  	x[3] = uint8(length)
   990  
   991  	x[4] = uint8(len(m.certificateTypes))
   992  
   993  	copy(x[5:], m.certificateTypes)
   994  	y := x[5+len(m.certificateTypes):]
   995  
   996  	if m.hasSignatureAndHash {
   997  		n := len(m.signatureAndHashes) * 2
   998  		y[0] = uint8(n >> 8)
   999  		y[1] = uint8(n)
  1000  		y = y[2:]
  1001  		for _, sigAndHash := range m.signatureAndHashes {
  1002  			y[0] = sigAndHash.hash
  1003  			y[1] = sigAndHash.signature
  1004  			y = y[2:]
  1005  		}
  1006  	}
  1007  
  1008  	y[0] = uint8(casLength >> 8)
  1009  	y[1] = uint8(casLength)
  1010  	y = y[2:]
  1011  	for _, ca := range m.certificateAuthorities {
  1012  		y[0] = uint8(len(ca) >> 8)
  1013  		y[1] = uint8(len(ca))
  1014  		y = y[2:]
  1015  		copy(y, ca)
  1016  		y = y[len(ca):]
  1017  	}
  1018  
  1019  	m.raw = x
  1020  	return
  1021  }
  1022  
  1023  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1024  	m.raw = data
  1025  
  1026  	if len(data) < 5 {
  1027  		return false
  1028  	}
  1029  
  1030  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1031  	if uint32(len(data))-4 != length {
  1032  		return false
  1033  	}
  1034  
  1035  	numCertTypes := int(data[4])
  1036  	data = data[5:]
  1037  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1038  		return false
  1039  	}
  1040  
  1041  	m.certificateTypes = make([]byte, numCertTypes)
  1042  	if copy(m.certificateTypes, data) != numCertTypes {
  1043  		return false
  1044  	}
  1045  
  1046  	data = data[numCertTypes:]
  1047  
  1048  	if m.hasSignatureAndHash {
  1049  		if len(data) < 2 {
  1050  			return false
  1051  		}
  1052  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1053  		data = data[2:]
  1054  		if sigAndHashLen&1 != 0 {
  1055  			return false
  1056  		}
  1057  		if len(data) < int(sigAndHashLen) {
  1058  			return false
  1059  		}
  1060  		numSigAndHash := sigAndHashLen / 2
  1061  		m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
  1062  		for i := range m.signatureAndHashes {
  1063  			m.signatureAndHashes[i].hash = data[0]
  1064  			m.signatureAndHashes[i].signature = data[1]
  1065  			data = data[2:]
  1066  		}
  1067  	}
  1068  
  1069  	if len(data) < 2 {
  1070  		return false
  1071  	}
  1072  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1073  	data = data[2:]
  1074  	if len(data) < int(casLength) {
  1075  		return false
  1076  	}
  1077  	cas := make([]byte, casLength)
  1078  	copy(cas, data)
  1079  	data = data[casLength:]
  1080  
  1081  	m.certificateAuthorities = nil
  1082  	for len(cas) > 0 {
  1083  		if len(cas) < 2 {
  1084  			return false
  1085  		}
  1086  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1087  		cas = cas[2:]
  1088  
  1089  		if len(cas) < int(caLen) {
  1090  			return false
  1091  		}
  1092  
  1093  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1094  		cas = cas[caLen:]
  1095  	}
  1096  	if len(data) > 0 {
  1097  		return false
  1098  	}
  1099  
  1100  	return true
  1101  }
  1102  
  1103  type certificateVerifyMsg struct {
  1104  	raw                 []byte
  1105  	hasSignatureAndHash bool
  1106  	signatureAndHash    signatureAndHash
  1107  	signature           []byte
  1108  }
  1109  
  1110  func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1111  	m1, ok := i.(*certificateVerifyMsg)
  1112  	if !ok {
  1113  		return false
  1114  	}
  1115  
  1116  	return bytes.Equal(m.raw, m1.raw) &&
  1117  		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1118  		m.signatureAndHash.hash == m1.signatureAndHash.hash &&
  1119  		m.signatureAndHash.signature == m1.signatureAndHash.signature &&
  1120  		bytes.Equal(m.signature, m1.signature)
  1121  }
  1122  
  1123  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1124  	if m.raw != nil {
  1125  		return m.raw
  1126  	}
  1127  
  1128  	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1129  	siglength := len(m.signature)
  1130  	length := 2 + siglength
  1131  	if m.hasSignatureAndHash {
  1132  		length += 2
  1133  	}
  1134  	x = make([]byte, 4+length)
  1135  	x[0] = typeCertificateVerify
  1136  	x[1] = uint8(length >> 16)
  1137  	x[2] = uint8(length >> 8)
  1138  	x[3] = uint8(length)
  1139  	y := x[4:]
  1140  	if m.hasSignatureAndHash {
  1141  		y[0] = m.signatureAndHash.hash
  1142  		y[1] = m.signatureAndHash.signature
  1143  		y = y[2:]
  1144  	}
  1145  	y[0] = uint8(siglength >> 8)
  1146  	y[1] = uint8(siglength)
  1147  	copy(y[2:], m.signature)
  1148  
  1149  	m.raw = x
  1150  
  1151  	return
  1152  }
  1153  
  1154  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1155  	m.raw = data
  1156  
  1157  	if len(data) < 6 {
  1158  		return false
  1159  	}
  1160  
  1161  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1162  	if uint32(len(data))-4 != length {
  1163  		return false
  1164  	}
  1165  
  1166  	data = data[4:]
  1167  	if m.hasSignatureAndHash {
  1168  		m.signatureAndHash.hash = data[0]
  1169  		m.signatureAndHash.signature = data[1]
  1170  		data = data[2:]
  1171  	}
  1172  
  1173  	if len(data) < 2 {
  1174  		return false
  1175  	}
  1176  	siglength := int(data[0])<<8 + int(data[1])
  1177  	data = data[2:]
  1178  	if len(data) != siglength {
  1179  		return false
  1180  	}
  1181  
  1182  	m.signature = data
  1183  
  1184  	return true
  1185  }
  1186  
  1187  type newSessionTicketMsg struct {
  1188  	raw    []byte
  1189  	ticket []byte
  1190  }
  1191  
  1192  func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1193  	m1, ok := i.(*newSessionTicketMsg)
  1194  	if !ok {
  1195  		return false
  1196  	}
  1197  
  1198  	return bytes.Equal(m.raw, m1.raw) &&
  1199  		bytes.Equal(m.ticket, m1.ticket)
  1200  }
  1201  
  1202  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1203  	if m.raw != nil {
  1204  		return m.raw
  1205  	}
  1206  
  1207  	// See http://tools.ietf.org/html/rfc5077#section-3.3
  1208  	ticketLen := len(m.ticket)
  1209  	length := 2 + 4 + ticketLen
  1210  	x = make([]byte, 4+length)
  1211  	x[0] = typeNewSessionTicket
  1212  	x[1] = uint8(length >> 16)
  1213  	x[2] = uint8(length >> 8)
  1214  	x[3] = uint8(length)
  1215  	x[8] = uint8(ticketLen >> 8)
  1216  	x[9] = uint8(ticketLen)
  1217  	copy(x[10:], m.ticket)
  1218  
  1219  	m.raw = x
  1220  
  1221  	return
  1222  }
  1223  
  1224  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1225  	m.raw = data
  1226  
  1227  	if len(data) < 10 {
  1228  		return false
  1229  	}
  1230  
  1231  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1232  	if uint32(len(data))-4 != length {
  1233  		return false
  1234  	}
  1235  
  1236  	ticketLen := int(data[8])<<8 + int(data[9])
  1237  	if len(data)-10 != ticketLen {
  1238  		return false
  1239  	}
  1240  
  1241  	m.ticket = data[10:]
  1242  
  1243  	return true
  1244  }
  1245  
  1246  func eqUint16s(x, y []uint16) bool {
  1247  	if len(x) != len(y) {
  1248  		return false
  1249  	}
  1250  	for i, v := range x {
  1251  		if y[i] != v {
  1252  			return false
  1253  		}
  1254  	}
  1255  	return true
  1256  }
  1257  
  1258  func eqStrings(x, y []string) bool {
  1259  	if len(x) != len(y) {
  1260  		return false
  1261  	}
  1262  	for i, v := range x {
  1263  		if y[i] != v {
  1264  			return false
  1265  		}
  1266  	}
  1267  	return true
  1268  }
  1269  
  1270  func eqByteSlices(x, y [][]byte) bool {
  1271  	if len(x) != len(y) {
  1272  		return false
  1273  	}
  1274  	for i, v := range x {
  1275  		if !bytes.Equal(v, y[i]) {
  1276  			return false
  1277  		}
  1278  	}
  1279  	return true
  1280  }
  1281  
  1282  func eqSignatureAndHashes(x, y []signatureAndHash) bool {
  1283  	if len(x) != len(y) {
  1284  		return false
  1285  	}
  1286  	for i, v := range x {
  1287  		v2 := y[i]
  1288  		if v.hash != v2.hash || v.signature != v2.signature {
  1289  			return false
  1290  		}
  1291  	}
  1292  	return true
  1293  }