github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/client.go (about)

     1  package gquic
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"sync"
    11  
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
    13  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    14  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    15  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    16  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/qerr"
    17  	"github.com/bifurcation/mint"
    18  )
    19  
    20  type client struct {
    21  	mutex sync.Mutex
    22  
    23  	conn connection
    24  	// If the client is created with DialAddr, we create a packet conn.
    25  	// If it is started with Dial, we take a packet conn as a parameter.
    26  	createdPacketConn bool
    27  
    28  	packetHandlers packetHandlerManager
    29  
    30  	token []byte
    31  
    32  	versionNegotiated                bool // has the server accepted our version
    33  	receivedVersionNegotiationPacket bool
    34  	negotiatedVersions               []protocol.VersionNumber // the list of versions from the version negotiation packet
    35  
    36  	tlsConf  *tls.Config
    37  	mintConf *mint.Config
    38  	config   *Config
    39  
    40  	srcConnID  protocol.ConnectionID
    41  	destConnID protocol.ConnectionID
    42  
    43  	initialVersion protocol.VersionNumber
    44  	version        protocol.VersionNumber
    45  
    46  	handshakeChan chan struct{}
    47  	closeCallback func(protocol.ConnectionID)
    48  
    49  	session quicSession
    50  
    51  	logger utils.Logger
    52  }
    53  
    54  var _ packetHandler = &client{}
    55  
    56  var (
    57  	// make it possible to mock connection ID generation in the tests
    58  	generateConnectionID           = protocol.GenerateConnectionID
    59  	generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
    60  	errCloseSessionForNewVersion   = errors.New("closing session in order to recreate it with a new version")
    61  	errCloseSessionForRetry        = errors.New("closing session in response to a stateless retry")
    62  )
    63  
    64  // DialAddr establishes a new QUIC connection to a server.
    65  // The hostname for SNI is taken from the given address.
    66  func DialAddr(
    67  	addr string,
    68  	tlsConf *tls.Config,
    69  	config *Config,
    70  ) (Session, error) {
    71  	return DialAddrContext(context.Background(), addr, tlsConf, config)
    72  }
    73  
    74  // DialAddrContext establishes a new QUIC connection to a server using the provided context.
    75  // The hostname for SNI is taken from the given address.
    76  func DialAddrContext(
    77  	ctx context.Context,
    78  	addr string,
    79  	tlsConf *tls.Config,
    80  	config *Config,
    81  ) (Session, error) {
    82  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
    91  }
    92  
    93  // Dial establishes a new QUIC connection to a server using a net.PacketConn.
    94  // The host parameter is used for SNI.
    95  func Dial(
    96  	pconn net.PacketConn,
    97  	remoteAddr net.Addr,
    98  	host string,
    99  	tlsConf *tls.Config,
   100  	config *Config,
   101  ) (Session, error) {
   102  	return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
   103  }
   104  
   105  // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
   106  // The host parameter is used for SNI.
   107  func DialContext(
   108  	ctx context.Context,
   109  	pconn net.PacketConn,
   110  	remoteAddr net.Addr,
   111  	host string,
   112  	tlsConf *tls.Config,
   113  	config *Config,
   114  ) (Session, error) {
   115  	return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
   116  }
   117  
   118  func dialContext(
   119  	ctx context.Context,
   120  	pconn net.PacketConn,
   121  	remoteAddr net.Addr,
   122  	host string,
   123  	tlsConf *tls.Config,
   124  	config *Config,
   125  	createdPacketConn bool,
   126  ) (Session, error) {
   127  	// [Psiphon]
   128  	// We call DialContext as we need to create a custom net.PacketConn.
   129  	// There is one custom net.PacketConn per QUIC connection, which
   130  	// satisfies the gQUIC 44 constraint.
   131  	config = populateClientConfig(config, true)
   132  	/*
   133  			config = populateClientConfig(config, createdPacketConn)
   134  			if !createdPacketConn {
   135  				for _, v := range config.Versions {
   136  					if v == protocol.Version44 {
   137  						return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
   138  					}
   139  				}
   140  			}
   141  		}
   142  	*/
   143  	// [Psiphon]
   144  	packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	c.packetHandlers = packetHandlers
   153  	if err := c.dial(ctx); err != nil {
   154  		return nil, err
   155  	}
   156  	return c.session, nil
   157  }
   158  
   159  func newClient(
   160  	pconn net.PacketConn,
   161  	remoteAddr net.Addr,
   162  	config *Config,
   163  	tlsConf *tls.Config,
   164  	host string,
   165  	closeCallback func(protocol.ConnectionID),
   166  	createdPacketConn bool,
   167  ) (*client, error) {
   168  	if tlsConf == nil {
   169  		tlsConf = &tls.Config{}
   170  	}
   171  	if tlsConf.ServerName == "" {
   172  		var err error
   173  		tlsConf.ServerName, _, err = net.SplitHostPort(host)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  	}
   178  
   179  	// check that all versions are actually supported
   180  	if config != nil {
   181  		for _, v := range config.Versions {
   182  			if !protocol.IsValidVersion(v) {
   183  				return nil, fmt.Errorf("%s is not a valid QUIC version", v)
   184  			}
   185  		}
   186  	}
   187  	onClose := func(protocol.ConnectionID) {}
   188  	if closeCallback != nil {
   189  		onClose = closeCallback
   190  	}
   191  	c := &client{
   192  		conn:              &conn{pconn: pconn, currentAddr: remoteAddr},
   193  		createdPacketConn: createdPacketConn,
   194  		tlsConf:           tlsConf,
   195  		config:            config,
   196  		version:           config.Versions[0],
   197  		handshakeChan:     make(chan struct{}),
   198  		closeCallback:     onClose,
   199  		logger:            utils.DefaultLogger.WithPrefix("client"),
   200  	}
   201  	return c, c.generateConnectionIDs()
   202  }
   203  
   204  // populateClientConfig populates fields in the quic.Config with their default values, if none are set
   205  // it may be called with nil
   206  func populateClientConfig(config *Config, createdPacketConn bool) *Config {
   207  	if config == nil {
   208  		config = &Config{}
   209  	}
   210  	versions := config.Versions
   211  	if len(versions) == 0 {
   212  		versions = protocol.SupportedVersions
   213  	}
   214  
   215  	handshakeTimeout := protocol.DefaultHandshakeTimeout
   216  	if config.HandshakeTimeout != 0 {
   217  		handshakeTimeout = config.HandshakeTimeout
   218  	}
   219  	idleTimeout := protocol.DefaultIdleTimeout
   220  	if config.IdleTimeout != 0 {
   221  		idleTimeout = config.IdleTimeout
   222  	}
   223  
   224  	maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
   225  	if maxReceiveStreamFlowControlWindow == 0 {
   226  		maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
   227  	}
   228  	maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
   229  	if maxReceiveConnectionFlowControlWindow == 0 {
   230  		maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
   231  	}
   232  	maxIncomingStreams := config.MaxIncomingStreams
   233  	if maxIncomingStreams == 0 {
   234  		maxIncomingStreams = protocol.DefaultMaxIncomingStreams
   235  	} else if maxIncomingStreams < 0 {
   236  		maxIncomingStreams = 0
   237  	}
   238  	maxIncomingUniStreams := config.MaxIncomingUniStreams
   239  	if maxIncomingUniStreams == 0 {
   240  		maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
   241  	} else if maxIncomingUniStreams < 0 {
   242  		maxIncomingUniStreams = 0
   243  	}
   244  	connIDLen := config.ConnectionIDLength
   245  	if connIDLen == 0 && !createdPacketConn {
   246  		connIDLen = protocol.DefaultConnectionIDLength
   247  	}
   248  	for _, v := range versions {
   249  		if v == protocol.Version44 {
   250  			connIDLen = 0
   251  		}
   252  	}
   253  
   254  	return &Config{
   255  		Versions:                              versions,
   256  		HandshakeTimeout:                      handshakeTimeout,
   257  		IdleTimeout:                           idleTimeout,
   258  		RequestConnectionIDOmission:           config.RequestConnectionIDOmission,
   259  		ConnectionIDLength:                    connIDLen,
   260  		MaxReceiveStreamFlowControlWindow:     maxReceiveStreamFlowControlWindow,
   261  		MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
   262  		MaxIncomingStreams:                    maxIncomingStreams,
   263  		MaxIncomingUniStreams:                 maxIncomingUniStreams,
   264  		KeepAlive:                             config.KeepAlive,
   265  	}
   266  }
   267  
   268  func (c *client) generateConnectionIDs() error {
   269  	connIDLen := protocol.ConnectionIDLenGQUIC
   270  	if c.version.UsesTLS() {
   271  		connIDLen = c.config.ConnectionIDLength
   272  	}
   273  	srcConnID, err := generateConnectionID(connIDLen)
   274  	if err != nil {
   275  		return err
   276  	}
   277  	destConnID := srcConnID
   278  	if c.version.UsesTLS() {
   279  		destConnID, err = generateConnectionIDForInitial()
   280  		if err != nil {
   281  			return err
   282  		}
   283  	}
   284  	c.srcConnID = srcConnID
   285  	c.destConnID = destConnID
   286  	if c.version == protocol.Version44 {
   287  		c.srcConnID = nil
   288  	}
   289  	return nil
   290  }
   291  
   292  func (c *client) dial(ctx context.Context) error {
   293  	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
   294  
   295  	var err error
   296  	if c.version.UsesTLS() {
   297  		err = c.dialTLS(ctx)
   298  	} else {
   299  		err = c.dialGQUIC(ctx)
   300  	}
   301  	return err
   302  }
   303  
   304  func (c *client) dialGQUIC(ctx context.Context) error {
   305  	if err := c.createNewGQUICSession(); err != nil {
   306  		return err
   307  	}
   308  	err := c.establishSecureConnection(ctx)
   309  	if err == errCloseSessionForNewVersion {
   310  		return c.dial(ctx)
   311  	}
   312  	return err
   313  }
   314  
   315  func (c *client) dialTLS(ctx context.Context) error {
   316  	params := &handshake.TransportParameters{
   317  		StreamFlowControlWindow:     protocol.ReceiveStreamFlowControlWindow,
   318  		ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
   319  		IdleTimeout:                 c.config.IdleTimeout,
   320  		OmitConnectionID:            c.config.RequestConnectionIDOmission,
   321  		MaxBidiStreams:              uint16(c.config.MaxIncomingStreams),
   322  		MaxUniStreams:               uint16(c.config.MaxIncomingUniStreams),
   323  		DisableMigration:            true,
   324  	}
   325  	extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
   326  	mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
   327  	if err != nil {
   328  		return err
   329  	}
   330  	mintConf.ExtensionHandler = extHandler
   331  	c.mintConf = mintConf
   332  
   333  	if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
   334  		return err
   335  	}
   336  	err = c.establishSecureConnection(ctx)
   337  	if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
   338  		return c.dial(ctx)
   339  	}
   340  	return err
   341  }
   342  
   343  // establishSecureConnection runs the session, and tries to establish a secure connection
   344  // It returns:
   345  // - errCloseSessionForNewVersion when the server sends a version negotiation packet
   346  // - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
   347  // - any other error that might occur
   348  // - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
   349  func (c *client) establishSecureConnection(ctx context.Context) error {
   350  	errorChan := make(chan error, 1)
   351  
   352  	go func() {
   353  		err := c.session.run() // returns as soon as the session is closed
   354  		if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
   355  			c.conn.Close()
   356  		}
   357  		errorChan <- err
   358  	}()
   359  
   360  	select {
   361  	case <-ctx.Done():
   362  		// The session will send a PeerGoingAway error to the server.
   363  		c.session.Close()
   364  		return ctx.Err()
   365  	case err := <-errorChan:
   366  		return err
   367  	case <-c.handshakeChan:
   368  		// handshake successfully completed
   369  		return nil
   370  	}
   371  }
   372  
   373  func (c *client) handlePacket(p *receivedPacket) {
   374  	if err := c.handlePacketImpl(p); err != nil {
   375  		c.logger.Errorf("error handling packet: %s", err)
   376  	}
   377  }
   378  
   379  func (c *client) handlePacketImpl(p *receivedPacket) error {
   380  	c.mutex.Lock()
   381  	defer c.mutex.Unlock()
   382  
   383  	// handle Version Negotiation Packets
   384  	if p.header.IsVersionNegotiation {
   385  		err := c.handleVersionNegotiationPacket(p.header)
   386  		if err != nil {
   387  			c.session.destroy(err)
   388  		}
   389  		// version negotiation packets have no payload
   390  		return err
   391  	}
   392  
   393  	if !c.version.UsesIETFHeaderFormat() {
   394  		connID := p.header.DestConnectionID
   395  		// reject packets with truncated connection id if we didn't request truncation
   396  		if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
   397  			return errors.New("received packet with truncated connection ID, but didn't request truncation")
   398  		}
   399  		// reject packets with the wrong connection ID
   400  		if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
   401  			return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
   402  		}
   403  		if p.header.ResetFlag {
   404  			return c.handlePublicReset(p)
   405  		}
   406  	} else {
   407  		// reject packets with the wrong connection ID
   408  		if !p.header.DestConnectionID.Equal(c.srcConnID) {
   409  			return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
   410  		}
   411  	}
   412  
   413  	if p.header.IsLongHeader {
   414  		switch p.header.Type {
   415  		case protocol.PacketTypeRetry:
   416  			c.handleRetryPacket(p.header)
   417  			return nil
   418  		case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
   419  
   420  		// [Psiphon]
   421  		// The fix in https://github.com/lucas-clemente/quic-go/commit/386b77f422028fe86aae7ae9c017ca2c692c8184 must
   422  		// also be applied here.
   423  		case protocol.PacketTypeInitial:
   424  			if p.header.Version == protocol.Version44 {
   425  				break
   426  			}
   427  			fallthrough
   428  		// [Psiphon]
   429  
   430  		default:
   431  			return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
   432  		}
   433  	}
   434  
   435  	// this is the first packet we are receiving
   436  	// since it is not a Version Negotiation Packet, this means the server supports the suggested version
   437  	if !c.versionNegotiated {
   438  		c.versionNegotiated = true
   439  	}
   440  
   441  	c.session.handlePacket(p)
   442  	return nil
   443  }
   444  
   445  func (c *client) handlePublicReset(p *receivedPacket) error {
   446  	cr := c.conn.RemoteAddr()
   447  	// check if the remote address and the connection ID match
   448  	// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
   449  	if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
   450  		return errors.New("Received a spoofed Public Reset")
   451  	}
   452  	pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
   453  	if err != nil {
   454  		return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
   455  	}
   456  	c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
   457  	c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
   458  	return nil
   459  }
   460  
   461  func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
   462  	// ignore delayed / duplicated version negotiation packets
   463  	if c.receivedVersionNegotiationPacket || c.versionNegotiated {
   464  		c.logger.Debugf("Received a delayed Version Negotiation Packet.")
   465  		return nil
   466  	}
   467  
   468  	for _, v := range hdr.SupportedVersions {
   469  		if v == c.version {
   470  			// the version negotiation packet contains the version that we offered
   471  			// this might be a packet sent by an attacker (or by a terribly broken server implementation)
   472  			// ignore it
   473  			return nil
   474  		}
   475  	}
   476  
   477  	c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
   478  	newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
   479  	if !ok {
   480  		return qerr.InvalidVersion
   481  	}
   482  	c.receivedVersionNegotiationPacket = true
   483  	c.negotiatedVersions = hdr.SupportedVersions
   484  
   485  	// switch to negotiated version
   486  	c.initialVersion = c.version
   487  	c.version = newVersion
   488  	if err := c.generateConnectionIDs(); err != nil {
   489  		return err
   490  	}
   491  
   492  	c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
   493  	c.session.destroy(errCloseSessionForNewVersion)
   494  	return nil
   495  }
   496  
   497  func (c *client) handleRetryPacket(hdr *wire.Header) {
   498  	c.logger.Debugf("<- Received Retry")
   499  	hdr.Log(c.logger)
   500  	if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
   501  		c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
   502  		return
   503  	}
   504  	if hdr.SrcConnectionID.Equal(c.destConnID) {
   505  		c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
   506  		return
   507  	}
   508  	// If a token is already set, this means that we already received a Retry from the server.
   509  	// Ignore this Retry packet.
   510  	if len(c.token) > 0 {
   511  		c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
   512  		return
   513  	}
   514  	c.destConnID = hdr.SrcConnectionID
   515  	c.token = hdr.Token
   516  	c.session.destroy(errCloseSessionForRetry)
   517  }
   518  
   519  func (c *client) createNewGQUICSession() error {
   520  	c.mutex.Lock()
   521  	defer c.mutex.Unlock()
   522  	runner := &runner{
   523  		onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
   524  		removeConnectionIDImpl:  c.closeCallback,
   525  	}
   526  	sess, err := newClientSession(
   527  		c.conn,
   528  		runner,
   529  		c.version,
   530  		c.destConnID,
   531  		c.srcConnID,
   532  		c.tlsConf,
   533  		c.config,
   534  		c.initialVersion,
   535  		c.negotiatedVersions,
   536  		c.logger,
   537  	)
   538  	if err != nil {
   539  		return err
   540  	}
   541  	c.session = sess
   542  	c.packetHandlers.Add(c.srcConnID, c)
   543  	if c.config.RequestConnectionIDOmission {
   544  		c.packetHandlers.Add(protocol.ConnectionID{}, c)
   545  	}
   546  	return nil
   547  }
   548  
   549  func (c *client) createNewTLSSession(
   550  	paramsChan <-chan handshake.TransportParameters,
   551  	version protocol.VersionNumber,
   552  ) error {
   553  	c.mutex.Lock()
   554  	defer c.mutex.Unlock()
   555  	runner := &runner{
   556  		onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
   557  		removeConnectionIDImpl:  c.closeCallback,
   558  	}
   559  	sess, err := newTLSClientSession(
   560  		c.conn,
   561  		runner,
   562  		c.token,
   563  		c.destConnID,
   564  		c.srcConnID,
   565  		c.config,
   566  		c.mintConf,
   567  		paramsChan,
   568  		1,
   569  		c.logger,
   570  		c.version,
   571  	)
   572  	if err != nil {
   573  		return err
   574  	}
   575  	c.session = sess
   576  	c.packetHandlers.Add(c.srcConnID, c)
   577  	return nil
   578  }
   579  
   580  func (c *client) Close() error {
   581  	c.mutex.Lock()
   582  	defer c.mutex.Unlock()
   583  	if c.session == nil {
   584  		return nil
   585  	}
   586  	return c.session.Close()
   587  }
   588  
   589  func (c *client) destroy(e error) {
   590  	c.mutex.Lock()
   591  	defer c.mutex.Unlock()
   592  	if c.session == nil {
   593  		return
   594  	}
   595  	c.session.destroy(e)
   596  }
   597  
   598  func (c *client) GetVersion() protocol.VersionNumber {
   599  	c.mutex.Lock()
   600  	v := c.version
   601  	c.mutex.Unlock()
   602  	return v
   603  }
   604  
   605  func (c *client) GetPerspective() protocol.Perspective {
   606  	return protocol.PerspectiveClient
   607  }