github.com/3andne/restls-client-go@v0.1.6/u_conn.go (about)

     1  // Copyright 2017 Google Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"context"
    11  	"crypto/cipher"
    12  	"encoding/binary"
    13  	"errors"
    14  	"fmt"
    15  	"hash"
    16  	"io"
    17  	"net"
    18  	"strconv"
    19  )
    20  
    21  type UConn struct {
    22  	*Conn
    23  
    24  	Extensions    []TLSExtension
    25  	ClientHelloID ClientHelloID
    26  
    27  	ClientHelloBuilt bool
    28  	HandshakeState   PubClientHandshakeState
    29  
    30  	// sessionID may or may not depend on ticket; nil => random
    31  	GetSessionID func(ticket []byte) [32]byte
    32  
    33  	greaseSeed [ssl_grease_last_index]uint16
    34  
    35  	omitSNIExtension bool
    36  
    37  	// certCompressionAlgs represents the set of advertised certificate compression
    38  	// algorithms, as specified in the ClientHello. This is only relevant client-side, for the
    39  	// server certificate. All other forms of certificate compression are unsupported.
    40  	certCompressionAlgs []CertCompressionAlgo
    41  }
    42  
    43  // UClient returns a new uTLS client, with behavior depending on clientHelloID.
    44  // Config CAN be nil, but make sure to eventually specify ServerName.
    45  func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
    46  	if config == nil {
    47  		config = &Config{}
    48  	}
    49  	tlsConn := Conn{conn: conn, config: config, isClient: true}
    50  	handshakeState := PubClientHandshakeState{C: &tlsConn, Hello: &PubClientHelloMsg{}}
    51  	uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
    52  	uconn.HandshakeState.uconn = &uconn
    53  	uconn.handshakeFn = uconn.clientHandshake
    54  	initRestlsPlugin(&uconn.in.restlsPlugin, &uconn.out.restlsPlugin) // #Restls#
    55  	return &uconn
    56  }
    57  
    58  // BuildHandshakeState behavior varies based on ClientHelloID and
    59  // whether it was already called before.
    60  // If HelloGolang:
    61  //
    62  //	[only once] make default ClientHello and overwrite existing state
    63  //
    64  // If any other mimicking ClientHelloID is used:
    65  //
    66  //	[only once] make ClientHello based on ID and overwrite existing state
    67  //	[each call] apply uconn.Extensions config to internal crypto/tls structures
    68  //	[each call] marshal ClientHello.
    69  //
    70  // BuildHandshakeState is automatically called before uTLS performs handshake,
    71  // amd should only be called explicitly to inspect/change fields of
    72  // default/mimicked ClientHello.
    73  func (uconn *UConn) BuildHandshakeState() error {
    74  	if uconn.ClientHelloID == HelloGolang {
    75  		panic("Golang ClientHello is disabled") // #Restls#
    76  		if uconn.ClientHelloBuilt {
    77  			return nil
    78  		}
    79  
    80  		// use default Golang ClientHello.
    81  		hello, ecdheKey, err := uconn.makeClientHello()
    82  		if err != nil {
    83  			return err
    84  		}
    85  
    86  		uconn.HandshakeState.Hello = hello.getPublicPtr()
    87  		uconn.HandshakeState.State13.EcdheKey = ecdheKey
    88  		uconn.HandshakeState.C = uconn.Conn
    89  	} else {
    90  		if !uconn.ClientHelloBuilt {
    91  			err := uconn.applyPresetByID(uconn.ClientHelloID)
    92  			if err != nil {
    93  				return err
    94  			}
    95  			if uconn.omitSNIExtension {
    96  				uconn.removeSNIExtension()
    97  			}
    98  		}
    99  
   100  		err := uconn.ApplyConfig()
   101  		if err != nil {
   102  			return err
   103  		}
   104  		err = uconn.MarshalClientHello()
   105  		if err != nil {
   106  			return err
   107  		}
   108  	}
   109  	uconn.ClientHelloBuilt = true
   110  	return nil
   111  }
   112  
   113  // SetSessionState sets the session ticket, which may be preshared or fake.
   114  // If session is nil, the body of session ticket extension will be unset,
   115  // but the extension itself still MAY be present for mimicking purposes.
   116  // Session tickets to be reused - use same cache on following connections.
   117  func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
   118  	var sessionTicket []uint8
   119  	if session != nil {
   120  		sessionTicket = session.ticket
   121  		uconn.HandshakeState.Session = session.session
   122  	}
   123  	uconn.HandshakeState.Hello.TicketSupported = true
   124  	uconn.HandshakeState.Hello.SessionTicket = sessionTicket
   125  
   126  	for _, ext := range uconn.Extensions {
   127  		st, ok := ext.(*SessionTicketExtension)
   128  		if !ok {
   129  			continue
   130  		}
   131  		st.Session = session
   132  		if session != nil {
   133  			if len(session.SessionTicket()) > 0 {
   134  				if uconn.GetSessionID != nil {
   135  					sid := uconn.GetSessionID(session.SessionTicket())
   136  					uconn.HandshakeState.Hello.SessionId = sid[:]
   137  					return nil
   138  				}
   139  			}
   140  			var sessionID [32]byte
   141  			_, err := io.ReadFull(uconn.config.rand(), sessionID[:])
   142  			if err != nil {
   143  				return err
   144  			}
   145  			uconn.HandshakeState.Hello.SessionId = sessionID[:]
   146  		}
   147  		return nil
   148  	}
   149  	return nil
   150  }
   151  
   152  // If you want session tickets to be reused - use same cache on following connections
   153  func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
   154  	uconn.config.ClientSessionCache = cache
   155  	uconn.HandshakeState.Hello.TicketSupported = true
   156  }
   157  
   158  // SetClientRandom sets client random explicitly.
   159  // BuildHandshakeFirst() must be called before SetClientRandom.
   160  // r must to be 32 bytes long.
   161  func (uconn *UConn) SetClientRandom(r []byte) error {
   162  	if len(r) != 32 {
   163  		return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
   164  	} else {
   165  		uconn.HandshakeState.Hello.Random = make([]byte, 32)
   166  		copy(uconn.HandshakeState.Hello.Random, r)
   167  		return nil
   168  	}
   169  }
   170  
   171  func (uconn *UConn) SetSNI(sni string) {
   172  	hname := hostnameInSNI(sni)
   173  	uconn.config.ServerName = hname
   174  	for _, ext := range uconn.Extensions {
   175  		sniExt, ok := ext.(*SNIExtension)
   176  		if ok {
   177  			sniExt.ServerName = hname
   178  		}
   179  	}
   180  }
   181  
   182  // RemoveSNIExtension removes SNI from the list of extensions sent in ClientHello
   183  // It returns an error when used with HelloGolang ClientHelloID
   184  func (uconn *UConn) RemoveSNIExtension() error {
   185  	if uconn.ClientHelloID == HelloGolang {
   186  		return fmt.Errorf("cannot call RemoveSNIExtension on a UConn with a HelloGolang ClientHelloID")
   187  	}
   188  	uconn.omitSNIExtension = true
   189  	return nil
   190  }
   191  
   192  func (uconn *UConn) removeSNIExtension() {
   193  	filteredExts := make([]TLSExtension, 0, len(uconn.Extensions))
   194  	for _, e := range uconn.Extensions {
   195  		if _, ok := e.(*SNIExtension); !ok {
   196  			filteredExts = append(filteredExts, e)
   197  		}
   198  	}
   199  	uconn.Extensions = filteredExts
   200  }
   201  
   202  // Handshake runs the client handshake using given clientHandshakeState
   203  // Requires hs.hello, and, optionally, hs.session to be set.
   204  func (c *UConn) Handshake() error {
   205  	return c.HandshakeContext(context.Background())
   206  }
   207  
   208  // HandshakeContext runs the client or server handshake
   209  // protocol if it has not yet been run.
   210  //
   211  // The provided Context must be non-nil. If the context is canceled before
   212  // the handshake is complete, the handshake is interrupted and an error is returned.
   213  // Once the handshake has completed, cancellation of the context will not affect the
   214  // connection.
   215  func (c *UConn) HandshakeContext(ctx context.Context) error {
   216  	// Delegate to unexported method for named return
   217  	// without confusing documented signature.
   218  	return c.handshakeContext(ctx)
   219  }
   220  
   221  func (c *UConn) handshakeContext(ctx context.Context) (ret error) {
   222  	// Fast sync/atomic-based exit if there is no handshake in flight and the
   223  	// last one succeeded without an error. Avoids the expensive context setup
   224  	// and mutex for most Read and Write calls.
   225  	if c.isHandshakeComplete.Load() {
   226  		return nil
   227  	}
   228  
   229  	handshakeCtx, cancel := context.WithCancel(ctx)
   230  	// Note: defer this before starting the "interrupter" goroutine
   231  	// so that we can tell the difference between the input being canceled and
   232  	// this cancellation. In the former case, we need to close the connection.
   233  	defer cancel()
   234  
   235  	// Start the "interrupter" goroutine, if this context might be canceled.
   236  	// (The background context cannot).
   237  	//
   238  	// The interrupter goroutine waits for the input context to be done and
   239  	// closes the connection if this happens before the function returns.
   240  	if c.quic != nil {
   241  		c.quic.cancelc = handshakeCtx.Done()
   242  		c.quic.cancel = cancel
   243  	} else if ctx.Done() != nil {
   244  		done := make(chan struct{})
   245  		interruptRes := make(chan error, 1)
   246  		defer func() {
   247  			close(done)
   248  			if ctxErr := <-interruptRes; ctxErr != nil {
   249  				// Return context error to user.
   250  				ret = ctxErr
   251  			}
   252  		}()
   253  		go func() {
   254  			select {
   255  			case <-handshakeCtx.Done():
   256  				// Close the connection, discarding the error
   257  				_ = c.conn.Close()
   258  				interruptRes <- handshakeCtx.Err()
   259  			case <-done:
   260  				interruptRes <- nil
   261  			}
   262  		}()
   263  	}
   264  
   265  	c.handshakeMutex.Lock()
   266  	defer c.handshakeMutex.Unlock()
   267  
   268  	if err := c.handshakeErr; err != nil {
   269  		return err
   270  	}
   271  	if c.isHandshakeComplete.Load() {
   272  		return nil
   273  	}
   274  
   275  	c.in.Lock()
   276  	defer c.in.Unlock()
   277  
   278  	// [uTLS section begins]
   279  	if c.isClient {
   280  		err := c.BuildHandshakeState()
   281  		if err != nil {
   282  			return err
   283  		}
   284  	}
   285  	// [uTLS section ends]
   286  	c.handshakeErr = c.handshakeFn(handshakeCtx)
   287  	if c.handshakeErr == nil {
   288  		c.handshakes++
   289  	} else {
   290  		// If an error occurred during the hadshake try to flush the
   291  		// alert that might be left in the buffer.
   292  		c.flush()
   293  	}
   294  
   295  	if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
   296  		c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
   297  	}
   298  	if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
   299  		panic("tls: internal error: handshake returned an error but is marked successful")
   300  	}
   301  
   302  	if c.quic != nil {
   303  		if c.handshakeErr == nil {
   304  			c.quicHandshakeComplete()
   305  			// Provide the 1-RTT read secret now that the handshake is complete.
   306  			// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
   307  			// the handshake (RFC 9001, Section 5.7).
   308  			c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
   309  		} else {
   310  			var a alert
   311  			c.out.Lock()
   312  			if !errors.As(c.out.err, &a) {
   313  				a = alertInternalError
   314  			}
   315  			c.out.Unlock()
   316  			// Return an error which wraps both the handshake error and
   317  			// any alert error we may have sent, or alertInternalError
   318  			// if we didn't send an alert.
   319  			// Truncate the text of the alert to 0 characters.
   320  			c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
   321  		}
   322  		close(c.quic.blockedc)
   323  		close(c.quic.signalc)
   324  	}
   325  
   326  	return c.handshakeErr
   327  }
   328  
   329  // Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
   330  // Write writes data to the connection.
   331  func (c *UConn) Write(b []byte) (int, error) {
   332  	// interlock with Close below
   333  	for {
   334  		x := c.activeCall.Load()
   335  		if x&1 != 0 {
   336  			return 0, net.ErrClosed
   337  		}
   338  		if c.activeCall.CompareAndSwap(x, x+2) {
   339  			defer c.activeCall.Add(-2)
   340  			break
   341  		}
   342  	}
   343  
   344  	if err := c.Handshake(); err != nil {
   345  		return 0, err
   346  	}
   347  
   348  	c.out.Lock()
   349  	defer c.out.Unlock()
   350  
   351  	if err := c.out.err; err != nil {
   352  		return 0, err
   353  	}
   354  
   355  	if !c.isHandshakeComplete.Load() {
   356  		return 0, alertInternalError
   357  	}
   358  
   359  	if c.closeNotifySent {
   360  		return 0, errShutdown
   361  	}
   362  
   363  	// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
   364  	// attack when using block mode ciphers due to predictable IVs.
   365  	// This can be prevented by splitting each Application Data
   366  	// record into two records, effectively randomizing the IV.
   367  	//
   368  	// https://www.openssl.org/~bodo/tls-cbc.txt
   369  	// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
   370  	// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
   371  
   372  	var m int
   373  	if len(b) > 1 && c.vers <= VersionTLS10 {
   374  		if _, ok := c.out.cipher.(cipher.BlockMode); ok {
   375  			n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
   376  			if err != nil {
   377  				return n, c.out.setErrorLocked(err)
   378  			}
   379  			m, b = 1, b[1:]
   380  		}
   381  	}
   382  
   383  	// #Restls# Begin
   384  	if c.restlsAuthed {
   385  		n, err := c.writeRestlsApplicationRecord(b)
   386  		return n, c.out.setErrorLocked(err)
   387  	}
   388  	// #Restls# End
   389  
   390  	n, err := c.writeRecordLocked(recordTypeApplicationData, b)
   391  	return n + m, c.out.setErrorLocked(err)
   392  }
   393  
   394  // clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3)
   395  // and performs client TLS handshake with that state
   396  func (c *UConn) clientHandshake(ctx context.Context) (err error) {
   397  	// [uTLS section begins]
   398  	hello := c.HandshakeState.Hello.getPrivatePtr()
   399  	defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
   400  
   401  	sessionIsAlreadySet := c.HandshakeState.Session != nil
   402  
   403  	// after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
   404  	// useTLS13 variable tells which pointer
   405  	// [uTLS section ends]
   406  
   407  	if c.config == nil {
   408  		c.config = defaultConfig()
   409  	}
   410  
   411  	// This may be a renegotiation handshake, in which case some fields
   412  	// need to be reset.
   413  	c.didResume = false
   414  
   415  	// [uTLS section begins]
   416  	// don't make new ClientHello, use hs.hello
   417  	// preserve the checks from beginning and end of makeClientHello()
   418  	if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify && len(c.config.InsecureServerNameToVerify) == 0 {
   419  		return errors.New("tls: at least one of ServerName, InsecureSkipVerify or InsecureServerNameToVerify must be specified in the tls.Config")
   420  	}
   421  
   422  	nextProtosLength := 0
   423  	for _, proto := range c.config.NextProtos {
   424  		if l := len(proto); l == 0 || l > 255 {
   425  			return errors.New("tls: invalid NextProtos value")
   426  		} else {
   427  			nextProtosLength += 1 + l
   428  		}
   429  	}
   430  
   431  	if nextProtosLength > 0xffff {
   432  		return errors.New("tls: NextProtos values too large")
   433  	}
   434  
   435  	if c.handshakes > 0 {
   436  		hello.secureRenegotiation = c.clientFinished[:]
   437  	}
   438  	// [uTLS section ends]
   439  
   440  	// #Restls# Begin
   441  	debugf(c.Conn, "hello.keyShares(private) %v\n", hello.keyShares)
   442  	supportTLS13 := AnyTrue(hello.supportedVersions, func(v uint16) bool {
   443  		return v == VersionTLS13
   444  	})
   445  	if c.config.VersionHint == TLS13Hint && supportTLS13 {
   446  		debugf(c.Conn, "u_conn generateSessionIDForTLS13\n")
   447  		c.generateSessionIDForTLS13(hello)
   448  	}
   449  	// #Restls# End
   450  
   451  	session, earlySecret, binderKey, err := c.loadSession(hello)
   452  	if err != nil {
   453  		return err
   454  	}
   455  	if session != nil {
   456  		debugf(c.Conn, "session loaded\n")
   457  		defer func() {
   458  			// If we got a handshake failure when resuming a session, throw away
   459  			// the session ticket. See RFC 5077, Section 3.2.
   460  			//
   461  			// RFC 8446 makes no mention of dropping tickets on failure, but it
   462  			// does require servers to abort on invalid binders, so we need to
   463  			// delete tickets to recover from a corrupted PSK.
   464  			if err != nil {
   465  				if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
   466  					c.config.ClientSessionCache.Put(cacheKey, nil)
   467  				}
   468  			}
   469  		}()
   470  	}
   471  
   472  	// #Restls# Begin
   473  	if c.config.VersionHint == TLS12Hint || c.config.VersionHint == TLS13Hint && !supportTLS13 {
   474  		debugf(c.Conn, "c.generateSessionIDForTLS12\n")
   475  		if err := c.generateSessionIDForTLS12(hello); err != nil {
   476  			return err
   477  		}
   478  	}
   479  	debugf(c.Conn, "%v, %v, %v\n", c.HandshakeState.Hello.SessionId, hello.sessionId, hello.raw[39:39+32])
   480  	copy(hello.raw[39:], hello.sessionId) // patch session id
   481  	// #Restls# End
   482  
   483  	cacheKey := c.clientSessionCacheKey()
   484  	if c.config.ClientSessionCache != nil {
   485  		cs, ok := c.config.ClientSessionCache.Get(cacheKey)
   486  		if !sessionIsAlreadySet && ok { // uTLS: do not overwrite already set session
   487  			err = c.SetSessionState(cs)
   488  			if err != nil {
   489  				return
   490  			}
   491  		}
   492  	}
   493  
   494  	if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
   495  		return err
   496  	}
   497  
   498  	msg, err := c.readHandshake(nil)
   499  	if err != nil {
   500  		return err
   501  	}
   502  
   503  	serverHello, ok := msg.(*serverHelloMsg)
   504  	if !ok {
   505  		c.sendAlert(alertUnexpectedMessage)
   506  		return unexpectedMessageError(serverHello, msg)
   507  	}
   508  
   509  	if err := c.pickTLSVersion(serverHello); err != nil {
   510  		return err
   511  	}
   512  
   513  	c.serverRandom = serverHello.random // #Restls#
   514  
   515  	// uTLS: do not create new handshakeState, use existing one
   516  	if c.vers == VersionTLS13 {
   517  		hs13 := c.HandshakeState.toPrivate13()
   518  		hs13.serverHello = serverHello
   519  		hs13.hello = hello
   520  		if !sessionIsAlreadySet {
   521  			hs13.earlySecret = earlySecret
   522  			hs13.binderKey = binderKey
   523  		}
   524  		hs13.ctx = ctx
   525  		// In TLS 1.3, session tickets are delivered after the handshake.
   526  		err = hs13.handshake()
   527  		if handshakeState := hs13.toPublic13(); handshakeState != nil {
   528  			c.HandshakeState = *handshakeState
   529  		}
   530  		return err
   531  	}
   532  
   533  	hs12 := c.HandshakeState.toPrivate12()
   534  	hs12.serverHello = serverHello
   535  	hs12.hello = hello
   536  	hs12.ctx = ctx
   537  	err = hs12.handshake()
   538  	if handshakeState := hs12.toPublic12(); handshakeState != nil {
   539  		c.HandshakeState = *handshakeState
   540  	}
   541  	if err != nil {
   542  		return err
   543  	}
   544  
   545  	// If we had a successful handshake and hs.session is different from
   546  	// the one already cached - cache a new one.
   547  	if cacheKey != "" && hs12.session != nil && session != hs12.session {
   548  		hs12cs := &ClientSessionState{
   549  			ticket:  hs12.ticket,
   550  			session: hs12.session,
   551  		}
   552  
   553  		c.config.ClientSessionCache.Put(cacheKey, hs12cs)
   554  	}
   555  	return nil
   556  }
   557  
   558  func (uconn *UConn) ApplyConfig() error {
   559  	for _, ext := range uconn.Extensions {
   560  		err := ext.writeToUConn(uconn)
   561  		if err != nil {
   562  			return err
   563  		}
   564  	}
   565  	return nil
   566  }
   567  
   568  func (uconn *UConn) MarshalClientHello() error {
   569  	hello := uconn.HandshakeState.Hello
   570  	headerLength := 2 + 32 + 1 + len(hello.SessionId) +
   571  		2 + len(hello.CipherSuites)*2 +
   572  		1 + len(hello.CompressionMethods)
   573  
   574  	extensionsLen := 0
   575  	var paddingExt *UtlsPaddingExtension
   576  	for _, ext := range uconn.Extensions {
   577  		if pe, ok := ext.(*UtlsPaddingExtension); !ok {
   578  			// If not padding - just add length of extension to total length
   579  			extensionsLen += ext.Len()
   580  		} else {
   581  			// If padding - process it later
   582  			if paddingExt == nil {
   583  				paddingExt = pe
   584  			} else {
   585  				return errors.New("multiple padding extensions!")
   586  			}
   587  		}
   588  	}
   589  
   590  	if paddingExt != nil {
   591  		// determine padding extension presence and length
   592  		paddingExt.Update(headerLength + 4 + extensionsLen + 2)
   593  		extensionsLen += paddingExt.Len()
   594  	}
   595  
   596  	helloLen := headerLength
   597  	if len(uconn.Extensions) > 0 {
   598  		helloLen += 2 + extensionsLen // 2 bytes for extensions' length
   599  	}
   600  
   601  	helloBuffer := bytes.Buffer{}
   602  	bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
   603  	// We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
   604  	// Write() will become noop, and error will be accessible via Flush(), which is called once in the end
   605  
   606  	binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
   607  	helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
   608  	binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
   609  	binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
   610  
   611  	binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
   612  
   613  	binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
   614  	binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
   615  
   616  	binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
   617  	for _, suite := range hello.CipherSuites {
   618  		binary.Write(bufferedWriter, binary.BigEndian, suite)
   619  	}
   620  
   621  	binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
   622  	binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
   623  
   624  	if len(uconn.Extensions) > 0 {
   625  		binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
   626  		for _, ext := range uconn.Extensions {
   627  			bufferedWriter.ReadFrom(ext)
   628  		}
   629  	}
   630  
   631  	err := bufferedWriter.Flush()
   632  	if err != nil {
   633  		return err
   634  	}
   635  
   636  	if helloBuffer.Len() != 4+helloLen {
   637  		return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
   638  			". Got: " + strconv.Itoa(helloBuffer.Len()))
   639  	}
   640  
   641  	hello.Raw = helloBuffer.Bytes()
   642  	return nil
   643  }
   644  
   645  // get current state of cipher and encrypt zeros to get keystream
   646  func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
   647  	zeros := make([]byte, length)
   648  
   649  	if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
   650  		// AEAD.Seal() does not mutate internal state, other ciphers might
   651  		return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
   652  	}
   653  	return nil, errors.New("could not convert OutCipher to cipher.AEAD")
   654  }
   655  
   656  // SetTLSVers sets min and max TLS version in all appropriate places.
   657  // Function will use first non-zero version parsed in following order:
   658  //  1. Provided minTLSVers, maxTLSVers
   659  //  2. specExtensions may have SupportedVersionsExtension
   660  //  3. [default] min = TLS 1.0, max = TLS 1.2
   661  //
   662  // Error is only returned if things are in clearly undesirable state
   663  // to help user fix them.
   664  func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error {
   665  	if minTLSVers == 0 && maxTLSVers == 0 {
   666  		// if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension
   667  		supportedVersionsExtensionsPresent := 0
   668  		for _, e := range specExtensions {
   669  			switch ext := e.(type) {
   670  			case *SupportedVersionsExtension:
   671  				findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) {
   672  					// returns (minVers, maxVers)
   673  					minVers := uint16(0)
   674  					maxVers := uint16(0)
   675  					for _, vers := range versions {
   676  						if isGREASEUint16(vers) {
   677  							continue
   678  						}
   679  						if maxVers < vers || maxVers == 0 {
   680  							maxVers = vers
   681  						}
   682  						if minVers > vers || minVers == 0 {
   683  							minVers = vers
   684  						}
   685  					}
   686  					return minVers, maxVers
   687  				}
   688  
   689  				supportedVersionsExtensionsPresent += 1
   690  				minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions)
   691  				if minTLSVers == 0 && maxTLSVers == 0 {
   692  					return fmt.Errorf("SupportedVersions extension has invalid Versions field")
   693  				} // else: proceed
   694  			}
   695  		}
   696  		switch supportedVersionsExtensionsPresent {
   697  		case 0:
   698  			// if mandatory for TLS 1.3 extension is not present, just default to 1.2
   699  			minTLSVers = VersionTLS10
   700  			maxTLSVers = VersionTLS12
   701  		case 1:
   702  		default:
   703  			return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions",
   704  				supportedVersionsExtensionsPresent)
   705  		}
   706  	}
   707  
   708  	if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS13 {
   709  		return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers)
   710  	}
   711  
   712  	if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 {
   713  		return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers)
   714  	}
   715  
   716  	uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers)
   717  	uconn.config.MinVersion = minTLSVers
   718  	uconn.config.MaxVersion = maxTLSVers
   719  
   720  	return nil
   721  }
   722  
   723  func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
   724  	uconn.Conn.conn = c
   725  }
   726  
   727  func (uconn *UConn) GetUnderlyingConn() net.Conn {
   728  	return uconn.Conn.conn
   729  }
   730  
   731  // MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
   732  // Major Hack Alert.
   733  func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
   734  	tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
   735  	cs := cipherSuiteByID(cipherSuite)
   736  	if cs != nil {
   737  		// This is mostly borrowed from establishKeys()
   738  		clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
   739  			keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
   740  				cs.macLen, cs.keyLen, cs.ivLen)
   741  
   742  		var clientCipher, serverCipher interface{}
   743  		var clientHash, serverHash hash.Hash
   744  		if cs.cipher != nil {
   745  			clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
   746  			clientHash = cs.mac(clientMAC)
   747  			serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
   748  			serverHash = cs.mac(serverMAC)
   749  		} else {
   750  			clientCipher = cs.aead(clientKey, clientIV)
   751  			serverCipher = cs.aead(serverKey, serverIV)
   752  		}
   753  
   754  		if isClient {
   755  			tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
   756  			tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
   757  		} else {
   758  			tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
   759  			tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
   760  		}
   761  
   762  		// skip the handshake states
   763  		tlsConn.isHandshakeComplete.Store(true)
   764  		tlsConn.cipherSuite = cipherSuite
   765  		tlsConn.haveVers = true
   766  		tlsConn.vers = version
   767  
   768  		// Update to the new cipher specs
   769  		// and consume the finished messages
   770  		tlsConn.in.changeCipherSpec()
   771  		tlsConn.out.changeCipherSpec()
   772  
   773  		tlsConn.in.incSeq()
   774  		tlsConn.out.incSeq()
   775  
   776  		return tlsConn
   777  	} else {
   778  		// TODO: Support TLS 1.3 Cipher Suites
   779  		return nil
   780  	}
   781  }
   782  
   783  func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
   784  	a := make([]uint16, maxVers-minVers+1)
   785  	for i := range a {
   786  		a[i] = maxVers - uint16(i)
   787  	}
   788  	return a
   789  }
   790  
   791  // Extending (*Conn).readHandshake() to support more customized handshake messages.
   792  func (c *Conn) utlsHandshakeMessageType(msgType byte) (handshakeMessage, error) {
   793  	switch msgType {
   794  	case utlsTypeCompressedCertificate:
   795  		return new(utlsCompressedCertificateMsg), nil
   796  	case utlsTypeEncryptedExtensions:
   797  		if c.isClient {
   798  			return new(encryptedExtensionsMsg), nil
   799  		} else {
   800  			return new(utlsClientEncryptedExtensionsMsg), nil
   801  		}
   802  	default:
   803  		return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
   804  	}
   805  }
   806  
   807  // Extending (*Conn).connectionStateLocked()
   808  func (c *Conn) utlsConnectionStateLocked(state *ConnectionState) {
   809  	state.PeerApplicationSettings = c.utls.peerApplicationSettings
   810  }
   811  
   812  type utlsConnExtraFields struct {
   813  	hasApplicationSettings   bool
   814  	peerApplicationSettings  []byte
   815  	localApplicationSettings []byte
   816  }