github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/client.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"net"
     8  
     9  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    12  )
    13  
    14  type client struct {
    15  	sendConn sendConn
    16  
    17  	use0RTT bool
    18  
    19  	packetHandlers packetHandlerManager
    20  	onClose        func()
    21  
    22  	tlsConf *tls.Config
    23  	config  *Config
    24  
    25  	connIDGenerator ConnectionIDGenerator
    26  	srcConnID       protocol.ConnectionID
    27  	destConnID      protocol.ConnectionID
    28  
    29  	initialPacketNumber  protocol.PacketNumber
    30  	hasNegotiatedVersion bool
    31  	version              protocol.Version
    32  
    33  	handshakeChan chan struct{}
    34  
    35  	conn quicConn
    36  
    37  	tracer    *logging.ConnectionTracer
    38  	tracingID uint64
    39  	logger    utils.Logger
    40  }
    41  
    42  // make it possible to mock connection ID for initial generation in the tests
    43  var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
    44  
    45  // DialAddr establishes a new QUIC connection to a server.
    46  // It resolves the address, and then creates a new UDP connection to dial the QUIC server.
    47  // When the QUIC connection is closed, this UDP connection is closed.
    48  // See Dial for more details.
    49  func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
    50  	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	tr, err := setupTransport(udpConn, tlsConf, true)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
    63  }
    64  
    65  // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
    66  // See DialAddr for more details.
    67  func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
    68  	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	tr, err := setupTransport(udpConn, tlsConf, true)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
    81  	if err != nil {
    82  		tr.Close()
    83  		return nil, err
    84  	}
    85  	return conn, nil
    86  }
    87  
    88  // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
    89  // See Dial for more details.
    90  func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
    91  	dl, err := setupTransport(c, tlsConf, false)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
    96  	if err != nil {
    97  		dl.Close()
    98  		return nil, err
    99  	}
   100  	return conn, nil
   101  }
   102  
   103  // Dial establishes a new QUIC connection to a server using a net.PacketConn.
   104  // If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
   105  // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
   106  // will be used instead of ReadFrom and WriteTo to read/write packets.
   107  // The tls.Config must define an application protocol (using NextProtos).
   108  //
   109  // This is a convenience function. More advanced use cases should instantiate a Transport,
   110  // which offers configuration options for a more fine-grained control of the connection establishment,
   111  // including reusing the underlying UDP socket for multiple QUIC connections.
   112  func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
   113  	dl, err := setupTransport(c, tlsConf, false)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	conn, err := dl.Dial(ctx, addr, tlsConf, conf)
   118  	if err != nil {
   119  		dl.Close()
   120  		return nil, err
   121  	}
   122  	return conn, nil
   123  }
   124  
   125  func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
   126  	if tlsConf == nil {
   127  		return nil, errors.New("quic: tls.Config not set")
   128  	}
   129  	return &Transport{
   130  		Conn:        c,
   131  		createdConn: createdPacketConn,
   132  		// PRIO_PACKS_TAG
   133  		// setting singleUse to false to force connection ids longer than 0 bytes
   134  		isSingleUse: false,
   135  	}, nil
   136  }
   137  
   138  func dial(
   139  	ctx context.Context,
   140  	conn sendConn,
   141  	connIDGenerator ConnectionIDGenerator,
   142  	packetHandlers packetHandlerManager,
   143  	tlsConf *tls.Config,
   144  	config *Config,
   145  	onClose func(),
   146  	use0RTT bool,
   147  ) (quicConn, error) {
   148  	c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	c.packetHandlers = packetHandlers
   153  
   154  	c.tracingID = nextConnTracingID()
   155  	if c.config.Tracer != nil {
   156  		c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
   157  	}
   158  	if c.tracer != nil && c.tracer.StartedConnection != nil {
   159  		c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
   160  	}
   161  	if err := c.dial(ctx); err != nil {
   162  		return nil, err
   163  	}
   164  	return c.conn, nil
   165  }
   166  
   167  func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
   168  	srcConnID, err := connIDGenerator.GenerateConnectionID()
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	destConnID, err := generateConnectionIDForInitial()
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	c := &client{
   177  		connIDGenerator: connIDGenerator,
   178  		srcConnID:       srcConnID,
   179  		destConnID:      destConnID,
   180  		sendConn:        sendConn,
   181  		use0RTT:         use0RTT,
   182  		onClose:         onClose,
   183  		tlsConf:         tlsConf,
   184  		config:          config,
   185  		version:         config.Versions[0],
   186  		handshakeChan:   make(chan struct{}),
   187  		logger:          utils.DefaultLogger.WithPrefix("client"),
   188  	}
   189  	return c, nil
   190  }
   191  
   192  func (c *client) dial(ctx context.Context) error {
   193  	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
   194  
   195  	c.conn = newClientConnection(
   196  		c.sendConn,
   197  		c.packetHandlers,
   198  		c.destConnID,
   199  		c.srcConnID,
   200  		c.connIDGenerator,
   201  		c.config,
   202  		c.tlsConf,
   203  		c.initialPacketNumber,
   204  		c.use0RTT,
   205  		c.hasNegotiatedVersion,
   206  		c.tracer,
   207  		c.tracingID,
   208  		c.logger,
   209  		c.version,
   210  	)
   211  	c.packetHandlers.Add(c.srcConnID, c.conn)
   212  
   213  	errorChan := make(chan error, 1)
   214  	recreateChan := make(chan errCloseForRecreating)
   215  	go func() {
   216  		err := c.conn.run()
   217  		var recreateErr *errCloseForRecreating
   218  		if errors.As(err, &recreateErr) {
   219  			recreateChan <- *recreateErr
   220  			return
   221  		}
   222  		if c.onClose != nil {
   223  			c.onClose()
   224  		}
   225  		errorChan <- err // returns as soon as the connection is closed
   226  	}()
   227  
   228  	// only set when we're using 0-RTT
   229  	// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
   230  	var earlyConnChan <-chan struct{}
   231  	if c.use0RTT {
   232  		earlyConnChan = c.conn.earlyConnReady()
   233  	}
   234  
   235  	select {
   236  	case <-ctx.Done():
   237  		c.conn.destroy(nil)
   238  		return context.Cause(ctx)
   239  	case err := <-errorChan:
   240  		return err
   241  	case recreateErr := <-recreateChan:
   242  		c.initialPacketNumber = recreateErr.nextPacketNumber
   243  		c.version = recreateErr.nextVersion
   244  		c.hasNegotiatedVersion = true
   245  		return c.dial(ctx)
   246  	case <-earlyConnChan:
   247  		// ready to send 0-RTT data
   248  		return nil
   249  	case <-c.conn.HandshakeComplete():
   250  		// handshake successfully completed
   251  		return nil
   252  	}
   253  }