golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/endpoint.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"context"
    11  	"crypto/rand"
    12  	"errors"
    13  	"net"
    14  	"net/netip"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  )
    19  
    20  // An Endpoint handles QUIC traffic on a network address.
    21  // It can accept inbound connections or create outbound ones.
    22  //
    23  // Multiple goroutines may invoke methods on an Endpoint simultaneously.
    24  type Endpoint struct {
    25  	listenConfig *Config
    26  	packetConn   packetConn
    27  	testHooks    endpointTestHooks
    28  	resetGen     statelessResetTokenGenerator
    29  	retry        retryState
    30  
    31  	acceptQueue queue[*Conn] // new inbound connections
    32  	connsMap    connsMap     // only accessed by the listen loop
    33  
    34  	connsMu sync.Mutex
    35  	conns   map[*Conn]struct{}
    36  	closing bool          // set when Close is called
    37  	closec  chan struct{} // closed when the listen loop exits
    38  }
    39  
    40  type endpointTestHooks interface {
    41  	timeNow() time.Time
    42  	newConn(c *Conn)
    43  }
    44  
    45  // A packetConn is the interface to sending and receiving UDP packets.
    46  type packetConn interface {
    47  	Close() error
    48  	LocalAddr() netip.AddrPort
    49  	Read(f func(*datagram))
    50  	Write(datagram) error
    51  }
    52  
    53  // Listen listens on a local network address.
    54  //
    55  // The config is used to for connections accepted by the endpoint.
    56  // If the config is nil, the endpoint will not accept connections.
    57  func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
    58  	if listenConfig != nil && listenConfig.TLSConfig == nil {
    59  		return nil, errors.New("TLSConfig is not set")
    60  	}
    61  	a, err := net.ResolveUDPAddr(network, address)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	udpConn, err := net.ListenUDP(network, a)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	pc, err := newNetUDPConn(udpConn)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	return newEndpoint(pc, listenConfig, nil)
    74  }
    75  
    76  func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
    77  	e := &Endpoint{
    78  		listenConfig: config,
    79  		packetConn:   pc,
    80  		testHooks:    hooks,
    81  		conns:        make(map[*Conn]struct{}),
    82  		acceptQueue:  newQueue[*Conn](),
    83  		closec:       make(chan struct{}),
    84  	}
    85  	var statelessResetKey [32]byte
    86  	if config != nil {
    87  		statelessResetKey = config.StatelessResetKey
    88  	}
    89  	e.resetGen.init(statelessResetKey)
    90  	e.connsMap.init()
    91  	if config != nil && config.RequireAddressValidation {
    92  		if err := e.retry.init(); err != nil {
    93  			return nil, err
    94  		}
    95  	}
    96  	go e.listen()
    97  	return e, nil
    98  }
    99  
   100  // LocalAddr returns the local network address.
   101  func (e *Endpoint) LocalAddr() netip.AddrPort {
   102  	return e.packetConn.LocalAddr()
   103  }
   104  
   105  // Close closes the Endpoint.
   106  // Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked
   107  // and return errors.
   108  //
   109  // Close aborts every open connection.
   110  // Data in stream read and write buffers is discarded.
   111  // It waits for the peers of any open connection to acknowledge the connection has been closed.
   112  func (e *Endpoint) Close(ctx context.Context) error {
   113  	e.acceptQueue.close(errors.New("endpoint closed"))
   114  
   115  	// It isn't safe to call Conn.Abort or conn.exit with connsMu held,
   116  	// so copy the list of conns.
   117  	var conns []*Conn
   118  	e.connsMu.Lock()
   119  	if !e.closing {
   120  		e.closing = true // setting e.closing prevents new conns from being created
   121  		for c := range e.conns {
   122  			conns = append(conns, c)
   123  		}
   124  		if len(e.conns) == 0 {
   125  			e.packetConn.Close()
   126  		}
   127  	}
   128  	e.connsMu.Unlock()
   129  
   130  	for _, c := range conns {
   131  		c.Abort(localTransportError{code: errNo})
   132  	}
   133  	select {
   134  	case <-e.closec:
   135  	case <-ctx.Done():
   136  		for _, c := range conns {
   137  			c.exit()
   138  		}
   139  		return ctx.Err()
   140  	}
   141  	return nil
   142  }
   143  
   144  // Accept waits for and returns the next connection.
   145  func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
   146  	return e.acceptQueue.get(ctx, nil)
   147  }
   148  
   149  // Dial creates and returns a connection to a network address.
   150  // The config cannot be nil.
   151  func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) {
   152  	u, err := net.ResolveUDPAddr(network, address)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	addr := u.AddrPort()
   157  	addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
   158  	c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	if err := c.waitReady(ctx); err != nil {
   163  		c.Abort(nil)
   164  		return nil, err
   165  	}
   166  	return c, nil
   167  }
   168  
   169  func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) {
   170  	e.connsMu.Lock()
   171  	defer e.connsMu.Unlock()
   172  	if e.closing {
   173  		return nil, errors.New("endpoint closed")
   174  	}
   175  	c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	e.conns[c] = struct{}{}
   180  	return c, nil
   181  }
   182  
   183  // serverConnEstablished is called by a conn when the handshake completes
   184  // for an inbound (serverSide) connection.
   185  func (e *Endpoint) serverConnEstablished(c *Conn) {
   186  	e.acceptQueue.put(c)
   187  }
   188  
   189  // connDrained is called by a conn when it leaves the draining state,
   190  // either when the peer acknowledges connection closure or the drain timeout expires.
   191  func (e *Endpoint) connDrained(c *Conn) {
   192  	var cids [][]byte
   193  	for i := range c.connIDState.local {
   194  		cids = append(cids, c.connIDState.local[i].cid)
   195  	}
   196  	var tokens []statelessResetToken
   197  	for i := range c.connIDState.remote {
   198  		tokens = append(tokens, c.connIDState.remote[i].resetToken)
   199  	}
   200  	e.connsMap.updateConnIDs(func(conns *connsMap) {
   201  		for _, cid := range cids {
   202  			conns.retireConnID(c, cid)
   203  		}
   204  		for _, token := range tokens {
   205  			conns.retireResetToken(c, token)
   206  		}
   207  	})
   208  	e.connsMu.Lock()
   209  	defer e.connsMu.Unlock()
   210  	delete(e.conns, c)
   211  	if e.closing && len(e.conns) == 0 {
   212  		e.packetConn.Close()
   213  	}
   214  }
   215  
   216  func (e *Endpoint) listen() {
   217  	defer close(e.closec)
   218  	e.packetConn.Read(func(m *datagram) {
   219  		if e.connsMap.updateNeeded.Load() {
   220  			e.connsMap.applyUpdates()
   221  		}
   222  		e.handleDatagram(m)
   223  	})
   224  }
   225  
   226  func (e *Endpoint) handleDatagram(m *datagram) {
   227  	dstConnID, ok := dstConnIDForDatagram(m.b)
   228  	if !ok {
   229  		m.recycle()
   230  		return
   231  	}
   232  	c := e.connsMap.byConnID[string(dstConnID)]
   233  	if c == nil {
   234  		// TODO: Move this branch into a separate goroutine to avoid blocking
   235  		// the endpoint while processing packets.
   236  		e.handleUnknownDestinationDatagram(m)
   237  		return
   238  	}
   239  
   240  	// TODO: This can block the endpoint while waiting for the conn to accept the dgram.
   241  	// Think about buffering between the receive loop and the conn.
   242  	c.sendMsg(m)
   243  }
   244  
   245  func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
   246  	defer func() {
   247  		if m != nil {
   248  			m.recycle()
   249  		}
   250  	}()
   251  	const minimumValidPacketSize = 21
   252  	if len(m.b) < minimumValidPacketSize {
   253  		return
   254  	}
   255  	var now time.Time
   256  	if e.testHooks != nil {
   257  		now = e.testHooks.timeNow()
   258  	} else {
   259  		now = time.Now()
   260  	}
   261  	// Check to see if this is a stateless reset.
   262  	var token statelessResetToken
   263  	copy(token[:], m.b[len(m.b)-len(token):])
   264  	if c := e.connsMap.byResetToken[token]; c != nil {
   265  		c.sendMsg(func(now time.Time, c *Conn) {
   266  			c.handleStatelessReset(now, token)
   267  		})
   268  		return
   269  	}
   270  	// If this is a 1-RTT packet, there's nothing productive we can do with it.
   271  	// Send a stateless reset if possible.
   272  	if !isLongHeader(m.b[0]) {
   273  		e.maybeSendStatelessReset(m.b, m.peerAddr)
   274  		return
   275  	}
   276  	p, ok := parseGenericLongHeaderPacket(m.b)
   277  	if !ok || len(m.b) < paddedInitialDatagramSize {
   278  		return
   279  	}
   280  	switch p.version {
   281  	case quicVersion1:
   282  	case 0:
   283  		// Version Negotiation for an unknown connection.
   284  		return
   285  	default:
   286  		// Unknown version.
   287  		e.sendVersionNegotiation(p, m.peerAddr)
   288  		return
   289  	}
   290  	if getPacketType(m.b) != packetTypeInitial {
   291  		// This packet isn't trying to create a new connection.
   292  		// It might be associated with some connection we've lost state for.
   293  		// We are technically permitted to send a stateless reset for
   294  		// a long-header packet, but this isn't generally useful. See:
   295  		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
   296  		return
   297  	}
   298  	if e.listenConfig == nil {
   299  		// We are not configured to accept connections.
   300  		return
   301  	}
   302  	cids := newServerConnIDs{
   303  		srcConnID: p.srcConnID,
   304  		dstConnID: p.dstConnID,
   305  	}
   306  	if e.listenConfig.RequireAddressValidation {
   307  		var ok bool
   308  		cids.retrySrcConnID = p.dstConnID
   309  		cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr)
   310  		if !ok {
   311  			return
   312  		}
   313  	} else {
   314  		cids.originalDstConnID = p.dstConnID
   315  	}
   316  	var err error
   317  	c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr)
   318  	if err != nil {
   319  		// The accept queue is probably full.
   320  		// We could send a CONNECTION_CLOSE to the peer to reject the connection.
   321  		// Currently, we just drop the datagram.
   322  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
   323  		return
   324  	}
   325  	c.sendMsg(m)
   326  	m = nil // don't recycle, sendMsg takes ownership
   327  }
   328  
   329  func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) {
   330  	if !e.resetGen.canReset {
   331  		// Config.StatelessResetKey isn't set, so we don't send stateless resets.
   332  		return
   333  	}
   334  	// The smallest possible valid packet a peer can send us is:
   335  	//   1 byte of header
   336  	//   connIDLen bytes of destination connection ID
   337  	//   1 byte of packet number
   338  	//   1 byte of payload
   339  	//   16 bytes AEAD expansion
   340  	if len(b) < 1+connIDLen+1+1+16 {
   341  		return
   342  	}
   343  	// TODO: Rate limit stateless resets.
   344  	cid := b[1:][:connIDLen]
   345  	token := e.resetGen.tokenForConnID(cid)
   346  	// We want to generate a stateless reset that is as short as possible,
   347  	// but long enough to be difficult to distinguish from a 1-RTT packet.
   348  	//
   349  	// The minimal 1-RTT packet is:
   350  	//   1 byte of header
   351  	//   0-20 bytes of destination connection ID
   352  	//   1-4 bytes of packet number
   353  	//   1 byte of payload
   354  	//   16 bytes AEAD expansion
   355  	//
   356  	// Assuming the maximum possible connection ID and packet number size,
   357  	// this gives 1 + 20 + 4 + 1 + 16 = 42 bytes.
   358  	//
   359  	// We also must generate a stateless reset that is shorter than the datagram
   360  	// we are responding to, in order to ensure that reset loops terminate.
   361  	//
   362  	// See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3
   363  	size := min(len(b)-1, 42)
   364  	// Reuse the input buffer for generating the stateless reset.
   365  	b = b[:size]
   366  	rand.Read(b[:len(b)-statelessResetTokenLen])
   367  	b[0] &^= headerFormLong // clear long header bit
   368  	b[0] |= fixedBit        // set fixed bit
   369  	copy(b[len(b)-statelessResetTokenLen:], token[:])
   370  	e.sendDatagram(datagram{
   371  		b:        b,
   372  		peerAddr: peerAddr,
   373  	})
   374  }
   375  
   376  func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) {
   377  	m := newDatagram()
   378  	m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
   379  	m.peerAddr = peerAddr
   380  	e.sendDatagram(*m)
   381  	m.recycle()
   382  }
   383  
   384  func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) {
   385  	keys := initialKeys(in.dstConnID, serverSide)
   386  	var w packetWriter
   387  	p := longPacket{
   388  		ptype:     packetTypeInitial,
   389  		version:   quicVersion1,
   390  		num:       0,
   391  		dstConnID: in.srcConnID,
   392  		srcConnID: in.dstConnID,
   393  	}
   394  	const pnumMaxAcked = 0
   395  	w.reset(paddedInitialDatagramSize)
   396  	w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
   397  	w.appendConnectionCloseTransportFrame(code, 0, "")
   398  	w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
   399  	buf := w.datagram()
   400  	if len(buf) == 0 {
   401  		return
   402  	}
   403  	e.sendDatagram(datagram{
   404  		b:        buf,
   405  		peerAddr: peerAddr,
   406  	})
   407  }
   408  
   409  func (e *Endpoint) sendDatagram(dgram datagram) error {
   410  	return e.packetConn.Write(dgram)
   411  }
   412  
   413  // A connsMap is an endpoint's mapping of conn ids and reset tokens to conns.
   414  type connsMap struct {
   415  	byConnID     map[string]*Conn
   416  	byResetToken map[statelessResetToken]*Conn
   417  
   418  	updateMu     sync.Mutex
   419  	updateNeeded atomic.Bool
   420  	updates      []func(*connsMap)
   421  }
   422  
   423  func (m *connsMap) init() {
   424  	m.byConnID = map[string]*Conn{}
   425  	m.byResetToken = map[statelessResetToken]*Conn{}
   426  }
   427  
   428  func (m *connsMap) addConnID(c *Conn, cid []byte) {
   429  	m.byConnID[string(cid)] = c
   430  }
   431  
   432  func (m *connsMap) retireConnID(c *Conn, cid []byte) {
   433  	delete(m.byConnID, string(cid))
   434  }
   435  
   436  func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
   437  	m.byResetToken[token] = c
   438  }
   439  
   440  func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
   441  	delete(m.byResetToken, token)
   442  }
   443  
   444  func (m *connsMap) updateConnIDs(f func(*connsMap)) {
   445  	m.updateMu.Lock()
   446  	defer m.updateMu.Unlock()
   447  	m.updates = append(m.updates, f)
   448  	m.updateNeeded.Store(true)
   449  }
   450  
   451  // applyConnIDUpdates is called by the datagram receive loop to update its connection ID map.
   452  func (m *connsMap) applyUpdates() {
   453  	m.updateMu.Lock()
   454  	defer m.updateMu.Unlock()
   455  	for _, f := range m.updates {
   456  		f(m)
   457  	}
   458  	clear(m.updates)
   459  	m.updates = m.updates[:0]
   460  	m.updateNeeded.Store(false)
   461  }