github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/server.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/handshake"
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    16  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    17  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    18  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    19  )
    20  
    21  // ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
    22  var ErrServerClosed = errors.New("quic: server closed")
    23  
    24  // packetHandler handles packets
    25  type packetHandler interface {
    26  	handlePacket(receivedPacket)
    27  	destroy(error)
    28  	closeWithTransportError(qerr.TransportErrorCode)
    29  }
    30  
    31  type packetHandlerManager interface {
    32  	Get(protocol.ConnectionID) (packetHandler, bool)
    33  	GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
    34  	AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool
    35  	Close(error)
    36  	connRunner
    37  }
    38  
    39  type quicConn interface {
    40  	EarlyConnection
    41  	earlyConnReady() <-chan struct{}
    42  	handlePacket(receivedPacket)
    43  	run() error
    44  	destroy(error)
    45  	closeWithTransportError(TransportErrorCode)
    46  }
    47  
    48  type zeroRTTQueue struct {
    49  	packets    []receivedPacket
    50  	expiration time.Time
    51  }
    52  
    53  type rejectedPacket struct {
    54  	receivedPacket
    55  	hdr *wire.Header
    56  }
    57  
    58  // A Listener of QUIC
    59  type baseServer struct {
    60  	disableVersionNegotiation bool
    61  	acceptEarlyConns          bool
    62  
    63  	tlsConf *tls.Config
    64  	config  *Config
    65  
    66  	conn rawConn
    67  
    68  	tokenGenerator *handshake.TokenGenerator
    69  	maxTokenAge    time.Duration
    70  
    71  	connIDGenerator ConnectionIDGenerator
    72  	connHandler     packetHandlerManager
    73  	onClose         func()
    74  
    75  	receivedPackets chan receivedPacket
    76  
    77  	nextZeroRTTCleanup time.Time
    78  	zeroRTTQueues      map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
    79  
    80  	// set as a member, so they can be set in the tests
    81  	newConn func(
    82  		sendConn,
    83  		connRunner,
    84  		protocol.ConnectionID, /* original dest connection ID */
    85  		*protocol.ConnectionID, /* retry src connection ID */
    86  		protocol.ConnectionID, /* client dest connection ID */
    87  		protocol.ConnectionID, /* destination connection ID */
    88  		protocol.ConnectionID, /* source connection ID */
    89  		ConnectionIDGenerator,
    90  		protocol.StatelessResetToken,
    91  		*Config,
    92  		*tls.Config,
    93  		*handshake.TokenGenerator,
    94  		bool, /* client address validated by an address validation token */
    95  		*logging.ConnectionTracer,
    96  		uint64,
    97  		utils.Logger,
    98  		protocol.Version,
    99  	) quicConn
   100  
   101  	closeMx   sync.Mutex
   102  	errorChan chan struct{} // is closed when the server is closed
   103  	closeErr  error
   104  	running   chan struct{} // closed as soon as run() returns
   105  
   106  	versionNegotiationQueue chan receivedPacket
   107  	invalidTokenQueue       chan rejectedPacket
   108  	connectionRefusedQueue  chan rejectedPacket
   109  	retryQueue              chan rejectedPacket
   110  
   111  	maxNumHandshakesUnvalidated int
   112  	maxNumHandshakesTotal       int
   113  	numHandshakesUnvalidated    atomic.Int64
   114  	numHandshakesValidated      atomic.Int64
   115  
   116  	connQueue chan quicConn
   117  
   118  	tracer *logging.Tracer
   119  
   120  	logger utils.Logger
   121  }
   122  
   123  // A Listener listens for incoming QUIC connections.
   124  // It returns connections once the handshake has completed.
   125  type Listener struct {
   126  	baseServer *baseServer
   127  }
   128  
   129  // Accept returns new connections. It should be called in a loop.
   130  func (l *Listener) Accept(ctx context.Context) (Connection, error) {
   131  	return l.baseServer.Accept(ctx)
   132  }
   133  
   134  // Close closes the listener.
   135  // Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted.
   136  // QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
   137  // The effect of closing the listener depends on how it was created:
   138  // * if it was created using Transport.Listen, already established connections will be unaffected
   139  // * if it was created using the Listen convenience method, all established connection will be closed immediately
   140  func (l *Listener) Close() error {
   141  	return l.baseServer.Close()
   142  }
   143  
   144  // Addr returns the local network address that the server is listening on.
   145  func (l *Listener) Addr() net.Addr {
   146  	return l.baseServer.Addr()
   147  }
   148  
   149  // An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes.
   150  // For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data.
   151  // This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified.
   152  // For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the
   153  // 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the
   154  // client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker.
   155  type EarlyListener struct {
   156  	baseServer *baseServer
   157  }
   158  
   159  // Accept returns a new connections. It should be called in a loop.
   160  func (l *EarlyListener) Accept(ctx context.Context) (EarlyConnection, error) {
   161  	return l.baseServer.accept(ctx)
   162  }
   163  
   164  // Close the server. All active connections will be closed.
   165  func (l *EarlyListener) Close() error {
   166  	return l.baseServer.Close()
   167  }
   168  
   169  // Addr returns the local network addr that the server is listening on.
   170  func (l *EarlyListener) Addr() net.Addr {
   171  	return l.baseServer.Addr()
   172  }
   173  
   174  // ListenAddr creates a QUIC server listening on a given address.
   175  // See Listen for more details.
   176  func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
   177  	conn, err := listenUDP(addr)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	return (&Transport{
   182  		Conn:        conn,
   183  		createdConn: true,
   184  		isSingleUse: true,
   185  	}).Listen(tlsConf, config)
   186  }
   187  
   188  // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
   189  func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
   190  	conn, err := listenUDP(addr)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  	return (&Transport{
   195  		Conn:        conn,
   196  		createdConn: true,
   197  		isSingleUse: true,
   198  	}).ListenEarly(tlsConf, config)
   199  }
   200  
   201  func listenUDP(addr string) (*net.UDPConn, error) {
   202  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  	return net.ListenUDP("udp", udpAddr)
   207  }
   208  
   209  // Listen listens for QUIC connections on a given net.PacketConn.
   210  // If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
   211  // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
   212  // will be used instead of ReadFrom and WriteTo to read/write packets.
   213  // A single net.PacketConn can only be used for a single call to Listen.
   214  //
   215  // The tls.Config must not be nil and must contain a certificate configuration.
   216  // Furthermore, it must define an application control (using NextProtos).
   217  // The quic.Config may be nil, in that case the default values will be used.
   218  //
   219  // This is a convenience function. More advanced use cases should instantiate a Transport,
   220  // which offers configuration options for a more fine-grained control of the connection establishment,
   221  // including reusing the underlying UDP socket for outgoing QUIC connections.
   222  // When closing a listener created with Listen, all established QUIC connections will be closed immediately.
   223  func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
   224  	tr := &Transport{Conn: conn, isSingleUse: true}
   225  	return tr.Listen(tlsConf, config)
   226  }
   227  
   228  // ListenEarly works like Listen, but it returns connections before the handshake completes.
   229  func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
   230  	tr := &Transport{Conn: conn, isSingleUse: true}
   231  	return tr.ListenEarly(tlsConf, config)
   232  }
   233  
   234  func newServer(
   235  	conn rawConn,
   236  	connHandler packetHandlerManager,
   237  	connIDGenerator ConnectionIDGenerator,
   238  	tlsConf *tls.Config,
   239  	config *Config,
   240  	tracer *logging.Tracer,
   241  	onClose func(),
   242  	tokenGeneratorKey TokenGeneratorKey,
   243  	maxTokenAge time.Duration,
   244  	maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
   245  	disableVersionNegotiation bool,
   246  	acceptEarly bool,
   247  ) *baseServer {
   248  	s := &baseServer{
   249  		conn:                        conn,
   250  		tlsConf:                     tlsConf,
   251  		config:                      config,
   252  		tokenGenerator:              handshake.NewTokenGenerator(tokenGeneratorKey),
   253  		maxTokenAge:                 maxTokenAge,
   254  		maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
   255  		maxNumHandshakesTotal:       maxNumHandshakesTotal,
   256  		connIDGenerator:             connIDGenerator,
   257  		connHandler:                 connHandler,
   258  		connQueue:                   make(chan quicConn, protocol.MaxAcceptQueueSize),
   259  		errorChan:                   make(chan struct{}),
   260  		running:                     make(chan struct{}),
   261  		receivedPackets:             make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
   262  		versionNegotiationQueue:     make(chan receivedPacket, 4),
   263  		invalidTokenQueue:           make(chan rejectedPacket, 4),
   264  		connectionRefusedQueue:      make(chan rejectedPacket, 4),
   265  		retryQueue:                  make(chan rejectedPacket, 8),
   266  		newConn:                     newConnection,
   267  		tracer:                      tracer,
   268  		logger:                      utils.DefaultLogger.WithPrefix("server"),
   269  		acceptEarlyConns:            acceptEarly,
   270  		disableVersionNegotiation:   disableVersionNegotiation,
   271  		onClose:                     onClose,
   272  	}
   273  	if acceptEarly {
   274  		s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
   275  	}
   276  	go s.run()
   277  	go s.runSendQueue()
   278  	s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
   279  	return s
   280  }
   281  
   282  func (s *baseServer) run() {
   283  	defer close(s.running)
   284  	for {
   285  		select {
   286  		case <-s.errorChan:
   287  			return
   288  		default:
   289  		}
   290  		select {
   291  		case <-s.errorChan:
   292  			return
   293  		case p := <-s.receivedPackets:
   294  			if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse {
   295  				p.buffer.Release()
   296  			}
   297  		}
   298  	}
   299  }
   300  
   301  func (s *baseServer) runSendQueue() {
   302  	for {
   303  		select {
   304  		case <-s.running:
   305  			return
   306  		case p := <-s.versionNegotiationQueue:
   307  			s.maybeSendVersionNegotiationPacket(p)
   308  		case p := <-s.invalidTokenQueue:
   309  			s.maybeSendInvalidToken(p)
   310  		case p := <-s.connectionRefusedQueue:
   311  			s.sendConnectionRefused(p)
   312  		case p := <-s.retryQueue:
   313  			s.sendRetry(p)
   314  		}
   315  	}
   316  }
   317  
   318  // Accept returns connections that already completed the handshake.
   319  // It is only valid if acceptEarlyConns is false.
   320  func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
   321  	return s.accept(ctx)
   322  }
   323  
   324  func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
   325  	select {
   326  	case <-ctx.Done():
   327  		return nil, ctx.Err()
   328  	case conn := <-s.connQueue:
   329  		return conn, nil
   330  	case <-s.errorChan:
   331  		return nil, s.closeErr
   332  	}
   333  }
   334  
   335  func (s *baseServer) Close() error {
   336  	s.close(ErrServerClosed, true)
   337  	return nil
   338  }
   339  
   340  func (s *baseServer) close(e error, notifyOnClose bool) {
   341  	s.closeMx.Lock()
   342  	if s.closeErr != nil {
   343  		s.closeMx.Unlock()
   344  		return
   345  	}
   346  	s.closeErr = e
   347  	close(s.errorChan)
   348  	<-s.running
   349  	s.closeMx.Unlock()
   350  
   351  	if notifyOnClose {
   352  		s.onClose()
   353  	}
   354  }
   355  
   356  // Addr returns the server's network address
   357  func (s *baseServer) Addr() net.Addr {
   358  	return s.conn.LocalAddr()
   359  }
   360  
   361  func (s *baseServer) handlePacket(p receivedPacket) {
   362  	select {
   363  	case s.receivedPackets <- p:
   364  	default:
   365  		s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
   366  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   367  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
   368  		}
   369  	}
   370  }
   371  
   372  func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ {
   373  	if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
   374  		defer s.cleanupZeroRTTQueues(p.rcvTime)
   375  	}
   376  
   377  	if wire.IsVersionNegotiationPacket(p.data) {
   378  		s.logger.Debugf("Dropping Version Negotiation packet.")
   379  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   380  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
   381  		}
   382  		return false
   383  	}
   384  	// Short header packets should never end up here in the first place
   385  	if !wire.IsLongHeaderPacket(p.data[0]) {
   386  		panic(fmt.Sprintf("misrouted packet: %#v", p.data))
   387  	}
   388  	v, err := wire.ParseVersion(p.data)
   389  	// drop the packet if we failed to parse the protocol version
   390  	if err != nil {
   391  		s.logger.Debugf("Dropping a packet with an unknown version")
   392  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   393  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
   394  		}
   395  		return false
   396  	}
   397  	// send a Version Negotiation Packet if the client is speaking a different protocol version
   398  	if !protocol.IsSupportedVersion(s.config.Versions, v) {
   399  		if s.disableVersionNegotiation {
   400  			return false
   401  		}
   402  
   403  		if p.Size() < protocol.MinUnknownVersionPacketSize {
   404  			s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size())
   405  			if s.tracer != nil && s.tracer.DroppedPacket != nil {
   406  				s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
   407  			}
   408  			return false
   409  		}
   410  		return s.enqueueVersionNegotiationPacket(p)
   411  	}
   412  
   413  	if wire.Is0RTTPacket(p.data) {
   414  		if !s.acceptEarlyConns {
   415  			if s.tracer != nil && s.tracer.DroppedPacket != nil {
   416  				s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
   417  			}
   418  			return false
   419  		}
   420  		return s.handle0RTTPacket(p)
   421  	}
   422  
   423  	// If we're creating a new connection, the packet will be passed to the connection.
   424  	// The header will then be parsed again.
   425  	hdr, _, _, err := wire.ParsePacket(p.data)
   426  	if err != nil {
   427  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   428  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
   429  		}
   430  		s.logger.Debugf("Error parsing packet: %s", err)
   431  		return false
   432  	}
   433  	if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
   434  		s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
   435  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   436  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   437  		}
   438  		return false
   439  	}
   440  
   441  	if hdr.Type != protocol.PacketTypeInitial {
   442  		// Drop long header packets.
   443  		// There's little point in sending a Stateless Reset, since the client
   444  		// might not have received the token yet.
   445  		s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
   446  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   447  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
   448  		}
   449  		return false
   450  	}
   451  
   452  	s.logger.Debugf("<- Received Initial packet.")
   453  
   454  	if err := s.handleInitialImpl(p, hdr); err != nil {
   455  		s.logger.Errorf("Error occurred handling initial packet: %s", err)
   456  	}
   457  	// Don't put the packet buffer back.
   458  	// handleInitialImpl deals with the buffer.
   459  	return true
   460  }
   461  
   462  func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
   463  	connID, err := wire.ParseConnectionID(p.data, 0)
   464  	if err != nil {
   465  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   466  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
   467  		}
   468  		return false
   469  	}
   470  
   471  	// check again if we might have a connection now
   472  	if handler, ok := s.connHandler.Get(connID); ok {
   473  		handler.handlePacket(p)
   474  		return true
   475  	}
   476  
   477  	if q, ok := s.zeroRTTQueues[connID]; ok {
   478  		if len(q.packets) >= protocol.Max0RTTQueueLen {
   479  			if s.tracer != nil && s.tracer.DroppedPacket != nil {
   480  				s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
   481  			}
   482  			return false
   483  		}
   484  		q.packets = append(q.packets, p)
   485  		return true
   486  	}
   487  
   488  	if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
   489  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   490  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
   491  		}
   492  		return false
   493  	}
   494  	queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)}
   495  	queue.packets[0] = p
   496  	expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
   497  	queue.expiration = expiration
   498  	if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
   499  		s.nextZeroRTTCleanup = expiration
   500  	}
   501  	s.zeroRTTQueues[connID] = queue
   502  	return true
   503  }
   504  
   505  func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
   506  	// Iterate over all queues to find those that are expired.
   507  	// This is ok since we're placing a pretty low limit on the number of queues.
   508  	var nextCleanup time.Time
   509  	for connID, q := range s.zeroRTTQueues {
   510  		if q.expiration.After(now) {
   511  			if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
   512  				nextCleanup = q.expiration
   513  			}
   514  			continue
   515  		}
   516  		for _, p := range q.packets {
   517  			if s.tracer != nil && s.tracer.DroppedPacket != nil {
   518  				s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
   519  			}
   520  			p.buffer.Release()
   521  		}
   522  		delete(s.zeroRTTQueues, connID)
   523  		if s.logger.Debug() {
   524  			s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
   525  		}
   526  	}
   527  	s.nextZeroRTTCleanup = nextCleanup
   528  }
   529  
   530  // validateToken returns false if:
   531  //   - address is invalid
   532  //   - token is expired
   533  //   - token is null
   534  func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
   535  	if token == nil {
   536  		return false
   537  	}
   538  	if !token.ValidateRemoteAddr(addr) {
   539  		return false
   540  	}
   541  	if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge {
   542  		return false
   543  	}
   544  	if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() {
   545  		return false
   546  	}
   547  	return true
   548  }
   549  
   550  func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
   551  	if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
   552  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   553  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   554  		}
   555  		p.buffer.Release()
   556  		return errors.New("too short connection ID")
   557  	}
   558  
   559  	// The server queues packets for a while, and we might already have established a connection by now.
   560  	// This results in a second check in the connection map.
   561  	// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
   562  	if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
   563  		handler.handlePacket(p)
   564  		return nil
   565  	}
   566  
   567  	var (
   568  		token          *handshake.Token
   569  		retrySrcConnID *protocol.ConnectionID
   570  	)
   571  	origDestConnID := hdr.DestConnectionID
   572  	if len(hdr.Token) > 0 {
   573  		tok, err := s.tokenGenerator.DecodeToken(hdr.Token)
   574  		if err == nil {
   575  			if tok.IsRetryToken {
   576  				origDestConnID = tok.OriginalDestConnectionID
   577  				retrySrcConnID = &tok.RetrySrcConnectionID
   578  			}
   579  			token = tok
   580  		}
   581  	}
   582  
   583  	clientAddrValidated := s.validateToken(token, p.remoteAddr)
   584  	if token != nil && !clientAddrValidated {
   585  		// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
   586  		// We just ignore them, and act as if there was no token on this packet at all.
   587  		// This also means we might send a Retry later.
   588  		if !token.IsRetryToken {
   589  			token = nil
   590  		} else {
   591  			// For Retry tokens, we send an INVALID_ERROR if
   592  			// * the token is too old, or
   593  			// * the token is invalid, in case of a retry token.
   594  			select {
   595  			case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
   596  			default:
   597  				// drop packet if we can't send out the  INVALID_TOKEN packets fast enough
   598  				p.buffer.Release()
   599  			}
   600  			return nil
   601  		}
   602  	}
   603  
   604  	// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
   605  	// They might decrease if another connection completes the handshake.
   606  	numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
   607  	numHandshakesValidated := s.numHandshakesValidated.Load()
   608  
   609  	// Check the total handshake limit first. It's better to reject than to initiate a retry.
   610  	if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
   611  		s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
   612  		delete(s.zeroRTTQueues, hdr.DestConnectionID)
   613  		select {
   614  		case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
   615  		default:
   616  			// drop packet if we can't send out the CONNECTION_REFUSED fast enough
   617  			p.buffer.Release()
   618  		}
   619  		return nil
   620  	}
   621  	if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) {
   622  		// Retry invalidates all 0-RTT packets sent.
   623  		delete(s.zeroRTTQueues, hdr.DestConnectionID)
   624  		select {
   625  		case s.retryQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
   626  		default:
   627  			// drop packet if we can't send out Retry packets fast enough
   628  			p.buffer.Release()
   629  		}
   630  		return nil
   631  	}
   632  
   633  	connID, err := s.connIDGenerator.GenerateConnectionID()
   634  	if err != nil {
   635  		return err
   636  	}
   637  	s.logger.Debugf("Changing connection ID to %s.", connID)
   638  	var conn quicConn
   639  	tracingID := nextConnTracingID()
   640  	config := s.config
   641  	if s.config.GetConfigForClient != nil {
   642  		conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
   643  		if err != nil {
   644  			s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
   645  			delete(s.zeroRTTQueues, hdr.DestConnectionID)
   646  			select {
   647  			case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
   648  			default:
   649  				// drop packet if we can't send out the CONNECTION_REFUSED fast enough
   650  				p.buffer.Release()
   651  			}
   652  			return nil
   653  		}
   654  		config = populateConfig(conf)
   655  	}
   656  	var tracer *logging.ConnectionTracer
   657  	if config.Tracer != nil {
   658  		// Use the same connection ID that is passed to the client's GetLogWriter callback.
   659  		connID := hdr.DestConnectionID
   660  		if origDestConnID.Len() > 0 {
   661  			connID = origDestConnID
   662  		}
   663  		tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
   664  	}
   665  	conn = s.newConn(
   666  		newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
   667  		s.connHandler,
   668  		origDestConnID,
   669  		retrySrcConnID,
   670  		hdr.DestConnectionID,
   671  		hdr.SrcConnectionID,
   672  		connID,
   673  		s.connIDGenerator,
   674  		s.connHandler.GetStatelessResetToken(connID),
   675  		config,
   676  		s.tlsConf,
   677  		s.tokenGenerator,
   678  		clientAddrValidated,
   679  		tracer,
   680  		tracingID,
   681  		s.logger,
   682  		hdr.Version,
   683  	)
   684  	conn.handlePacket(p)
   685  	// Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
   686  	// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
   687  	// under normal circumstances the packet would just be routed to that connection.
   688  	// The only time this collision will occur if we receive the two Initial packets at the same time.
   689  	if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
   690  		delete(s.zeroRTTQueues, hdr.DestConnectionID)
   691  		conn.closeWithTransportError(qerr.ConnectionRefused)
   692  		return nil
   693  	}
   694  	// Pass queued 0-RTT to the newly established connection.
   695  	if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
   696  		for _, p := range q.packets {
   697  			conn.handlePacket(p)
   698  		}
   699  		delete(s.zeroRTTQueues, hdr.DestConnectionID)
   700  	}
   701  
   702  	if clientAddrValidated {
   703  		s.numHandshakesValidated.Add(1)
   704  	} else {
   705  		s.numHandshakesUnvalidated.Add(1)
   706  	}
   707  	go conn.run()
   708  	go func() {
   709  		completed := s.handleNewConn(conn)
   710  		if clientAddrValidated {
   711  			if s.numHandshakesValidated.Add(-1) < 0 {
   712  				panic("server BUG: number of validated handshakes negative")
   713  			}
   714  		} else if s.numHandshakesUnvalidated.Add(-1) < 0 {
   715  			panic("server BUG: number of unvalidated handshakes negative")
   716  		}
   717  		if !completed {
   718  			return
   719  		}
   720  
   721  		select {
   722  		case s.connQueue <- conn:
   723  		default:
   724  			conn.closeWithTransportError(ConnectionRefused)
   725  		}
   726  	}()
   727  	return nil
   728  }
   729  
   730  func (s *baseServer) handleNewConn(conn quicConn) bool {
   731  	if s.acceptEarlyConns {
   732  		// wait until the early connection is ready, the handshake fails, or the server is closed
   733  		select {
   734  		case <-s.errorChan:
   735  			conn.closeWithTransportError(ConnectionRefused)
   736  			return false
   737  		case <-conn.Context().Done():
   738  			return false
   739  		case <-conn.earlyConnReady():
   740  			return true
   741  		}
   742  	}
   743  	// wait until the handshake completes, fails, or the server is closed
   744  	select {
   745  	case <-s.errorChan:
   746  		conn.closeWithTransportError(ConnectionRefused)
   747  		return false
   748  	case <-conn.Context().Done():
   749  		return false
   750  	case <-conn.HandshakeComplete():
   751  		return true
   752  	}
   753  }
   754  
   755  func (s *baseServer) sendRetry(p rejectedPacket) {
   756  	if err := s.sendRetryPacket(p); err != nil {
   757  		s.logger.Debugf("Error sending Retry packet: %s", err)
   758  	}
   759  }
   760  
   761  func (s *baseServer) sendRetryPacket(p rejectedPacket) error {
   762  	hdr := p.hdr
   763  	// Log the Initial packet now.
   764  	// If no Retry is sent, the packet will be logged by the connection.
   765  	(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
   766  	srcConnID, err := s.connIDGenerator.GenerateConnectionID()
   767  	if err != nil {
   768  		return err
   769  	}
   770  	token, err := s.tokenGenerator.NewRetryToken(p.remoteAddr, hdr.DestConnectionID, srcConnID)
   771  	if err != nil {
   772  		return err
   773  	}
   774  	replyHdr := &wire.ExtendedHeader{}
   775  	replyHdr.Type = protocol.PacketTypeRetry
   776  	replyHdr.Version = hdr.Version
   777  	replyHdr.SrcConnectionID = srcConnID
   778  	replyHdr.DestConnectionID = hdr.SrcConnectionID
   779  	replyHdr.Token = token
   780  	if s.logger.Debug() {
   781  		s.logger.Debugf("Changing connection ID to %s.", srcConnID)
   782  		s.logger.Debugf("-> Sending Retry")
   783  		replyHdr.Log(s.logger)
   784  	}
   785  
   786  	buf := getPacketBuffer()
   787  	defer buf.Release()
   788  	buf.Data, err = replyHdr.Append(buf.Data, hdr.Version)
   789  	if err != nil {
   790  		return err
   791  	}
   792  	// append the Retry integrity tag
   793  	tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
   794  	buf.Data = append(buf.Data, tag[:]...)
   795  	if s.tracer != nil && s.tracer.SentPacket != nil {
   796  		s.tracer.SentPacket(p.remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
   797  	}
   798  	_, err = s.conn.WritePacket(buf.Data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported)
   799  	return err
   800  }
   801  
   802  func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) {
   803  	defer p.buffer.Release()
   804  
   805  	// Only send INVALID_TOKEN if we can unprotect the packet.
   806  	// This makes sure that we won't send it for packets that were corrupted.
   807  	hdr := p.hdr
   808  	sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
   809  	data := p.data[:hdr.ParsedLen()+hdr.Length]
   810  	extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
   811  	// Only send INVALID_TOKEN if we can unprotect the packet.
   812  	// This makes sure that we won't send it for packets that were corrupted.
   813  	if err != nil {
   814  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   815  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
   816  		}
   817  		return
   818  	}
   819  	hdrLen := extHdr.ParsedLen()
   820  	if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
   821  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   822  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
   823  		}
   824  		return
   825  	}
   826  	if s.logger.Debug() {
   827  		s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr)
   828  	}
   829  	if err := s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info); err != nil {
   830  		s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
   831  	}
   832  }
   833  
   834  func (s *baseServer) sendConnectionRefused(p rejectedPacket) {
   835  	defer p.buffer.Release()
   836  	sealer, _ := handshake.NewInitialAEAD(p.hdr.DestConnectionID, protocol.PerspectiveServer, p.hdr.Version)
   837  	if err := s.sendError(p.remoteAddr, p.hdr, sealer, qerr.ConnectionRefused, p.info); err != nil {
   838  		s.logger.Debugf("Error sending CONNECTION_REFUSED error: %s", err)
   839  	}
   840  }
   841  
   842  // sendError sends the error as a response to the packet received with header hdr
   843  func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error {
   844  	b := getPacketBuffer()
   845  	defer b.Release()
   846  
   847  	ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)}
   848  
   849  	replyHdr := &wire.ExtendedHeader{}
   850  	replyHdr.Type = protocol.PacketTypeInitial
   851  	replyHdr.Version = hdr.Version
   852  	replyHdr.SrcConnectionID = hdr.DestConnectionID
   853  	replyHdr.DestConnectionID = hdr.SrcConnectionID
   854  	replyHdr.PacketNumberLen = protocol.PacketNumberLen4
   855  	replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead())
   856  	var err error
   857  	b.Data, err = replyHdr.Append(b.Data, hdr.Version)
   858  	if err != nil {
   859  		return err
   860  	}
   861  	payloadOffset := len(b.Data)
   862  
   863  	b.Data, err = ccf.Append(b.Data, hdr.Version)
   864  	if err != nil {
   865  		return err
   866  	}
   867  
   868  	_ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset])
   869  	b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()]
   870  
   871  	pnOffset := payloadOffset - int(replyHdr.PacketNumberLen)
   872  	sealer.EncryptHeader(
   873  		b.Data[pnOffset+4:pnOffset+4+16],
   874  		&b.Data[0],
   875  		b.Data[pnOffset:payloadOffset],
   876  	)
   877  
   878  	replyHdr.Log(s.logger)
   879  	wire.LogFrame(s.logger, ccf, true)
   880  	if s.tracer != nil && s.tracer.SentPacket != nil {
   881  		s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
   882  	}
   883  	_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
   884  	return err
   885  }
   886  
   887  func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) {
   888  	select {
   889  	case s.versionNegotiationQueue <- p:
   890  		return true
   891  	default:
   892  		// it's fine to not send version negotiation packets when we are busy
   893  	}
   894  	return false
   895  }
   896  
   897  func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
   898  	defer p.buffer.Release()
   899  
   900  	v, err := wire.ParseVersion(p.data)
   901  	if err != nil {
   902  		s.logger.Debugf("failed to parse version for sending version negotiation packet: %s", err)
   903  		return
   904  	}
   905  
   906  	_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
   907  	if err != nil { // should never happen
   908  		s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
   909  		if s.tracer != nil && s.tracer.DroppedPacket != nil {
   910  			s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
   911  		}
   912  		return
   913  	}
   914  
   915  	s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
   916  
   917  	data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
   918  	if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil {
   919  		s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
   920  	}
   921  	if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
   922  		s.logger.Debugf("Error sending Version Negotiation: %s", err)
   923  	}
   924  }