github.com/MerlinKodo/quic-go@v0.39.2/transport.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/MerlinKodo/quic-go/internal/protocol"
    14  	"github.com/MerlinKodo/quic-go/internal/utils"
    15  	"github.com/MerlinKodo/quic-go/internal/wire"
    16  	"github.com/MerlinKodo/quic-go/logging"
    17  )
    18  
    19  var errListenerAlreadySet = errors.New("listener already set")
    20  
    21  // The Transport is the central point to manage incoming and outgoing QUIC connections.
    22  // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
    23  // This means that a single UDP socket can be used for listening for incoming connections, as well as
    24  // for dialing an arbitrary number of outgoing connections.
    25  // A Transport handles a single net.PacketConn, and offers a range of configuration options
    26  // compared to the simple helper functions like Listen and Dial that this package provides.
    27  type Transport struct {
    28  	// A single net.PacketConn can only be handled by one Transport.
    29  	// Bad things will happen if passed to multiple Transports.
    30  	//
    31  	// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
    32  	// as a *net.UDPConn does.
    33  	// 1. It enables the Don't Fragment (DF) bit on the IP header.
    34  	//    This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
    35  	// 2. It enables reading of the ECN bits from the IP header.
    36  	//    This allows the remote node to speed up its loss detection and recovery.
    37  	// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
    38  	// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
    39  	//
    40  	// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
    41  	Conn net.PacketConn
    42  
    43  	// The length of the connection ID in bytes.
    44  	// It can be 0, or any value between 4 and 18.
    45  	// If unset, a 4 byte connection ID will be used.
    46  	ConnectionIDLength int
    47  
    48  	// Use for generating new connection IDs.
    49  	// This allows the application to control of the connection IDs used,
    50  	// which allows routing / load balancing based on connection IDs.
    51  	// All Connection IDs returned by the ConnectionIDGenerator MUST
    52  	// have the same length.
    53  	ConnectionIDGenerator ConnectionIDGenerator
    54  
    55  	// The StatelessResetKey is used to generate stateless reset tokens.
    56  	// If no key is configured, sending of stateless resets is disabled.
    57  	// It is highly recommended to configure a stateless reset key, as stateless resets
    58  	// allow the peer to quickly recover from crashes and reboots of this node.
    59  	// See section 10.3 of RFC 9000 for details.
    60  	StatelessResetKey *StatelessResetKey
    61  
    62  	// The TokenGeneratorKey is used to encrypt session resumption tokens.
    63  	// If no key is configured, a random key will be generated.
    64  	// If multiple servers are authoritative for the same domain, they should use the same key,
    65  	// see section 8.1.3 of RFC 9000 for details.
    66  	TokenGeneratorKey *TokenGeneratorKey
    67  
    68  	// MaxTokenAge is the maximum age of the resumption token presented during the handshake.
    69  	// These tokens allow skipping address resumption when resuming a QUIC connection,
    70  	// and are especially useful when using 0-RTT.
    71  	// If not set, it defaults to 24 hours.
    72  	// See section 8.1.3 of RFC 9000 for details.
    73  	MaxTokenAge time.Duration
    74  
    75  	// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
    76  	// This can be useful if version information is exchanged out-of-band.
    77  	// It has no effect for clients.
    78  	DisableVersionNegotiationPackets bool
    79  
    80  	// A Tracer traces events that don't belong to a single QUIC connection.
    81  	Tracer *logging.Tracer
    82  
    83  	handlerMap packetHandlerManager
    84  
    85  	mutex    sync.Mutex
    86  	initOnce sync.Once
    87  	initErr  error
    88  
    89  	// Set in init.
    90  	// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
    91  	connIDLen int
    92  	// Set in init.
    93  	// If no ConnectionIDGenerator is set, this is set to a default.
    94  	connIDGenerator ConnectionIDGenerator
    95  
    96  	server *baseServer
    97  
    98  	conn rawConn
    99  
   100  	closeQueue          chan closePacket
   101  	statelessResetQueue chan receivedPacket
   102  
   103  	listening   chan struct{} // is closed when listen returns
   104  	closed      bool
   105  	createdConn bool
   106  	isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
   107  
   108  	readingNonQUICPackets atomic.Bool
   109  	nonQUICPackets        chan receivedPacket
   110  
   111  	logger utils.Logger
   112  }
   113  
   114  // Listen starts listening for incoming QUIC connections.
   115  // There can only be a single listener on any net.PacketConn.
   116  // Listen may only be called again after the current Listener was closed.
   117  func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
   118  	s, err := t.createServer(tlsConf, conf, false)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	return &Listener{baseServer: s}, nil
   123  }
   124  
   125  // ListenEarly starts listening for incoming QUIC connections.
   126  // There can only be a single listener on any net.PacketConn.
   127  // Listen may only be called again after the current Listener was closed.
   128  func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
   129  	s, err := t.createServer(tlsConf, conf, true)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	return &EarlyListener{baseServer: s}, nil
   134  }
   135  
   136  func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) {
   137  	if tlsConf == nil {
   138  		return nil, errors.New("quic: tls.Config not set")
   139  	}
   140  	if err := validateConfig(conf); err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	t.mutex.Lock()
   145  	defer t.mutex.Unlock()
   146  
   147  	if t.server != nil {
   148  		return nil, errListenerAlreadySet
   149  	}
   150  	conf = populateServerConfig(conf)
   151  	if err := t.init(false); err != nil {
   152  		return nil, err
   153  	}
   154  	s := newServer(
   155  		t.conn,
   156  		t.handlerMap,
   157  		t.connIDGenerator,
   158  		tlsConf,
   159  		conf,
   160  		t.Tracer,
   161  		t.closeServer,
   162  		*t.TokenGeneratorKey,
   163  		t.MaxTokenAge,
   164  		t.DisableVersionNegotiationPackets,
   165  		allow0RTT,
   166  	)
   167  	t.server = s
   168  	return s, nil
   169  }
   170  
   171  // Dial dials a new connection to a remote host (not using 0-RTT).
   172  func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
   173  	return t.dial(ctx, addr, "", tlsConf, conf, false)
   174  }
   175  
   176  // DialEarly dials a new connection, attempting to use 0-RTT if possible.
   177  func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
   178  	return t.dial(ctx, addr, "", tlsConf, conf, true)
   179  }
   180  
   181  func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
   182  	if err := validateConfig(conf); err != nil {
   183  		return nil, err
   184  	}
   185  	conf = populateConfig(conf)
   186  	if err := t.init(t.isSingleUse); err != nil {
   187  		return nil, err
   188  	}
   189  	var onClose func()
   190  	if t.isSingleUse {
   191  		onClose = func() { t.Close() }
   192  	}
   193  	tlsConf = tlsConf.Clone()
   194  	tlsConf.MinVersion = tls.VersionTLS13
   195  	//setTLSConfigServerName(tlsConf, addr, host)
   196  	return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
   197  }
   198  
   199  func (t *Transport) init(allowZeroLengthConnIDs bool) error {
   200  	t.initOnce.Do(func() {
   201  		var conn rawConn
   202  		if c, ok := t.Conn.(rawConn); ok {
   203  			conn = c
   204  		} else {
   205  			var err error
   206  			conn, err = wrapConn(t.Conn)
   207  			if err != nil {
   208  				t.initErr = err
   209  				return
   210  			}
   211  		}
   212  
   213  		t.logger = utils.DefaultLogger // TODO: make this configurable
   214  		t.conn = conn
   215  		t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
   216  		t.listening = make(chan struct{})
   217  
   218  		t.closeQueue = make(chan closePacket, 4)
   219  		t.statelessResetQueue = make(chan receivedPacket, 4)
   220  		if t.TokenGeneratorKey == nil {
   221  			var key TokenGeneratorKey
   222  			if _, err := rand.Read(key[:]); err != nil {
   223  				t.initErr = err
   224  				return
   225  			}
   226  			t.TokenGeneratorKey = &key
   227  		}
   228  
   229  		if t.ConnectionIDGenerator != nil {
   230  			t.connIDGenerator = t.ConnectionIDGenerator
   231  			t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
   232  		} else {
   233  			connIDLen := t.ConnectionIDLength
   234  			if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
   235  				connIDLen = protocol.DefaultConnectionIDLength
   236  			}
   237  			t.connIDLen = connIDLen
   238  			t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
   239  		}
   240  
   241  		//getMultiplexer().AddConn(t.Conn)
   242  		go t.listen(conn)
   243  		go t.runSendQueue()
   244  	})
   245  	return t.initErr
   246  }
   247  
   248  // WriteTo sends a packet on the underlying connection.
   249  func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
   250  	if err := t.init(false); err != nil {
   251  		return 0, err
   252  	}
   253  	return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
   254  }
   255  
   256  func (t *Transport) enqueueClosePacket(p closePacket) {
   257  	select {
   258  	case t.closeQueue <- p:
   259  	default:
   260  		// Oops, we're backlogged.
   261  		// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
   262  	}
   263  }
   264  
   265  func (t *Transport) runSendQueue() {
   266  	for {
   267  		select {
   268  		case <-t.listening:
   269  			return
   270  		case p := <-t.closeQueue:
   271  			t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported)
   272  		case p := <-t.statelessResetQueue:
   273  			t.sendStatelessReset(p)
   274  		}
   275  	}
   276  }
   277  
   278  // Close closes the underlying connection and waits until listen has returned.
   279  // It is invalid to start new listeners or connections after that.
   280  func (t *Transport) Close() error {
   281  	t.close(errors.New("closing"))
   282  	if t.createdConn {
   283  		if err := t.Conn.Close(); err != nil {
   284  			return err
   285  		}
   286  	} else if t.conn != nil {
   287  		t.conn.SetReadDeadline(time.Now())
   288  		defer func() { t.conn.SetReadDeadline(time.Time{}) }()
   289  	}
   290  	if t.listening != nil {
   291  		<-t.listening // wait until listening returns
   292  	}
   293  	return nil
   294  }
   295  
   296  func (t *Transport) closeServer() {
   297  	t.handlerMap.CloseServer()
   298  	t.mutex.Lock()
   299  	t.server = nil
   300  	if t.isSingleUse {
   301  		t.closed = true
   302  	}
   303  	t.mutex.Unlock()
   304  	if t.createdConn {
   305  		t.Conn.Close()
   306  	}
   307  	if t.isSingleUse {
   308  		t.conn.SetReadDeadline(time.Now())
   309  		defer func() { t.conn.SetReadDeadline(time.Time{}) }()
   310  		<-t.listening // wait until listening returns
   311  	}
   312  }
   313  
   314  func (t *Transport) close(e error) {
   315  	t.mutex.Lock()
   316  	defer t.mutex.Unlock()
   317  	if t.closed {
   318  		return
   319  	}
   320  
   321  	if t.handlerMap != nil {
   322  		t.handlerMap.Close(e)
   323  	}
   324  	if t.server != nil {
   325  		t.server.setCloseError(e)
   326  	}
   327  	t.closed = true
   328  }
   329  
   330  func (t *Transport) listen(conn rawConn) {
   331  	defer close(t.listening)
   332  	//defer getMultiplexer().RemoveConn(t.Conn)
   333  
   334  	for {
   335  		p, err := conn.ReadPacket()
   336  		//nolint:staticcheck // SA1019 ignore this!
   337  		// TODO: This code is used to ignore wsa errors on Windows.
   338  		// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
   339  		// See https://github.com/MerlinKodo/quic-go/issues/1737 for details.
   340  		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
   341  			t.mutex.Lock()
   342  			closed := t.closed
   343  			t.mutex.Unlock()
   344  			if closed {
   345  				return
   346  			}
   347  			t.logger.Debugf("Temporary error reading from conn: %w", err)
   348  			continue
   349  		}
   350  		if err != nil {
   351  			// Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer.
   352  			if isRecvMsgSizeErr(err) {
   353  				continue
   354  			}
   355  			t.close(err)
   356  			return
   357  		}
   358  		t.handlePacket(p)
   359  	}
   360  }
   361  
   362  func (t *Transport) handlePacket(p receivedPacket) {
   363  	if len(p.data) == 0 {
   364  		return
   365  	}
   366  	if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
   367  		t.handleNonQUICPacket(p)
   368  		return
   369  	}
   370  	connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
   371  	if err != nil {
   372  		t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   373  		if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
   374  			t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
   375  		}
   376  		p.buffer.MaybeRelease()
   377  		return
   378  	}
   379  
   380  	if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
   381  		return
   382  	}
   383  	if handler, ok := t.handlerMap.Get(connID); ok {
   384  		handler.handlePacket(p)
   385  		return
   386  	}
   387  	if !wire.IsLongHeaderPacket(p.data[0]) {
   388  		t.maybeSendStatelessReset(p)
   389  		return
   390  	}
   391  
   392  	t.mutex.Lock()
   393  	defer t.mutex.Unlock()
   394  	if t.server == nil { // no server set
   395  		t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
   396  		return
   397  	}
   398  	t.server.handlePacket(p)
   399  }
   400  
   401  func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
   402  	if t.StatelessResetKey == nil {
   403  		p.buffer.Release()
   404  		return
   405  	}
   406  
   407  	// Don't send a stateless reset in response to very small packets.
   408  	// This includes packets that could be stateless resets.
   409  	if len(p.data) <= protocol.MinStatelessResetSize {
   410  		p.buffer.Release()
   411  		return
   412  	}
   413  
   414  	select {
   415  	case t.statelessResetQueue <- p:
   416  	default:
   417  		// it's fine to not send a stateless reset when we're busy
   418  		p.buffer.Release()
   419  	}
   420  }
   421  
   422  func (t *Transport) sendStatelessReset(p receivedPacket) {
   423  	defer p.buffer.Release()
   424  
   425  	connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
   426  	if err != nil {
   427  		t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   428  		return
   429  	}
   430  	token := t.handlerMap.GetStatelessResetToken(connID)
   431  	t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
   432  	data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
   433  	rand.Read(data)
   434  	data[0] = (data[0] & 0x7f) | 0x40
   435  	data = append(data, token[:]...)
   436  	if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
   437  		t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
   438  	}
   439  }
   440  
   441  func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
   442  	// stateless resets are always short header packets
   443  	if wire.IsLongHeaderPacket(data[0]) {
   444  		return false
   445  	}
   446  	if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
   447  		return false
   448  	}
   449  
   450  	token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
   451  	if conn, ok := t.handlerMap.GetByResetToken(token); ok {
   452  		t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
   453  		go conn.destroy(&StatelessResetError{Token: token})
   454  		return true
   455  	}
   456  	return false
   457  }
   458  
   459  func (t *Transport) handleNonQUICPacket(p receivedPacket) {
   460  	// Strictly speaking, this is racy,
   461  	// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
   462  	if !t.readingNonQUICPackets.Load() {
   463  		return
   464  	}
   465  	select {
   466  	case t.nonQUICPackets <- p:
   467  	default:
   468  		if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
   469  			t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
   470  		}
   471  	}
   472  }
   473  
   474  const maxQueuedNonQUICPackets = 32
   475  
   476  // ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
   477  // The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
   478  // Note that this is stricter than the detection logic defined in RFC 9443.
   479  func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
   480  	if err := t.init(false); err != nil {
   481  		return 0, nil, err
   482  	}
   483  	if !t.readingNonQUICPackets.Load() {
   484  		t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
   485  		t.readingNonQUICPackets.Store(true)
   486  	}
   487  	select {
   488  	case <-ctx.Done():
   489  		return 0, nil, ctx.Err()
   490  	case p := <-t.nonQUICPackets:
   491  		n := copy(b, p.data)
   492  		return n, p.remoteAddr, nil
   493  	case <-t.listening:
   494  		return 0, nil, errors.New("closed")
   495  	}
   496  }
   497  
   498  func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
   499  	// If no ServerName is set, infer the ServerName from the host we're connecting to.
   500  	if tlsConf.ServerName != "" {
   501  		return
   502  	}
   503  	if host == "" {
   504  		if udpAddr, ok := addr.(*net.UDPAddr); ok {
   505  			tlsConf.ServerName = udpAddr.IP.String()
   506  			return
   507  		}
   508  	}
   509  	h, _, err := net.SplitHostPort(host)
   510  	if err != nil { // This happens if the host doesn't contain a port number.
   511  		tlsConf.ServerName = host
   512  		return
   513  	}
   514  	tlsConf.ServerName = h
   515  }
   516  
   517  func (t *Transport) SetCreatedConn(createdConn bool) {
   518  	t.createdConn = createdConn
   519  }
   520  
   521  func (t *Transport) SetSingleUse(isSingleUse bool) {
   522  	t.isSingleUse = isSingleUse
   523  }