github.com/Carcraftz/utls@v0.0.0-20220413235215-6b7c52fd78b6/u_conn.go (about) 1 // Copyright 2017 Google Inc. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bufio" 9 "bytes" 10 "crypto/cipher" 11 "encoding/binary" 12 "errors" 13 "fmt" 14 "io" 15 "net" 16 "strconv" 17 "sync/atomic" 18 ) 19 20 type UConn struct { 21 *Conn 22 23 Extensions []TLSExtension 24 ClientHelloID ClientHelloID 25 26 ClientHelloBuilt bool 27 HandshakeState ClientHandshakeState 28 29 // sessionID may or may not depend on ticket; nil => random 30 GetSessionID func(ticket []byte) [32]byte 31 32 greaseSeed [ssl_grease_last_index]uint16 33 34 extCompressCerts bool 35 } 36 37 // UClient returns a new uTLS client, with behavior depending on clientHelloID. 38 // Config CAN be nil, but make sure to eventually specify ServerName. 39 func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn { 40 if config == nil { 41 config = &Config{} 42 } 43 tlsConn := Conn{conn: conn, config: config, isClient: true} 44 handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}} 45 uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState} 46 uconn.HandshakeState.uconn = &uconn 47 return &uconn 48 } 49 50 // BuildHandshakeState behavior varies based on ClientHelloID and 51 // whether it was already called before. 52 // If HelloGolang: 53 // [only once] make default ClientHello and overwrite existing state 54 // If any other mimicking ClientHelloID is used: 55 // [only once] make ClientHello based on ID and overwrite existing state 56 // [each call] apply uconn.Extensions config to internal crypto/tls structures 57 // [each call] marshal ClientHello. 58 // 59 // BuildHandshakeState is automatically called before uTLS performs handshake, 60 // amd should only be called explicitly to inspect/change fields of 61 // default/mimicked ClientHello. 62 func (uconn *UConn) BuildHandshakeState() error { 63 if uconn.ClientHelloID == HelloGolang { 64 if uconn.ClientHelloBuilt { 65 return nil 66 } 67 68 // use default Golang ClientHello. 69 hello, ecdheParams, err := uconn.makeClientHello() 70 if err != nil { 71 return err 72 } 73 74 uconn.HandshakeState.Hello = hello.getPublicPtr() 75 uconn.HandshakeState.State13.EcdheParams = ecdheParamMapToPublic(ecdheParams) 76 uconn.HandshakeState.C = uconn.Conn 77 } else { 78 if !uconn.ClientHelloBuilt { 79 err := uconn.applyPresetByID(uconn.ClientHelloID) 80 if err != nil { 81 return err 82 } 83 } 84 85 err := uconn.ApplyConfig() 86 if err != nil { 87 return err 88 } 89 err = uconn.MarshalClientHello() 90 if err != nil { 91 return err 92 } 93 } 94 uconn.ClientHelloBuilt = true 95 return nil 96 } 97 98 // SetSessionState sets the session ticket, which may be preshared or fake. 99 // If session is nil, the body of session ticket extension will be unset, 100 // but the extension itself still MAY be present for mimicking purposes. 101 // Session tickets to be reused - use same cache on following connections. 102 func (uconn *UConn) SetSessionState(session *ClientSessionState) error { 103 uconn.HandshakeState.Session = session 104 var sessionTicket []uint8 105 if session != nil { 106 sessionTicket = session.sessionTicket 107 } 108 uconn.HandshakeState.Hello.TicketSupported = true 109 uconn.HandshakeState.Hello.SessionTicket = sessionTicket 110 111 for _, ext := range uconn.Extensions { 112 st, ok := ext.(*SessionTicketExtension) 113 if !ok { 114 continue 115 } 116 st.Session = session 117 if session != nil { 118 if len(session.SessionTicket()) > 0 { 119 if uconn.GetSessionID != nil { 120 sid := uconn.GetSessionID(session.SessionTicket()) 121 uconn.HandshakeState.Hello.SessionId = sid[:] 122 return nil 123 } 124 } 125 var sessionID [32]byte 126 _, err := io.ReadFull(uconn.config.rand(), sessionID[:]) 127 if err != nil { 128 return err 129 } 130 uconn.HandshakeState.Hello.SessionId = sessionID[:] 131 } 132 return nil 133 } 134 return nil 135 } 136 137 // If you want session tickets to be reused - use same cache on following connections 138 func (uconn *UConn) SetSessionCache(cache ClientSessionCache) { 139 uconn.config.ClientSessionCache = cache 140 uconn.HandshakeState.Hello.TicketSupported = true 141 } 142 143 // SetClientRandom sets client random explicitly. 144 // BuildHandshakeFirst() must be called before SetClientRandom. 145 // r must to be 32 bytes long. 146 func (uconn *UConn) SetClientRandom(r []byte) error { 147 if len(r) != 32 { 148 return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r))) 149 } else { 150 uconn.HandshakeState.Hello.Random = make([]byte, 32) 151 copy(uconn.HandshakeState.Hello.Random, r) 152 return nil 153 } 154 } 155 156 func (uconn *UConn) SetSNI(sni string) { 157 hname := hostnameInSNI(sni) 158 uconn.config.ServerName = hname 159 for _, ext := range uconn.Extensions { 160 sniExt, ok := ext.(*SNIExtension) 161 if ok { 162 sniExt.ServerName = hname 163 } 164 } 165 } 166 167 // Handshake runs the client handshake using given clientHandshakeState 168 // Requires hs.hello, and, optionally, hs.session to be set. 169 func (c *UConn) Handshake() error { 170 c.handshakeMutex.Lock() 171 defer c.handshakeMutex.Unlock() 172 173 if err := c.handshakeErr; err != nil { 174 return err 175 } 176 if c.handshakeComplete() { 177 return nil 178 } 179 180 c.in.Lock() 181 defer c.in.Unlock() 182 183 if c.isClient { 184 // [uTLS section begins] 185 err := c.BuildHandshakeState() 186 if err != nil { 187 return err 188 } 189 // [uTLS section ends] 190 191 c.handshakeErr = c.clientHandshake() 192 } else { 193 c.handshakeErr = c.serverHandshake() 194 } 195 if c.handshakeErr == nil { 196 c.handshakes++ 197 } else { 198 // If an error occurred during the hadshake try to flush the 199 // alert that might be left in the buffer. 200 c.flush() 201 } 202 203 if c.handshakeErr == nil && !c.handshakeComplete() { 204 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") 205 } 206 207 return c.handshakeErr 208 } 209 210 // Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls. 211 // Write writes data to the connection. 212 func (c *UConn) Write(b []byte) (int, error) { 213 // interlock with Close below 214 for { 215 x := atomic.LoadInt32(&c.activeCall) 216 if x&1 != 0 { 217 return 0, errClosed 218 } 219 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { 220 defer atomic.AddInt32(&c.activeCall, -2) 221 break 222 } 223 } 224 225 if err := c.Handshake(); err != nil { 226 return 0, err 227 } 228 229 c.out.Lock() 230 defer c.out.Unlock() 231 232 if err := c.out.err; err != nil { 233 return 0, err 234 } 235 236 if !c.handshakeComplete() { 237 return 0, alertInternalError 238 } 239 240 if c.closeNotifySent { 241 return 0, errShutdown 242 } 243 244 // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext 245 // attack when using block mode ciphers due to predictable IVs. 246 // This can be prevented by splitting each Application Data 247 // record into two records, effectively randomizing the IV. 248 // 249 // https://www.openssl.org/~bodo/tls-cbc.txt 250 // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 251 // https://www.imperialviolet.org/2012/01/15/beastfollowup.html 252 253 var m int 254 if len(b) > 1 && c.vers <= VersionTLS10 { 255 if _, ok := c.out.cipher.(cipher.BlockMode); ok { 256 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) 257 if err != nil { 258 return n, c.out.setErrorLocked(err) 259 } 260 m, b = 1, b[1:] 261 } 262 } 263 264 n, err := c.writeRecordLocked(recordTypeApplicationData, b) 265 return n + m, c.out.setErrorLocked(err) 266 } 267 268 // clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3) 269 // and performs client TLS handshake with that state 270 func (c *UConn) clientHandshake() (err error) { 271 // [uTLS section begins] 272 hello := c.HandshakeState.Hello.getPrivatePtr() 273 defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }() 274 275 sessionIsAlreadySet := c.HandshakeState.Session != nil 276 277 // after this point exactly 1 out of 2 HandshakeState pointers is non-nil, 278 // useTLS13 variable tells which pointer 279 // [uTLS section ends] 280 281 if c.config == nil { 282 c.config = defaultConfig() 283 } 284 285 // This may be a renegotiation handshake, in which case some fields 286 // need to be reset. 287 c.didResume = false 288 289 // [uTLS section begins] 290 // don't make new ClientHello, use hs.hello 291 // preserve the checks from beginning and end of makeClientHello() 292 if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { 293 return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") 294 } 295 296 nextProtosLength := 0 297 for _, proto := range c.config.NextProtos { 298 if l := len(proto); l == 0 || l > 255 { 299 return errors.New("tls: invalid NextProtos value") 300 } else { 301 nextProtosLength += 1 + l 302 } 303 } 304 305 if nextProtosLength > 0xffff { 306 return errors.New("tls: NextProtos values too large") 307 } 308 309 if c.handshakes > 0 { 310 hello.secureRenegotiation = c.clientFinished[:] 311 } 312 // [uTLS section ends] 313 314 cacheKey, session, earlySecret, binderKey := c.loadSession(hello) 315 if cacheKey != "" && session != nil { 316 defer func() { 317 // If we got a handshake failure when resuming a session, throw away 318 // the session ticket. See RFC 5077, Section 3.2. 319 // 320 // RFC 8446 makes no mention of dropping tickets on failure, but it 321 // does require servers to abort on invalid binders, so we need to 322 // delete tickets to recover from a corrupted PSK. 323 if err != nil { 324 c.config.ClientSessionCache.Put(cacheKey, nil) 325 } 326 }() 327 } 328 329 if !sessionIsAlreadySet { // uTLS: do not overwrite already set session 330 err = c.SetSessionState(session) 331 if err != nil { 332 return 333 } 334 } 335 336 if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { 337 return err 338 } 339 340 msg, err := c.readHandshake() 341 if err != nil { 342 return err 343 } 344 345 serverHello, ok := msg.(*serverHelloMsg) 346 if !ok { 347 c.sendAlert(alertUnexpectedMessage) 348 return unexpectedMessageError(serverHello, msg) 349 } 350 351 if err := c.pickTLSVersion(serverHello); err != nil { 352 return err 353 } 354 355 // uTLS: do not create new handshakeState, use existing one 356 if c.vers == VersionTLS13 { 357 hs13 := c.HandshakeState.toPrivate13() 358 hs13.serverHello = serverHello 359 hs13.hello = hello 360 if !sessionIsAlreadySet { 361 hs13.earlySecret = earlySecret 362 hs13.binderKey = binderKey 363 } 364 // In TLS 1.3, session tickets are delivered after the handshake. 365 err = hs13.handshake() 366 if handshakeState := hs13.toPublic13(); handshakeState != nil { 367 c.HandshakeState = *handshakeState 368 } 369 return err 370 } 371 372 hs12 := c.HandshakeState.toPrivate12() 373 hs12.serverHello = serverHello 374 hs12.hello = hello 375 err = hs12.handshake() 376 if handshakeState := hs12.toPublic12(); handshakeState != nil { 377 c.HandshakeState = *handshakeState 378 } 379 if err != nil { 380 return err 381 } 382 383 // If we had a successful handshake and hs.session is different from 384 // the one already cached - cache a new one. 385 if cacheKey != "" && hs12.session != nil && session != hs12.session { 386 c.config.ClientSessionCache.Put(cacheKey, hs12.session) 387 } 388 return nil 389 } 390 391 func (uconn *UConn) ApplyConfig() error { 392 for _, ext := range uconn.Extensions { 393 err := ext.writeToUConn(uconn) 394 if err != nil { 395 return err 396 } 397 } 398 return nil 399 } 400 401 func (uconn *UConn) MarshalClientHello() error { 402 hello := uconn.HandshakeState.Hello 403 headerLength := 2 + 32 + 1 + len(hello.SessionId) + 404 2 + len(hello.CipherSuites)*2 + 405 1 + len(hello.CompressionMethods) 406 407 extensionsLen := 0 408 var paddingExt *UtlsPaddingExtension 409 for _, ext := range uconn.Extensions { 410 if pe, ok := ext.(*UtlsPaddingExtension); !ok { 411 // If not padding - just add length of extension to total length 412 extensionsLen += ext.Len() 413 } else { 414 // If padding - process it later 415 if paddingExt == nil { 416 paddingExt = pe 417 } else { 418 return errors.New("Multiple padding extensions!") 419 } 420 } 421 } 422 423 if paddingExt != nil { 424 // determine padding extension presence and length 425 paddingExt.Update(headerLength + 4 + extensionsLen + 2) 426 extensionsLen += paddingExt.Len() 427 } 428 429 helloLen := headerLength 430 if len(uconn.Extensions) > 0 { 431 helloLen += 2 + extensionsLen // 2 bytes for extensions' length 432 } 433 434 helloBuffer := bytes.Buffer{} 435 bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length 436 // We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens 437 // Write() will become noop, and error will be accessible via Flush(), which is called once in the end 438 439 binary.Write(bufferedWriter, binary.BigEndian, typeClientHello) 440 helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24 441 binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes) 442 binary.Write(bufferedWriter, binary.BigEndian, hello.Vers) 443 444 binary.Write(bufferedWriter, binary.BigEndian, hello.Random) 445 446 binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId))) 447 binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId) 448 449 binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1)) 450 for _, suite := range hello.CipherSuites { 451 binary.Write(bufferedWriter, binary.BigEndian, suite) 452 } 453 454 binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods))) 455 binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods) 456 457 if len(uconn.Extensions) > 0 { 458 binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen)) 459 for _, ext := range uconn.Extensions { 460 bufferedWriter.ReadFrom(ext) 461 } 462 } 463 464 err := bufferedWriter.Flush() 465 if err != nil { 466 return err 467 } 468 469 if helloBuffer.Len() != 4+helloLen { 470 return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) + 471 ". Got: " + strconv.Itoa(helloBuffer.Len())) 472 } 473 474 hello.Raw = helloBuffer.Bytes() 475 return nil 476 } 477 478 // get current state of cipher and encrypt zeros to get keystream 479 func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) { 480 zeros := make([]byte, length) 481 482 if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok { 483 // AEAD.Seal() does not mutate internal state, other ciphers might 484 return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil 485 } 486 return nil, errors.New("Could not convert OutCipher to cipher.AEAD") 487 } 488 489 // SetTLSVers sets min and max TLS version in all appropriate places. 490 // Function will use first non-zero version parsed in following order: 491 // 1) Provided minTLSVers, maxTLSVers 492 // 2) specExtensions may have SupportedVersionsExtension 493 // 3) [default] min = TLS 1.0, max = TLS 1.2 494 // 495 // Error is only returned if things are in clearly undesirable state 496 // to help user fix them. 497 func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error { 498 if minTLSVers == 0 && maxTLSVers == 0 { 499 // if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension 500 supportedVersionsExtensionsPresent := 0 501 for _, e := range specExtensions { 502 switch ext := e.(type) { 503 case *SupportedVersionsExtension: 504 findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) { 505 // returns (minVers, maxVers) 506 minVers := uint16(0) 507 maxVers := uint16(0) 508 for _, vers := range versions { 509 if vers == GREASE_PLACEHOLDER { 510 continue 511 } 512 if maxVers < vers || maxVers == 0 { 513 maxVers = vers 514 } 515 if minVers > vers || minVers == 0 { 516 minVers = vers 517 } 518 } 519 return minVers, maxVers 520 } 521 522 supportedVersionsExtensionsPresent += 1 523 minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions) 524 if minTLSVers == 0 && maxTLSVers == 0 { 525 return fmt.Errorf("SupportedVersions extension has invalid Versions field") 526 } // else: proceed 527 } 528 } 529 switch supportedVersionsExtensionsPresent { 530 case 0: 531 // if mandatory for TLS 1.3 extension is not present, just default to 1.2 532 minTLSVers = VersionTLS10 533 maxTLSVers = VersionTLS12 534 case 1: 535 default: 536 return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions", 537 supportedVersionsExtensionsPresent) 538 } 539 } 540 541 if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 { 542 return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers) 543 } 544 545 if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 { 546 return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers) 547 } 548 549 uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers) 550 uconn.config.MinVersion = minTLSVers 551 uconn.config.MaxVersion = maxTLSVers 552 553 return nil 554 } 555 556 func (uconn *UConn) SetUnderlyingConn(c net.Conn) { 557 uconn.Conn.conn = c 558 } 559 560 func (uconn *UConn) GetUnderlyingConn() net.Conn { 561 return uconn.Conn.conn 562 } 563 564 // MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections. 565 // Major Hack Alert. 566 func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn { 567 tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient} 568 cs := cipherSuiteByID(cipherSuite) 569 if cs == nil { 570 return nil 571 } 572 573 // This is mostly borrowed from establishKeys() 574 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := 575 keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom, 576 cs.macLen, cs.keyLen, cs.ivLen) 577 578 var clientCipher, serverCipher interface{} 579 var clientHash, serverHash macFunction 580 if cs.cipher != nil { 581 clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */) 582 clientHash = cs.mac(version, clientMAC) 583 serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */) 584 serverHash = cs.mac(version, serverMAC) 585 } else { 586 clientCipher = cs.aead(clientKey, clientIV) 587 serverCipher = cs.aead(serverKey, serverIV) 588 } 589 590 if isClient { 591 tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash) 592 tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash) 593 } else { 594 tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash) 595 tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash) 596 } 597 598 // skip the handshake states 599 tlsConn.handshakeStatus = 1 600 tlsConn.cipherSuite = cipherSuite 601 tlsConn.haveVers = true 602 tlsConn.vers = version 603 604 // Update to the new cipher specs 605 // and consume the finished messages 606 tlsConn.in.changeCipherSpec() 607 tlsConn.out.changeCipherSpec() 608 609 tlsConn.in.incSeq() 610 tlsConn.out.incSeq() 611 612 return tlsConn 613 } 614 615 func makeSupportedVersions(minVers, maxVers uint16) []uint16 { 616 a := make([]uint16, maxVers-minVers+1) 617 for i := range a { 618 a[i] = maxVers - uint16(i) 619 } 620 return a 621 }