gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/handshake_messages.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
    11  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
    12  */
    13  
    14  package gmtls
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  
    20  	"golang.org/x/crypto/cryptobyte"
    21  )
    22  
    23  // The marshalingFunction type is an adapter to allow the use of ordinary
    24  // functions as cryptobyte.MarshalingValue.
    25  type marshalingFunction func(b *cryptobyte.Builder) error
    26  
    27  func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
    28  	return f(b)
    29  }
    30  
    31  // addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
    32  // the length of the sequence is not the value specified, it produces an error.
    33  func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
    34  	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
    35  		if len(v) != n {
    36  			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
    37  		}
    38  		b.AddBytes(v)
    39  		return nil
    40  	}))
    41  }
    42  
    43  // addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
    44  func addUint64(b *cryptobyte.Builder, v uint64) {
    45  	b.AddUint32(uint32(v >> 32))
    46  	b.AddUint32(uint32(v))
    47  }
    48  
    49  // readUint64 decodes a big-endian, 64-bit value into out and advances over it.
    50  // It reports whether the read was successful.
    51  func readUint64(s *cryptobyte.String, out *uint64) bool {
    52  	var hi, lo uint32
    53  	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
    54  		return false
    55  	}
    56  	*out = uint64(hi)<<32 | uint64(lo)
    57  	return true
    58  }
    59  
    60  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
    61  // []byte instead of a cryptobyte.String.
    62  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    63  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
    64  }
    65  
    66  // readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
    67  // []byte instead of a cryptobyte.String.
    68  func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    69  	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
    70  }
    71  
    72  // readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
    73  // []byte instead of a cryptobyte.String.
    74  func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    75  	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
    76  }
    77  
    78  type clientHelloMsg struct {
    79  	raw                              []byte            // 序列化原文
    80  	vers                             uint16            // tls协议最高支持版本, tls1.2及之前版本用
    81  	random                           []byte            // ClientRandom, tls1.2及之前版本用来生成主密钥的参数之一
    82  	sessionId                        []byte            // 会话ID
    83  	cipherSuites                     []uint16          // 支持的密码套件列表
    84  	compressionMethods               []uint8           // 支持的压缩方法列表
    85  	serverName                       string            // 访问的服务端名称
    86  	ocspStapling                     bool              // 是否需要OCSP装订
    87  	supportedCurves                  []CurveID         // 支持的椭圆曲线ID
    88  	supportedPoints                  []uint8           // 支持的椭圆曲线点坐标压缩格式
    89  	ticketSupported                  bool              // 是否支持票据
    90  	sessionTicket                    []uint8           // 会话票据
    91  	supportedSignatureAlgorithms     []SignatureScheme // 支持的签名算法
    92  	supportedSignatureAlgorithmsCert []SignatureScheme // 支持的签名算法证书?
    93  	secureRenegotiationSupported     bool              // 是否支持安全重协议
    94  	secureRenegotiation              []byte            // 安全重协议用字段,非首次握手时,将上次握手时客户端的Finished消息作为该字段传入
    95  	alpnProtocols                    []string          // ALPN协议列表
    96  	scts                             bool              // signed certificate timestamps from server
    97  	supportedVersions                []uint16          // 支持的tls协议, tls1.3后启用的扩展信息
    98  	cookie                           []byte            // cookie?
    99  	keyShares                        []keyShare        // 密钥交换算法参数
   100  	earlyData                        bool              // 上次握手的密钥材料?
   101  	pskModes                         []uint8           // psk(pre shared key)模式?
   102  	pskIdentities                    []pskIdentity     // psk ids 客户端存储的之前的共享密钥ID?
   103  	pskBinders                       [][]byte          // psk Binders ?
   104  }
   105  
   106  func (m *clientHelloMsg) marshal() []byte {
   107  	if m.raw != nil {
   108  		return m.raw
   109  	}
   110  
   111  	var b cryptobyte.Builder
   112  	b.AddUint8(typeClientHello)
   113  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   114  		b.AddUint16(m.vers)
   115  		addBytesWithLength(b, m.random, 32)
   116  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   117  			b.AddBytes(m.sessionId)
   118  		})
   119  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   120  			for _, suite := range m.cipherSuites {
   121  				b.AddUint16(suite)
   122  			}
   123  		})
   124  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   125  			b.AddBytes(m.compressionMethods)
   126  		})
   127  
   128  		// If extensions aren't present, omit them.
   129  		var extensionsPresent bool
   130  		bWithoutExtensions := *b
   131  
   132  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   133  			if len(m.serverName) > 0 {
   134  				// RFC 6066, Section 3
   135  				b.AddUint16(extensionServerName)
   136  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   137  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   138  						b.AddUint8(0) // name_type = host_name
   139  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   140  							b.AddBytes([]byte(m.serverName))
   141  						})
   142  					})
   143  				})
   144  			}
   145  			if m.ocspStapling {
   146  				// RFC 4366, Section 3.6
   147  				b.AddUint16(extensionStatusRequest)
   148  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   149  					b.AddUint8(1)  // status_type = ocsp
   150  					b.AddUint16(0) // empty responder_id_list
   151  					b.AddUint16(0) // empty request_extensions
   152  				})
   153  			}
   154  			if len(m.supportedCurves) > 0 {
   155  				// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   156  				b.AddUint16(extensionSupportedCurves)
   157  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   158  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   159  						for _, curve := range m.supportedCurves {
   160  							b.AddUint16(uint16(curve))
   161  						}
   162  					})
   163  				})
   164  			}
   165  			if len(m.supportedPoints) > 0 {
   166  				// RFC 4492, Section 5.1.2
   167  				b.AddUint16(extensionSupportedPoints)
   168  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   169  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   170  						b.AddBytes(m.supportedPoints)
   171  					})
   172  				})
   173  			}
   174  			if m.ticketSupported {
   175  				// RFC 5077, Section 3.2
   176  				b.AddUint16(extensionSessionTicket)
   177  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   178  					b.AddBytes(m.sessionTicket)
   179  				})
   180  			}
   181  			if len(m.supportedSignatureAlgorithms) > 0 {
   182  				// RFC 5246, Section 7.4.1.4.1
   183  				b.AddUint16(extensionSignatureAlgorithms)
   184  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   185  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   186  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
   187  							b.AddUint16(uint16(sigAlgo))
   188  						}
   189  					})
   190  				})
   191  			}
   192  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
   193  				// RFC 8446, Section 4.2.3
   194  				b.AddUint16(extensionSignatureAlgorithmsCert)
   195  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   196  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   197  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
   198  							b.AddUint16(uint16(sigAlgo))
   199  						}
   200  					})
   201  				})
   202  			}
   203  			if m.secureRenegotiationSupported {
   204  				// RFC 5746, Section 3.2
   205  				b.AddUint16(extensionRenegotiationInfo)
   206  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   207  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   208  						b.AddBytes(m.secureRenegotiation)
   209  					})
   210  				})
   211  			}
   212  			if len(m.alpnProtocols) > 0 {
   213  				// RFC 7301, Section 3.1
   214  				b.AddUint16(extensionALPN)
   215  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   216  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   217  						for _, proto := range m.alpnProtocols {
   218  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   219  								b.AddBytes([]byte(proto))
   220  							})
   221  						}
   222  					})
   223  				})
   224  			}
   225  			if m.scts {
   226  				// RFC 6962, Section 3.3.1
   227  				b.AddUint16(extensionSCT)
   228  				b.AddUint16(0) // empty extension_data
   229  			}
   230  			if len(m.supportedVersions) > 0 {
   231  				// RFC 8446, Section 4.2.1
   232  				b.AddUint16(extensionSupportedVersions)
   233  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   234  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   235  						for _, vers := range m.supportedVersions {
   236  							b.AddUint16(vers)
   237  						}
   238  					})
   239  				})
   240  			}
   241  			if len(m.cookie) > 0 {
   242  				// RFC 8446, Section 4.2.2
   243  				b.AddUint16(extensionCookie)
   244  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   245  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   246  						b.AddBytes(m.cookie)
   247  					})
   248  				})
   249  			}
   250  			if len(m.keyShares) > 0 {
   251  				// RFC 8446, Section 4.2.8
   252  				b.AddUint16(extensionKeyShare)
   253  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   254  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   255  						for _, ks := range m.keyShares {
   256  							b.AddUint16(uint16(ks.group))
   257  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   258  								b.AddBytes(ks.data)
   259  							})
   260  						}
   261  					})
   262  				})
   263  			}
   264  			if m.earlyData {
   265  				// RFC 8446, Section 4.2.10
   266  				b.AddUint16(extensionEarlyData)
   267  				b.AddUint16(0) // empty extension_data
   268  			}
   269  			if len(m.pskModes) > 0 {
   270  				// RFC 8446, Section 4.2.9
   271  				b.AddUint16(extensionPSKModes)
   272  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   273  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   274  						b.AddBytes(m.pskModes)
   275  					})
   276  				})
   277  			}
   278  			if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
   279  				// RFC 8446, Section 4.2.11
   280  				b.AddUint16(extensionPreSharedKey)
   281  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   282  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   283  						for _, psk := range m.pskIdentities {
   284  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   285  								b.AddBytes(psk.label)
   286  							})
   287  							b.AddUint32(psk.obfuscatedTicketAge)
   288  						}
   289  					})
   290  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   291  						for _, binder := range m.pskBinders {
   292  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   293  								b.AddBytes(binder)
   294  							})
   295  						}
   296  					})
   297  				})
   298  			}
   299  
   300  			extensionsPresent = len(b.BytesOrPanic()) > 2
   301  		})
   302  
   303  		if !extensionsPresent {
   304  			*b = bWithoutExtensions
   305  		}
   306  	})
   307  
   308  	m.raw = b.BytesOrPanic()
   309  	return m.raw
   310  }
   311  
   312  // marshalWithoutBinders returns the ClientHello through the
   313  // PreSharedKeyExtension.identities field, according to RFC 8446, Section
   314  // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
   315  func (m *clientHelloMsg) marshalWithoutBinders() []byte {
   316  	bindersLen := 2 // uint16 length prefix
   317  	for _, binder := range m.pskBinders {
   318  		bindersLen += 1 // uint8 length prefix
   319  		bindersLen += len(binder)
   320  	}
   321  
   322  	fullMessage := m.marshal()
   323  	return fullMessage[:len(fullMessage)-bindersLen]
   324  }
   325  
   326  // updateBinders updates the m.pskBinders field, if necessary updating the
   327  // cached marshaled representation. The supplied binders must have the same
   328  // length as the current m.pskBinders.
   329  func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
   330  	if len(pskBinders) != len(m.pskBinders) {
   331  		panic("gmtls: internal error: pskBinders length mismatch")
   332  	}
   333  	for i := range m.pskBinders {
   334  		if len(pskBinders[i]) != len(m.pskBinders[i]) {
   335  			panic("gmtls: internal error: pskBinders length mismatch")
   336  		}
   337  	}
   338  	m.pskBinders = pskBinders
   339  	if m.raw != nil {
   340  		lenWithoutBinders := len(m.marshalWithoutBinders())
   341  		// TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported.
   342  		b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders])
   343  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   344  			for _, binder := range m.pskBinders {
   345  				b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   346  					b.AddBytes(binder)
   347  				})
   348  			}
   349  		})
   350  		if len(b.BytesOrPanic()) != len(m.raw) {
   351  			panic("gmtls: internal error: failed to update binders")
   352  		}
   353  	}
   354  }
   355  
   356  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   357  	*m = clientHelloMsg{raw: data}
   358  	s := cryptobyte.String(data)
   359  
   360  	if !s.Skip(4) || // message type and uint24 length field
   361  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   362  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   363  		return false
   364  	}
   365  
   366  	var cipherSuites cryptobyte.String
   367  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   368  		return false
   369  	}
   370  	m.cipherSuites = []uint16{}
   371  	m.secureRenegotiationSupported = false
   372  	for !cipherSuites.Empty() {
   373  		var suite uint16
   374  		if !cipherSuites.ReadUint16(&suite) {
   375  			return false
   376  		}
   377  		if suite == scsvRenegotiation {
   378  			m.secureRenegotiationSupported = true
   379  		}
   380  		m.cipherSuites = append(m.cipherSuites, suite)
   381  	}
   382  
   383  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   384  		return false
   385  	}
   386  
   387  	if s.Empty() {
   388  		// ClientHello is optionally followed by extension data
   389  		return true
   390  	}
   391  
   392  	var extensions cryptobyte.String
   393  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   394  		return false
   395  	}
   396  
   397  	for !extensions.Empty() {
   398  		var extension uint16
   399  		var extData cryptobyte.String
   400  		if !extensions.ReadUint16(&extension) ||
   401  			!extensions.ReadUint16LengthPrefixed(&extData) {
   402  			return false
   403  		}
   404  
   405  		switch extension {
   406  		case extensionServerName:
   407  			// RFC 6066, Section 3
   408  			var nameList cryptobyte.String
   409  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   410  				return false
   411  			}
   412  			for !nameList.Empty() {
   413  				var nameType uint8
   414  				var serverName cryptobyte.String
   415  				if !nameList.ReadUint8(&nameType) ||
   416  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   417  					serverName.Empty() {
   418  					return false
   419  				}
   420  				if nameType != 0 {
   421  					continue
   422  				}
   423  				if len(m.serverName) != 0 {
   424  					// Multiple names of the same name_type are prohibited.
   425  					return false
   426  				}
   427  				m.serverName = string(serverName)
   428  				// An SNI value may not include a trailing dot.
   429  				if strings.HasSuffix(m.serverName, ".") {
   430  					return false
   431  				}
   432  			}
   433  		case extensionStatusRequest:
   434  			// RFC 4366, Section 3.6
   435  			var statusType uint8
   436  			var ignored cryptobyte.String
   437  			if !extData.ReadUint8(&statusType) ||
   438  				!extData.ReadUint16LengthPrefixed(&ignored) ||
   439  				!extData.ReadUint16LengthPrefixed(&ignored) {
   440  				return false
   441  			}
   442  			m.ocspStapling = statusType == statusTypeOCSP
   443  		case extensionSupportedCurves:
   444  			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   445  			var curves cryptobyte.String
   446  			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
   447  				return false
   448  			}
   449  			for !curves.Empty() {
   450  				var curve uint16
   451  				if !curves.ReadUint16(&curve) {
   452  					return false
   453  				}
   454  				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
   455  			}
   456  		case extensionSupportedPoints:
   457  			// RFC 4492, Section 5.1.2
   458  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   459  				len(m.supportedPoints) == 0 {
   460  				return false
   461  			}
   462  		case extensionSessionTicket:
   463  			// RFC 5077, Section 3.2
   464  			m.ticketSupported = true
   465  			extData.ReadBytes(&m.sessionTicket, len(extData))
   466  		case extensionSignatureAlgorithms:
   467  			// RFC 5246, Section 7.4.1.4.1
   468  			var sigAndAlgs cryptobyte.String
   469  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   470  				return false
   471  			}
   472  			for !sigAndAlgs.Empty() {
   473  				var sigAndAlg uint16
   474  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   475  					return false
   476  				}
   477  				m.supportedSignatureAlgorithms = append(
   478  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
   479  			}
   480  		case extensionSignatureAlgorithmsCert:
   481  			// RFC 8446, Section 4.2.3
   482  			var sigAndAlgs cryptobyte.String
   483  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   484  				return false
   485  			}
   486  			for !sigAndAlgs.Empty() {
   487  				var sigAndAlg uint16
   488  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   489  					return false
   490  				}
   491  				m.supportedSignatureAlgorithmsCert = append(
   492  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
   493  			}
   494  		case extensionRenegotiationInfo:
   495  			// RFC 5746, Section 3.2
   496  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   497  				return false
   498  			}
   499  			m.secureRenegotiationSupported = true
   500  		case extensionALPN:
   501  			// RFC 7301, Section 3.1
   502  			var protoList cryptobyte.String
   503  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   504  				return false
   505  			}
   506  			for !protoList.Empty() {
   507  				var proto cryptobyte.String
   508  				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
   509  					return false
   510  				}
   511  				m.alpnProtocols = append(m.alpnProtocols, string(proto))
   512  			}
   513  		case extensionSCT:
   514  			// RFC 6962, Section 3.3.1
   515  			m.scts = true
   516  		case extensionSupportedVersions:
   517  			// RFC 8446, Section 4.2.1
   518  			var versList cryptobyte.String
   519  			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
   520  				return false
   521  			}
   522  			for !versList.Empty() {
   523  				var vers uint16
   524  				if !versList.ReadUint16(&vers) {
   525  					return false
   526  				}
   527  				m.supportedVersions = append(m.supportedVersions, vers)
   528  			}
   529  		case extensionCookie:
   530  			// RFC 8446, Section 4.2.2
   531  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   532  				len(m.cookie) == 0 {
   533  				return false
   534  			}
   535  		case extensionKeyShare:
   536  			// RFC 8446, Section 4.2.8
   537  			var clientShares cryptobyte.String
   538  			if !extData.ReadUint16LengthPrefixed(&clientShares) {
   539  				return false
   540  			}
   541  			for !clientShares.Empty() {
   542  				var ks keyShare
   543  				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
   544  					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
   545  					len(ks.data) == 0 {
   546  					return false
   547  				}
   548  				m.keyShares = append(m.keyShares, ks)
   549  			}
   550  		case extensionEarlyData:
   551  			// RFC 8446, Section 4.2.10
   552  			m.earlyData = true
   553  		case extensionPSKModes:
   554  			// RFC 8446, Section 4.2.9
   555  			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
   556  				return false
   557  			}
   558  		case extensionPreSharedKey:
   559  			// RFC 8446, Section 4.2.11
   560  			if !extensions.Empty() {
   561  				return false // pre_shared_key must be the last extension
   562  			}
   563  			var identities cryptobyte.String
   564  			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
   565  				return false
   566  			}
   567  			for !identities.Empty() {
   568  				var psk pskIdentity
   569  				if !readUint16LengthPrefixed(&identities, &psk.label) ||
   570  					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
   571  					len(psk.label) == 0 {
   572  					return false
   573  				}
   574  				m.pskIdentities = append(m.pskIdentities, psk)
   575  			}
   576  			var binders cryptobyte.String
   577  			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
   578  				return false
   579  			}
   580  			for !binders.Empty() {
   581  				var binder []byte
   582  				if !readUint8LengthPrefixed(&binders, &binder) ||
   583  					len(binder) == 0 {
   584  					return false
   585  				}
   586  				m.pskBinders = append(m.pskBinders, binder)
   587  			}
   588  		default:
   589  			// Ignore unknown extensions.
   590  			continue
   591  		}
   592  
   593  		if !extData.Empty() {
   594  			return false
   595  		}
   596  	}
   597  
   598  	return true
   599  }
   600  
   601  type serverHelloMsg struct {
   602  	raw                          []byte   // 序列化原文
   603  	vers                         uint16   // 协商好的tls协议, tls1.2及之前版本用
   604  	random                       []byte   // ServerRandom, tls1.2及之前版本用来生成主密钥的参数之一,tls1.3用来标识是否 HelloRetryRequest
   605  	sessionId                    []byte   // 会话ID
   606  	cipherSuite                  uint16   // 协商好的密码套件
   607  	compressionMethod            uint8    // 压缩方法
   608  	ocspStapling                 bool     // 是否需要OCSP装订
   609  	ticketSupported              bool     // 是否支持票据
   610  	secureRenegotiationSupported bool     // 是否支持安全重协议
   611  	secureRenegotiation          []byte   // 安全重协议?
   612  	alpnProtocol                 string   // 协商好的ALPN协议
   613  	scts                         [][]byte // signed certificate timestamps from server
   614  	supportedVersion             uint16   // 协商好的tls协议, tls1.3后启用的扩展信息
   615  	serverShare                  keyShare // 服务端密钥交换算法参数
   616  	selectedIdentityPresent      bool     // 是否选择恢复之前的会话的pskID
   617  	selectedIdentity             uint16   // 选择恢复会话的pskID
   618  	supportedPoints              []uint8  // 支持的曲线坐标点?
   619  	cookie                       []byte   // cookie , HelloRetryRequest扩展信息
   620  	selectedGroup                CurveID  // 曲线ID , HelloRetryRequest扩展信息
   621  }
   622  
   623  func (m *serverHelloMsg) marshal() []byte {
   624  	if m.raw != nil {
   625  		return m.raw
   626  	}
   627  
   628  	var b cryptobyte.Builder
   629  	b.AddUint8(typeServerHello)
   630  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   631  		b.AddUint16(m.vers)
   632  		addBytesWithLength(b, m.random, 32)
   633  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   634  			b.AddBytes(m.sessionId)
   635  		})
   636  		b.AddUint16(m.cipherSuite)
   637  		b.AddUint8(m.compressionMethod)
   638  
   639  		// If extensions aren't present, omit them.
   640  		var extensionsPresent bool
   641  		bWithoutExtensions := *b
   642  
   643  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   644  			if m.ocspStapling {
   645  				b.AddUint16(extensionStatusRequest)
   646  				b.AddUint16(0) // empty extension_data
   647  			}
   648  			if m.ticketSupported {
   649  				b.AddUint16(extensionSessionTicket)
   650  				b.AddUint16(0) // empty extension_data
   651  			}
   652  			if m.secureRenegotiationSupported {
   653  				b.AddUint16(extensionRenegotiationInfo)
   654  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   655  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   656  						b.AddBytes(m.secureRenegotiation)
   657  					})
   658  				})
   659  			}
   660  			if len(m.alpnProtocol) > 0 {
   661  				b.AddUint16(extensionALPN)
   662  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   663  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   664  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   665  							b.AddBytes([]byte(m.alpnProtocol))
   666  						})
   667  					})
   668  				})
   669  			}
   670  			if len(m.scts) > 0 {
   671  				b.AddUint16(extensionSCT)
   672  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   673  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   674  						for _, sct := range m.scts {
   675  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   676  								b.AddBytes(sct)
   677  							})
   678  						}
   679  					})
   680  				})
   681  			}
   682  			if m.supportedVersion != 0 {
   683  				b.AddUint16(extensionSupportedVersions)
   684  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   685  					b.AddUint16(m.supportedVersion)
   686  				})
   687  			}
   688  			if m.serverShare.group != 0 {
   689  				b.AddUint16(extensionKeyShare)
   690  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   691  					b.AddUint16(uint16(m.serverShare.group))
   692  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   693  						b.AddBytes(m.serverShare.data)
   694  					})
   695  				})
   696  			}
   697  			if m.selectedIdentityPresent {
   698  				b.AddUint16(extensionPreSharedKey)
   699  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   700  					b.AddUint16(m.selectedIdentity)
   701  				})
   702  			}
   703  
   704  			if len(m.cookie) > 0 {
   705  				b.AddUint16(extensionCookie)
   706  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   707  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   708  						b.AddBytes(m.cookie)
   709  					})
   710  				})
   711  			}
   712  			if m.selectedGroup != 0 {
   713  				b.AddUint16(extensionKeyShare)
   714  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   715  					b.AddUint16(uint16(m.selectedGroup))
   716  				})
   717  			}
   718  			if len(m.supportedPoints) > 0 {
   719  				b.AddUint16(extensionSupportedPoints)
   720  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   721  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   722  						b.AddBytes(m.supportedPoints)
   723  					})
   724  				})
   725  			}
   726  
   727  			extensionsPresent = len(b.BytesOrPanic()) > 2
   728  		})
   729  
   730  		if !extensionsPresent {
   731  			*b = bWithoutExtensions
   732  		}
   733  	})
   734  
   735  	m.raw = b.BytesOrPanic()
   736  	return m.raw
   737  }
   738  
   739  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   740  	*m = serverHelloMsg{raw: data}
   741  	s := cryptobyte.String(data)
   742  
   743  	if !s.Skip(4) || // message type and uint24 length field
   744  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   745  		!readUint8LengthPrefixed(&s, &m.sessionId) ||
   746  		!s.ReadUint16(&m.cipherSuite) ||
   747  		!s.ReadUint8(&m.compressionMethod) {
   748  		return false
   749  	}
   750  
   751  	if s.Empty() {
   752  		// ServerHello is optionally followed by extension data
   753  		return true
   754  	}
   755  
   756  	var extensions cryptobyte.String
   757  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   758  		return false
   759  	}
   760  
   761  	for !extensions.Empty() {
   762  		var extension uint16
   763  		var extData cryptobyte.String
   764  		if !extensions.ReadUint16(&extension) ||
   765  			!extensions.ReadUint16LengthPrefixed(&extData) {
   766  			return false
   767  		}
   768  
   769  		switch extension {
   770  		case extensionStatusRequest:
   771  			m.ocspStapling = true
   772  		case extensionSessionTicket:
   773  			m.ticketSupported = true
   774  		case extensionRenegotiationInfo:
   775  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   776  				return false
   777  			}
   778  			m.secureRenegotiationSupported = true
   779  		case extensionALPN:
   780  			var protoList cryptobyte.String
   781  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   782  				return false
   783  			}
   784  			var proto cryptobyte.String
   785  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   786  				proto.Empty() || !protoList.Empty() {
   787  				return false
   788  			}
   789  			m.alpnProtocol = string(proto)
   790  		case extensionSCT:
   791  			var sctList cryptobyte.String
   792  			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
   793  				return false
   794  			}
   795  			for !sctList.Empty() {
   796  				var sct []byte
   797  				if !readUint16LengthPrefixed(&sctList, &sct) ||
   798  					len(sct) == 0 {
   799  					return false
   800  				}
   801  				m.scts = append(m.scts, sct)
   802  			}
   803  		case extensionSupportedVersions:
   804  			if !extData.ReadUint16(&m.supportedVersion) {
   805  				return false
   806  			}
   807  		case extensionCookie:
   808  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   809  				len(m.cookie) == 0 {
   810  				return false
   811  			}
   812  		case extensionKeyShare:
   813  			// This extension has different formats in SH and HRR, accept either
   814  			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
   815  			if len(extData) == 2 {
   816  				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
   817  					return false
   818  				}
   819  			} else {
   820  				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
   821  					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
   822  					return false
   823  				}
   824  			}
   825  		case extensionPreSharedKey:
   826  			m.selectedIdentityPresent = true
   827  			if !extData.ReadUint16(&m.selectedIdentity) {
   828  				return false
   829  			}
   830  		case extensionSupportedPoints:
   831  			// RFC 4492, Section 5.1.2
   832  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   833  				len(m.supportedPoints) == 0 {
   834  				return false
   835  			}
   836  		default:
   837  			// Ignore unknown extensions.
   838  			continue
   839  		}
   840  
   841  		if !extData.Empty() {
   842  			return false
   843  		}
   844  	}
   845  
   846  	return true
   847  }
   848  
   849  type encryptedExtensionsMsg struct {
   850  	raw          []byte
   851  	alpnProtocol string
   852  }
   853  
   854  func (m *encryptedExtensionsMsg) marshal() []byte {
   855  	if m.raw != nil {
   856  		return m.raw
   857  	}
   858  
   859  	var b cryptobyte.Builder
   860  	b.AddUint8(typeEncryptedExtensions)
   861  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   862  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   863  			if len(m.alpnProtocol) > 0 {
   864  				b.AddUint16(extensionALPN)
   865  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   866  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   867  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   868  							b.AddBytes([]byte(m.alpnProtocol))
   869  						})
   870  					})
   871  				})
   872  			}
   873  		})
   874  	})
   875  
   876  	m.raw = b.BytesOrPanic()
   877  	return m.raw
   878  }
   879  
   880  func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
   881  	*m = encryptedExtensionsMsg{raw: data}
   882  	s := cryptobyte.String(data)
   883  
   884  	var extensions cryptobyte.String
   885  	if !s.Skip(4) || // message type and uint24 length field
   886  		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   887  		return false
   888  	}
   889  
   890  	for !extensions.Empty() {
   891  		var extension uint16
   892  		var extData cryptobyte.String
   893  		if !extensions.ReadUint16(&extension) ||
   894  			!extensions.ReadUint16LengthPrefixed(&extData) {
   895  			return false
   896  		}
   897  
   898  		switch extension {
   899  		case extensionALPN:
   900  			var protoList cryptobyte.String
   901  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   902  				return false
   903  			}
   904  			var proto cryptobyte.String
   905  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   906  				proto.Empty() || !protoList.Empty() {
   907  				return false
   908  			}
   909  			m.alpnProtocol = string(proto)
   910  		default:
   911  			// Ignore unknown extensions.
   912  			continue
   913  		}
   914  
   915  		if !extData.Empty() {
   916  			return false
   917  		}
   918  	}
   919  
   920  	return true
   921  }
   922  
   923  type endOfEarlyDataMsg struct{}
   924  
   925  func (m *endOfEarlyDataMsg) marshal() []byte {
   926  	x := make([]byte, 4)
   927  	x[0] = typeEndOfEarlyData
   928  	return x
   929  }
   930  
   931  func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
   932  	return len(data) == 4
   933  }
   934  
   935  type keyUpdateMsg struct {
   936  	raw             []byte
   937  	updateRequested bool
   938  }
   939  
   940  func (m *keyUpdateMsg) marshal() []byte {
   941  	if m.raw != nil {
   942  		return m.raw
   943  	}
   944  
   945  	var b cryptobyte.Builder
   946  	b.AddUint8(typeKeyUpdate)
   947  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   948  		if m.updateRequested {
   949  			b.AddUint8(1)
   950  		} else {
   951  			b.AddUint8(0)
   952  		}
   953  	})
   954  
   955  	m.raw = b.BytesOrPanic()
   956  	return m.raw
   957  }
   958  
   959  func (m *keyUpdateMsg) unmarshal(data []byte) bool {
   960  	m.raw = data
   961  	s := cryptobyte.String(data)
   962  
   963  	var updateRequested uint8
   964  	if !s.Skip(4) || // message type and uint24 length field
   965  		!s.ReadUint8(&updateRequested) || !s.Empty() {
   966  		return false
   967  	}
   968  	switch updateRequested {
   969  	case 0:
   970  		m.updateRequested = false
   971  	case 1:
   972  		m.updateRequested = true
   973  	default:
   974  		return false
   975  	}
   976  	return true
   977  }
   978  
   979  type newSessionTicketMsgTLS13 struct {
   980  	raw          []byte
   981  	lifetime     uint32
   982  	ageAdd       uint32
   983  	nonce        []byte
   984  	label        []byte
   985  	maxEarlyData uint32
   986  }
   987  
   988  func (m *newSessionTicketMsgTLS13) marshal() []byte {
   989  	if m.raw != nil {
   990  		return m.raw
   991  	}
   992  
   993  	var b cryptobyte.Builder
   994  	b.AddUint8(typeNewSessionTicket)
   995  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   996  		b.AddUint32(m.lifetime)
   997  		b.AddUint32(m.ageAdd)
   998  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   999  			b.AddBytes(m.nonce)
  1000  		})
  1001  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1002  			b.AddBytes(m.label)
  1003  		})
  1004  
  1005  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1006  			if m.maxEarlyData > 0 {
  1007  				b.AddUint16(extensionEarlyData)
  1008  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1009  					b.AddUint32(m.maxEarlyData)
  1010  				})
  1011  			}
  1012  		})
  1013  	})
  1014  
  1015  	m.raw = b.BytesOrPanic()
  1016  	return m.raw
  1017  }
  1018  
  1019  func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
  1020  	*m = newSessionTicketMsgTLS13{raw: data}
  1021  	s := cryptobyte.String(data)
  1022  
  1023  	var extensions cryptobyte.String
  1024  	if !s.Skip(4) || // message type and uint24 length field
  1025  		!s.ReadUint32(&m.lifetime) ||
  1026  		!s.ReadUint32(&m.ageAdd) ||
  1027  		!readUint8LengthPrefixed(&s, &m.nonce) ||
  1028  		!readUint16LengthPrefixed(&s, &m.label) ||
  1029  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1030  		!s.Empty() {
  1031  		return false
  1032  	}
  1033  
  1034  	for !extensions.Empty() {
  1035  		var extension uint16
  1036  		var extData cryptobyte.String
  1037  		if !extensions.ReadUint16(&extension) ||
  1038  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1039  			return false
  1040  		}
  1041  
  1042  		switch extension {
  1043  		case extensionEarlyData:
  1044  			if !extData.ReadUint32(&m.maxEarlyData) {
  1045  				return false
  1046  			}
  1047  		default:
  1048  			// Ignore unknown extensions.
  1049  			continue
  1050  		}
  1051  
  1052  		if !extData.Empty() {
  1053  			return false
  1054  		}
  1055  	}
  1056  
  1057  	return true
  1058  }
  1059  
  1060  type certificateRequestMsgTLS13 struct {
  1061  	raw                              []byte
  1062  	ocspStapling                     bool
  1063  	scts                             bool
  1064  	supportedSignatureAlgorithms     []SignatureScheme
  1065  	supportedSignatureAlgorithmsCert []SignatureScheme
  1066  	certificateAuthorities           [][]byte
  1067  }
  1068  
  1069  func (m *certificateRequestMsgTLS13) marshal() []byte {
  1070  	if m.raw != nil {
  1071  		return m.raw
  1072  	}
  1073  
  1074  	var b cryptobyte.Builder
  1075  	b.AddUint8(typeCertificateRequest)
  1076  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1077  		// certificate_request_context (SHALL be zero length unless used for
  1078  		// post-handshake authentication)
  1079  		b.AddUint8(0)
  1080  
  1081  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1082  			if m.ocspStapling {
  1083  				b.AddUint16(extensionStatusRequest)
  1084  				b.AddUint16(0) // empty extension_data
  1085  			}
  1086  			if m.scts {
  1087  				// RFC 8446, Section 4.4.2.1 makes no mention of
  1088  				// signed_certificate_timestamp in CertificateRequest, but
  1089  				// "Extensions in the Certificate message from the client MUST
  1090  				// correspond to extensions in the CertificateRequest message
  1091  				// from the server." and it appears in the table in Section 4.2.
  1092  				b.AddUint16(extensionSCT)
  1093  				b.AddUint16(0) // empty extension_data
  1094  			}
  1095  			if len(m.supportedSignatureAlgorithms) > 0 {
  1096  				b.AddUint16(extensionSignatureAlgorithms)
  1097  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1098  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1099  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1100  							b.AddUint16(uint16(sigAlgo))
  1101  						}
  1102  					})
  1103  				})
  1104  			}
  1105  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
  1106  				b.AddUint16(extensionSignatureAlgorithmsCert)
  1107  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1108  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1109  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
  1110  							b.AddUint16(uint16(sigAlgo))
  1111  						}
  1112  					})
  1113  				})
  1114  			}
  1115  			if len(m.certificateAuthorities) > 0 {
  1116  				b.AddUint16(extensionCertificateAuthorities)
  1117  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1118  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1119  						for _, ca := range m.certificateAuthorities {
  1120  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1121  								b.AddBytes(ca)
  1122  							})
  1123  						}
  1124  					})
  1125  				})
  1126  			}
  1127  		})
  1128  	})
  1129  
  1130  	m.raw = b.BytesOrPanic()
  1131  	return m.raw
  1132  }
  1133  
  1134  func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
  1135  	*m = certificateRequestMsgTLS13{raw: data}
  1136  	s := cryptobyte.String(data)
  1137  
  1138  	var context, extensions cryptobyte.String
  1139  	if !s.Skip(4) || // message type and uint24 length field
  1140  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1141  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1142  		!s.Empty() {
  1143  		return false
  1144  	}
  1145  
  1146  	for !extensions.Empty() {
  1147  		var extension uint16
  1148  		var extData cryptobyte.String
  1149  		if !extensions.ReadUint16(&extension) ||
  1150  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1151  			return false
  1152  		}
  1153  
  1154  		switch extension {
  1155  		case extensionStatusRequest:
  1156  			m.ocspStapling = true
  1157  		case extensionSCT:
  1158  			m.scts = true
  1159  		case extensionSignatureAlgorithms:
  1160  			var sigAndAlgs cryptobyte.String
  1161  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1162  				return false
  1163  			}
  1164  			for !sigAndAlgs.Empty() {
  1165  				var sigAndAlg uint16
  1166  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1167  					return false
  1168  				}
  1169  				m.supportedSignatureAlgorithms = append(
  1170  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
  1171  			}
  1172  		case extensionSignatureAlgorithmsCert:
  1173  			var sigAndAlgs cryptobyte.String
  1174  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1175  				return false
  1176  			}
  1177  			for !sigAndAlgs.Empty() {
  1178  				var sigAndAlg uint16
  1179  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1180  					return false
  1181  				}
  1182  				m.supportedSignatureAlgorithmsCert = append(
  1183  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
  1184  			}
  1185  		case extensionCertificateAuthorities:
  1186  			var auths cryptobyte.String
  1187  			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
  1188  				return false
  1189  			}
  1190  			for !auths.Empty() {
  1191  				var ca []byte
  1192  				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
  1193  					return false
  1194  				}
  1195  				m.certificateAuthorities = append(m.certificateAuthorities, ca)
  1196  			}
  1197  		default:
  1198  			// Ignore unknown extensions.
  1199  			continue
  1200  		}
  1201  
  1202  		if !extData.Empty() {
  1203  			return false
  1204  		}
  1205  	}
  1206  
  1207  	return true
  1208  }
  1209  
  1210  type certificateMsg struct {
  1211  	raw          []byte
  1212  	certificates [][]byte
  1213  }
  1214  
  1215  func (m *certificateMsg) marshal() (x []byte) {
  1216  	if m.raw != nil {
  1217  		return m.raw
  1218  	}
  1219  
  1220  	var i int
  1221  	for _, slice := range m.certificates {
  1222  		i += len(slice)
  1223  	}
  1224  
  1225  	length := 3 + 3*len(m.certificates) + i
  1226  	x = make([]byte, 4+length)
  1227  	x[0] = typeCertificate
  1228  	x[1] = uint8(length >> 16)
  1229  	x[2] = uint8(length >> 8)
  1230  	x[3] = uint8(length)
  1231  
  1232  	certificateOctets := length - 3
  1233  	x[4] = uint8(certificateOctets >> 16)
  1234  	x[5] = uint8(certificateOctets >> 8)
  1235  	x[6] = uint8(certificateOctets)
  1236  
  1237  	y := x[7:]
  1238  	for _, slice := range m.certificates {
  1239  		y[0] = uint8(len(slice) >> 16)
  1240  		y[1] = uint8(len(slice) >> 8)
  1241  		y[2] = uint8(len(slice))
  1242  		copy(y[3:], slice)
  1243  		y = y[3+len(slice):]
  1244  	}
  1245  
  1246  	m.raw = x
  1247  	return
  1248  }
  1249  
  1250  func (m *certificateMsg) unmarshal(data []byte) bool {
  1251  	if len(data) < 7 {
  1252  		return false
  1253  	}
  1254  
  1255  	m.raw = data
  1256  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1257  	if uint32(len(data)) != certsLen+7 {
  1258  		return false
  1259  	}
  1260  
  1261  	numCerts := 0
  1262  	d := data[7:]
  1263  	for certsLen > 0 {
  1264  		if len(d) < 4 {
  1265  			return false
  1266  		}
  1267  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1268  		if uint32(len(d)) < 3+certLen {
  1269  			return false
  1270  		}
  1271  		d = d[3+certLen:]
  1272  		certsLen -= 3 + certLen
  1273  		numCerts++
  1274  	}
  1275  
  1276  	m.certificates = make([][]byte, numCerts)
  1277  	d = data[7:]
  1278  	for i := 0; i < numCerts; i++ {
  1279  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1280  		m.certificates[i] = d[3 : 3+certLen]
  1281  		d = d[3+certLen:]
  1282  	}
  1283  
  1284  	return true
  1285  }
  1286  
  1287  type certificateMsgTLS13 struct {
  1288  	raw          []byte
  1289  	certificate  Certificate
  1290  	ocspStapling bool
  1291  	scts         bool
  1292  }
  1293  
  1294  func (m *certificateMsgTLS13) marshal() []byte {
  1295  	if m.raw != nil {
  1296  		return m.raw
  1297  	}
  1298  
  1299  	var b cryptobyte.Builder
  1300  	b.AddUint8(typeCertificate)
  1301  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1302  		b.AddUint8(0) // certificate_request_context
  1303  
  1304  		certificate := m.certificate
  1305  		if !m.ocspStapling {
  1306  			certificate.OCSPStaple = nil
  1307  		}
  1308  		if !m.scts {
  1309  			certificate.SignedCertificateTimestamps = nil
  1310  		}
  1311  		marshalCertificate(b, certificate)
  1312  	})
  1313  
  1314  	m.raw = b.BytesOrPanic()
  1315  	return m.raw
  1316  }
  1317  
  1318  func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
  1319  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1320  		for i, cert := range certificate.Certificate {
  1321  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1322  				b.AddBytes(cert)
  1323  			})
  1324  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1325  				if i > 0 {
  1326  					// This library only supports OCSP and SCT for leaf certificates.
  1327  					return
  1328  				}
  1329  				if certificate.OCSPStaple != nil {
  1330  					b.AddUint16(extensionStatusRequest)
  1331  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1332  						b.AddUint8(statusTypeOCSP)
  1333  						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1334  							b.AddBytes(certificate.OCSPStaple)
  1335  						})
  1336  					})
  1337  				}
  1338  				if certificate.SignedCertificateTimestamps != nil {
  1339  					b.AddUint16(extensionSCT)
  1340  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1341  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1342  							for _, sct := range certificate.SignedCertificateTimestamps {
  1343  								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1344  									b.AddBytes(sct)
  1345  								})
  1346  							}
  1347  						})
  1348  					})
  1349  				}
  1350  			})
  1351  		}
  1352  	})
  1353  }
  1354  
  1355  func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
  1356  	*m = certificateMsgTLS13{raw: data}
  1357  	s := cryptobyte.String(data)
  1358  
  1359  	var context cryptobyte.String
  1360  	if !s.Skip(4) || // message type and uint24 length field
  1361  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1362  		!unmarshalCertificate(&s, &m.certificate) ||
  1363  		!s.Empty() {
  1364  		return false
  1365  	}
  1366  
  1367  	m.scts = m.certificate.SignedCertificateTimestamps != nil
  1368  	m.ocspStapling = m.certificate.OCSPStaple != nil
  1369  
  1370  	return true
  1371  }
  1372  
  1373  func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
  1374  	var certList cryptobyte.String
  1375  	if !s.ReadUint24LengthPrefixed(&certList) {
  1376  		return false
  1377  	}
  1378  	for !certList.Empty() {
  1379  		var cert []byte
  1380  		var extensions cryptobyte.String
  1381  		if !readUint24LengthPrefixed(&certList, &cert) ||
  1382  			!certList.ReadUint16LengthPrefixed(&extensions) {
  1383  			return false
  1384  		}
  1385  		certificate.Certificate = append(certificate.Certificate, cert)
  1386  		for !extensions.Empty() {
  1387  			var extension uint16
  1388  			var extData cryptobyte.String
  1389  			if !extensions.ReadUint16(&extension) ||
  1390  				!extensions.ReadUint16LengthPrefixed(&extData) {
  1391  				return false
  1392  			}
  1393  			if len(certificate.Certificate) > 1 {
  1394  				// This library only supports OCSP and SCT for leaf certificates.
  1395  				continue
  1396  			}
  1397  
  1398  			switch extension {
  1399  			case extensionStatusRequest:
  1400  				var statusType uint8
  1401  				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1402  					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
  1403  					len(certificate.OCSPStaple) == 0 {
  1404  					return false
  1405  				}
  1406  			case extensionSCT:
  1407  				var sctList cryptobyte.String
  1408  				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
  1409  					return false
  1410  				}
  1411  				for !sctList.Empty() {
  1412  					var sct []byte
  1413  					if !readUint16LengthPrefixed(&sctList, &sct) ||
  1414  						len(sct) == 0 {
  1415  						return false
  1416  					}
  1417  					certificate.SignedCertificateTimestamps = append(
  1418  						certificate.SignedCertificateTimestamps, sct)
  1419  				}
  1420  			default:
  1421  				// Ignore unknown extensions.
  1422  				continue
  1423  			}
  1424  
  1425  			if !extData.Empty() {
  1426  				return false
  1427  			}
  1428  		}
  1429  	}
  1430  	return true
  1431  }
  1432  
  1433  type serverKeyExchangeMsg struct {
  1434  	raw []byte
  1435  	key []byte
  1436  }
  1437  
  1438  func (m *serverKeyExchangeMsg) marshal() []byte {
  1439  	if m.raw != nil {
  1440  		return m.raw
  1441  	}
  1442  	length := len(m.key)
  1443  	x := make([]byte, length+4)
  1444  	x[0] = typeServerKeyExchange
  1445  	x[1] = uint8(length >> 16)
  1446  	x[2] = uint8(length >> 8)
  1447  	x[3] = uint8(length)
  1448  	copy(x[4:], m.key)
  1449  
  1450  	m.raw = x
  1451  	return x
  1452  }
  1453  
  1454  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1455  	m.raw = data
  1456  	if len(data) < 4 {
  1457  		return false
  1458  	}
  1459  	m.key = data[4:]
  1460  	return true
  1461  }
  1462  
  1463  type certificateStatusMsg struct {
  1464  	raw      []byte
  1465  	response []byte
  1466  }
  1467  
  1468  func (m *certificateStatusMsg) marshal() []byte {
  1469  	if m.raw != nil {
  1470  		return m.raw
  1471  	}
  1472  
  1473  	var b cryptobyte.Builder
  1474  	b.AddUint8(typeCertificateStatus)
  1475  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1476  		b.AddUint8(statusTypeOCSP)
  1477  		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1478  			b.AddBytes(m.response)
  1479  		})
  1480  	})
  1481  
  1482  	m.raw = b.BytesOrPanic()
  1483  	return m.raw
  1484  }
  1485  
  1486  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1487  	m.raw = data
  1488  	s := cryptobyte.String(data)
  1489  
  1490  	var statusType uint8
  1491  	if !s.Skip(4) || // message type and uint24 length field
  1492  		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1493  		!readUint24LengthPrefixed(&s, &m.response) ||
  1494  		len(m.response) == 0 || !s.Empty() {
  1495  		return false
  1496  	}
  1497  	return true
  1498  }
  1499  
  1500  type serverHelloDoneMsg struct{}
  1501  
  1502  func (m *serverHelloDoneMsg) marshal() []byte {
  1503  	x := make([]byte, 4)
  1504  	x[0] = typeServerHelloDone
  1505  	return x
  1506  }
  1507  
  1508  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1509  	return len(data) == 4
  1510  }
  1511  
  1512  type clientKeyExchangeMsg struct {
  1513  	raw        []byte
  1514  	ciphertext []byte
  1515  }
  1516  
  1517  func (m *clientKeyExchangeMsg) marshal() []byte {
  1518  	if m.raw != nil {
  1519  		return m.raw
  1520  	}
  1521  	length := len(m.ciphertext)
  1522  	x := make([]byte, length+4)
  1523  	x[0] = typeClientKeyExchange
  1524  	x[1] = uint8(length >> 16)
  1525  	x[2] = uint8(length >> 8)
  1526  	x[3] = uint8(length)
  1527  	copy(x[4:], m.ciphertext)
  1528  
  1529  	m.raw = x
  1530  	return x
  1531  }
  1532  
  1533  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1534  	m.raw = data
  1535  	if len(data) < 4 {
  1536  		return false
  1537  	}
  1538  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1539  	if l != len(data)-4 {
  1540  		return false
  1541  	}
  1542  	m.ciphertext = data[4:]
  1543  	return true
  1544  }
  1545  
  1546  type finishedMsg struct {
  1547  	raw        []byte
  1548  	verifyData []byte
  1549  }
  1550  
  1551  func (m *finishedMsg) marshal() []byte {
  1552  	if m.raw != nil {
  1553  		return m.raw
  1554  	}
  1555  
  1556  	var b cryptobyte.Builder
  1557  	b.AddUint8(typeFinished)
  1558  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1559  		b.AddBytes(m.verifyData)
  1560  	})
  1561  
  1562  	m.raw = b.BytesOrPanic()
  1563  	return m.raw
  1564  }
  1565  
  1566  func (m *finishedMsg) unmarshal(data []byte) bool {
  1567  	m.raw = data
  1568  	s := cryptobyte.String(data)
  1569  	return s.Skip(1) &&
  1570  		readUint24LengthPrefixed(&s, &m.verifyData) &&
  1571  		s.Empty()
  1572  }
  1573  
  1574  type certificateRequestMsg struct {
  1575  	raw []byte
  1576  	// hasSignatureAlgorithm indicates whether this message includes a list of
  1577  	// supported signature algorithms. This change was introduced with TLS 1.2.
  1578  	hasSignatureAlgorithm bool
  1579  
  1580  	certificateTypes             []byte
  1581  	supportedSignatureAlgorithms []SignatureScheme
  1582  	certificateAuthorities       [][]byte
  1583  }
  1584  
  1585  func (m *certificateRequestMsg) marshal() (x []byte) {
  1586  	if m.raw != nil {
  1587  		return m.raw
  1588  	}
  1589  
  1590  	// See RFC 4346, Section 7.4.4.
  1591  	length := 1 + len(m.certificateTypes) + 2
  1592  	casLength := 0
  1593  	for _, ca := range m.certificateAuthorities {
  1594  		casLength += 2 + len(ca)
  1595  	}
  1596  	length += casLength
  1597  
  1598  	if m.hasSignatureAlgorithm {
  1599  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1600  	}
  1601  
  1602  	x = make([]byte, 4+length)
  1603  	x[0] = typeCertificateRequest
  1604  	x[1] = uint8(length >> 16)
  1605  	x[2] = uint8(length >> 8)
  1606  	x[3] = uint8(length)
  1607  
  1608  	x[4] = uint8(len(m.certificateTypes))
  1609  
  1610  	copy(x[5:], m.certificateTypes)
  1611  	y := x[5+len(m.certificateTypes):]
  1612  
  1613  	if m.hasSignatureAlgorithm {
  1614  		n := len(m.supportedSignatureAlgorithms) * 2
  1615  		y[0] = uint8(n >> 8)
  1616  		y[1] = uint8(n)
  1617  		y = y[2:]
  1618  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1619  			y[0] = uint8(sigAlgo >> 8)
  1620  			y[1] = uint8(sigAlgo)
  1621  			y = y[2:]
  1622  		}
  1623  	}
  1624  
  1625  	y[0] = uint8(casLength >> 8)
  1626  	y[1] = uint8(casLength)
  1627  	y = y[2:]
  1628  	for _, ca := range m.certificateAuthorities {
  1629  		y[0] = uint8(len(ca) >> 8)
  1630  		y[1] = uint8(len(ca))
  1631  		y = y[2:]
  1632  		copy(y, ca)
  1633  		y = y[len(ca):]
  1634  	}
  1635  
  1636  	m.raw = x
  1637  	return
  1638  }
  1639  
  1640  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1641  	m.raw = data
  1642  
  1643  	if len(data) < 5 {
  1644  		return false
  1645  	}
  1646  
  1647  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1648  	if uint32(len(data))-4 != length {
  1649  		return false
  1650  	}
  1651  
  1652  	numCertTypes := int(data[4])
  1653  	data = data[5:]
  1654  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1655  		return false
  1656  	}
  1657  
  1658  	m.certificateTypes = make([]byte, numCertTypes)
  1659  	if copy(m.certificateTypes, data) != numCertTypes {
  1660  		return false
  1661  	}
  1662  
  1663  	data = data[numCertTypes:]
  1664  
  1665  	if m.hasSignatureAlgorithm {
  1666  		if len(data) < 2 {
  1667  			return false
  1668  		}
  1669  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1670  		data = data[2:]
  1671  		if sigAndHashLen&1 != 0 {
  1672  			return false
  1673  		}
  1674  		if len(data) < int(sigAndHashLen) {
  1675  			return false
  1676  		}
  1677  		numSigAlgos := sigAndHashLen / 2
  1678  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1679  		for i := range m.supportedSignatureAlgorithms {
  1680  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1681  			data = data[2:]
  1682  		}
  1683  	}
  1684  
  1685  	if len(data) < 2 {
  1686  		return false
  1687  	}
  1688  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1689  	data = data[2:]
  1690  	if len(data) < int(casLength) {
  1691  		return false
  1692  	}
  1693  	cas := make([]byte, casLength)
  1694  	copy(cas, data)
  1695  	data = data[casLength:]
  1696  
  1697  	m.certificateAuthorities = nil
  1698  	for len(cas) > 0 {
  1699  		if len(cas) < 2 {
  1700  			return false
  1701  		}
  1702  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1703  		cas = cas[2:]
  1704  
  1705  		if len(cas) < int(caLen) {
  1706  			return false
  1707  		}
  1708  
  1709  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1710  		cas = cas[caLen:]
  1711  	}
  1712  
  1713  	return len(data) == 0
  1714  }
  1715  
  1716  type certificateVerifyMsg struct {
  1717  	raw                   []byte
  1718  	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
  1719  	signatureAlgorithm    SignatureScheme
  1720  	signature             []byte
  1721  }
  1722  
  1723  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1724  	if m.raw != nil {
  1725  		return m.raw
  1726  	}
  1727  
  1728  	var b cryptobyte.Builder
  1729  	b.AddUint8(typeCertificateVerify)
  1730  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1731  		if m.hasSignatureAlgorithm {
  1732  			b.AddUint16(uint16(m.signatureAlgorithm))
  1733  		}
  1734  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1735  			b.AddBytes(m.signature)
  1736  		})
  1737  	})
  1738  
  1739  	m.raw = b.BytesOrPanic()
  1740  	return m.raw
  1741  }
  1742  
  1743  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1744  	m.raw = data
  1745  	s := cryptobyte.String(data)
  1746  
  1747  	if !s.Skip(4) { // message type and uint24 length field
  1748  		return false
  1749  	}
  1750  	if m.hasSignatureAlgorithm {
  1751  		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
  1752  			return false
  1753  		}
  1754  	}
  1755  	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
  1756  }
  1757  
  1758  type newSessionTicketMsg struct {
  1759  	raw    []byte
  1760  	ticket []byte
  1761  }
  1762  
  1763  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1764  	if m.raw != nil {
  1765  		return m.raw
  1766  	}
  1767  
  1768  	// See RFC 5077, Section 3.3.
  1769  	ticketLen := len(m.ticket)
  1770  	length := 2 + 4 + ticketLen
  1771  	x = make([]byte, 4+length)
  1772  	x[0] = typeNewSessionTicket
  1773  	x[1] = uint8(length >> 16)
  1774  	x[2] = uint8(length >> 8)
  1775  	x[3] = uint8(length)
  1776  	x[8] = uint8(ticketLen >> 8)
  1777  	x[9] = uint8(ticketLen)
  1778  	copy(x[10:], m.ticket)
  1779  
  1780  	m.raw = x
  1781  
  1782  	return
  1783  }
  1784  
  1785  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1786  	m.raw = data
  1787  
  1788  	if len(data) < 10 {
  1789  		return false
  1790  	}
  1791  
  1792  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1793  	if uint32(len(data))-4 != length {
  1794  		return false
  1795  	}
  1796  
  1797  	ticketLen := int(data[8])<<8 + int(data[9])
  1798  	if len(data)-10 != ticketLen {
  1799  		return false
  1800  	}
  1801  
  1802  	m.ticket = data[10:]
  1803  
  1804  	return true
  1805  }
  1806  
  1807  type helloRequestMsg struct {
  1808  }
  1809  
  1810  func (*helloRequestMsg) marshal() []byte {
  1811  	return []byte{typeHelloRequest, 0, 0, 0}
  1812  }
  1813  
  1814  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1815  	return len(data) == 4
  1816  }