github.com/sagernet/quic-go@v0.43.1-beta.1/ech/transport.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"errors"
     7  	"net"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/sagernet/quic-go/internal/protocol"
    13  	"github.com/sagernet/quic-go/internal/utils"
    14  	"github.com/sagernet/quic-go/internal/wire"
    15  	"github.com/sagernet/quic-go/logging"
    16  	"github.com/sagernet/cloudflare-tls"
    17  )
    18  
    19  var errListenerAlreadySet = errors.New("listener already set")
    20  
    21  // The Transport is the central point to manage incoming and outgoing QUIC connections.
    22  // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
    23  // This means that a single UDP socket can be used for listening for incoming connections, as well as
    24  // for dialing an arbitrary number of outgoing connections.
    25  // A Transport handles a single net.PacketConn, and offers a range of configuration options
    26  // compared to the simple helper functions like Listen and Dial that this package provides.
    27  type Transport struct {
    28  	// A single net.PacketConn can only be handled by one Transport.
    29  	// Bad things will happen if passed to multiple Transports.
    30  	//
    31  	// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
    32  	// as a *net.UDPConn does.
    33  	// 1. It enables the Don't Fragment (DF) bit on the IP header.
    34  	//    This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
    35  	// 2. It enables reading of the ECN bits from the IP header.
    36  	//    This allows the remote node to speed up its loss detection and recovery.
    37  	// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
    38  	// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
    39  	//
    40  	// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
    41  	Conn net.PacketConn
    42  
    43  	// The length of the connection ID in bytes.
    44  	// It can be any value between 1 and 20.
    45  	// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
    46  	// If unset, a 4 byte connection ID will be used.
    47  	ConnectionIDLength int
    48  
    49  	// Use for generating new connection IDs.
    50  	// This allows the application to control of the connection IDs used,
    51  	// which allows routing / load balancing based on connection IDs.
    52  	// All Connection IDs returned by the ConnectionIDGenerator MUST
    53  	// have the same length.
    54  	ConnectionIDGenerator ConnectionIDGenerator
    55  
    56  	// The StatelessResetKey is used to generate stateless reset tokens.
    57  	// If no key is configured, sending of stateless resets is disabled.
    58  	// It is highly recommended to configure a stateless reset key, as stateless resets
    59  	// allow the peer to quickly recover from crashes and reboots of this node.
    60  	// See section 10.3 of RFC 9000 for details.
    61  	StatelessResetKey *StatelessResetKey
    62  
    63  	// The TokenGeneratorKey is used to encrypt session resumption tokens.
    64  	// If no key is configured, a random key will be generated.
    65  	// If multiple servers are authoritative for the same domain, they should use the same key,
    66  	// see section 8.1.3 of RFC 9000 for details.
    67  	TokenGeneratorKey *TokenGeneratorKey
    68  
    69  	// MaxTokenAge is the maximum age of the resumption token presented during the handshake.
    70  	// These tokens allow skipping address resumption when resuming a QUIC connection,
    71  	// and are especially useful when using 0-RTT.
    72  	// If not set, it defaults to 24 hours.
    73  	// See section 8.1.3 of RFC 9000 for details.
    74  	MaxTokenAge time.Duration
    75  
    76  	// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
    77  	// This can be useful if version information is exchanged out-of-band.
    78  	// It has no effect for clients.
    79  	DisableVersionNegotiationPackets bool
    80  
    81  	// VerifySourceAddress decides if a connection attempt originating from unvalidated source
    82  	// addresses first needs to go through source address validation using QUIC's Retry mechanism,
    83  	// as described in RFC 9000 section 8.1.2.
    84  	// Note that the address passed to this callback is unvalidated, and might be spoofed in case
    85  	// of an attack.
    86  	// Validating the source address adds one additional network roundtrip to the handshake,
    87  	// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
    88  	// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
    89  	// implementation of this callback (negating its return value).
    90  	VerifySourceAddress func(net.Addr) bool
    91  
    92  	// A Tracer traces events that don't belong to a single QUIC connection.
    93  	// Tracer.Close is called when the transport is closed.
    94  	Tracer *logging.Tracer
    95  
    96  	handlerMap packetHandlerManager
    97  
    98  	mutex    sync.Mutex
    99  	initOnce sync.Once
   100  	initErr  error
   101  
   102  	// Set in init.
   103  	// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
   104  	connIDLen int
   105  	// Set in init.
   106  	// If no ConnectionIDGenerator is set, this is set to a default.
   107  	connIDGenerator ConnectionIDGenerator
   108  
   109  	server *baseServer
   110  
   111  	conn rawConn
   112  
   113  	closeQueue          chan closePacket
   114  	statelessResetQueue chan receivedPacket
   115  
   116  	listening   chan struct{} // is closed when listen returns
   117  	closed      bool
   118  	createdConn bool
   119  	isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
   120  
   121  	readingNonQUICPackets atomic.Bool
   122  	nonQUICPackets        chan receivedPacket
   123  
   124  	logger utils.Logger
   125  }
   126  
   127  // Listen starts listening for incoming QUIC connections.
   128  // There can only be a single listener on any net.PacketConn.
   129  // Listen may only be called again after the current Listener was closed.
   130  func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
   131  	s, err := t.createServer(tlsConf, conf, false)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	return &Listener{baseServer: s}, nil
   136  }
   137  
   138  // ListenEarly starts listening for incoming QUIC connections.
   139  // There can only be a single listener on any net.PacketConn.
   140  // Listen may only be called again after the current Listener was closed.
   141  func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
   142  	s, err := t.createServer(tlsConf, conf, true)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	return &EarlyListener{baseServer: s}, nil
   147  }
   148  
   149  func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) {
   150  	if tlsConf == nil {
   151  		return nil, errors.New("quic: tls.Config not set")
   152  	}
   153  	if err := validateConfig(conf); err != nil {
   154  		return nil, err
   155  	}
   156  
   157  	t.mutex.Lock()
   158  	defer t.mutex.Unlock()
   159  
   160  	if t.server != nil {
   161  		return nil, errListenerAlreadySet
   162  	}
   163  	conf = populateConfig(conf)
   164  	if err := t.init(false); err != nil {
   165  		return nil, err
   166  	}
   167  	s := newServer(
   168  		t.conn,
   169  		t.handlerMap,
   170  		t.connIDGenerator,
   171  		tlsConf,
   172  		conf,
   173  		t.Tracer,
   174  		t.closeServer,
   175  		*t.TokenGeneratorKey,
   176  		t.MaxTokenAge,
   177  		t.VerifySourceAddress,
   178  		t.DisableVersionNegotiationPackets,
   179  		allow0RTT,
   180  	)
   181  	t.server = s
   182  	return s, nil
   183  }
   184  
   185  // Dial dials a new connection to a remote host (not using 0-RTT).
   186  func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
   187  	return t.dial(ctx, addr, "", tlsConf, conf, false)
   188  }
   189  
   190  // DialEarly dials a new connection, attempting to use 0-RTT if possible.
   191  func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
   192  	return t.dial(ctx, addr, "", tlsConf, conf, true)
   193  }
   194  
   195  func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
   196  	if err := validateConfig(conf); err != nil {
   197  		return nil, err
   198  	}
   199  	conf = populateConfig(conf)
   200  	if err := t.init(t.isSingleUse); err != nil {
   201  		return nil, err
   202  	}
   203  	var onClose func()
   204  	if t.isSingleUse {
   205  		onClose = func() { t.Close() }
   206  	}
   207  	tlsConf = tlsConf.Clone()
   208  	// setTLSConfigServerName(tlsConf, addr, host)
   209  	return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
   210  }
   211  
   212  func (t *Transport) init(allowZeroLengthConnIDs bool) error {
   213  	t.initOnce.Do(func() {
   214  		var conn rawConn
   215  		if c, ok := t.Conn.(rawConn); ok {
   216  			conn = c
   217  		} else {
   218  			var err error
   219  			conn, err = wrapConn(t.Conn)
   220  			if err != nil {
   221  				t.initErr = err
   222  				return
   223  			}
   224  		}
   225  
   226  		t.logger = utils.DefaultLogger // TODO: make this configurable
   227  		t.conn = conn
   228  		t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
   229  		t.listening = make(chan struct{})
   230  
   231  		t.closeQueue = make(chan closePacket, 4)
   232  		t.statelessResetQueue = make(chan receivedPacket, 4)
   233  		if t.TokenGeneratorKey == nil {
   234  			var key TokenGeneratorKey
   235  			if _, err := rand.Read(key[:]); err != nil {
   236  				t.initErr = err
   237  				return
   238  			}
   239  			t.TokenGeneratorKey = &key
   240  		}
   241  
   242  		if t.ConnectionIDGenerator != nil {
   243  			t.connIDGenerator = t.ConnectionIDGenerator
   244  			t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
   245  		} else {
   246  			connIDLen := t.ConnectionIDLength
   247  			if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
   248  				connIDLen = protocol.DefaultConnectionIDLength
   249  			}
   250  			t.connIDLen = connIDLen
   251  			t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
   252  		}
   253  
   254  		// getMultiplexer().AddConn(t.Conn)
   255  		go t.listen(conn)
   256  		go t.runSendQueue()
   257  	})
   258  	return t.initErr
   259  }
   260  
   261  // WriteTo sends a packet on the underlying connection.
   262  func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
   263  	if err := t.init(false); err != nil {
   264  		return 0, err
   265  	}
   266  	return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
   267  }
   268  
   269  func (t *Transport) enqueueClosePacket(p closePacket) {
   270  	select {
   271  	case t.closeQueue <- p:
   272  	default:
   273  		// Oops, we're backlogged.
   274  		// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
   275  	}
   276  }
   277  
   278  func (t *Transport) runSendQueue() {
   279  	for {
   280  		select {
   281  		case <-t.listening:
   282  			return
   283  		case p := <-t.closeQueue:
   284  			t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported)
   285  		case p := <-t.statelessResetQueue:
   286  			t.sendStatelessReset(p)
   287  		}
   288  	}
   289  }
   290  
   291  // Close closes the underlying connection.
   292  // If any listener was started, it will be closed as well.
   293  // It is invalid to start new listeners or connections after that.
   294  func (t *Transport) Close() error {
   295  	t.close(errors.New("closing"))
   296  	if t.createdConn {
   297  		if err := t.Conn.Close(); err != nil {
   298  			return err
   299  		}
   300  	} else if t.conn != nil {
   301  		t.conn.SetReadDeadline(time.Now())
   302  		defer func() { t.conn.SetReadDeadline(time.Time{}) }()
   303  	}
   304  	if t.listening != nil {
   305  		<-t.listening // wait until listening returns
   306  	}
   307  	return nil
   308  }
   309  
   310  func (t *Transport) closeServer() {
   311  	t.mutex.Lock()
   312  	t.server = nil
   313  	if t.isSingleUse {
   314  		t.closed = true
   315  	}
   316  	t.mutex.Unlock()
   317  	if t.createdConn {
   318  		t.Conn.Close()
   319  	}
   320  	if t.isSingleUse {
   321  		t.conn.SetReadDeadline(time.Now())
   322  		defer func() { t.conn.SetReadDeadline(time.Time{}) }()
   323  		<-t.listening // wait until listening returns
   324  	}
   325  }
   326  
   327  func (t *Transport) close(e error) {
   328  	t.mutex.Lock()
   329  	defer t.mutex.Unlock()
   330  	if t.closed {
   331  		return
   332  	}
   333  
   334  	if t.handlerMap != nil {
   335  		t.handlerMap.Close(e)
   336  	}
   337  	if t.server != nil {
   338  		t.server.close(e, false)
   339  	}
   340  	if t.Tracer != nil && t.Tracer.Close != nil {
   341  		t.Tracer.Close()
   342  	}
   343  	t.closed = true
   344  }
   345  
   346  func (t *Transport) listen(conn rawConn) {
   347  	defer close(t.listening)
   348  	// defer getMultiplexer().RemoveConn(t.Conn)
   349  
   350  	for {
   351  		p, err := conn.ReadPacket()
   352  		//nolint:staticcheck // SA1019 ignore this!
   353  		// TODO: This code is used to ignore wsa errors on Windows.
   354  		// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
   355  		// See https://github.com/sagernet/quic-go/issues/1737 for details.
   356  		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
   357  			t.mutex.Lock()
   358  			closed := t.closed
   359  			t.mutex.Unlock()
   360  			if closed {
   361  				return
   362  			}
   363  			t.logger.Debugf("Temporary error reading from conn: %w", err)
   364  			continue
   365  		}
   366  		if err != nil {
   367  			// Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer.
   368  			if isRecvMsgSizeErr(err) {
   369  				continue
   370  			}
   371  			t.close(err)
   372  			return
   373  		}
   374  		t.handlePacket(p)
   375  	}
   376  }
   377  
   378  func (t *Transport) handlePacket(p receivedPacket) {
   379  	if len(p.data) == 0 {
   380  		return
   381  	}
   382  	if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
   383  		t.handleNonQUICPacket(p)
   384  		return
   385  	}
   386  	connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
   387  	if err != nil {
   388  		t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   389  		if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
   390  			t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
   391  		}
   392  		p.buffer.MaybeRelease()
   393  		return
   394  	}
   395  
   396  	// If there's a connection associated with the connection ID, pass the packet there.
   397  	if handler, ok := t.handlerMap.Get(connID); ok {
   398  		handler.handlePacket(p)
   399  		return
   400  	}
   401  	// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
   402  	// packets that cannot be associated with any connections, and for packets that can't be decrypted.
   403  	// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
   404  	// existing connection, it is dropped there if if it can't be decrypted.
   405  	// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
   406  	// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
   407  	// it is to be expected that the next stateless reset will be correctly detected.
   408  	if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
   409  		return
   410  	}
   411  	if !wire.IsLongHeaderPacket(p.data[0]) {
   412  		t.maybeSendStatelessReset(p)
   413  		return
   414  	}
   415  
   416  	t.mutex.Lock()
   417  	defer t.mutex.Unlock()
   418  	if t.server == nil { // no server set
   419  		t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
   420  		return
   421  	}
   422  	t.server.handlePacket(p)
   423  }
   424  
   425  func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
   426  	if t.StatelessResetKey == nil {
   427  		p.buffer.Release()
   428  		return
   429  	}
   430  
   431  	// Don't send a stateless reset in response to very small packets.
   432  	// This includes packets that could be stateless resets.
   433  	if len(p.data) <= protocol.MinStatelessResetSize {
   434  		p.buffer.Release()
   435  		return
   436  	}
   437  
   438  	select {
   439  	case t.statelessResetQueue <- p:
   440  	default:
   441  		// it's fine to not send a stateless reset when we're busy
   442  		p.buffer.Release()
   443  	}
   444  }
   445  
   446  func (t *Transport) sendStatelessReset(p receivedPacket) {
   447  	defer p.buffer.Release()
   448  
   449  	connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
   450  	if err != nil {
   451  		t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   452  		return
   453  	}
   454  	token := t.handlerMap.GetStatelessResetToken(connID)
   455  	t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
   456  	data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
   457  	rand.Read(data)
   458  	data[0] = (data[0] & 0x7f) | 0x40
   459  	data = append(data, token[:]...)
   460  	if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
   461  		t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
   462  	}
   463  }
   464  
   465  func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
   466  	// stateless resets are always short header packets
   467  	if wire.IsLongHeaderPacket(data[0]) {
   468  		return false
   469  	}
   470  	if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
   471  		return false
   472  	}
   473  
   474  	token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
   475  	if conn, ok := t.handlerMap.GetByResetToken(token); ok {
   476  		t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
   477  		go conn.destroy(&StatelessResetError{Token: token})
   478  		return true
   479  	}
   480  	return false
   481  }
   482  
   483  func (t *Transport) handleNonQUICPacket(p receivedPacket) {
   484  	// Strictly speaking, this is racy,
   485  	// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
   486  	if !t.readingNonQUICPackets.Load() {
   487  		return
   488  	}
   489  	select {
   490  	case t.nonQUICPackets <- p:
   491  	default:
   492  		if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
   493  			t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
   494  		}
   495  	}
   496  }
   497  
   498  const maxQueuedNonQUICPackets = 32
   499  
   500  // ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
   501  // The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
   502  // Note that this is stricter than the detection logic defined in RFC 9443.
   503  func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
   504  	if err := t.init(false); err != nil {
   505  		return 0, nil, err
   506  	}
   507  	if !t.readingNonQUICPackets.Load() {
   508  		t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
   509  		t.readingNonQUICPackets.Store(true)
   510  	}
   511  	select {
   512  	case <-ctx.Done():
   513  		return 0, nil, ctx.Err()
   514  	case p := <-t.nonQUICPackets:
   515  		n := copy(b, p.data)
   516  		return n, p.remoteAddr, nil
   517  	case <-t.listening:
   518  		return 0, nil, errors.New("closed")
   519  	}
   520  }
   521  
   522  func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
   523  	// If no ServerName is set, infer the ServerName from the host we're connecting to.
   524  	if tlsConf.ServerName != "" {
   525  		return
   526  	}
   527  	if host == "" {
   528  		if udpAddr, ok := addr.(*net.UDPAddr); ok {
   529  			tlsConf.ServerName = udpAddr.IP.String()
   530  			return
   531  		}
   532  	}
   533  	h, _, err := net.SplitHostPort(host)
   534  	if err != nil { // This happens if the host doesn't contain a port number.
   535  		tlsConf.ServerName = host
   536  		return
   537  	}
   538  	tlsConf.ServerName = h
   539  }