github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/client.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"strings"
    10  
    11  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    12  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    13  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/logging"
    14  )
    15  
    16  type client struct {
    17  	conn sendConn
    18  	// If the client is created with DialAddr, we create a packet conn.
    19  	// If it is started with Dial, we take a packet conn as a parameter.
    20  	createdPacketConn bool
    21  
    22  	use0RTT bool
    23  
    24  	packetHandlers packetHandlerManager
    25  
    26  	tlsConf *tls.Config
    27  	config  *Config
    28  
    29  	srcConnID  protocol.ConnectionID
    30  	destConnID protocol.ConnectionID
    31  
    32  	initialPacketNumber  protocol.PacketNumber
    33  	hasNegotiatedVersion bool
    34  	version              protocol.VersionNumber
    35  
    36  	handshakeChan chan struct{}
    37  
    38  	session quicSession
    39  
    40  	tracer    logging.ConnectionTracer
    41  	tracingID uint64
    42  	logger    utils.Logger
    43  }
    44  
    45  var (
    46  	// make it possible to mock connection ID generation in the tests
    47  	generateConnectionID           = protocol.GenerateConnectionID
    48  	generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
    49  )
    50  
    51  // DialAddr establishes a new QUIC connection to a server.
    52  // It uses a new UDP connection and closes this connection when the QUIC session is closed.
    53  // The hostname for SNI is taken from the given address.
    54  // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
    55  func DialAddr(
    56  	addr string,
    57  	tlsConf *tls.Config,
    58  	config *Config,
    59  ) (Session, error) {
    60  	return DialAddrContext(context.Background(), addr, tlsConf, config)
    61  }
    62  
    63  // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
    64  // It uses a new UDP connection and closes this connection when the QUIC session is closed.
    65  // The hostname for SNI is taken from the given address.
    66  // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
    67  func DialAddrEarly(
    68  	addr string,
    69  	tlsConf *tls.Config,
    70  	config *Config,
    71  ) (EarlySession, error) {
    72  	return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
    73  }
    74  
    75  // DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context.
    76  // See DialAddrEarly for details
    77  func DialAddrEarlyContext(
    78  	ctx context.Context,
    79  	addr string,
    80  	tlsConf *tls.Config,
    81  	config *Config,
    82  ) (EarlySession, error) {
    83  	sess, err := dialAddrContext(ctx, addr, tlsConf, config, true)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early session")
    88  	return sess, nil
    89  }
    90  
    91  // DialAddrContext establishes a new QUIC connection to a server using the provided context.
    92  // See DialAddr for details.
    93  func DialAddrContext(
    94  	ctx context.Context,
    95  	addr string,
    96  	tlsConf *tls.Config,
    97  	config *Config,
    98  ) (Session, error) {
    99  	return dialAddrContext(ctx, addr, tlsConf, config, false)
   100  }
   101  
   102  func dialAddrContext(
   103  	ctx context.Context,
   104  	addr string,
   105  	tlsConf *tls.Config,
   106  	config *Config,
   107  	use0RTT bool,
   108  ) (quicSession, error) {
   109  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
   118  }
   119  
   120  // Dial establishes a new QUIC connection to a server using a net.PacketConn. If
   121  // the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
   122  // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
   123  // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
   124  // packets. The same PacketConn can be used for multiple calls to Dial and
   125  // Listen, QUIC connection IDs are used for demultiplexing the different
   126  // connections. The host parameter is used for SNI. The tls.Config must define
   127  // an application protocol (using NextProtos).
   128  func Dial(
   129  	pconn net.PacketConn,
   130  	remoteAddr net.Addr,
   131  	host string,
   132  	tlsConf *tls.Config,
   133  	config *Config,
   134  ) (Session, error) {
   135  	return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
   136  }
   137  
   138  // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
   139  // The same PacketConn can be used for multiple calls to Dial and Listen,
   140  // QUIC connection IDs are used for demultiplexing the different connections.
   141  // The host parameter is used for SNI.
   142  // The tls.Config must define an application protocol (using NextProtos).
   143  func DialEarly(
   144  	pconn net.PacketConn,
   145  	remoteAddr net.Addr,
   146  	host string,
   147  	tlsConf *tls.Config,
   148  	config *Config,
   149  ) (EarlySession, error) {
   150  	return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
   151  }
   152  
   153  // DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
   154  // See DialEarly for details.
   155  func DialEarlyContext(
   156  	ctx context.Context,
   157  	pconn net.PacketConn,
   158  	remoteAddr net.Addr,
   159  	host string,
   160  	tlsConf *tls.Config,
   161  	config *Config,
   162  ) (EarlySession, error) {
   163  	return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
   164  }
   165  
   166  // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
   167  // See Dial for details.
   168  func DialContext(
   169  	ctx context.Context,
   170  	pconn net.PacketConn,
   171  	remoteAddr net.Addr,
   172  	host string,
   173  	tlsConf *tls.Config,
   174  	config *Config,
   175  ) (Session, error) {
   176  	return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
   177  }
   178  
   179  func dialContext(
   180  	ctx context.Context,
   181  	pconn net.PacketConn,
   182  	remoteAddr net.Addr,
   183  	host string,
   184  	tlsConf *tls.Config,
   185  	config *Config,
   186  	use0RTT bool,
   187  	createdPacketConn bool,
   188  ) (quicSession, error) {
   189  	if tlsConf == nil {
   190  		return nil, errors.New("quic: tls.Config not set")
   191  	}
   192  	if err := validateConfig(config); err != nil {
   193  		return nil, err
   194  	}
   195  	config = populateClientConfig(config, createdPacketConn)
   196  	packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	c.packetHandlers = packetHandlers
   205  
   206  	c.tracingID = nextSessionTracingID()
   207  	if c.config.Tracer != nil {
   208  		c.tracer = c.config.Tracer.TracerForConnection(
   209  			context.WithValue(ctx, SessionTracingKey, c.tracingID),
   210  			protocol.PerspectiveClient,
   211  			c.destConnID,
   212  		)
   213  	}
   214  	if c.tracer != nil {
   215  		c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID)
   216  	}
   217  	if err := c.dial(ctx); err != nil {
   218  		return nil, err
   219  	}
   220  	return c.session, nil
   221  }
   222  
   223  func newClient(
   224  	pconn net.PacketConn,
   225  	remoteAddr net.Addr,
   226  	config *Config,
   227  	tlsConf *tls.Config,
   228  	host string,
   229  	use0RTT bool,
   230  	createdPacketConn bool,
   231  ) (*client, error) {
   232  	if tlsConf == nil {
   233  		tlsConf = &tls.Config{}
   234  	}
   235  	if tlsConf.ServerName == "" {
   236  		sni := host
   237  		if strings.IndexByte(sni, ':') != -1 {
   238  			var err error
   239  			sni, _, err = net.SplitHostPort(sni)
   240  			if err != nil {
   241  				return nil, err
   242  			}
   243  		}
   244  
   245  		tlsConf.ServerName = sni
   246  	}
   247  
   248  	// check that all versions are actually supported
   249  	if config != nil {
   250  		for _, v := range config.Versions {
   251  			if !protocol.IsValidVersion(v) {
   252  				return nil, fmt.Errorf("%s is not a valid QUIC version", v)
   253  			}
   254  		}
   255  	}
   256  
   257  	srcConnID, err := generateConnectionID(config.ConnectionIDLength)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  	destConnID, err := generateConnectionIDForInitial()
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  	c := &client{
   266  		srcConnID:         srcConnID,
   267  		destConnID:        destConnID,
   268  		conn:              newSendPconn(pconn, remoteAddr),
   269  		createdPacketConn: createdPacketConn,
   270  		use0RTT:           use0RTT,
   271  		tlsConf:           tlsConf,
   272  		config:            config,
   273  		version:           config.Versions[0],
   274  		handshakeChan:     make(chan struct{}),
   275  		logger:            utils.DefaultLogger.WithPrefix("client"),
   276  	}
   277  	return c, nil
   278  }
   279  
   280  func (c *client) dial(ctx context.Context) error {
   281  	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
   282  
   283  	c.session = newClientSession(
   284  		c.conn,
   285  		c.packetHandlers,
   286  		c.destConnID,
   287  		c.srcConnID,
   288  		c.config,
   289  		c.tlsConf,
   290  		c.initialPacketNumber,
   291  		c.use0RTT,
   292  		c.hasNegotiatedVersion,
   293  		c.tracer,
   294  		c.tracingID,
   295  		c.logger,
   296  		c.version,
   297  	)
   298  	c.packetHandlers.Add(c.srcConnID, c.session)
   299  
   300  	errorChan := make(chan error, 1)
   301  	go func() {
   302  		err := c.session.run() // returns as soon as the session is closed
   303  
   304  		if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
   305  			c.packetHandlers.Destroy()
   306  		}
   307  		errorChan <- err
   308  	}()
   309  
   310  	// only set when we're using 0-RTT
   311  	// Otherwise, earlySessionChan will be nil. Receiving from a nil chan blocks forever.
   312  	var earlySessionChan <-chan struct{}
   313  	if c.use0RTT {
   314  		earlySessionChan = c.session.earlySessionReady()
   315  	}
   316  
   317  	select {
   318  	case <-ctx.Done():
   319  		c.session.shutdown()
   320  		return ctx.Err()
   321  	case err := <-errorChan:
   322  		var recreateErr *errCloseForRecreating
   323  		if errors.As(err, &recreateErr) {
   324  			c.initialPacketNumber = recreateErr.nextPacketNumber
   325  			c.version = recreateErr.nextVersion
   326  			c.hasNegotiatedVersion = true
   327  			return c.dial(ctx)
   328  		}
   329  		return err
   330  	case <-earlySessionChan:
   331  		// ready to send 0-RTT data
   332  		return nil
   333  	case <-c.session.HandshakeComplete().Done():
   334  		// handshake successfully completed
   335  		return nil
   336  	}
   337  }