github.com/fumiama/terasu@v0.0.0-20240507144117-547a591149c0/handshake.go (about) 1 //go:build go1.21 2 3 package terasu 4 5 import ( 6 "context" 7 "crypto" 8 "crypto/ecdh" 9 "crypto/tls" 10 "errors" 11 "hash" 12 "unsafe" 13 ) 14 15 //go:linkname defaultConfig crypto/tls.defaultConfig 16 func defaultConfig() *tls.Config 17 18 type clientHelloMsg struct { 19 raw []byte 20 vers uint16 21 random []byte 22 sessionId []byte 23 cipherSuites []uint16 24 compressionMethods []uint8 25 serverName string 26 ocspStapling bool 27 supportedCurves []tls.CurveID 28 supportedPoints []uint8 29 ticketSupported bool 30 sessionTicket []uint8 31 supportedSignatureAlgorithms []tls.SignatureScheme 32 supportedSignatureAlgorithmsCert []tls.SignatureScheme 33 secureRenegotiationSupported bool 34 secureRenegotiation []byte 35 extendedMasterSecret bool 36 alpnProtocols []string 37 scts bool 38 supportedVersions []uint16 39 cookie []byte 40 keyShares []byte 41 earlyData bool 42 } 43 44 //go:linkname marshal crypto/tls.(*clientHelloMsg).marshal 45 func marshal(m *clientHelloMsg) ([]byte, error) 46 47 func (m *clientHelloMsg) marshal() ([]byte, error) { 48 return marshal(m) 49 } 50 51 //go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal 52 func unmarshal(m *clientHelloMsg, data []byte) bool 53 54 func (m *clientHelloMsg) unmarshal(data []byte) bool { 55 return unmarshal(m, data) 56 } 57 58 //go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello 59 func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) 60 61 func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { 62 return makeClientHello(c) 63 } 64 65 // A sessionState is a resumable session. 66 type sessionState struct { 67 // Encoded as a SessionState (in the language of RFC 8446, Section 3). 68 // 69 // enum { server(1), client(2) } SessionStateType; 70 // 71 // opaque Certificate<1..2^24-1>; 72 // 73 // Certificate CertificateChain<0..2^24-1>; 74 // 75 // opaque Extra<0..2^24-1>; 76 // 77 // struct { 78 // uint16 version; 79 // SessionStateType type; 80 // uint16 cipher_suite; 81 // uint64 created_at; 82 // opaque secret<1..2^8-1>; 83 // Extra extra<0..2^24-1>; 84 // uint8 ext_master_secret = { 0, 1 }; 85 // uint8 early_data = { 0, 1 }; 86 // CertificateEntry certificate_list<0..2^24-1>; 87 // CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */ 88 // select (SessionState.early_data) { 89 // case 0: Empty; 90 // case 1: opaque alpn<1..2^8-1>; 91 // }; 92 // select (SessionState.type) { 93 // case server: Empty; 94 // case client: struct { 95 // select (SessionState.version) { 96 // case VersionTLS10..VersionTLS12: Empty; 97 // case VersionTLS13: struct { 98 // uint64 use_by; 99 // uint32 age_add; 100 // }; 101 // }; 102 // }; 103 // }; 104 // } SessionState; 105 // 106 107 // Extra is ignored by crypto/tls, but is encoded by [SessionState.Bytes] 108 // and parsed by [ParseSessionState]. 109 // 110 // This allows [Config.UnwrapSession]/[Config.WrapSession] and 111 // [ClientSessionCache] implementations to store and retrieve additional 112 // data alongside this session. 113 // 114 // To allow different layers in a protocol stack to share this field, 115 // applications must only append to it, not replace it, and must use entries 116 // that can be recognized even if out of order (for example, by starting 117 // with an id and version prefix). 118 Extra [][]byte 119 120 // EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC 121 // connection. The application may set this to false if it is true to 122 // decline to offer 0-RTT even if supported. 123 EarlyData bool 124 125 version uint16 126 isClient bool 127 cipherSuite uint16 128 } 129 130 //go:linkname loadSession crypto/tls.(*Conn).loadSession 131 func loadSession(c *_trsconn, hello *clientHelloMsg) ( 132 session *sessionState, earlySecret, binderKey []byte, err error, 133 ) 134 135 func (c *_trsconn) loadSession(hello *clientHelloMsg) ( 136 session *sessionState, earlySecret, binderKey []byte, err error, 137 ) { 138 return loadSession(c, hello) 139 } 140 141 //go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey 142 func clientSessionCacheKey(c *_trsconn) string 143 144 func (c *_trsconn) clientSessionCacheKey() string { 145 return clientSessionCacheKey(c) 146 } 147 148 // A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash 149 // algorithm to be used with HKDF. See RFC 8446, Appendix B.4. 150 type cipherSuiteTLS13 struct { 151 id uint16 152 keyLen int 153 aead func(key, fixedNonce []byte) any 154 hash crypto.Hash 155 } 156 157 //go:linkname deriveSecret crypto/tls.(*cipherSuiteTLS13).deriveSecret 158 func deriveSecret(c *cipherSuiteTLS13, secret []byte, label string, transcript hash.Hash) []byte 159 160 func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { 161 return deriveSecret(c, secret, label, transcript) 162 } 163 164 //go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID 165 func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 166 167 type handshakeMessage interface { 168 marshal() ([]byte, error) 169 unmarshal([]byte) bool 170 } 171 172 type transcriptHash interface { 173 Write([]byte) (int, error) 174 } 175 176 //go:linkname transcriptMsg crypto/tls.transcriptMsg 177 func transcriptMsg(msg handshakeMessage, h transcriptHash) error 178 179 const clientEarlyTrafficLabel = "c e traffic" 180 181 //go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret 182 func quicSetWriteSecret(c *_trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte) 183 184 //go:linkname readHandshake crypto/tls.(*Conn).readHandshake 185 func readHandshake(c *_trsconn, transcript transcriptHash) (any, error) 186 187 func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) { 188 return readHandshake(c, transcript) 189 } 190 191 type serverHelloMsg struct { 192 raw []byte 193 vers uint16 194 random []byte 195 } 196 197 //go:linkname sendAlert crypto/tls.(*Conn).sendAlert 198 func sendAlert(c *_trsconn, err alert) error 199 200 func (c *_trsconn) sendAlert(err alert) error { 201 return sendAlert(c, err) 202 } 203 204 //go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError 205 func unexpectedMessageError(wanted, got any) error 206 207 const ( 208 alertUnexpectedMessage alert = 10 209 alertIllegalParameter alert = 47 210 ) 211 212 //go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion 213 func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error 214 215 func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { 216 return pickTLSVersion(c, serverHello) 217 } 218 219 //go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion 220 func maxSupportedVersion(c *tls.Config, isClient bool) uint16 221 222 const roleClient = true 223 224 const ( 225 // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server 226 // random as a downgrade protection if the server would be capable of 227 // negotiating a higher version. See RFC 8446, Section 4.1.3. 228 downgradeCanaryTLS12 = "DOWNGRD\x01" 229 downgradeCanaryTLS11 = "DOWNGRD\x00" 230 ) 231 232 type clientHandshakeStateTLS13 struct { 233 c *Conn 234 ctx context.Context 235 serverHello *serverHelloMsg 236 hello *clientHelloMsg 237 ecdheKey *ecdh.PrivateKey 238 239 session *sessionState 240 earlySecret []byte 241 binderKey []byte 242 243 certReq *uintptr 244 usingPSK bool 245 sentDummyCCS bool 246 suite *cipherSuiteTLS13 247 transcript hash.Hash 248 masterSecret []byte 249 trafficSecret []byte // client_application_traffic_secret_0 250 } 251 252 //go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake 253 func handshake13(hs *clientHandshakeStateTLS13) error 254 255 func (hs *clientHandshakeStateTLS13) handshake() error { 256 return handshake13(hs) 257 } 258 259 // A finishedHash calculates the hash of a set of handshake messages suitable 260 // for including in a Finished message. 261 type finishedHash struct { 262 client hash.Hash 263 server hash.Hash 264 265 // Prior to TLS 1.2, an additional MD5 hash is required. 266 clientMD5 hash.Hash 267 serverMD5 hash.Hash 268 269 // In TLS 1.2, a full buffer is sadly required. 270 buffer []byte 271 272 version uint16 273 prf func(result, secret, label, seed []byte) 274 } 275 276 type clientHandshakeState struct { 277 c *Conn 278 ctx context.Context 279 serverHello *serverHelloMsg 280 hello *clientHelloMsg 281 suite *uintptr 282 finishedHash finishedHash 283 masterSecret []byte 284 session *sessionState // the session being resumed 285 ticket []byte // a fresh ticket received during this handshake 286 } 287 288 //go:linkname handshake crypto/tls.(*clientHandshakeState).handshake 289 func handshake(hs *clientHandshakeState) error 290 291 func (hs *clientHandshakeState) handshake() error { 292 return handshake(hs) 293 } 294 295 // writeHandshakeRecord writes a handshake message to the connection and updates 296 // the record layer state. If transcript is non-nil the marshalled message is 297 // written to it. 298 func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) { 299 c.out.Lock() 300 defer c.out.Unlock() 301 302 data, err := msg.marshal() 303 if err != nil { 304 return 0, err 305 } 306 if transcript != nil { 307 transcript.Write(data) 308 } 309 310 return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data) 311 } 312 313 func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error { 314 return func(ctx context.Context) (err error) { 315 c := (*_trsconn)(unsafe.Pointer(cout)) 316 317 if c.config == nil { 318 c.config = defaultConfig() 319 } 320 321 // This may be a renegotiation handshake, in which case some fields 322 // need to be reset. 323 c.didResume = false 324 325 hello, ecdheKey, err := c.makeClientHello() 326 if err != nil { 327 return err 328 } 329 c.serverName = hello.serverName 330 331 session, earlySecret, binderKey, err := c.loadSession(hello) 332 if err != nil { 333 return err 334 } 335 if session != nil { 336 defer func() { 337 // If we got a handshake failure when resuming a session, throw away 338 // the session ticket. See RFC 5077, Section 3.2. 339 // 340 // RFC 8446 makes no mention of dropping tickets on failure, but it 341 // does require servers to abort on invalid binders, so we need to 342 // delete tickets to recover from a corrupted PSK. 343 if err != nil { 344 if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { 345 c.config.ClientSessionCache.Put(cacheKey, nil) 346 } 347 } 348 }() 349 } 350 351 if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil { 352 return err 353 } 354 355 if hello.earlyData { 356 suite := cipherSuiteTLS13ByID(session.cipherSuite) 357 transcript := suite.hash.New() 358 if err := transcriptMsg(hello, transcript); err != nil { 359 return err 360 } 361 earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript) 362 quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) 363 } 364 365 // serverHelloMsg is not included in the transcript 366 msg, err := c.readHandshake(nil) 367 if err != nil { 368 return err 369 } 370 371 var serverHello *serverHelloMsg 372 if !isTypeEqual(msg, "*tls.serverHelloMsg") { 373 c.sendAlert(alertUnexpectedMessage) 374 return unexpectedMessageError(serverHello, msg) 375 } 376 serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( 377 unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), 378 )) 379 380 if err := c.pickTLSVersion(serverHello); err != nil { 381 return err 382 } 383 384 // If we are negotiating a protocol version that's lower than what we 385 // support, check for the server downgrade canaries. 386 // See RFC 8446, Section 4.1.3. 387 maxVers := maxSupportedVersion(c.config, roleClient) 388 tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 389 tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 390 if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || 391 maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { 392 c.sendAlert(alertIllegalParameter) 393 return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") 394 } 395 396 if c.vers == tls.VersionTLS13 { 397 hs := &clientHandshakeStateTLS13{ 398 c: cout, 399 ctx: ctx, 400 serverHello: serverHello, 401 hello: hello, 402 ecdheKey: ecdheKey, 403 session: session, 404 earlySecret: earlySecret, 405 binderKey: binderKey, 406 } 407 408 // In TLS 1.3, session tickets are delivered after the handshake. 409 return hs.handshake() 410 } 411 412 hs := &clientHandshakeState{ 413 c: cout, 414 ctx: ctx, 415 serverHello: serverHello, 416 hello: hello, 417 session: session, 418 } 419 420 if err := hs.handshake(); err != nil { 421 return err 422 } 423 424 return nil 425 } 426 }