github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/server_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/handshake"
    14  	mocklogging "github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks/logging"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    16  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    17  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/testdata"
    18  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    19  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    20  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  	"go.uber.org/mock/gomock"
    25  )
    26  
    27  var _ = Describe("Server", func() {
    28  	var (
    29  		conn    *MockPacketConn
    30  		tlsConf *tls.Config
    31  	)
    32  
    33  	getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
    34  		buf := getPacketBuffer()
    35  		hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
    36  		var err error
    37  		buf.Data, err = (&wire.ExtendedHeader{
    38  			Header:          *hdr,
    39  			PacketNumber:    0x42,
    40  			PacketNumberLen: protocol.PacketNumberLen4,
    41  		}).Append(buf.Data, protocol.Version1)
    42  		Expect(err).ToNot(HaveOccurred())
    43  		n := len(buf.Data)
    44  		buf.Data = append(buf.Data, p...)
    45  		data := buf.Data
    46  		sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
    47  		_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
    48  		data = data[:len(data)+16]
    49  		sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
    50  		return receivedPacket{
    51  			remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
    52  			data:       data,
    53  			buffer:     buf,
    54  		}
    55  	}
    56  
    57  	getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
    58  		senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
    59  		hdr := &wire.Header{
    60  			Type:             protocol.PacketTypeInitial,
    61  			SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
    62  			DestConnectionID: destConnID,
    63  			Version:          protocol.Version1,
    64  		}
    65  		p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
    66  		p.buffer = getPacketBuffer()
    67  		p.remoteAddr = senderAddr
    68  		return p
    69  	}
    70  
    71  	getInitialWithRandomDestConnID := func() receivedPacket {
    72  		b := make([]byte, 10)
    73  		_, err := rand.Read(b)
    74  		Expect(err).ToNot(HaveOccurred())
    75  
    76  		return getInitial(protocol.ParseConnectionID(b))
    77  	}
    78  
    79  	parseHeader := func(data []byte) *wire.Header {
    80  		hdr, _, _, err := wire.ParsePacket(data)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		return hdr
    83  	}
    84  
    85  	checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
    86  		replyHdr := parseHeader(b)
    87  		Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
    88  		Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
    89  		Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
    90  		_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
    91  		extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
    92  		Expect(err).ToNot(HaveOccurred())
    93  		data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
    94  		Expect(err).ToNot(HaveOccurred())
    95  		_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
    96  		Expect(err).ToNot(HaveOccurred())
    97  		Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
    98  		ccf := f.(*wire.ConnectionCloseFrame)
    99  		Expect(ccf.IsApplicationError).To(BeFalse())
   100  		Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
   101  		Expect(ccf.ReasonPhrase).To(BeEmpty())
   102  	}
   103  
   104  	BeforeEach(func() {
   105  		conn = NewMockPacketConn(mockCtrl)
   106  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
   107  		wait := make(chan struct{})
   108  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
   109  			<-wait
   110  			return 0, nil, errors.New("done")
   111  		}).MaxTimes(1)
   112  		conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error {
   113  			close(wait)
   114  			conn.EXPECT().SetReadDeadline(time.Time{})
   115  			return nil
   116  		}).MaxTimes(1)
   117  		tlsConf = testdata.GetTLSConfig()
   118  		tlsConf.NextProtos = []string{"proto1"}
   119  	})
   120  
   121  	It("errors when no tls.Config is given", func() {
   122  		_, err := ListenAddr("localhost:0", nil, nil)
   123  		Expect(err).To(HaveOccurred())
   124  		Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
   125  	})
   126  
   127  	It("errors when the Config contains an invalid version", func() {
   128  		version := protocol.Version(0x1234)
   129  		_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}})
   130  		Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
   131  	})
   132  
   133  	It("fills in default values if options are not set in the Config", func() {
   134  		ln, err := Listen(conn, tlsConf, &Config{})
   135  		Expect(err).ToNot(HaveOccurred())
   136  		server := ln.baseServer
   137  		Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
   138  		Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   139  		Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
   140  		Expect(server.config.KeepAlivePeriod).To(BeZero())
   141  		// stop the listener
   142  		Expect(ln.Close()).To(Succeed())
   143  	})
   144  
   145  	It("setups with the right values", func() {
   146  		supportedVersions := []protocol.Version{protocol.Version1}
   147  		config := Config{
   148  			Versions:             supportedVersions,
   149  			HandshakeIdleTimeout: 1337 * time.Hour,
   150  			MaxIdleTimeout:       42 * time.Minute,
   151  			KeepAlivePeriod:      5 * time.Second,
   152  		}
   153  		ln, err := Listen(conn, tlsConf, &config)
   154  		Expect(err).ToNot(HaveOccurred())
   155  		server := ln.baseServer
   156  		Expect(server.connHandler).ToNot(BeNil())
   157  		Expect(server.config.Versions).To(Equal(supportedVersions))
   158  		Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
   159  		Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
   160  		Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
   161  		// stop the listener
   162  		Expect(ln.Close()).To(Succeed())
   163  	})
   164  
   165  	It("listens on a given address", func() {
   166  		addr := "127.0.0.1:13579"
   167  		ln, err := ListenAddr(addr, tlsConf, &Config{})
   168  		Expect(err).ToNot(HaveOccurred())
   169  		Expect(ln.Addr().String()).To(Equal(addr))
   170  		// stop the listener
   171  		Expect(ln.Close()).To(Succeed())
   172  	})
   173  
   174  	It("errors if given an invalid address", func() {
   175  		addr := "127.0.0.1"
   176  		_, err := ListenAddr(addr, tlsConf, &Config{})
   177  		Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
   178  	})
   179  
   180  	It("errors if given an invalid address", func() {
   181  		addr := "1.1.1.1:1111"
   182  		_, err := ListenAddr(addr, tlsConf, &Config{})
   183  		Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
   184  	})
   185  
   186  	Context("server accepting connections that completed the handshake", func() {
   187  		var (
   188  			tr     *Transport
   189  			serv   *baseServer
   190  			phm    *MockPacketHandlerManager
   191  			tracer *mocklogging.MockTracer
   192  		)
   193  
   194  		BeforeEach(func() {
   195  			var t *logging.Tracer
   196  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
   197  			tr = &Transport{Conn: conn, Tracer: t}
   198  			ln, err := tr.Listen(tlsConf, nil)
   199  			Expect(err).ToNot(HaveOccurred())
   200  			serv = ln.baseServer
   201  			phm = NewMockPacketHandlerManager(mockCtrl)
   202  			serv.connHandler = phm
   203  		})
   204  
   205  		AfterEach(func() {
   206  			tracer.EXPECT().Close()
   207  			tr.Close()
   208  		})
   209  
   210  		Context("handling packets", func() {
   211  			It("drops Initial packets with a too short connection ID", func() {
   212  				p := getPacket(&wire.Header{
   213  					Type:             protocol.PacketTypeInitial,
   214  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
   215  					Version:          serv.config.Versions[0],
   216  				}, nil)
   217  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   218  				serv.handlePacket(p)
   219  				// make sure there are no Write calls on the packet conn
   220  				time.Sleep(50 * time.Millisecond)
   221  			})
   222  
   223  			It("drops too small Initial", func() {
   224  				p := getPacket(&wire.Header{
   225  					Type:             protocol.PacketTypeInitial,
   226  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
   227  					Version:          serv.config.Versions[0],
   228  				}, make([]byte, protocol.MinInitialPacketSize-100))
   229  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   230  				serv.handlePacket(p)
   231  				// make sure there are no Write calls on the packet conn
   232  				time.Sleep(50 * time.Millisecond)
   233  			})
   234  
   235  			It("drops non-Initial packets", func() {
   236  				p := getPacket(&wire.Header{
   237  					Type:    protocol.PacketTypeHandshake,
   238  					Version: serv.config.Versions[0],
   239  				}, []byte("invalid"))
   240  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
   241  				serv.handlePacket(p)
   242  				// make sure there are no Write calls on the packet conn
   243  				time.Sleep(50 * time.Millisecond)
   244  			})
   245  
   246  			It("passes packets to existing connections", func() {
   247  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
   248  				p := getPacket(&wire.Header{
   249  					Type:             protocol.PacketTypeInitial,
   250  					DestConnectionID: connID,
   251  					Version:          serv.config.Versions[0],
   252  				}, make([]byte, protocol.MinInitialPacketSize))
   253  				conn := NewMockPacketHandler(mockCtrl)
   254  				phm.EXPECT().Get(connID).Return(conn, true)
   255  				handled := make(chan struct{})
   256  				conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
   257  				serv.handlePacket(p)
   258  				Eventually(handled).Should(BeClosed())
   259  			})
   260  
   261  			It("creates a connection when the token is accepted", func() {
   262  				serv.maxNumHandshakesUnvalidated = 0
   263  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   264  				retryToken, err := serv.tokenGenerator.NewRetryToken(
   265  					raddr,
   266  					protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
   267  					protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
   268  				)
   269  				Expect(err).ToNot(HaveOccurred())
   270  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   271  				hdr := &wire.Header{
   272  					Type:             protocol.PacketTypeInitial,
   273  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   274  					DestConnectionID: connID,
   275  					Version:          protocol.Version1,
   276  					Token:            retryToken,
   277  				}
   278  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   279  				p.remoteAddr = raddr
   280  				run := make(chan struct{})
   281  				var token protocol.StatelessResetToken
   282  				rand.Read(token[:])
   283  
   284  				var newConnID protocol.ConnectionID
   285  				conn := NewMockQUICConn(mockCtrl)
   286  				serv.newConn = func(
   287  					_ sendConn,
   288  					_ connRunner,
   289  					origDestConnID protocol.ConnectionID,
   290  					retrySrcConnID *protocol.ConnectionID,
   291  					clientDestConnID protocol.ConnectionID,
   292  					destConnID protocol.ConnectionID,
   293  					srcConnID protocol.ConnectionID,
   294  					_ ConnectionIDGenerator,
   295  					tokenP protocol.StatelessResetToken,
   296  					_ *Config,
   297  					_ *tls.Config,
   298  					_ *handshake.TokenGenerator,
   299  					_ bool,
   300  					_ *logging.ConnectionTracer,
   301  					_ uint64,
   302  					_ utils.Logger,
   303  					_ protocol.Version,
   304  				) quicConn {
   305  					Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
   306  					Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
   307  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   308  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   309  					// make sure we're using a server-generated connection ID
   310  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   311  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   312  					newConnID = srcConnID
   313  					Expect(tokenP).To(Equal(token))
   314  					conn.EXPECT().handlePacket(p)
   315  					conn.EXPECT().run().Do(func() error { close(run); return nil })
   316  					conn.EXPECT().Context().Return(context.Background())
   317  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   318  					return conn
   319  				}
   320  				phm.EXPECT().Get(connID)
   321  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
   322  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
   323  					Expect(cid).To(Equal(newConnID))
   324  					return true
   325  				})
   326  
   327  				done := make(chan struct{})
   328  				go func() {
   329  					defer GinkgoRecover()
   330  					serv.handlePacket(p)
   331  					// the Handshake packet is written by the connection.
   332  					// Make sure there are no Write calls on the packet conn.
   333  					time.Sleep(50 * time.Millisecond)
   334  					close(done)
   335  				}()
   336  				// make sure we're using a server-generated connection ID
   337  				Eventually(run).Should(BeClosed())
   338  				Eventually(done).Should(BeClosed())
   339  				// shutdown
   340  				conn.EXPECT().closeWithTransportError(gomock.Any())
   341  			})
   342  
   343  			It("sends a Version Negotiation Packet for unsupported versions", func() {
   344  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   345  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   346  				packet := getPacket(&wire.Header{
   347  					Type:             protocol.PacketTypeHandshake,
   348  					SrcConnectionID:  srcConnID,
   349  					DestConnectionID: destConnID,
   350  					Version:          0x42,
   351  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   352  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   353  				packet.remoteAddr = raddr
   354  				tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) {
   355  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   356  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   357  				})
   358  				done := make(chan struct{})
   359  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   360  					defer close(done)
   361  					Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
   362  					dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
   363  					Expect(err).ToNot(HaveOccurred())
   364  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   365  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   366  					Expect(versions).ToNot(ContainElement(protocol.Version(0x42)))
   367  					return len(b), nil
   368  				})
   369  				serv.handlePacket(packet)
   370  				Eventually(done).Should(BeClosed())
   371  			})
   372  
   373  			It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
   374  				serv.disableVersionNegotiation = true
   375  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   376  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   377  				packet := getPacket(&wire.Header{
   378  					Type:             protocol.PacketTypeHandshake,
   379  					SrcConnectionID:  srcConnID,
   380  					DestConnectionID: destConnID,
   381  					Version:          0x42,
   382  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   383  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   384  				packet.remoteAddr = raddr
   385  				done := make(chan struct{})
   386  				serv.handlePacket(packet)
   387  				Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
   388  			})
   389  
   390  			It("ignores Version Negotiation packets", func() {
   391  				data := wire.ComposeVersionNegotiation(
   392  					protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
   393  					protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
   394  					[]protocol.Version{1, 2, 3},
   395  				)
   396  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   397  				done := make(chan struct{})
   398  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   399  					close(done)
   400  				})
   401  				serv.handlePacket(receivedPacket{
   402  					remoteAddr: raddr,
   403  					data:       data,
   404  					buffer:     getPacketBuffer(),
   405  				})
   406  				Eventually(done).Should(BeClosed())
   407  				// make sure no other packet is sent
   408  				time.Sleep(scaleDuration(20 * time.Millisecond))
   409  			})
   410  
   411  			It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
   412  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   413  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   414  				p := getPacket(&wire.Header{
   415  					Type:             protocol.PacketTypeHandshake,
   416  					SrcConnectionID:  srcConnID,
   417  					DestConnectionID: destConnID,
   418  					Version:          0x42,
   419  				}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
   420  				Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
   421  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   422  				p.remoteAddr = raddr
   423  				done := make(chan struct{})
   424  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   425  					close(done)
   426  				})
   427  				serv.handlePacket(p)
   428  				Eventually(done).Should(BeClosed())
   429  				// make sure no other packet is sent
   430  				time.Sleep(scaleDuration(20 * time.Millisecond))
   431  			})
   432  
   433  			It("replies with a Retry packet, if a token is required", func() {
   434  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   435  				serv.maxNumHandshakesUnvalidated = 0
   436  				hdr := &wire.Header{
   437  					Type:             protocol.PacketTypeInitial,
   438  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   439  					DestConnectionID: connID,
   440  					Version:          protocol.Version1,
   441  				}
   442  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   443  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   444  				packet.remoteAddr = raddr
   445  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
   446  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   447  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   448  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   449  					Expect(replyHdr.Token).ToNot(BeEmpty())
   450  				})
   451  				done := make(chan struct{})
   452  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   453  					defer close(done)
   454  					replyHdr := parseHeader(b)
   455  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   456  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   457  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   458  					Expect(replyHdr.Token).ToNot(BeEmpty())
   459  					Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
   460  					return len(b), nil
   461  				})
   462  				phm.EXPECT().Get(connID)
   463  				serv.handlePacket(packet)
   464  				Eventually(done).Should(BeClosed())
   465  			})
   466  
   467  			It("creates a connection, if no token is required", func() {
   468  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   469  				hdr := &wire.Header{
   470  					Type:             protocol.PacketTypeInitial,
   471  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   472  					DestConnectionID: connID,
   473  					Version:          protocol.Version1,
   474  				}
   475  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   476  				run := make(chan struct{})
   477  				var token protocol.StatelessResetToken
   478  				rand.Read(token[:])
   479  
   480  				var newConnID protocol.ConnectionID
   481  				conn := NewMockQUICConn(mockCtrl)
   482  				serv.newConn = func(
   483  					_ sendConn,
   484  					_ connRunner,
   485  					origDestConnID protocol.ConnectionID,
   486  					retrySrcConnID *protocol.ConnectionID,
   487  					clientDestConnID protocol.ConnectionID,
   488  					destConnID protocol.ConnectionID,
   489  					srcConnID protocol.ConnectionID,
   490  					_ ConnectionIDGenerator,
   491  					tokenP protocol.StatelessResetToken,
   492  					_ *Config,
   493  					_ *tls.Config,
   494  					_ *handshake.TokenGenerator,
   495  					_ bool,
   496  					_ *logging.ConnectionTracer,
   497  					_ uint64,
   498  					_ utils.Logger,
   499  					_ protocol.Version,
   500  				) quicConn {
   501  					Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
   502  					Expect(retrySrcConnID).To(BeNil())
   503  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   504  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   505  					// make sure we're using a server-generated connection ID
   506  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   507  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   508  					newConnID = srcConnID
   509  					Expect(tokenP).To(Equal(token))
   510  					conn.EXPECT().handlePacket(p)
   511  					conn.EXPECT().run().Do(func() error { close(run); return nil })
   512  					conn.EXPECT().Context().Return(context.Background())
   513  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   514  					return conn
   515  				}
   516  				gomock.InOrder(
   517  					phm.EXPECT().Get(connID),
   518  					phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
   519  					phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
   520  						Expect(c).To(Equal(newConnID))
   521  						return true
   522  					}),
   523  				)
   524  
   525  				done := make(chan struct{})
   526  				go func() {
   527  					defer GinkgoRecover()
   528  					serv.handlePacket(p)
   529  					// the Handshake packet is written by the connection
   530  					// make sure there are no Write calls on the packet conn
   531  					time.Sleep(50 * time.Millisecond)
   532  					close(done)
   533  				}()
   534  				// make sure we're using a server-generated connection ID
   535  				Eventually(run).Should(BeClosed())
   536  				Eventually(done).Should(BeClosed())
   537  				// shutdown
   538  				conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
   539  			})
   540  
   541  			It("drops packets if the receive queue is full", func() {
   542  				serv.maxNumHandshakesTotal = 10000
   543  				serv.maxNumHandshakesUnvalidated = 10000
   544  
   545  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   546  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   547  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   548  
   549  				acceptConn := make(chan struct{})
   550  				var counter atomic.Uint32
   551  				serv.newConn = func(
   552  					_ sendConn,
   553  					runner connRunner,
   554  					_ protocol.ConnectionID,
   555  					_ *protocol.ConnectionID,
   556  					_ protocol.ConnectionID,
   557  					_ protocol.ConnectionID,
   558  					_ protocol.ConnectionID,
   559  					_ ConnectionIDGenerator,
   560  					_ protocol.StatelessResetToken,
   561  					_ *Config,
   562  					_ *tls.Config,
   563  					_ *handshake.TokenGenerator,
   564  					_ bool,
   565  					_ *logging.ConnectionTracer,
   566  					_ uint64,
   567  					_ utils.Logger,
   568  					_ protocol.Version,
   569  				) quicConn {
   570  					<-acceptConn
   571  					counter.Add(1)
   572  					conn := NewMockQUICConn(mockCtrl)
   573  					conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
   574  					conn.EXPECT().run().MaxTimes(1)
   575  					conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
   576  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
   577  					// shutdown
   578  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
   579  					return conn
   580  				}
   581  
   582  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
   583  				serv.handlePacket(p)
   584  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
   585  				var wg sync.WaitGroup
   586  				for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
   587  					wg.Add(1)
   588  					go func() {
   589  						defer GinkgoRecover()
   590  						defer wg.Done()
   591  						serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
   592  					}()
   593  				}
   594  				wg.Wait()
   595  
   596  				close(acceptConn)
   597  				Eventually(
   598  					func() uint32 { return counter.Load() },
   599  					scaleDuration(100*time.Millisecond),
   600  				).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   601  				Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   602  			})
   603  
   604  			It("only creates a single connection for a duplicate Initial", func() {
   605  				done := make(chan struct{})
   606  				serv.newConn = func(
   607  					_ sendConn,
   608  					runner connRunner,
   609  					_ protocol.ConnectionID,
   610  					_ *protocol.ConnectionID,
   611  					_ protocol.ConnectionID,
   612  					_ protocol.ConnectionID,
   613  					_ protocol.ConnectionID,
   614  					_ ConnectionIDGenerator,
   615  					_ protocol.StatelessResetToken,
   616  					_ *Config,
   617  					_ *tls.Config,
   618  					_ *handshake.TokenGenerator,
   619  					_ bool,
   620  					_ *logging.ConnectionTracer,
   621  					_ uint64,
   622  					_ utils.Logger,
   623  					_ protocol.Version,
   624  				) quicConn {
   625  					conn := NewMockQUICConn(mockCtrl)
   626  					conn.EXPECT().handlePacket(gomock.Any())
   627  					conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) {
   628  						close(done)
   629  					})
   630  					return conn
   631  				}
   632  
   633  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
   634  				p := getInitial(connID)
   635  				phm.EXPECT().Get(connID)
   636  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
   637  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
   638  				Expect(serv.handlePacketImpl(p)).To(BeTrue())
   639  				Eventually(done).Should(BeClosed())
   640  			})
   641  
   642  			It("limits the number of unvalidated handshakes", func() {
   643  				const limit = 3
   644  				serv.maxNumHandshakesTotal = 10000
   645  				serv.maxNumHandshakesUnvalidated = limit
   646  
   647  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   648  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   649  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   650  
   651  				handshakeChan := make(chan struct{})
   652  				connChan := make(chan *MockQUICConn, 1)
   653  				var wg sync.WaitGroup
   654  				wg.Add(2 * limit)
   655  				serv.newConn = func(
   656  					_ sendConn,
   657  					runner connRunner,
   658  					_ protocol.ConnectionID,
   659  					_ *protocol.ConnectionID,
   660  					_ protocol.ConnectionID,
   661  					_ protocol.ConnectionID,
   662  					_ protocol.ConnectionID,
   663  					_ ConnectionIDGenerator,
   664  					_ protocol.StatelessResetToken,
   665  					_ *Config,
   666  					_ *tls.Config,
   667  					_ *handshake.TokenGenerator,
   668  					_ bool,
   669  					_ *logging.ConnectionTracer,
   670  					_ uint64,
   671  					_ utils.Logger,
   672  					_ protocol.Version,
   673  				) quicConn {
   674  					conn := <-connChan
   675  					conn.EXPECT().handlePacket(gomock.Any())
   676  					conn.EXPECT().run()
   677  					conn.EXPECT().Context().Return(context.Background())
   678  					conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil })
   679  					return conn
   680  				}
   681  
   682  				// Initiate the maximum number of allowed connection attempts.
   683  				for i := 0; i < limit; i++ {
   684  					conn := NewMockQUICConn(mockCtrl)
   685  					connChan <- conn
   686  					serv.handlePacket(getInitialWithRandomDestConnID())
   687  				}
   688  
   689  				// Now initiate another connection attempt.
   690  				p := getInitialWithRandomDestConnID()
   691  				done := make(chan struct{})
   692  				tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   693  					defer GinkgoRecover()
   694  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   695  				})
   696  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   697  					defer GinkgoRecover()
   698  					defer close(done)
   699  					hdr, _, _, err := wire.ParsePacket(b)
   700  					Expect(err).ToNot(HaveOccurred())
   701  					Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
   702  					return len(b), nil
   703  				})
   704  				serv.handlePacket(p)
   705  				Eventually(done).Should(BeClosed())
   706  
   707  				close(handshakeChan)
   708  				for i := 0; i < limit; i++ {
   709  					_, err := serv.Accept(context.Background())
   710  					Expect(err).ToNot(HaveOccurred())
   711  				}
   712  				for i := 0; i < limit; i++ {
   713  					conn := NewMockQUICConn(mockCtrl)
   714  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
   715  					connChan <- conn
   716  					serv.handlePacket(getInitialWithRandomDestConnID())
   717  				}
   718  				wg.Wait()
   719  			})
   720  
   721  			It("limits the number of total handshakes", func() {
   722  				const limit = 3
   723  				serv.maxNumHandshakesTotal = limit
   724  				serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
   725  
   726  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   727  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   728  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   729  
   730  				handshakeChan := make(chan struct{})
   731  				connChan := make(chan *MockQUICConn, 1)
   732  				serv.newConn = func(
   733  					_ sendConn,
   734  					runner connRunner,
   735  					_ protocol.ConnectionID,
   736  					_ *protocol.ConnectionID,
   737  					_ protocol.ConnectionID,
   738  					_ protocol.ConnectionID,
   739  					_ protocol.ConnectionID,
   740  					_ ConnectionIDGenerator,
   741  					_ protocol.StatelessResetToken,
   742  					_ *Config,
   743  					_ *tls.Config,
   744  					_ *handshake.TokenGenerator,
   745  					_ bool,
   746  					_ *logging.ConnectionTracer,
   747  					_ uint64,
   748  					_ utils.Logger,
   749  					_ protocol.Version,
   750  				) quicConn {
   751  					conn := <-connChan
   752  					conn.EXPECT().handlePacket(gomock.Any())
   753  					conn.EXPECT().run()
   754  					conn.EXPECT().Context().Return(context.Background())
   755  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   756  					return conn
   757  				}
   758  
   759  				for i := 0; i < limit; i++ {
   760  					conn := NewMockQUICConn(mockCtrl)
   761  					connChan <- conn
   762  					serv.handlePacket(getInitialWithRandomDestConnID())
   763  				}
   764  
   765  				p := getInitialWithRandomDestConnID()
   766  				done := make(chan struct{})
   767  				tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   768  					defer GinkgoRecover()
   769  					hdr, _, _, err := wire.ParsePacket(p.data)
   770  					Expect(err).ToNot(HaveOccurred())
   771  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   772  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   773  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   774  					Expect(frames).To(HaveLen(1))
   775  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   776  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   777  					Expect(ccf.IsApplicationError).To(BeFalse())
   778  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
   779  				})
   780  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   781  					defer GinkgoRecover()
   782  					defer close(done)
   783  					hdr, _, _, err := wire.ParsePacket(p.data)
   784  					Expect(err).ToNot(HaveOccurred())
   785  					checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
   786  					return len(b), nil
   787  				})
   788  				serv.handlePacket(p)
   789  				Eventually(done).Should(BeClosed())
   790  
   791  				close(handshakeChan)
   792  				for i := 0; i < limit; i++ {
   793  					_, err := serv.Accept(context.Background())
   794  					Expect(err).ToNot(HaveOccurred())
   795  				}
   796  				// make sure we can enqueue and accept more connections after that
   797  				for i := 0; i < limit; i++ {
   798  					conn := NewMockQUICConn(mockCtrl)
   799  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
   800  					connChan <- conn
   801  					serv.handlePacket(getInitialWithRandomDestConnID())
   802  				}
   803  				for i := 0; i < limit; i++ {
   804  					_, err := serv.Accept(context.Background())
   805  					Expect(err).ToNot(HaveOccurred())
   806  				}
   807  			})
   808  		})
   809  
   810  		Context("token validation", func() {
   811  			It("decodes the token from the token field", func() {
   812  				raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
   813  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   814  				Expect(err).ToNot(HaveOccurred())
   815  				packet := getPacket(&wire.Header{
   816  					Type:    protocol.PacketTypeInitial,
   817  					Token:   token,
   818  					Version: serv.config.Versions[0],
   819  				}, make([]byte, protocol.MinInitialPacketSize))
   820  				packet.remoteAddr = raddr
   821  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
   822  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   823  
   824  				done := make(chan struct{})
   825  				phm.EXPECT().Get(gomock.Any())
   826  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
   827  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
   828  					close(done)
   829  					return true
   830  				})
   831  				phm.EXPECT().Remove(gomock.Any()).AnyTimes()
   832  				serv.handlePacket(packet)
   833  				Eventually(done).Should(BeClosed())
   834  			})
   835  
   836  			It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
   837  				serv.maxNumHandshakesUnvalidated = 0
   838  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   839  				Expect(err).ToNot(HaveOccurred())
   840  				hdr := &wire.Header{
   841  					Type:             protocol.PacketTypeInitial,
   842  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   843  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   844  					Token:            token,
   845  					Version:          protocol.Version1,
   846  				}
   847  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   848  				packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
   849  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   850  				packet.remoteAddr = raddr
   851  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   852  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   853  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   854  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   855  					Expect(frames).To(HaveLen(1))
   856  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   857  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   858  					Expect(ccf.IsApplicationError).To(BeFalse())
   859  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   860  				})
   861  				done := make(chan struct{})
   862  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   863  					defer close(done)
   864  					checkConnectionCloseError(b, hdr, qerr.InvalidToken)
   865  					return len(b), nil
   866  				})
   867  				phm.EXPECT().Get(gomock.Any())
   868  				serv.handlePacket(packet)
   869  				Eventually(done).Should(BeClosed())
   870  			})
   871  
   872  			It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
   873  				serv.maxNumHandshakesUnvalidated = 0
   874  				serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
   875  				Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
   876  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   877  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   878  				Expect(err).ToNot(HaveOccurred())
   879  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   880  				hdr := &wire.Header{
   881  					Type:             protocol.PacketTypeInitial,
   882  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   883  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   884  					Token:            token,
   885  					Version:          protocol.Version1,
   886  				}
   887  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   888  				packet.remoteAddr = raddr
   889  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   890  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   891  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   892  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   893  					Expect(frames).To(HaveLen(1))
   894  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   895  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   896  					Expect(ccf.IsApplicationError).To(BeFalse())
   897  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   898  				})
   899  				done := make(chan struct{})
   900  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   901  					defer close(done)
   902  					checkConnectionCloseError(b, hdr, qerr.InvalidToken)
   903  					return len(b), nil
   904  				})
   905  				phm.EXPECT().Get(gomock.Any())
   906  				serv.handlePacket(packet)
   907  				Eventually(done).Should(BeClosed())
   908  			})
   909  
   910  			It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
   911  				serv.maxNumHandshakesUnvalidated = 0
   912  				token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
   913  				Expect(err).ToNot(HaveOccurred())
   914  				hdr := &wire.Header{
   915  					Type:             protocol.PacketTypeInitial,
   916  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   917  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   918  					Token:            token,
   919  					Version:          protocol.Version1,
   920  				}
   921  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   922  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   923  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   924  				packet.remoteAddr = raddr
   925  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   926  				done := make(chan struct{})
   927  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   928  					defer close(done)
   929  					replyHdr := parseHeader(b)
   930  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   931  					return len(b), nil
   932  				})
   933  				phm.EXPECT().Get(gomock.Any())
   934  				serv.handlePacket(packet)
   935  				// make sure there are no Write calls on the packet conn
   936  				Eventually(done).Should(BeClosed())
   937  			})
   938  
   939  			It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
   940  				serv.maxNumHandshakesUnvalidated = 0
   941  				serv.maxTokenAge = time.Millisecond
   942  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   943  				token, err := serv.tokenGenerator.NewToken(raddr)
   944  				Expect(err).ToNot(HaveOccurred())
   945  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   946  				hdr := &wire.Header{
   947  					Type:             protocol.PacketTypeInitial,
   948  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   949  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   950  					Token:            token,
   951  					Version:          protocol.Version1,
   952  				}
   953  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   954  				packet.remoteAddr = raddr
   955  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   956  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   957  				})
   958  				done := make(chan struct{})
   959  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   960  					defer close(done)
   961  					return len(b), nil
   962  				})
   963  				phm.EXPECT().Get(gomock.Any())
   964  				serv.handlePacket(packet)
   965  				Eventually(done).Should(BeClosed())
   966  			})
   967  
   968  			It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
   969  				serv.maxNumHandshakesUnvalidated = 0
   970  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   971  				Expect(err).ToNot(HaveOccurred())
   972  				hdr := &wire.Header{
   973  					Type:             protocol.PacketTypeInitial,
   974  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   975  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   976  					Token:            token,
   977  					Version:          protocol.Version1,
   978  				}
   979  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   980  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   981  				packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   982  				done := make(chan struct{})
   983  				tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
   984  				phm.EXPECT().Get(gomock.Any())
   985  				serv.handlePacket(packet)
   986  				// make sure there are no Write calls on the packet conn
   987  				time.Sleep(50 * time.Millisecond)
   988  				Eventually(done).Should(BeClosed())
   989  			})
   990  		})
   991  
   992  		Context("accepting connections", func() {
   993  			It("returns Accept when closed", func() {
   994  				done := make(chan struct{})
   995  				go func() {
   996  					defer GinkgoRecover()
   997  					_, err := serv.Accept(context.Background())
   998  					Expect(err).To(MatchError(ErrServerClosed))
   999  					close(done)
  1000  				}()
  1001  
  1002  				serv.Close()
  1003  				Eventually(done).Should(BeClosed())
  1004  			})
  1005  
  1006  			It("returns immediately, if an error occurred before", func() {
  1007  				serv.Close()
  1008  				for i := 0; i < 3; i++ {
  1009  					_, err := serv.Accept(context.Background())
  1010  					Expect(err).To(MatchError(ErrServerClosed))
  1011  				}
  1012  			})
  1013  
  1014  			It("closes connection that are still handshaking after Close", func() {
  1015  				serv.Close()
  1016  
  1017  				destroyed := make(chan struct{})
  1018  				serv.newConn = func(
  1019  					_ sendConn,
  1020  					_ connRunner,
  1021  					_ protocol.ConnectionID,
  1022  					_ *protocol.ConnectionID,
  1023  					_ protocol.ConnectionID,
  1024  					_ protocol.ConnectionID,
  1025  					_ protocol.ConnectionID,
  1026  					_ ConnectionIDGenerator,
  1027  					_ protocol.StatelessResetToken,
  1028  					conf *Config,
  1029  					_ *tls.Config,
  1030  					_ *handshake.TokenGenerator,
  1031  					_ bool,
  1032  					_ *logging.ConnectionTracer,
  1033  					_ uint64,
  1034  					_ utils.Logger,
  1035  					_ protocol.Version,
  1036  				) quicConn {
  1037  					conn := NewMockQUICConn(mockCtrl)
  1038  					conn.EXPECT().handlePacket(gomock.Any())
  1039  					conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
  1040  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
  1041  					conn.EXPECT().run().MaxTimes(1)
  1042  					conn.EXPECT().Context().Return(context.Background())
  1043  					return conn
  1044  				}
  1045  				phm.EXPECT().Get(gomock.Any())
  1046  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1047  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1048  				serv.handleInitialImpl(
  1049  					receivedPacket{buffer: getPacketBuffer()},
  1050  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1051  				)
  1052  				Eventually(destroyed).Should(BeClosed())
  1053  			})
  1054  
  1055  			It("returns when the context is canceled", func() {
  1056  				ctx, cancel := context.WithCancel(context.Background())
  1057  				done := make(chan struct{})
  1058  				go func() {
  1059  					defer GinkgoRecover()
  1060  					_, err := serv.Accept(ctx)
  1061  					Expect(err).To(MatchError("context canceled"))
  1062  					close(done)
  1063  				}()
  1064  
  1065  				Consistently(done).ShouldNot(BeClosed())
  1066  				cancel()
  1067  				Eventually(done).Should(BeClosed())
  1068  			})
  1069  
  1070  			It("uses the config returned by GetConfigClient", func() {
  1071  				conn := NewMockQUICConn(mockCtrl)
  1072  
  1073  				conf := &Config{MaxIncomingStreams: 1234}
  1074  				serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
  1075  				done := make(chan struct{})
  1076  				go func() {
  1077  					defer GinkgoRecover()
  1078  					s, err := serv.Accept(context.Background())
  1079  					Expect(err).ToNot(HaveOccurred())
  1080  					Expect(s).To(Equal(conn))
  1081  					close(done)
  1082  				}()
  1083  
  1084  				handshakeChan := make(chan struct{})
  1085  				serv.newConn = func(
  1086  					_ sendConn,
  1087  					_ connRunner,
  1088  					_ protocol.ConnectionID,
  1089  					_ *protocol.ConnectionID,
  1090  					_ protocol.ConnectionID,
  1091  					_ protocol.ConnectionID,
  1092  					_ protocol.ConnectionID,
  1093  					_ ConnectionIDGenerator,
  1094  					_ protocol.StatelessResetToken,
  1095  					conf *Config,
  1096  					_ *tls.Config,
  1097  					_ *handshake.TokenGenerator,
  1098  					_ bool,
  1099  					_ *logging.ConnectionTracer,
  1100  					_ uint64,
  1101  					_ utils.Logger,
  1102  					_ protocol.Version,
  1103  				) quicConn {
  1104  					Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
  1105  					conn.EXPECT().handlePacket(gomock.Any())
  1106  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1107  					conn.EXPECT().run()
  1108  					conn.EXPECT().Context().Return(context.Background())
  1109  					return conn
  1110  				}
  1111  				phm.EXPECT().Get(gomock.Any())
  1112  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1113  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1114  				serv.handleInitialImpl(
  1115  					receivedPacket{buffer: getPacketBuffer()},
  1116  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1117  				)
  1118  				Consistently(done).ShouldNot(BeClosed())
  1119  				close(handshakeChan) // complete the handshake
  1120  				Eventually(done).Should(BeClosed())
  1121  			})
  1122  
  1123  			It("rejects a connection attempt when GetConfigClient returns an error", func() {
  1124  				serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
  1125  
  1126  				phm.EXPECT().Get(gomock.Any())
  1127  				done := make(chan struct{})
  1128  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
  1129  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1130  					defer close(done)
  1131  					rejectHdr := parseHeader(b)
  1132  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1133  					return len(b), nil
  1134  				})
  1135  				serv.handleInitialImpl(
  1136  					receivedPacket{buffer: getPacketBuffer()},
  1137  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
  1138  				)
  1139  				Eventually(done).Should(BeClosed())
  1140  			})
  1141  
  1142  			It("accepts new connections when the handshake completes", func() {
  1143  				conn := NewMockQUICConn(mockCtrl)
  1144  
  1145  				done := make(chan struct{})
  1146  				go func() {
  1147  					defer GinkgoRecover()
  1148  					s, err := serv.Accept(context.Background())
  1149  					Expect(err).ToNot(HaveOccurred())
  1150  					Expect(s).To(Equal(conn))
  1151  					close(done)
  1152  				}()
  1153  
  1154  				handshakeChan := make(chan struct{})
  1155  				serv.newConn = func(
  1156  					_ sendConn,
  1157  					runner connRunner,
  1158  					_ protocol.ConnectionID,
  1159  					_ *protocol.ConnectionID,
  1160  					_ protocol.ConnectionID,
  1161  					_ protocol.ConnectionID,
  1162  					_ protocol.ConnectionID,
  1163  					_ ConnectionIDGenerator,
  1164  					_ protocol.StatelessResetToken,
  1165  					_ *Config,
  1166  					_ *tls.Config,
  1167  					_ *handshake.TokenGenerator,
  1168  					_ bool,
  1169  					_ *logging.ConnectionTracer,
  1170  					_ uint64,
  1171  					_ utils.Logger,
  1172  					_ protocol.Version,
  1173  				) quicConn {
  1174  					conn.EXPECT().handlePacket(gomock.Any())
  1175  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1176  					conn.EXPECT().run()
  1177  					conn.EXPECT().Context().Return(context.Background())
  1178  					return conn
  1179  				}
  1180  				phm.EXPECT().Get(gomock.Any())
  1181  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1182  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1183  				serv.handleInitialImpl(
  1184  					receivedPacket{buffer: getPacketBuffer()},
  1185  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1186  				)
  1187  				Consistently(done).ShouldNot(BeClosed())
  1188  				close(handshakeChan) // complete the handshake
  1189  				Eventually(done).Should(BeClosed())
  1190  			})
  1191  		})
  1192  	})
  1193  
  1194  	Context("server accepting connections that haven't completed the handshake", func() {
  1195  		var (
  1196  			serv *EarlyListener
  1197  			phm  *MockPacketHandlerManager
  1198  		)
  1199  
  1200  		BeforeEach(func() {
  1201  			var err error
  1202  			serv, err = ListenEarly(conn, tlsConf, nil)
  1203  			Expect(err).ToNot(HaveOccurred())
  1204  			phm = NewMockPacketHandlerManager(mockCtrl)
  1205  			serv.baseServer.connHandler = phm
  1206  		})
  1207  
  1208  		AfterEach(func() {
  1209  			serv.Close()
  1210  		})
  1211  
  1212  		It("accepts new connections when they become ready", func() {
  1213  			conn := NewMockQUICConn(mockCtrl)
  1214  
  1215  			done := make(chan struct{})
  1216  			go func() {
  1217  				defer GinkgoRecover()
  1218  				s, err := serv.Accept(context.Background())
  1219  				Expect(err).ToNot(HaveOccurred())
  1220  				Expect(s).To(Equal(conn))
  1221  				close(done)
  1222  			}()
  1223  
  1224  			ready := make(chan struct{})
  1225  			serv.baseServer.newConn = func(
  1226  				_ sendConn,
  1227  				runner connRunner,
  1228  				_ protocol.ConnectionID,
  1229  				_ *protocol.ConnectionID,
  1230  				_ protocol.ConnectionID,
  1231  				_ protocol.ConnectionID,
  1232  				_ protocol.ConnectionID,
  1233  				_ ConnectionIDGenerator,
  1234  				_ protocol.StatelessResetToken,
  1235  				_ *Config,
  1236  				_ *tls.Config,
  1237  				_ *handshake.TokenGenerator,
  1238  				_ bool,
  1239  				_ *logging.ConnectionTracer,
  1240  				_ uint64,
  1241  				_ utils.Logger,
  1242  				_ protocol.Version,
  1243  			) quicConn {
  1244  				conn.EXPECT().handlePacket(gomock.Any())
  1245  				conn.EXPECT().run()
  1246  				conn.EXPECT().earlyConnReady().Return(ready)
  1247  				conn.EXPECT().Context().Return(context.Background())
  1248  				return conn
  1249  			}
  1250  			phm.EXPECT().Get(gomock.Any())
  1251  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1252  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1253  			serv.baseServer.handleInitialImpl(
  1254  				receivedPacket{buffer: getPacketBuffer()},
  1255  				&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1256  			)
  1257  			Consistently(done).ShouldNot(BeClosed())
  1258  			close(ready)
  1259  			Eventually(done).Should(BeClosed())
  1260  		})
  1261  
  1262  		It("rejects new connection attempts if the accept queue is full", func() {
  1263  			connChan := make(chan *MockQUICConn, 1)
  1264  			var wg sync.WaitGroup // to make sure the test fully completes
  1265  			wg.Add(protocol.MaxAcceptQueueSize + 1)
  1266  			serv.baseServer.newConn = func(
  1267  				_ sendConn,
  1268  				runner connRunner,
  1269  				_ protocol.ConnectionID,
  1270  				_ *protocol.ConnectionID,
  1271  				_ protocol.ConnectionID,
  1272  				_ protocol.ConnectionID,
  1273  				_ protocol.ConnectionID,
  1274  				_ ConnectionIDGenerator,
  1275  				_ protocol.StatelessResetToken,
  1276  				_ *Config,
  1277  				_ *tls.Config,
  1278  				_ *handshake.TokenGenerator,
  1279  				_ bool,
  1280  				_ *logging.ConnectionTracer,
  1281  				_ uint64,
  1282  				_ utils.Logger,
  1283  				_ protocol.Version,
  1284  			) quicConn {
  1285  				defer wg.Done()
  1286  				ready := make(chan struct{})
  1287  				close(ready)
  1288  				conn := <-connChan
  1289  				conn.EXPECT().handlePacket(gomock.Any())
  1290  				conn.EXPECT().run()
  1291  				conn.EXPECT().earlyConnReady().Return(ready)
  1292  				conn.EXPECT().Context().Return(context.Background())
  1293  				return conn
  1294  			}
  1295  
  1296  			phm.EXPECT().Get(gomock.Any()).AnyTimes()
  1297  			phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
  1298  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
  1299  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
  1300  				conn := NewMockQUICConn(mockCtrl)
  1301  				connChan <- conn
  1302  				serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1303  			}
  1304  
  1305  			Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
  1306  
  1307  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1308  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1309  			conn := NewMockQUICConn(mockCtrl)
  1310  			conn.EXPECT().closeWithTransportError(ConnectionRefused)
  1311  			connChan <- conn
  1312  			serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1313  			wg.Wait()
  1314  		})
  1315  
  1316  		It("doesn't accept new connections if they were closed in the mean time", func() {
  1317  			p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
  1318  			ctx, cancel := context.WithCancel(context.Background())
  1319  			connCreated := make(chan struct{})
  1320  			conn := NewMockQUICConn(mockCtrl)
  1321  			serv.baseServer.newConn = func(
  1322  				_ sendConn,
  1323  				runner connRunner,
  1324  				_ protocol.ConnectionID,
  1325  				_ *protocol.ConnectionID,
  1326  				_ protocol.ConnectionID,
  1327  				_ protocol.ConnectionID,
  1328  				_ protocol.ConnectionID,
  1329  				_ ConnectionIDGenerator,
  1330  				_ protocol.StatelessResetToken,
  1331  				_ *Config,
  1332  				_ *tls.Config,
  1333  				_ *handshake.TokenGenerator,
  1334  				_ bool,
  1335  				_ *logging.ConnectionTracer,
  1336  				_ uint64,
  1337  				_ utils.Logger,
  1338  				_ protocol.Version,
  1339  			) quicConn {
  1340  				conn.EXPECT().handlePacket(p)
  1341  				conn.EXPECT().run()
  1342  				conn.EXPECT().earlyConnReady()
  1343  				conn.EXPECT().Context().Return(ctx)
  1344  				close(connCreated)
  1345  				return conn
  1346  			}
  1347  
  1348  			phm.EXPECT().Get(gomock.Any())
  1349  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1350  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1351  			serv.baseServer.handlePacket(p)
  1352  			// make sure there are no Write calls on the packet conn
  1353  			time.Sleep(50 * time.Millisecond)
  1354  			Eventually(connCreated).Should(BeClosed())
  1355  			cancel()
  1356  			time.Sleep(scaleDuration(200 * time.Millisecond))
  1357  
  1358  			done := make(chan struct{})
  1359  			go func() {
  1360  				defer GinkgoRecover()
  1361  				serv.Accept(context.Background())
  1362  				close(done)
  1363  			}()
  1364  			Consistently(done).ShouldNot(BeClosed())
  1365  
  1366  			// make the go routine return
  1367  			Expect(serv.Close()).To(Succeed())
  1368  			Eventually(done).Should(BeClosed())
  1369  		})
  1370  	})
  1371  
  1372  	Context("0-RTT", func() {
  1373  		var (
  1374  			tr     *Transport
  1375  			serv   *baseServer
  1376  			phm    *MockPacketHandlerManager
  1377  			tracer *mocklogging.MockTracer
  1378  		)
  1379  
  1380  		BeforeEach(func() {
  1381  			var t *logging.Tracer
  1382  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
  1383  			tr = &Transport{Conn: conn, Tracer: t}
  1384  			ln, err := tr.ListenEarly(tlsConf, nil)
  1385  			Expect(err).ToNot(HaveOccurred())
  1386  			phm = NewMockPacketHandlerManager(mockCtrl)
  1387  			serv = ln.baseServer
  1388  			serv.connHandler = phm
  1389  		})
  1390  
  1391  		AfterEach(func() {
  1392  			tracer.EXPECT().Close()
  1393  			Expect(tr.Close()).To(Succeed())
  1394  		})
  1395  
  1396  		It("passes packets to existing connections", func() {
  1397  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1398  			p := getPacket(&wire.Header{
  1399  				Type:             protocol.PacketType0RTT,
  1400  				DestConnectionID: connID,
  1401  				Version:          serv.config.Versions[0],
  1402  			}, make([]byte, 100))
  1403  			conn := NewMockPacketHandler(mockCtrl)
  1404  			phm.EXPECT().Get(connID).Return(conn, true)
  1405  			handled := make(chan struct{})
  1406  			conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
  1407  			serv.handlePacket(p)
  1408  			Eventually(handled).Should(BeClosed())
  1409  		})
  1410  
  1411  		It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
  1412  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1413  
  1414  			var zeroRTTPackets []receivedPacket
  1415  
  1416  			for i := 0; i < protocol.Max0RTTQueueLen; i++ {
  1417  				p := getPacket(&wire.Header{
  1418  					Type:             protocol.PacketType0RTT,
  1419  					DestConnectionID: connID,
  1420  					Version:          serv.config.Versions[0],
  1421  				}, make([]byte, 100+i))
  1422  				phm.EXPECT().Get(connID)
  1423  				serv.handlePacket(p)
  1424  				zeroRTTPackets = append(zeroRTTPackets, p)
  1425  			}
  1426  
  1427  			// send one more packet, this one should be dropped
  1428  			p := getPacket(&wire.Header{
  1429  				Type:             protocol.PacketType0RTT,
  1430  				DestConnectionID: connID,
  1431  				Version:          serv.config.Versions[0],
  1432  			}, make([]byte, 200))
  1433  			phm.EXPECT().Get(connID)
  1434  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
  1435  			serv.handlePacket(p)
  1436  
  1437  			initial := getPacket(&wire.Header{
  1438  				Type:             protocol.PacketTypeInitial,
  1439  				DestConnectionID: connID,
  1440  				Version:          serv.config.Versions[0],
  1441  			}, make([]byte, protocol.MinInitialPacketSize))
  1442  			called := make(chan struct{})
  1443  			serv.newConn = func(
  1444  				_ sendConn,
  1445  				_ connRunner,
  1446  				_ protocol.ConnectionID,
  1447  				_ *protocol.ConnectionID,
  1448  				_ protocol.ConnectionID,
  1449  				_ protocol.ConnectionID,
  1450  				_ protocol.ConnectionID,
  1451  				_ ConnectionIDGenerator,
  1452  				_ protocol.StatelessResetToken,
  1453  				_ *Config,
  1454  				_ *tls.Config,
  1455  				_ *handshake.TokenGenerator,
  1456  				_ bool,
  1457  				_ *logging.ConnectionTracer,
  1458  				_ uint64,
  1459  				_ utils.Logger,
  1460  				_ protocol.Version,
  1461  			) quicConn {
  1462  				conn := NewMockQUICConn(mockCtrl)
  1463  				var calls []any
  1464  				calls = append(calls, conn.EXPECT().handlePacket(initial))
  1465  				for _, p := range zeroRTTPackets {
  1466  					calls = append(calls, conn.EXPECT().handlePacket(p))
  1467  				}
  1468  				gomock.InOrder(calls...)
  1469  				conn.EXPECT().run()
  1470  				conn.EXPECT().earlyConnReady()
  1471  				conn.EXPECT().Context().Return(context.Background())
  1472  				close(called)
  1473  				// shutdown
  1474  				conn.EXPECT().closeWithTransportError(gomock.Any())
  1475  				return conn
  1476  			}
  1477  
  1478  			phm.EXPECT().Get(connID)
  1479  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1480  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1481  			serv.handlePacket(initial)
  1482  			Eventually(called).Should(BeClosed())
  1483  		})
  1484  
  1485  		It("limits the number of queues", func() {
  1486  			for i := 0; i < protocol.Max0RTTQueues; i++ {
  1487  				b := make([]byte, 16)
  1488  				rand.Read(b)
  1489  				connID := protocol.ParseConnectionID(b)
  1490  				p := getPacket(&wire.Header{
  1491  					Type:             protocol.PacketType0RTT,
  1492  					DestConnectionID: connID,
  1493  					Version:          serv.config.Versions[0],
  1494  				}, make([]byte, 100+i))
  1495  				phm.EXPECT().Get(connID)
  1496  				serv.handlePacket(p)
  1497  			}
  1498  
  1499  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1500  			p := getPacket(&wire.Header{
  1501  				Type:             protocol.PacketType0RTT,
  1502  				DestConnectionID: connID,
  1503  				Version:          serv.config.Versions[0],
  1504  			}, make([]byte, 200))
  1505  			phm.EXPECT().Get(connID)
  1506  			dropped := make(chan struct{})
  1507  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1508  				close(dropped)
  1509  			})
  1510  			serv.handlePacket(p)
  1511  			Eventually(dropped).Should(BeClosed())
  1512  		})
  1513  
  1514  		It("drops queues after a while", func() {
  1515  			now := time.Now()
  1516  
  1517  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1518  			p := getPacket(&wire.Header{
  1519  				Type:             protocol.PacketType0RTT,
  1520  				DestConnectionID: connID,
  1521  				Version:          serv.config.Versions[0],
  1522  			}, make([]byte, 200))
  1523  			p.rcvTime = now
  1524  
  1525  			connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
  1526  			p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
  1527  			p2 := getPacket(&wire.Header{
  1528  				Type:             protocol.PacketType0RTT,
  1529  				DestConnectionID: connID2,
  1530  				Version:          serv.config.Versions[0],
  1531  			}, make([]byte, 300))
  1532  			p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
  1533  
  1534  			dropped1 := make(chan struct{})
  1535  			dropped2 := make(chan struct{})
  1536  			// need to register the call before handling the packet to avoid race condition
  1537  			gomock.InOrder(
  1538  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1539  					close(dropped1)
  1540  				}),
  1541  				tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1542  					close(dropped2)
  1543  				}),
  1544  			)
  1545  
  1546  			phm.EXPECT().Get(connID)
  1547  			serv.handlePacket(p)
  1548  
  1549  			// There's no cleanup Go routine.
  1550  			// Cleanup is triggered when new packets are received.
  1551  
  1552  			phm.EXPECT().Get(connID2)
  1553  			serv.handlePacket(p2)
  1554  			// make sure no cleanup is executed
  1555  			Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
  1556  
  1557  			// There's no cleanup Go routine.
  1558  			// Cleanup is triggered when new packets are received.
  1559  			connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
  1560  			p3 := getPacket(&wire.Header{
  1561  				Type:             protocol.PacketType0RTT,
  1562  				DestConnectionID: connID3,
  1563  				Version:          serv.config.Versions[0],
  1564  			}, make([]byte, 200))
  1565  			p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1566  			phm.EXPECT().Get(connID3)
  1567  			serv.handlePacket(p3)
  1568  			Eventually(dropped1).Should(BeClosed())
  1569  			Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
  1570  
  1571  			// make sure the second packet is also cleaned up
  1572  			connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
  1573  			p4 := getPacket(&wire.Header{
  1574  				Type:             protocol.PacketType0RTT,
  1575  				DestConnectionID: connID4,
  1576  				Version:          serv.config.Versions[0],
  1577  			}, make([]byte, 200))
  1578  			p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1579  			phm.EXPECT().Get(connID4)
  1580  			serv.handlePacket(p4)
  1581  			Eventually(dropped2).Should(BeClosed())
  1582  		})
  1583  	})
  1584  })