github.com/sagernet/quic-go@v0.43.1-beta.1/ech/client.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  
     8  	"github.com/sagernet/quic-go/internal/protocol"
     9  	"github.com/sagernet/quic-go/internal/utils"
    10  	"github.com/sagernet/quic-go/logging"
    11  	"github.com/sagernet/cloudflare-tls"
    12  )
    13  
    14  type client struct {
    15  	sendConn sendConn
    16  
    17  	use0RTT bool
    18  
    19  	packetHandlers packetHandlerManager
    20  	onClose        func()
    21  
    22  	tlsConf *tls.Config
    23  	config  *Config
    24  
    25  	connIDGenerator ConnectionIDGenerator
    26  	srcConnID       protocol.ConnectionID
    27  	destConnID      protocol.ConnectionID
    28  
    29  	initialPacketNumber  protocol.PacketNumber
    30  	hasNegotiatedVersion bool
    31  	version              protocol.Version
    32  
    33  	handshakeChan chan struct{}
    34  
    35  	conn quicConn
    36  
    37  	tracer    *logging.ConnectionTracer
    38  	tracingID ConnectionTracingID
    39  	logger    utils.Logger
    40  }
    41  
    42  // make it possible to mock connection ID for initial generation in the tests
    43  var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
    44  
    45  // DialAddr establishes a new QUIC connection to a server.
    46  // It resolves the address, and then creates a new UDP connection to dial the QUIC server.
    47  // When the QUIC connection is closed, this UDP connection is closed.
    48  // See Dial for more details.
    49  func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
    50  	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	tr, err := setupTransport(udpConn, tlsConf, true)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
    63  }
    64  
    65  // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
    66  // See DialAddr for more details.
    67  func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
    68  	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	tr, err := setupTransport(udpConn, tlsConf, true)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
    81  	if err != nil {
    82  		tr.Close()
    83  		return nil, err
    84  	}
    85  	return conn, nil
    86  }
    87  
    88  // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
    89  // See Dial for more details.
    90  func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
    91  	dl, err := setupTransport(c, tlsConf, false)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
    96  	if err != nil {
    97  		dl.Close()
    98  		return nil, err
    99  	}
   100  	return conn, nil
   101  }
   102  
   103  // Dial establishes a new QUIC connection to a server using a net.PacketConn.
   104  // If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
   105  // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
   106  // will be used instead of ReadFrom and WriteTo to read/write packets.
   107  // The tls.Config must define an application protocol (using NextProtos).
   108  //
   109  // This is a convenience function. More advanced use cases should instantiate a Transport,
   110  // which offers configuration options for a more fine-grained control of the connection establishment,
   111  // including reusing the underlying UDP socket for multiple QUIC connections.
   112  func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
   113  	dl, err := setupTransport(c, tlsConf, false)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	conn, err := dl.Dial(ctx, addr, tlsConf, conf)
   118  	if err != nil {
   119  		dl.Close()
   120  		return nil, err
   121  	}
   122  	return conn, nil
   123  }
   124  
   125  func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
   126  	if tlsConf == nil {
   127  		return nil, errors.New("quic: tls.Config not set")
   128  	}
   129  	return &Transport{
   130  		Conn:        c,
   131  		createdConn: createdPacketConn,
   132  		isSingleUse: true,
   133  	}, nil
   134  }
   135  
   136  func dial(
   137  	ctx context.Context,
   138  	conn sendConn,
   139  	connIDGenerator ConnectionIDGenerator,
   140  	packetHandlers packetHandlerManager,
   141  	tlsConf *tls.Config,
   142  	config *Config,
   143  	onClose func(),
   144  	use0RTT bool,
   145  ) (quicConn, error) {
   146  	c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  	c.packetHandlers = packetHandlers
   151  
   152  	c.tracingID = nextConnTracingID()
   153  	if c.config.Tracer != nil {
   154  		c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
   155  	}
   156  	if c.tracer != nil && c.tracer.StartedConnection != nil {
   157  		c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
   158  	}
   159  	if err := c.dial(ctx); err != nil {
   160  		return nil, err
   161  	}
   162  	return c.conn, nil
   163  }
   164  
   165  func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
   166  	srcConnID, err := connIDGenerator.GenerateConnectionID()
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	destConnID, err := generateConnectionIDForInitial()
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	c := &client{
   175  		connIDGenerator: connIDGenerator,
   176  		srcConnID:       srcConnID,
   177  		destConnID:      destConnID,
   178  		sendConn:        sendConn,
   179  		use0RTT:         use0RTT,
   180  		onClose:         onClose,
   181  		tlsConf:         tlsConf,
   182  		config:          config,
   183  		version:         config.Versions[0],
   184  		handshakeChan:   make(chan struct{}),
   185  		logger:          utils.DefaultLogger.WithPrefix("client"),
   186  	}
   187  	return c, nil
   188  }
   189  
   190  func (c *client) dial(ctx context.Context) error {
   191  	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
   192  
   193  	c.conn = newClientConnection(
   194  		c.sendConn,
   195  		c.packetHandlers,
   196  		c.destConnID,
   197  		c.srcConnID,
   198  		c.connIDGenerator,
   199  		c.config,
   200  		c.tlsConf,
   201  		c.initialPacketNumber,
   202  		c.use0RTT,
   203  		c.hasNegotiatedVersion,
   204  		c.tracer,
   205  		c.tracingID,
   206  		c.logger,
   207  		c.version,
   208  	)
   209  	c.packetHandlers.Add(c.srcConnID, c.conn)
   210  
   211  	errorChan := make(chan error, 1)
   212  	recreateChan := make(chan errCloseForRecreating)
   213  	go func() {
   214  		err := c.conn.run()
   215  		var recreateErr *errCloseForRecreating
   216  		if errors.As(err, &recreateErr) {
   217  			recreateChan <- *recreateErr
   218  			return
   219  		}
   220  		if c.onClose != nil {
   221  			c.onClose()
   222  		}
   223  		errorChan <- err // returns as soon as the connection is closed
   224  	}()
   225  
   226  	// only set when we're using 0-RTT
   227  	// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
   228  	var earlyConnChan <-chan struct{}
   229  	if c.use0RTT {
   230  		earlyConnChan = c.conn.earlyConnReady()
   231  	}
   232  
   233  	select {
   234  	case <-ctx.Done():
   235  		c.conn.destroy(nil)
   236  		return context.Cause(ctx)
   237  	case err := <-errorChan:
   238  		return err
   239  	case recreateErr := <-recreateChan:
   240  		c.initialPacketNumber = recreateErr.nextPacketNumber
   241  		c.version = recreateErr.nextVersion
   242  		c.hasNegotiatedVersion = true
   243  		return c.dial(ctx)
   244  	case <-earlyConnChan:
   245  		// ready to send 0-RTT data
   246  		return nil
   247  	case <-c.conn.HandshakeComplete():
   248  		// handshake successfully completed
   249  		return nil
   250  	}
   251  }