golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/conn.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"errors"
    13  	"fmt"
    14  	"log/slog"
    15  	"net/netip"
    16  	"time"
    17  )
    18  
    19  // A Conn is a QUIC connection.
    20  //
    21  // Multiple goroutines may invoke methods on a Conn simultaneously.
    22  type Conn struct {
    23  	side      connSide
    24  	endpoint  *Endpoint
    25  	config    *Config
    26  	testHooks connTestHooks
    27  	peerAddr  netip.AddrPort
    28  	localAddr netip.AddrPort
    29  
    30  	msgc  chan any
    31  	donec chan struct{} // closed when conn loop exits
    32  
    33  	w           packetWriter
    34  	acks        [numberSpaceCount]ackState // indexed by number space
    35  	lifetime    lifetimeState
    36  	idle        idleState
    37  	connIDState connIDState
    38  	loss        lossState
    39  	streams     streamsState
    40  	path        pathState
    41  
    42  	// Packet protection keys, CRYPTO streams, and TLS state.
    43  	keysInitial   fixedKeyPair
    44  	keysHandshake fixedKeyPair
    45  	keysAppData   updatingKeyPair
    46  	crypto        [numberSpaceCount]cryptoStream
    47  	tls           *tls.QUICConn
    48  
    49  	// retryToken is the token provided by the peer in a Retry packet.
    50  	retryToken []byte
    51  
    52  	// handshakeConfirmed is set when the handshake is confirmed.
    53  	// For server connections, it tracks sending HANDSHAKE_DONE.
    54  	handshakeConfirmed sentVal
    55  
    56  	peerAckDelayExponent int8 // -1 when unknown
    57  
    58  	// Tests only: Send a PING in a specific number space.
    59  	testSendPingSpace numberSpace
    60  	testSendPing      sentVal
    61  
    62  	log *slog.Logger
    63  }
    64  
    65  // connTestHooks override conn behavior in tests.
    66  type connTestHooks interface {
    67  	// init is called after a conn is created.
    68  	init()
    69  
    70  	// nextMessage is called to request the next event from msgc.
    71  	// Used to give tests control of the connection event loop.
    72  	nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
    73  
    74  	// handleTLSEvent is called with each TLS event.
    75  	handleTLSEvent(tls.QUICEvent)
    76  
    77  	// newConnID is called to generate a new connection ID.
    78  	// Permits tests to generate consistent connection IDs rather than random ones.
    79  	newConnID(seq int64) ([]byte, error)
    80  
    81  	// waitUntil blocks until the until func returns true or the context is done.
    82  	// Used to synchronize asynchronous blocking operations in tests.
    83  	waitUntil(ctx context.Context, until func() bool) error
    84  
    85  	// timeNow returns the current time.
    86  	timeNow() time.Time
    87  }
    88  
    89  // newServerConnIDs is connection IDs associated with a new server connection.
    90  type newServerConnIDs struct {
    91  	srcConnID         []byte // source from client's current Initial
    92  	dstConnID         []byte // destination from client's current Initial
    93  	originalDstConnID []byte // destination from client's first Initial
    94  	retrySrcConnID    []byte // source from server's Retry
    95  }
    96  
    97  func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) {
    98  	c := &Conn{
    99  		side:                 side,
   100  		endpoint:             e,
   101  		config:               config,
   102  		peerAddr:             unmapAddrPort(peerAddr),
   103  		msgc:                 make(chan any, 1),
   104  		donec:                make(chan struct{}),
   105  		peerAckDelayExponent: -1,
   106  	}
   107  	defer func() {
   108  		// If we hit an error in newConn, close donec so tests don't get stuck waiting for it.
   109  		// This is only relevant if we've got a bug, but it makes tracking that bug down
   110  		// much easier.
   111  		if conn == nil {
   112  			close(c.donec)
   113  		}
   114  	}()
   115  
   116  	// A one-element buffer allows us to wake a Conn's event loop as a
   117  	// non-blocking operation.
   118  	c.msgc = make(chan any, 1)
   119  
   120  	if e.testHooks != nil {
   121  		e.testHooks.newConn(c)
   122  	}
   123  
   124  	// initialConnID is the connection ID used to generate Initial packet protection keys.
   125  	var initialConnID []byte
   126  	if c.side == clientSide {
   127  		if err := c.connIDState.initClient(c); err != nil {
   128  			return nil, err
   129  		}
   130  		initialConnID, _ = c.connIDState.dstConnID()
   131  	} else {
   132  		initialConnID = cids.originalDstConnID
   133  		if cids.retrySrcConnID != nil {
   134  			initialConnID = cids.retrySrcConnID
   135  		}
   136  		if err := c.connIDState.initServer(c, cids); err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  
   141  	// TODO: PMTU discovery.
   142  	c.logConnectionStarted(cids.originalDstConnID, peerAddr)
   143  	c.keysAppData.init()
   144  	c.loss.init(c.side, smallestMaxDatagramSize, now)
   145  	c.streamsInit()
   146  	c.lifetimeInit()
   147  	c.restartIdleTimer(now)
   148  
   149  	if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{
   150  		initialSrcConnID:               c.connIDState.srcConnID(),
   151  		originalDstConnID:              cids.originalDstConnID,
   152  		retrySrcConnID:                 cids.retrySrcConnID,
   153  		ackDelayExponent:               ackDelayExponent,
   154  		maxUDPPayloadSize:              maxUDPPayloadSize,
   155  		maxAckDelay:                    maxAckDelay,
   156  		disableActiveMigration:         true,
   157  		initialMaxData:                 config.maxConnReadBufferSize(),
   158  		initialMaxStreamDataBidiLocal:  config.maxStreamReadBufferSize(),
   159  		initialMaxStreamDataBidiRemote: config.maxStreamReadBufferSize(),
   160  		initialMaxStreamDataUni:        config.maxStreamReadBufferSize(),
   161  		initialMaxStreamsBidi:          c.streams.remoteLimit[bidiStream].max,
   162  		initialMaxStreamsUni:           c.streams.remoteLimit[uniStream].max,
   163  		activeConnIDLimit:              activeConnIDLimit,
   164  	}); err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	if c.testHooks != nil {
   169  		c.testHooks.init()
   170  	}
   171  	go c.loop(now)
   172  	return c, nil
   173  }
   174  
   175  func (c *Conn) String() string {
   176  	return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
   177  }
   178  
   179  // confirmHandshake is called when the handshake is confirmed.
   180  // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
   181  func (c *Conn) confirmHandshake(now time.Time) {
   182  	// If handshakeConfirmed is unset, the handshake is not confirmed.
   183  	// If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE.
   184  	// If it is sent, we have sent a HANDSHAKE_DONE.
   185  	// If it is received, the handshake is confirmed and we do not need to send anything.
   186  	if c.handshakeConfirmed.isSet() {
   187  		return // already confirmed
   188  	}
   189  	if c.side == serverSide {
   190  		// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
   191  		c.handshakeConfirmed.setUnsent()
   192  		c.endpoint.serverConnEstablished(c)
   193  	} else {
   194  		// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
   195  		// to the received state, indicating that the handshake is confirmed and we
   196  		// don't need to send anything.
   197  		c.handshakeConfirmed.setReceived()
   198  	}
   199  	c.restartIdleTimer(now)
   200  	c.loss.confirmHandshake()
   201  	// "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed"
   202  	// https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1
   203  	c.discardKeys(now, handshakeSpace)
   204  }
   205  
   206  // discardKeys discards unused packet protection keys.
   207  // https://www.rfc-editor.org/rfc/rfc9001#section-4.9
   208  func (c *Conn) discardKeys(now time.Time, space numberSpace) {
   209  	switch space {
   210  	case initialSpace:
   211  		c.keysInitial.discard()
   212  	case handshakeSpace:
   213  		c.keysHandshake.discard()
   214  	}
   215  	c.loss.discardKeys(now, c.log, space)
   216  }
   217  
   218  // receiveTransportParameters applies transport parameters sent by the peer.
   219  func (c *Conn) receiveTransportParameters(p transportParameters) error {
   220  	isRetry := c.retryToken != nil
   221  	if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil {
   222  		return err
   223  	}
   224  	c.streams.outflow.setMaxData(p.initialMaxData)
   225  	c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi)
   226  	c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni)
   227  	c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal
   228  	c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote
   229  	c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
   230  	c.receivePeerMaxIdleTimeout(p.maxIdleTimeout)
   231  	c.peerAckDelayExponent = p.ackDelayExponent
   232  	c.loss.setMaxAckDelay(p.maxAckDelay)
   233  	if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
   234  		return err
   235  	}
   236  	if p.preferredAddrConnID != nil {
   237  		var (
   238  			seq           int64 = 1 // sequence number of this conn id is 1
   239  			retirePriorTo int64 = 0 // retire nothing
   240  			resetToken    [16]byte
   241  		)
   242  		copy(resetToken[:], p.preferredAddrResetToken)
   243  		if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
   244  			return err
   245  		}
   246  	}
   247  	// TODO: stateless_reset_token
   248  	// TODO: max_udp_payload_size
   249  	// TODO: disable_active_migration
   250  	// TODO: preferred_address
   251  	return nil
   252  }
   253  
   254  type (
   255  	timerEvent struct{}
   256  	wakeEvent  struct{}
   257  )
   258  
   259  var errIdleTimeout = errors.New("idle timeout")
   260  
   261  // loop is the connection main loop.
   262  //
   263  // Except where otherwise noted, all connection state is owned by the loop goroutine.
   264  //
   265  // The loop processes messages from c.msgc and timer events.
   266  // Other goroutines may examine or modify conn state by sending the loop funcs to execute.
   267  func (c *Conn) loop(now time.Time) {
   268  	defer c.cleanup()
   269  
   270  	// The connection timer sends a message to the connection loop on expiry.
   271  	// We need to give it an expiry when creating it, so set the initial timeout to
   272  	// an arbitrary large value. The timer will be reset before this expires (and it
   273  	// isn't a problem if it does anyway). Skip creating the timer in tests which
   274  	// take control of the connection message loop.
   275  	var timer *time.Timer
   276  	var lastTimeout time.Time
   277  	hooks := c.testHooks
   278  	if hooks == nil {
   279  		timer = time.AfterFunc(1*time.Hour, func() {
   280  			c.sendMsg(timerEvent{})
   281  		})
   282  		defer timer.Stop()
   283  	}
   284  
   285  	for c.lifetime.state != connStateDone {
   286  		sendTimeout := c.maybeSend(now) // try sending
   287  
   288  		// Note that we only need to consider the ack timer for the App Data space,
   289  		// since the Initial and Handshake spaces always ack immediately.
   290  		nextTimeout := sendTimeout
   291  		nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout)
   292  		if c.isAlive() {
   293  			nextTimeout = firstTime(nextTimeout, c.loss.timer)
   294  			nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck)
   295  		} else {
   296  			nextTimeout = firstTime(nextTimeout, c.lifetime.drainEndTime)
   297  		}
   298  
   299  		var m any
   300  		if hooks != nil {
   301  			// Tests only: Wait for the test to tell us to continue.
   302  			now, m = hooks.nextMessage(c.msgc, nextTimeout)
   303  		} else if !nextTimeout.IsZero() && nextTimeout.Before(now) {
   304  			// A connection timer has expired.
   305  			now = time.Now()
   306  			m = timerEvent{}
   307  		} else {
   308  			// Reschedule the connection timer if necessary
   309  			// and wait for the next event.
   310  			if !nextTimeout.Equal(lastTimeout) && !nextTimeout.IsZero() {
   311  				// Resetting a timer created with time.AfterFunc guarantees
   312  				// that the timer will run again. We might generate a spurious
   313  				// timer event under some circumstances, but that's okay.
   314  				timer.Reset(nextTimeout.Sub(now))
   315  				lastTimeout = nextTimeout
   316  			}
   317  			m = <-c.msgc
   318  			now = time.Now()
   319  		}
   320  		switch m := m.(type) {
   321  		case *datagram:
   322  			if !c.handleDatagram(now, m) {
   323  				if c.logEnabled(QLogLevelPacket) {
   324  					c.logPacketDropped(m)
   325  				}
   326  			}
   327  			m.recycle()
   328  		case timerEvent:
   329  			// A connection timer has expired.
   330  			if c.idleAdvance(now) {
   331  				// The connection idle timer has expired.
   332  				c.abortImmediately(now, errIdleTimeout)
   333  				return
   334  			}
   335  			c.loss.advance(now, c.handleAckOrLoss)
   336  			if c.lifetimeAdvance(now) {
   337  				// The connection has completed the draining period,
   338  				// and may be shut down.
   339  				return
   340  			}
   341  		case wakeEvent:
   342  			// We're being woken up to try sending some frames.
   343  		case func(time.Time, *Conn):
   344  			// Send a func to msgc to run it on the main Conn goroutine
   345  			m(now, c)
   346  		default:
   347  			panic(fmt.Sprintf("quic: unrecognized conn message %T", m))
   348  		}
   349  	}
   350  }
   351  
   352  func (c *Conn) cleanup() {
   353  	c.logConnectionClosed()
   354  	c.endpoint.connDrained(c)
   355  	c.tls.Close()
   356  	close(c.donec)
   357  }
   358  
   359  // sendMsg sends a message to the conn's loop.
   360  // It does not wait for the message to be processed.
   361  // The conn may close before processing the message, in which case it is lost.
   362  func (c *Conn) sendMsg(m any) {
   363  	select {
   364  	case c.msgc <- m:
   365  	case <-c.donec:
   366  	}
   367  }
   368  
   369  // wake wakes up the conn's loop.
   370  func (c *Conn) wake() {
   371  	select {
   372  	case c.msgc <- wakeEvent{}:
   373  	default:
   374  	}
   375  }
   376  
   377  // runOnLoop executes a function within the conn's loop goroutine.
   378  func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error {
   379  	donec := make(chan struct{})
   380  	msg := func(now time.Time, c *Conn) {
   381  		defer close(donec)
   382  		f(now, c)
   383  	}
   384  	if c.testHooks != nil {
   385  		// In tests, we can't rely on being able to send a message immediately:
   386  		// c.msgc might be full, and testConnHooks.nextMessage might be waiting
   387  		// for us to block before it processes the next message.
   388  		// To avoid a deadlock, we send the message in waitUntil.
   389  		// If msgc is empty, the message is buffered.
   390  		// If msgc is full, we block and let nextMessage process the queue.
   391  		msgc := c.msgc
   392  		c.testHooks.waitUntil(ctx, func() bool {
   393  			for {
   394  				select {
   395  				case msgc <- msg:
   396  					msgc = nil // send msg only once
   397  				case <-donec:
   398  					return true
   399  				case <-c.donec:
   400  					return true
   401  				default:
   402  					return false
   403  				}
   404  			}
   405  		})
   406  	} else {
   407  		c.sendMsg(msg)
   408  	}
   409  	select {
   410  	case <-donec:
   411  	case <-c.donec:
   412  		return errors.New("quic: connection closed")
   413  	}
   414  	return nil
   415  }
   416  
   417  func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error {
   418  	if c.testHooks != nil {
   419  		return c.testHooks.waitUntil(ctx, func() bool {
   420  			select {
   421  			case <-ch:
   422  				return true
   423  			default:
   424  			}
   425  			return false
   426  		})
   427  	}
   428  	// Check the channel before the context.
   429  	// We always prefer to return results when available,
   430  	// even when provided with an already-canceled context.
   431  	select {
   432  	case <-ch:
   433  		return nil
   434  	default:
   435  	}
   436  	select {
   437  	case <-ch:
   438  	case <-ctx.Done():
   439  		return ctx.Err()
   440  	}
   441  	return nil
   442  }
   443  
   444  // firstTime returns the earliest non-zero time, or zero if both times are zero.
   445  func firstTime(a, b time.Time) time.Time {
   446  	switch {
   447  	case a.IsZero():
   448  		return b
   449  	case b.IsZero():
   450  		return a
   451  	case a.Before(b):
   452  		return a
   453  	default:
   454  		return b
   455  	}
   456  }