github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/transport.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"log"
     9  	"net"
    10  	"os"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/mikelsr/quic-go/internal/wire"
    17  
    18  	"github.com/mikelsr/quic-go/internal/protocol"
    19  	"github.com/mikelsr/quic-go/internal/utils"
    20  	"github.com/mikelsr/quic-go/logging"
    21  )
    22  
    23  // The Transport is the central point to manage incoming and outgoing QUIC connections.
    24  // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
    25  // This means that a single UDP socket can be used for listening for incoming connections, as well as
    26  // for dialing an arbitrary number of outgoing connections.
    27  // A Transport handles a single net.PacketConn, and offers a range of configuration options
    28  // compared to the simple helper functions like Listen and Dial that this package provides.
    29  type Transport struct {
    30  	// A single net.PacketConn can only be handled by one Transport.
    31  	// Bad things will happen if passed to multiple Transports.
    32  	//
    33  	// If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations.
    34  	// After passing the connection to the Transport, it's invalid to call ReadFrom on the connection.
    35  	// Calling WriteTo is only valid on the connection returned by OptimizeConn.
    36  	Conn net.PacketConn
    37  
    38  	// The length of the connection ID in bytes.
    39  	// It can be 0, or any value between 4 and 18.
    40  	// If unset, a 4 byte connection ID will be used.
    41  	ConnectionIDLength int
    42  
    43  	// Use for generating new connection IDs.
    44  	// This allows the application to control of the connection IDs used,
    45  	// which allows routing / load balancing based on connection IDs.
    46  	// All Connection IDs returned by the ConnectionIDGenerator MUST
    47  	// have the same length.
    48  	ConnectionIDGenerator ConnectionIDGenerator
    49  
    50  	// The StatelessResetKey is used to generate stateless reset tokens.
    51  	// If no key is configured, sending of stateless resets is disabled.
    52  	// It is highly recommended to configure a stateless reset key, as stateless resets
    53  	// allow the peer to quickly recover from crashes and reboots of this node.
    54  	// See section 10.3 of RFC 9000 for details.
    55  	StatelessResetKey *StatelessResetKey
    56  
    57  	// A Tracer traces events that don't belong to a single QUIC connection.
    58  	Tracer logging.Tracer
    59  
    60  	handlerMap packetHandlerManager
    61  
    62  	mutex    sync.Mutex
    63  	initOnce sync.Once
    64  	initErr  error
    65  
    66  	// Set in init.
    67  	// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
    68  	connIDLen int
    69  	// Set in init.
    70  	// If no ConnectionIDGenerator is set, this is set to a default.
    71  	connIDGenerator ConnectionIDGenerator
    72  
    73  	server unknownPacketHandler
    74  
    75  	conn rawConn
    76  
    77  	closeQueue          chan closePacket
    78  	statelessResetQueue chan receivedPacket
    79  
    80  	listening   chan struct{} // is closed when listen returns
    81  	closed      bool
    82  	createdConn bool
    83  	isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
    84  
    85  	logger utils.Logger
    86  }
    87  
    88  // Listen starts listening for incoming QUIC connections.
    89  // There can only be a single listener on any net.PacketConn.
    90  // Listen may only be called again after the current Listener was closed.
    91  func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
    92  	if tlsConf == nil {
    93  		return nil, errors.New("quic: tls.Config not set")
    94  	}
    95  	if err := validateConfig(conf); err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	t.mutex.Lock()
   100  	defer t.mutex.Unlock()
   101  
   102  	if t.server != nil {
   103  		return nil, errListenerAlreadySet
   104  	}
   105  	conf = populateServerConfig(conf)
   106  	if err := t.init(true); err != nil {
   107  		return nil, err
   108  	}
   109  	s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	t.server = s
   114  	return &Listener{baseServer: s}, nil
   115  }
   116  
   117  // ListenEarly starts listening for incoming QUIC connections.
   118  // There can only be a single listener on any net.PacketConn.
   119  // Listen may only be called again after the current Listener was closed.
   120  func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
   121  	if tlsConf == nil {
   122  		return nil, errors.New("quic: tls.Config not set")
   123  	}
   124  	if err := validateConfig(conf); err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	t.mutex.Lock()
   129  	defer t.mutex.Unlock()
   130  
   131  	if t.server != nil {
   132  		return nil, errListenerAlreadySet
   133  	}
   134  	conf = populateServerConfig(conf)
   135  	if err := t.init(true); err != nil {
   136  		return nil, err
   137  	}
   138  	s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  	t.server = s
   143  	return &EarlyListener{baseServer: s}, nil
   144  }
   145  
   146  // Dial dials a new connection to a remote host (not using 0-RTT).
   147  func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
   148  	if err := validateConfig(conf); err != nil {
   149  		return nil, err
   150  	}
   151  	conf = populateConfig(conf)
   152  	if err := t.init(false); err != nil {
   153  		return nil, err
   154  	}
   155  	var onClose func()
   156  	if t.isSingleUse {
   157  		onClose = func() { t.Close() }
   158  	}
   159  	tlsConf = tlsConf.Clone()
   160  	tlsConf.MinVersion = tls.VersionTLS13
   161  	return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
   162  }
   163  
   164  // DialEarly dials a new connection, attempting to use 0-RTT if possible.
   165  func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
   166  	if err := validateConfig(conf); err != nil {
   167  		return nil, err
   168  	}
   169  	conf = populateConfig(conf)
   170  	if err := t.init(false); err != nil {
   171  		return nil, err
   172  	}
   173  	var onClose func()
   174  	if t.isSingleUse {
   175  		onClose = func() { t.Close() }
   176  	}
   177  	tlsConf = tlsConf.Clone()
   178  	tlsConf.MinVersion = tls.VersionTLS13
   179  	return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
   180  }
   181  
   182  func (t *Transport) init(isServer bool) error {
   183  	t.initOnce.Do(func() {
   184  		var conn rawConn
   185  		if c, ok := t.Conn.(rawConn); ok {
   186  			conn = c
   187  		} else {
   188  			var err error
   189  			conn, err = wrapConn(t.Conn)
   190  			if err != nil {
   191  				t.initErr = err
   192  				return
   193  			}
   194  		}
   195  		t.conn = conn
   196  
   197  		t.logger = utils.DefaultLogger // TODO: make this configurable
   198  		t.conn = conn
   199  		t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
   200  		t.listening = make(chan struct{})
   201  
   202  		t.closeQueue = make(chan closePacket, 4)
   203  		t.statelessResetQueue = make(chan receivedPacket, 4)
   204  
   205  		if t.ConnectionIDGenerator != nil {
   206  			t.connIDGenerator = t.ConnectionIDGenerator
   207  			t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
   208  		} else {
   209  			connIDLen := t.ConnectionIDLength
   210  			if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
   211  				connIDLen = protocol.DefaultConnectionIDLength
   212  			}
   213  			t.connIDLen = connIDLen
   214  			t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
   215  		}
   216  
   217  		getMultiplexer().AddConn(t.Conn)
   218  		go t.listen(conn)
   219  		go t.runSendQueue()
   220  	})
   221  	return t.initErr
   222  }
   223  
   224  func (t *Transport) enqueueClosePacket(p closePacket) {
   225  	select {
   226  	case t.closeQueue <- p:
   227  	default:
   228  		// Oops, we're backlogged.
   229  		// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
   230  	}
   231  }
   232  
   233  func (t *Transport) runSendQueue() {
   234  	for {
   235  		select {
   236  		case <-t.listening:
   237  			return
   238  		case p := <-t.closeQueue:
   239  			t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB())
   240  		case p := <-t.statelessResetQueue:
   241  			t.sendStatelessReset(p)
   242  		}
   243  	}
   244  }
   245  
   246  // Close closes the underlying connection and waits until listen has returned.
   247  // It is invalid to start new listeners or connections after that.
   248  func (t *Transport) Close() error {
   249  	t.close(errors.New("closing"))
   250  	if t.createdConn {
   251  		if err := t.Conn.Close(); err != nil {
   252  			return err
   253  		}
   254  	} else if t.conn != nil {
   255  		t.conn.SetReadDeadline(time.Now())
   256  		defer func() { t.conn.SetReadDeadline(time.Time{}) }()
   257  	}
   258  	if t.listening != nil {
   259  		<-t.listening // wait until listening returns
   260  	}
   261  	return nil
   262  }
   263  
   264  func (t *Transport) closeServer() {
   265  	t.handlerMap.CloseServer()
   266  	t.mutex.Lock()
   267  	t.server = nil
   268  	if t.isSingleUse {
   269  		t.closed = true
   270  	}
   271  	t.mutex.Unlock()
   272  	if t.createdConn {
   273  		t.Conn.Close()
   274  	}
   275  	if t.isSingleUse {
   276  		t.conn.SetReadDeadline(time.Now())
   277  		defer func() { t.conn.SetReadDeadline(time.Time{}) }()
   278  		<-t.listening // wait until listening returns
   279  	}
   280  }
   281  
   282  func (t *Transport) close(e error) {
   283  	t.mutex.Lock()
   284  	defer t.mutex.Unlock()
   285  	if t.closed {
   286  		return
   287  	}
   288  
   289  	if t.handlerMap != nil {
   290  		t.handlerMap.Close(e)
   291  	}
   292  	if t.server != nil {
   293  		t.server.setCloseError(e)
   294  	}
   295  	t.closed = true
   296  }
   297  
   298  // only print warnings about the UDP receive buffer size once
   299  var setBufferWarningOnce sync.Once
   300  
   301  func (t *Transport) listen(conn rawConn) {
   302  	defer close(t.listening)
   303  	defer getMultiplexer().RemoveConn(t.Conn)
   304  
   305  	if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
   306  		if !strings.Contains(err.Error(), "use of closed network connection") {
   307  			setBufferWarningOnce.Do(func() {
   308  				if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
   309  					return
   310  				}
   311  				log.Printf("%s. See https://github.com/mikelsr/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
   312  			})
   313  		}
   314  	}
   315  	if err := setSendBuffer(t.Conn, t.logger); err != nil {
   316  		if !strings.Contains(err.Error(), "use of closed network connection") {
   317  			setBufferWarningOnce.Do(func() {
   318  				if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
   319  					return
   320  				}
   321  				log.Printf("%s. See https://github.com/mikelsr/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
   322  			})
   323  		}
   324  	}
   325  
   326  	for {
   327  		p, err := conn.ReadPacket()
   328  		//nolint:staticcheck // SA1019 ignore this!
   329  		// TODO: This code is used to ignore wsa errors on Windows.
   330  		// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
   331  		// See https://github.com/mikelsr/quic-go/issues/1737 for details.
   332  		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
   333  			t.mutex.Lock()
   334  			closed := t.closed
   335  			t.mutex.Unlock()
   336  			if closed {
   337  				return
   338  			}
   339  			t.logger.Debugf("Temporary error reading from conn: %w", err)
   340  			continue
   341  		}
   342  		if err != nil {
   343  			t.close(err)
   344  			return
   345  		}
   346  		t.handlePacket(p)
   347  	}
   348  }
   349  
   350  func (t *Transport) handlePacket(p receivedPacket) {
   351  	connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
   352  	if err != nil {
   353  		t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   354  		if t.Tracer != nil {
   355  			t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
   356  		}
   357  		p.buffer.MaybeRelease()
   358  		return
   359  	}
   360  
   361  	if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
   362  		return
   363  	}
   364  	if handler, ok := t.handlerMap.Get(connID); ok {
   365  		handler.handlePacket(p)
   366  		return
   367  	}
   368  	if !wire.IsLongHeaderPacket(p.data[0]) {
   369  		t.maybeSendStatelessReset(p)
   370  		return
   371  	}
   372  
   373  	t.mutex.Lock()
   374  	defer t.mutex.Unlock()
   375  	if t.server == nil { // no server set
   376  		t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
   377  		return
   378  	}
   379  	t.server.handlePacket(p)
   380  }
   381  
   382  func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
   383  	if t.StatelessResetKey == nil {
   384  		p.buffer.Release()
   385  		return
   386  	}
   387  
   388  	// Don't send a stateless reset in response to very small packets.
   389  	// This includes packets that could be stateless resets.
   390  	if len(p.data) <= protocol.MinStatelessResetSize {
   391  		p.buffer.Release()
   392  		return
   393  	}
   394  
   395  	select {
   396  	case t.statelessResetQueue <- p:
   397  	default:
   398  		// it's fine to not send a stateless reset when we're busy
   399  		p.buffer.Release()
   400  	}
   401  }
   402  
   403  func (t *Transport) sendStatelessReset(p receivedPacket) {
   404  	defer p.buffer.Release()
   405  
   406  	connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
   407  	if err != nil {
   408  		t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   409  		return
   410  	}
   411  	token := t.handlerMap.GetStatelessResetToken(connID)
   412  	t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
   413  	data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
   414  	rand.Read(data)
   415  	data[0] = (data[0] & 0x7f) | 0x40
   416  	data = append(data, token[:]...)
   417  	if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil {
   418  		t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
   419  	}
   420  }
   421  
   422  func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
   423  	// stateless resets are always short header packets
   424  	if wire.IsLongHeaderPacket(data[0]) {
   425  		return false
   426  	}
   427  	if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
   428  		return false
   429  	}
   430  
   431  	token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
   432  	if conn, ok := t.handlerMap.GetByResetToken(token); ok {
   433  		t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
   434  		go conn.destroy(&StatelessResetError{Token: token})
   435  		return true
   436  	}
   437  	return false
   438  }