github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/server_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"golang.org/x/time/rate"
    14  
    15  	"github.com/daeuniverse/quic-go/internal/handshake"
    16  	mocklogging "github.com/daeuniverse/quic-go/internal/mocks/logging"
    17  	"github.com/daeuniverse/quic-go/internal/protocol"
    18  	"github.com/daeuniverse/quic-go/internal/qerr"
    19  	"github.com/daeuniverse/quic-go/internal/testdata"
    20  	"github.com/daeuniverse/quic-go/internal/utils"
    21  	"github.com/daeuniverse/quic-go/internal/wire"
    22  	"github.com/daeuniverse/quic-go/logging"
    23  
    24  	. "github.com/onsi/ginkgo/v2"
    25  	. "github.com/onsi/gomega"
    26  	"go.uber.org/mock/gomock"
    27  )
    28  
    29  var _ = Describe("Server", func() {
    30  	var (
    31  		conn    *MockPacketConn
    32  		tlsConf *tls.Config
    33  	)
    34  
    35  	getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
    36  		buf := getPacketBuffer()
    37  		hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
    38  		var err error
    39  		buf.Data, err = (&wire.ExtendedHeader{
    40  			Header:          *hdr,
    41  			PacketNumber:    0x42,
    42  			PacketNumberLen: protocol.PacketNumberLen4,
    43  		}).Append(buf.Data, protocol.Version1)
    44  		Expect(err).ToNot(HaveOccurred())
    45  		n := len(buf.Data)
    46  		buf.Data = append(buf.Data, p...)
    47  		data := buf.Data
    48  		sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
    49  		_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
    50  		data = data[:len(data)+16]
    51  		sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
    52  		return receivedPacket{
    53  			rcvTime:    time.Now(),
    54  			remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
    55  			data:       data,
    56  			buffer:     buf,
    57  		}
    58  	}
    59  
    60  	getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
    61  		senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
    62  		hdr := &wire.Header{
    63  			Type:             protocol.PacketTypeInitial,
    64  			SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
    65  			DestConnectionID: destConnID,
    66  			Version:          protocol.Version1,
    67  		}
    68  		p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
    69  		p.buffer = getPacketBuffer()
    70  		p.remoteAddr = senderAddr
    71  		return p
    72  	}
    73  
    74  	getInitialWithRandomDestConnID := func() receivedPacket {
    75  		b := make([]byte, 10)
    76  		_, err := rand.Read(b)
    77  		Expect(err).ToNot(HaveOccurred())
    78  
    79  		return getInitial(protocol.ParseConnectionID(b))
    80  	}
    81  
    82  	parseHeader := func(data []byte) *wire.Header {
    83  		hdr, _, _, err := wire.ParsePacket(data)
    84  		Expect(err).ToNot(HaveOccurred())
    85  		return hdr
    86  	}
    87  
    88  	checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
    89  		replyHdr := parseHeader(b)
    90  		Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
    91  		Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
    92  		Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
    93  		_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
    94  		extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
    95  		Expect(err).ToNot(HaveOccurred())
    96  		data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
    97  		Expect(err).ToNot(HaveOccurred())
    98  		_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
    99  		Expect(err).ToNot(HaveOccurred())
   100  		Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   101  		ccf := f.(*wire.ConnectionCloseFrame)
   102  		Expect(ccf.IsApplicationError).To(BeFalse())
   103  		Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
   104  		Expect(ccf.ReasonPhrase).To(BeEmpty())
   105  	}
   106  
   107  	BeforeEach(func() {
   108  		conn = NewMockPacketConn(mockCtrl)
   109  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
   110  		wait := make(chan struct{})
   111  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
   112  			<-wait
   113  			return 0, nil, errors.New("done")
   114  		}).MaxTimes(1)
   115  		conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error {
   116  			close(wait)
   117  			conn.EXPECT().SetReadDeadline(time.Time{})
   118  			return nil
   119  		}).MaxTimes(1)
   120  		tlsConf = testdata.GetTLSConfig()
   121  		tlsConf.NextProtos = []string{"proto1"}
   122  	})
   123  
   124  	It("errors when no tls.Config is given", func() {
   125  		_, err := ListenAddr("localhost:0", nil, nil)
   126  		Expect(err).To(HaveOccurred())
   127  		Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
   128  	})
   129  
   130  	It("errors when the Config contains an invalid version", func() {
   131  		version := protocol.Version(0x1234)
   132  		_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}})
   133  		Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
   134  	})
   135  
   136  	It("fills in default values if options are not set in the Config", func() {
   137  		ln, err := Listen(conn, tlsConf, &Config{})
   138  		Expect(err).ToNot(HaveOccurred())
   139  		server := ln.baseServer
   140  		Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
   141  		Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   142  		Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
   143  		Expect(server.config.KeepAlivePeriod).To(BeZero())
   144  		// stop the listener
   145  		Expect(ln.Close()).To(Succeed())
   146  	})
   147  
   148  	It("setups with the right values", func() {
   149  		supportedVersions := []protocol.Version{protocol.Version1}
   150  		config := Config{
   151  			Versions:             supportedVersions,
   152  			HandshakeIdleTimeout: 1337 * time.Hour,
   153  			MaxIdleTimeout:       42 * time.Minute,
   154  			KeepAlivePeriod:      5 * time.Second,
   155  		}
   156  		ln, err := Listen(conn, tlsConf, &config)
   157  		Expect(err).ToNot(HaveOccurred())
   158  		server := ln.baseServer
   159  		Expect(server.connHandler).ToNot(BeNil())
   160  		Expect(server.config.Versions).To(Equal(supportedVersions))
   161  		Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
   162  		Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
   163  		Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
   164  		// stop the listener
   165  		Expect(ln.Close()).To(Succeed())
   166  	})
   167  
   168  	It("listens on a given address", func() {
   169  		addr := "127.0.0.1:13579"
   170  		ln, err := ListenAddr(addr, tlsConf, &Config{})
   171  		Expect(err).ToNot(HaveOccurred())
   172  		Expect(ln.Addr().String()).To(Equal(addr))
   173  		// stop the listener
   174  		Expect(ln.Close()).To(Succeed())
   175  	})
   176  
   177  	It("errors if given an invalid address", func() {
   178  		addr := "127.0.0.1"
   179  		_, err := ListenAddr(addr, tlsConf, &Config{})
   180  		Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
   181  	})
   182  
   183  	It("errors if given an invalid address", func() {
   184  		addr := "1.1.1.1:1111"
   185  		_, err := ListenAddr(addr, tlsConf, &Config{})
   186  		Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
   187  	})
   188  
   189  	Context("server accepting connections that completed the handshake", func() {
   190  		var (
   191  			tr     *Transport
   192  			serv   *baseServer
   193  			phm    *MockPacketHandlerManager
   194  			tracer *mocklogging.MockTracer
   195  		)
   196  
   197  		BeforeEach(func() {
   198  			var t *logging.Tracer
   199  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
   200  			tr = &Transport{Conn: conn, Tracer: t}
   201  			ln, err := tr.Listen(tlsConf, nil)
   202  			Expect(err).ToNot(HaveOccurred())
   203  			serv = ln.baseServer
   204  			phm = NewMockPacketHandlerManager(mockCtrl)
   205  			serv.connHandler = phm
   206  		})
   207  
   208  		AfterEach(func() {
   209  			tracer.EXPECT().Close()
   210  			tr.Close()
   211  		})
   212  
   213  		Context("handling packets", func() {
   214  			It("drops Initial packets with a too short connection ID", func() {
   215  				p := getPacket(&wire.Header{
   216  					Type:             protocol.PacketTypeInitial,
   217  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
   218  					Version:          serv.config.Versions[0],
   219  				}, nil)
   220  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   221  				serv.handlePacket(p)
   222  				// make sure there are no Write calls on the packet conn
   223  				time.Sleep(50 * time.Millisecond)
   224  			})
   225  
   226  			It("drops too small Initial", func() {
   227  				p := getPacket(&wire.Header{
   228  					Type:             protocol.PacketTypeInitial,
   229  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
   230  					Version:          serv.config.Versions[0],
   231  				}, make([]byte, protocol.MinInitialPacketSize-100))
   232  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   233  				serv.handlePacket(p)
   234  				// make sure there are no Write calls on the packet conn
   235  				time.Sleep(50 * time.Millisecond)
   236  			})
   237  
   238  			It("drops non-Initial packets", func() {
   239  				p := getPacket(&wire.Header{
   240  					Type:    protocol.PacketTypeHandshake,
   241  					Version: serv.config.Versions[0],
   242  				}, []byte("invalid"))
   243  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
   244  				serv.handlePacket(p)
   245  				// make sure there are no Write calls on the packet conn
   246  				time.Sleep(50 * time.Millisecond)
   247  			})
   248  
   249  			It("passes packets to existing connections", func() {
   250  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
   251  				p := getPacket(&wire.Header{
   252  					Type:             protocol.PacketTypeInitial,
   253  					DestConnectionID: connID,
   254  					Version:          serv.config.Versions[0],
   255  				}, make([]byte, protocol.MinInitialPacketSize))
   256  				conn := NewMockPacketHandler(mockCtrl)
   257  				phm.EXPECT().Get(connID).Return(conn, true)
   258  				handled := make(chan struct{})
   259  				conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
   260  				serv.handlePacket(p)
   261  				Eventually(handled).Should(BeClosed())
   262  			})
   263  
   264  			It("creates a connection when the token is accepted", func() {
   265  				serv.verifySourceAddress = func(net.Addr) bool { return true }
   266  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   267  				retryToken, err := serv.tokenGenerator.NewRetryToken(
   268  					raddr,
   269  					protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
   270  					protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
   271  				)
   272  				Expect(err).ToNot(HaveOccurred())
   273  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   274  				hdr := &wire.Header{
   275  					Type:             protocol.PacketTypeInitial,
   276  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   277  					DestConnectionID: connID,
   278  					Version:          protocol.Version1,
   279  					Token:            retryToken,
   280  				}
   281  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   282  				p.remoteAddr = raddr
   283  				run := make(chan struct{})
   284  				var token protocol.StatelessResetToken
   285  				rand.Read(token[:])
   286  
   287  				var newConnID protocol.ConnectionID
   288  				conn := NewMockQUICConn(mockCtrl)
   289  				serv.newConn = func(
   290  					_ sendConn,
   291  					_ connRunner,
   292  					origDestConnID protocol.ConnectionID,
   293  					retrySrcConnID *protocol.ConnectionID,
   294  					clientDestConnID protocol.ConnectionID,
   295  					destConnID protocol.ConnectionID,
   296  					srcConnID protocol.ConnectionID,
   297  					_ ConnectionIDGenerator,
   298  					tokenP protocol.StatelessResetToken,
   299  					_ *Config,
   300  					_ *tls.Config,
   301  					_ *handshake.TokenGenerator,
   302  					_ bool,
   303  					_ *logging.ConnectionTracer,
   304  					_ uint64,
   305  					_ utils.Logger,
   306  					_ protocol.Version,
   307  				) quicConn {
   308  					Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
   309  					Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
   310  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   311  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   312  					// make sure we're using a server-generated connection ID
   313  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   314  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   315  					newConnID = srcConnID
   316  					Expect(tokenP).To(Equal(token))
   317  					conn.EXPECT().handlePacket(p)
   318  					conn.EXPECT().run().Do(func() error { close(run); return nil })
   319  					conn.EXPECT().Context().Return(context.Background())
   320  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   321  					return conn
   322  				}
   323  				phm.EXPECT().Get(connID)
   324  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
   325  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
   326  					Expect(cid).To(Equal(newConnID))
   327  					return true
   328  				})
   329  
   330  				done := make(chan struct{})
   331  				go func() {
   332  					defer GinkgoRecover()
   333  					serv.handlePacket(p)
   334  					// the Handshake packet is written by the connection.
   335  					// Make sure there are no Write calls on the packet conn.
   336  					time.Sleep(50 * time.Millisecond)
   337  					close(done)
   338  				}()
   339  				// make sure we're using a server-generated connection ID
   340  				Eventually(run).Should(BeClosed())
   341  				Eventually(done).Should(BeClosed())
   342  				// shutdown
   343  				conn.EXPECT().closeWithTransportError(gomock.Any())
   344  			})
   345  
   346  			It("sends a Version Negotiation Packet for unsupported versions", func() {
   347  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   348  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   349  				packet := getPacket(&wire.Header{
   350  					Type:             protocol.PacketTypeHandshake,
   351  					SrcConnectionID:  srcConnID,
   352  					DestConnectionID: destConnID,
   353  					Version:          0x42,
   354  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   355  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   356  				packet.remoteAddr = raddr
   357  				tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) {
   358  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   359  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   360  				})
   361  				done := make(chan struct{})
   362  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   363  					defer close(done)
   364  					Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
   365  					dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
   366  					Expect(err).ToNot(HaveOccurred())
   367  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   368  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   369  					Expect(versions).ToNot(ContainElement(protocol.Version(0x42)))
   370  					return len(b), nil
   371  				})
   372  				serv.handlePacket(packet)
   373  				Eventually(done).Should(BeClosed())
   374  			})
   375  
   376  			It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
   377  				serv.disableVersionNegotiation = true
   378  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   379  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   380  				packet := getPacket(&wire.Header{
   381  					Type:             protocol.PacketTypeHandshake,
   382  					SrcConnectionID:  srcConnID,
   383  					DestConnectionID: destConnID,
   384  					Version:          0x42,
   385  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   386  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   387  				packet.remoteAddr = raddr
   388  				done := make(chan struct{})
   389  				serv.handlePacket(packet)
   390  				Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
   391  			})
   392  
   393  			It("ignores Version Negotiation packets", func() {
   394  				data := wire.ComposeVersionNegotiation(
   395  					protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
   396  					protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
   397  					[]protocol.Version{1, 2, 3},
   398  				)
   399  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   400  				done := make(chan struct{})
   401  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   402  					close(done)
   403  				})
   404  				serv.handlePacket(receivedPacket{
   405  					remoteAddr: raddr,
   406  					data:       data,
   407  					buffer:     getPacketBuffer(),
   408  				})
   409  				Eventually(done).Should(BeClosed())
   410  				// make sure no other packet is sent
   411  				time.Sleep(scaleDuration(20 * time.Millisecond))
   412  			})
   413  
   414  			It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
   415  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   416  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   417  				p := getPacket(&wire.Header{
   418  					Type:             protocol.PacketTypeHandshake,
   419  					SrcConnectionID:  srcConnID,
   420  					DestConnectionID: destConnID,
   421  					Version:          0x42,
   422  				}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
   423  				Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
   424  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   425  				p.remoteAddr = raddr
   426  				done := make(chan struct{})
   427  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   428  					close(done)
   429  				})
   430  				serv.handlePacket(p)
   431  				Eventually(done).Should(BeClosed())
   432  				// make sure no other packet is sent
   433  				time.Sleep(scaleDuration(20 * time.Millisecond))
   434  			})
   435  
   436  			It("replies with a Retry packet, if a token is required", func() {
   437  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   438  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   439  				var called bool
   440  				serv.verifySourceAddress = func(addr net.Addr) bool {
   441  					Expect(addr).To(Equal(raddr))
   442  					called = true
   443  					return true
   444  				}
   445  				hdr := &wire.Header{
   446  					Type:             protocol.PacketTypeInitial,
   447  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   448  					DestConnectionID: connID,
   449  					Version:          protocol.Version1,
   450  				}
   451  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   452  				packet.remoteAddr = raddr
   453  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
   454  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   455  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   456  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   457  					Expect(replyHdr.Token).ToNot(BeEmpty())
   458  				})
   459  				done := make(chan struct{})
   460  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   461  					defer close(done)
   462  					replyHdr := parseHeader(b)
   463  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   464  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   465  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   466  					Expect(replyHdr.Token).ToNot(BeEmpty())
   467  					Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
   468  					return len(b), nil
   469  				})
   470  				phm.EXPECT().Get(connID)
   471  				serv.handlePacket(packet)
   472  				Eventually(done).Should(BeClosed())
   473  				Expect(called).To(BeTrue())
   474  			})
   475  
   476  			It("creates a connection, if no token is required", func() {
   477  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   478  				hdr := &wire.Header{
   479  					Type:             protocol.PacketTypeInitial,
   480  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   481  					DestConnectionID: connID,
   482  					Version:          protocol.Version1,
   483  				}
   484  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   485  				run := make(chan struct{})
   486  				var token protocol.StatelessResetToken
   487  				rand.Read(token[:])
   488  
   489  				var newConnID protocol.ConnectionID
   490  				conn := NewMockQUICConn(mockCtrl)
   491  				serv.newConn = func(
   492  					_ sendConn,
   493  					_ connRunner,
   494  					origDestConnID protocol.ConnectionID,
   495  					retrySrcConnID *protocol.ConnectionID,
   496  					clientDestConnID protocol.ConnectionID,
   497  					destConnID protocol.ConnectionID,
   498  					srcConnID protocol.ConnectionID,
   499  					_ ConnectionIDGenerator,
   500  					tokenP protocol.StatelessResetToken,
   501  					_ *Config,
   502  					_ *tls.Config,
   503  					_ *handshake.TokenGenerator,
   504  					_ bool,
   505  					_ *logging.ConnectionTracer,
   506  					_ uint64,
   507  					_ utils.Logger,
   508  					_ protocol.Version,
   509  				) quicConn {
   510  					Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
   511  					Expect(retrySrcConnID).To(BeNil())
   512  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   513  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   514  					// make sure we're using a server-generated connection ID
   515  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   516  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   517  					newConnID = srcConnID
   518  					Expect(tokenP).To(Equal(token))
   519  					conn.EXPECT().handlePacket(p)
   520  					conn.EXPECT().run().Do(func() error { close(run); return nil })
   521  					conn.EXPECT().Context().Return(context.Background())
   522  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   523  					return conn
   524  				}
   525  				gomock.InOrder(
   526  					phm.EXPECT().Get(connID),
   527  					phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
   528  					phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
   529  						Expect(c).To(Equal(newConnID))
   530  						return true
   531  					}),
   532  				)
   533  
   534  				done := make(chan struct{})
   535  				go func() {
   536  					defer GinkgoRecover()
   537  					serv.handlePacket(p)
   538  					// the Handshake packet is written by the connection
   539  					// make sure there are no Write calls on the packet conn
   540  					time.Sleep(50 * time.Millisecond)
   541  					close(done)
   542  				}()
   543  				// make sure we're using a server-generated connection ID
   544  				Eventually(run).Should(BeClosed())
   545  				Eventually(done).Should(BeClosed())
   546  				// shutdown
   547  				conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
   548  			})
   549  
   550  			It("drops packets if the receive queue is full", func() {
   551  				serv.verifySourceAddress = func(net.Addr) bool { return false }
   552  
   553  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   554  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   555  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   556  
   557  				acceptConn := make(chan struct{})
   558  				var counter atomic.Uint32
   559  				serv.newConn = func(
   560  					_ sendConn,
   561  					runner connRunner,
   562  					_ protocol.ConnectionID,
   563  					_ *protocol.ConnectionID,
   564  					_ protocol.ConnectionID,
   565  					_ protocol.ConnectionID,
   566  					_ protocol.ConnectionID,
   567  					_ ConnectionIDGenerator,
   568  					_ protocol.StatelessResetToken,
   569  					_ *Config,
   570  					_ *tls.Config,
   571  					_ *handshake.TokenGenerator,
   572  					_ bool,
   573  					_ *logging.ConnectionTracer,
   574  					_ uint64,
   575  					_ utils.Logger,
   576  					_ protocol.Version,
   577  				) quicConn {
   578  					<-acceptConn
   579  					counter.Add(1)
   580  					conn := NewMockQUICConn(mockCtrl)
   581  					conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
   582  					conn.EXPECT().run().MaxTimes(1)
   583  					conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
   584  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
   585  					// shutdown
   586  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
   587  					return conn
   588  				}
   589  
   590  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
   591  				serv.handlePacket(p)
   592  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
   593  				var wg sync.WaitGroup
   594  				for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
   595  					wg.Add(1)
   596  					go func() {
   597  						defer GinkgoRecover()
   598  						defer wg.Done()
   599  						serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
   600  					}()
   601  				}
   602  				wg.Wait()
   603  
   604  				close(acceptConn)
   605  				Eventually(
   606  					func() uint32 { return counter.Load() },
   607  					scaleDuration(100*time.Millisecond),
   608  				).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   609  				Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   610  			})
   611  
   612  			It("only creates a single connection for a duplicate Initial", func() {
   613  				done := make(chan struct{})
   614  				serv.newConn = func(
   615  					_ sendConn,
   616  					runner connRunner,
   617  					_ protocol.ConnectionID,
   618  					_ *protocol.ConnectionID,
   619  					_ protocol.ConnectionID,
   620  					_ protocol.ConnectionID,
   621  					_ protocol.ConnectionID,
   622  					_ ConnectionIDGenerator,
   623  					_ protocol.StatelessResetToken,
   624  					_ *Config,
   625  					_ *tls.Config,
   626  					_ *handshake.TokenGenerator,
   627  					_ bool,
   628  					_ *logging.ConnectionTracer,
   629  					_ uint64,
   630  					_ utils.Logger,
   631  					_ protocol.Version,
   632  				) quicConn {
   633  					conn := NewMockQUICConn(mockCtrl)
   634  					conn.EXPECT().handlePacket(gomock.Any())
   635  					conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) {
   636  						close(done)
   637  					})
   638  					return conn
   639  				}
   640  
   641  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
   642  				p := getInitial(connID)
   643  				phm.EXPECT().Get(connID)
   644  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
   645  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
   646  				Expect(serv.handlePacketImpl(p)).To(BeTrue())
   647  				Eventually(done).Should(BeClosed())
   648  			})
   649  
   650  			It("limits the number of unvalidated handshakes", func() {
   651  				const limit = 3
   652  				limiter := rate.NewLimiter(0, limit)
   653  				serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }
   654  
   655  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   656  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   657  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   658  
   659  				connChan := make(chan *MockQUICConn, 1)
   660  				var wg sync.WaitGroup
   661  				wg.Add(limit)
   662  				done := make(chan struct{})
   663  				serv.newConn = func(
   664  					_ sendConn,
   665  					runner connRunner,
   666  					_ protocol.ConnectionID,
   667  					_ *protocol.ConnectionID,
   668  					_ protocol.ConnectionID,
   669  					_ protocol.ConnectionID,
   670  					_ protocol.ConnectionID,
   671  					_ ConnectionIDGenerator,
   672  					_ protocol.StatelessResetToken,
   673  					_ *Config,
   674  					_ *tls.Config,
   675  					_ *handshake.TokenGenerator,
   676  					_ bool,
   677  					_ *logging.ConnectionTracer,
   678  					_ uint64,
   679  					_ utils.Logger,
   680  					_ protocol.Version,
   681  				) quicConn {
   682  					conn := <-connChan
   683  					conn.EXPECT().handlePacket(gomock.Any())
   684  					conn.EXPECT().run()
   685  					conn.EXPECT().Context().Return(context.Background())
   686  					conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done })
   687  					return conn
   688  				}
   689  
   690  				// Initiate the maximum number of allowed connection attempts.
   691  				for i := 0; i < limit; i++ {
   692  					conn := NewMockQUICConn(mockCtrl)
   693  					connChan <- conn
   694  					serv.handlePacket(getInitialWithRandomDestConnID())
   695  				}
   696  
   697  				// Now initiate another connection attempt.
   698  				p := getInitialWithRandomDestConnID()
   699  				tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   700  					defer GinkgoRecover()
   701  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   702  				})
   703  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   704  					defer GinkgoRecover()
   705  					defer close(done)
   706  					hdr, _, _, err := wire.ParsePacket(b)
   707  					Expect(err).ToNot(HaveOccurred())
   708  					Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
   709  					return len(b), nil
   710  				})
   711  				serv.handlePacket(p)
   712  				Eventually(done).Should(BeClosed())
   713  
   714  				for i := 0; i < limit; i++ {
   715  					_, err := serv.Accept(context.Background())
   716  					Expect(err).ToNot(HaveOccurred())
   717  				}
   718  				wg.Wait()
   719  			})
   720  		})
   721  
   722  		Context("token validation", func() {
   723  			It("decodes the token from the token field", func() {
   724  				serv.newConn = func(
   725  					_ sendConn,
   726  					_ connRunner,
   727  					_ protocol.ConnectionID,
   728  					_ *protocol.ConnectionID,
   729  					_ protocol.ConnectionID,
   730  					_ protocol.ConnectionID,
   731  					_ protocol.ConnectionID,
   732  					_ ConnectionIDGenerator,
   733  					_ protocol.StatelessResetToken,
   734  					_ *Config,
   735  					_ *tls.Config,
   736  					_ *handshake.TokenGenerator,
   737  					_ bool,
   738  					_ *logging.ConnectionTracer,
   739  					_ uint64,
   740  					_ utils.Logger,
   741  					_ protocol.Version,
   742  				) quicConn {
   743  					c := NewMockQUICConn(mockCtrl)
   744  					c.EXPECT().handlePacket(gomock.Any())
   745  					c.EXPECT().run()
   746  					c.EXPECT().HandshakeComplete()
   747  					ctx, cancel := context.WithCancel(context.Background())
   748  					cancel()
   749  					c.EXPECT().Context().Return(ctx)
   750  					return c
   751  				}
   752  				raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
   753  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   754  				Expect(err).ToNot(HaveOccurred())
   755  				packet := getPacket(&wire.Header{
   756  					Type:    protocol.PacketTypeInitial,
   757  					Token:   token,
   758  					Version: serv.config.Versions[0],
   759  				}, make([]byte, protocol.MinInitialPacketSize))
   760  				packet.remoteAddr = raddr
   761  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
   762  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   763  
   764  				done := make(chan struct{})
   765  				phm.EXPECT().Get(gomock.Any())
   766  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
   767  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
   768  					close(done)
   769  					return true
   770  				})
   771  				phm.EXPECT().Remove(gomock.Any()).AnyTimes()
   772  				serv.handlePacket(packet)
   773  				Eventually(done).Should(BeClosed())
   774  			})
   775  
   776  			It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
   777  				serv.verifySourceAddress = func(net.Addr) bool { return true }
   778  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   779  				Expect(err).ToNot(HaveOccurred())
   780  				hdr := &wire.Header{
   781  					Type:             protocol.PacketTypeInitial,
   782  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   783  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   784  					Token:            token,
   785  					Version:          protocol.Version1,
   786  				}
   787  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   788  				packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
   789  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   790  				packet.remoteAddr = raddr
   791  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   792  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   793  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   794  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   795  					Expect(frames).To(HaveLen(1))
   796  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   797  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   798  					Expect(ccf.IsApplicationError).To(BeFalse())
   799  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   800  				})
   801  				done := make(chan struct{})
   802  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   803  					defer close(done)
   804  					checkConnectionCloseError(b, hdr, qerr.InvalidToken)
   805  					return len(b), nil
   806  				})
   807  				phm.EXPECT().Get(gomock.Any())
   808  				serv.handlePacket(packet)
   809  				Eventually(done).Should(BeClosed())
   810  			})
   811  
   812  			It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
   813  				serv.verifySourceAddress = func(net.Addr) bool { return true }
   814  				serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
   815  				Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
   816  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   817  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   818  				Expect(err).ToNot(HaveOccurred())
   819  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   820  				hdr := &wire.Header{
   821  					Type:             protocol.PacketTypeInitial,
   822  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   823  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   824  					Token:            token,
   825  					Version:          protocol.Version1,
   826  				}
   827  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   828  				packet.remoteAddr = raddr
   829  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   830  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   831  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   832  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   833  					Expect(frames).To(HaveLen(1))
   834  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   835  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   836  					Expect(ccf.IsApplicationError).To(BeFalse())
   837  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   838  				})
   839  				done := make(chan struct{})
   840  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   841  					defer close(done)
   842  					checkConnectionCloseError(b, hdr, qerr.InvalidToken)
   843  					return len(b), nil
   844  				})
   845  				phm.EXPECT().Get(gomock.Any())
   846  				serv.handlePacket(packet)
   847  				Eventually(done).Should(BeClosed())
   848  			})
   849  
   850  			It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
   851  				serv.verifySourceAddress = func(net.Addr) bool { return true }
   852  				token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
   853  				Expect(err).ToNot(HaveOccurred())
   854  				hdr := &wire.Header{
   855  					Type:             protocol.PacketTypeInitial,
   856  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   857  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   858  					Token:            token,
   859  					Version:          protocol.Version1,
   860  				}
   861  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   862  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   863  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   864  				packet.remoteAddr = raddr
   865  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   866  				done := make(chan struct{})
   867  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   868  					defer close(done)
   869  					replyHdr := parseHeader(b)
   870  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   871  					return len(b), nil
   872  				})
   873  				phm.EXPECT().Get(gomock.Any())
   874  				serv.handlePacket(packet)
   875  				// make sure there are no Write calls on the packet conn
   876  				Eventually(done).Should(BeClosed())
   877  			})
   878  
   879  			It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
   880  				serv.verifySourceAddress = func(net.Addr) bool { return true }
   881  				serv.maxTokenAge = time.Millisecond
   882  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   883  				token, err := serv.tokenGenerator.NewToken(raddr)
   884  				Expect(err).ToNot(HaveOccurred())
   885  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   886  				hdr := &wire.Header{
   887  					Type:             protocol.PacketTypeInitial,
   888  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   889  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   890  					Token:            token,
   891  					Version:          protocol.Version1,
   892  				}
   893  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   894  				packet.remoteAddr = raddr
   895  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   896  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   897  				})
   898  				done := make(chan struct{})
   899  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   900  					defer close(done)
   901  					return len(b), nil
   902  				})
   903  				phm.EXPECT().Get(gomock.Any())
   904  				serv.handlePacket(packet)
   905  				Eventually(done).Should(BeClosed())
   906  			})
   907  
   908  			It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
   909  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   910  				Expect(err).ToNot(HaveOccurred())
   911  				hdr := &wire.Header{
   912  					Type:             protocol.PacketTypeInitial,
   913  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   914  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   915  					Token:            token,
   916  					Version:          protocol.Version1,
   917  				}
   918  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   919  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   920  				packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   921  				done := make(chan struct{})
   922  				tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
   923  				phm.EXPECT().Get(gomock.Any())
   924  				serv.handlePacket(packet)
   925  				// make sure there are no Write calls on the packet conn
   926  				time.Sleep(50 * time.Millisecond)
   927  				Eventually(done).Should(BeClosed())
   928  			})
   929  		})
   930  
   931  		Context("accepting connections", func() {
   932  			It("returns Accept when closed", func() {
   933  				done := make(chan struct{})
   934  				go func() {
   935  					defer GinkgoRecover()
   936  					_, err := serv.Accept(context.Background())
   937  					Expect(err).To(MatchError(ErrServerClosed))
   938  					close(done)
   939  				}()
   940  
   941  				serv.Close()
   942  				Eventually(done).Should(BeClosed())
   943  			})
   944  
   945  			It("returns immediately, if an error occurred before", func() {
   946  				serv.Close()
   947  				for i := 0; i < 3; i++ {
   948  					_, err := serv.Accept(context.Background())
   949  					Expect(err).To(MatchError(ErrServerClosed))
   950  				}
   951  			})
   952  
   953  			PIt("closes connection that are still handshaking after Close", func() {
   954  				serv.Close()
   955  
   956  				destroyed := make(chan struct{})
   957  				serv.newConn = func(
   958  					_ sendConn,
   959  					_ connRunner,
   960  					_ protocol.ConnectionID,
   961  					_ *protocol.ConnectionID,
   962  					_ protocol.ConnectionID,
   963  					_ protocol.ConnectionID,
   964  					_ protocol.ConnectionID,
   965  					_ ConnectionIDGenerator,
   966  					_ protocol.StatelessResetToken,
   967  					conf *Config,
   968  					_ *tls.Config,
   969  					_ *handshake.TokenGenerator,
   970  					_ bool,
   971  					_ *logging.ConnectionTracer,
   972  					_ uint64,
   973  					_ utils.Logger,
   974  					_ protocol.Version,
   975  				) quicConn {
   976  					conn := NewMockQUICConn(mockCtrl)
   977  					conn.EXPECT().handlePacket(gomock.Any())
   978  					conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
   979  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   980  					conn.EXPECT().run().MaxTimes(1)
   981  					conn.EXPECT().Context().Return(context.Background())
   982  					return conn
   983  				}
   984  				phm.EXPECT().Get(gomock.Any())
   985  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
   986  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
   987  				serv.handleInitialImpl(
   988  					receivedPacket{buffer: getPacketBuffer()},
   989  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
   990  				)
   991  				Eventually(destroyed).Should(BeClosed())
   992  			})
   993  
   994  			It("returns when the context is canceled", func() {
   995  				ctx, cancel := context.WithCancel(context.Background())
   996  				done := make(chan struct{})
   997  				go func() {
   998  					defer GinkgoRecover()
   999  					_, err := serv.Accept(ctx)
  1000  					Expect(err).To(MatchError("context canceled"))
  1001  					close(done)
  1002  				}()
  1003  
  1004  				Consistently(done).ShouldNot(BeClosed())
  1005  				cancel()
  1006  				Eventually(done).Should(BeClosed())
  1007  			})
  1008  
  1009  			It("uses the config returned by GetConfigClient", func() {
  1010  				conn := NewMockQUICConn(mockCtrl)
  1011  
  1012  				conf := &Config{MaxIncomingStreams: 1234}
  1013  				serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
  1014  				done := make(chan struct{})
  1015  				go func() {
  1016  					defer GinkgoRecover()
  1017  					s, err := serv.Accept(context.Background())
  1018  					Expect(err).ToNot(HaveOccurred())
  1019  					Expect(s).To(Equal(conn))
  1020  					close(done)
  1021  				}()
  1022  
  1023  				handshakeChan := make(chan struct{})
  1024  				serv.newConn = func(
  1025  					_ sendConn,
  1026  					_ connRunner,
  1027  					_ protocol.ConnectionID,
  1028  					_ *protocol.ConnectionID,
  1029  					_ protocol.ConnectionID,
  1030  					_ protocol.ConnectionID,
  1031  					_ protocol.ConnectionID,
  1032  					_ ConnectionIDGenerator,
  1033  					_ protocol.StatelessResetToken,
  1034  					conf *Config,
  1035  					_ *tls.Config,
  1036  					_ *handshake.TokenGenerator,
  1037  					_ bool,
  1038  					_ *logging.ConnectionTracer,
  1039  					_ uint64,
  1040  					_ utils.Logger,
  1041  					_ protocol.Version,
  1042  				) quicConn {
  1043  					Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
  1044  					conn.EXPECT().handlePacket(gomock.Any())
  1045  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1046  					conn.EXPECT().run()
  1047  					conn.EXPECT().Context().Return(context.Background())
  1048  					return conn
  1049  				}
  1050  				phm.EXPECT().Get(gomock.Any())
  1051  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1052  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1053  				serv.handleInitialImpl(
  1054  					receivedPacket{buffer: getPacketBuffer()},
  1055  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1056  				)
  1057  				Consistently(done).ShouldNot(BeClosed())
  1058  				close(handshakeChan) // complete the handshake
  1059  				Eventually(done).Should(BeClosed())
  1060  			})
  1061  
  1062  			It("rejects a connection attempt when GetConfigClient returns an error", func() {
  1063  				serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
  1064  
  1065  				phm.EXPECT().Get(gomock.Any())
  1066  				done := make(chan struct{})
  1067  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
  1068  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1069  					defer close(done)
  1070  					rejectHdr := parseHeader(b)
  1071  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1072  					return len(b), nil
  1073  				})
  1074  				serv.handleInitialImpl(
  1075  					receivedPacket{buffer: getPacketBuffer()},
  1076  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
  1077  				)
  1078  				Eventually(done).Should(BeClosed())
  1079  			})
  1080  
  1081  			It("accepts new connections when the handshake completes", func() {
  1082  				conn := NewMockQUICConn(mockCtrl)
  1083  
  1084  				done := make(chan struct{})
  1085  				go func() {
  1086  					defer GinkgoRecover()
  1087  					s, err := serv.Accept(context.Background())
  1088  					Expect(err).ToNot(HaveOccurred())
  1089  					Expect(s).To(Equal(conn))
  1090  					close(done)
  1091  				}()
  1092  
  1093  				handshakeChan := make(chan struct{})
  1094  				serv.newConn = func(
  1095  					_ sendConn,
  1096  					runner connRunner,
  1097  					_ protocol.ConnectionID,
  1098  					_ *protocol.ConnectionID,
  1099  					_ protocol.ConnectionID,
  1100  					_ protocol.ConnectionID,
  1101  					_ protocol.ConnectionID,
  1102  					_ ConnectionIDGenerator,
  1103  					_ protocol.StatelessResetToken,
  1104  					_ *Config,
  1105  					_ *tls.Config,
  1106  					_ *handshake.TokenGenerator,
  1107  					_ bool,
  1108  					_ *logging.ConnectionTracer,
  1109  					_ uint64,
  1110  					_ utils.Logger,
  1111  					_ protocol.Version,
  1112  				) quicConn {
  1113  					conn.EXPECT().handlePacket(gomock.Any())
  1114  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1115  					conn.EXPECT().run()
  1116  					conn.EXPECT().Context().Return(context.Background())
  1117  					return conn
  1118  				}
  1119  				phm.EXPECT().Get(gomock.Any())
  1120  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1121  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1122  				serv.handleInitialImpl(
  1123  					receivedPacket{buffer: getPacketBuffer()},
  1124  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1125  				)
  1126  				Consistently(done).ShouldNot(BeClosed())
  1127  				close(handshakeChan) // complete the handshake
  1128  				Eventually(done).Should(BeClosed())
  1129  			})
  1130  		})
  1131  	})
  1132  
  1133  	Context("server accepting connections that haven't completed the handshake", func() {
  1134  		var (
  1135  			serv *EarlyListener
  1136  			phm  *MockPacketHandlerManager
  1137  		)
  1138  
  1139  		BeforeEach(func() {
  1140  			var err error
  1141  			serv, err = ListenEarly(conn, tlsConf, nil)
  1142  			Expect(err).ToNot(HaveOccurred())
  1143  			phm = NewMockPacketHandlerManager(mockCtrl)
  1144  			serv.baseServer.connHandler = phm
  1145  		})
  1146  
  1147  		AfterEach(func() {
  1148  			serv.Close()
  1149  		})
  1150  
  1151  		It("accepts new connections when they become ready", func() {
  1152  			conn := NewMockQUICConn(mockCtrl)
  1153  
  1154  			done := make(chan struct{})
  1155  			go func() {
  1156  				defer GinkgoRecover()
  1157  				s, err := serv.Accept(context.Background())
  1158  				Expect(err).ToNot(HaveOccurred())
  1159  				Expect(s).To(Equal(conn))
  1160  				close(done)
  1161  			}()
  1162  
  1163  			ready := make(chan struct{})
  1164  			serv.baseServer.newConn = func(
  1165  				_ sendConn,
  1166  				runner connRunner,
  1167  				_ protocol.ConnectionID,
  1168  				_ *protocol.ConnectionID,
  1169  				_ protocol.ConnectionID,
  1170  				_ protocol.ConnectionID,
  1171  				_ protocol.ConnectionID,
  1172  				_ ConnectionIDGenerator,
  1173  				_ protocol.StatelessResetToken,
  1174  				_ *Config,
  1175  				_ *tls.Config,
  1176  				_ *handshake.TokenGenerator,
  1177  				_ bool,
  1178  				_ *logging.ConnectionTracer,
  1179  				_ uint64,
  1180  				_ utils.Logger,
  1181  				_ protocol.Version,
  1182  			) quicConn {
  1183  				conn.EXPECT().handlePacket(gomock.Any())
  1184  				conn.EXPECT().run()
  1185  				conn.EXPECT().earlyConnReady().Return(ready)
  1186  				conn.EXPECT().Context().Return(context.Background())
  1187  				return conn
  1188  			}
  1189  			phm.EXPECT().Get(gomock.Any())
  1190  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1191  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1192  			serv.baseServer.handleInitialImpl(
  1193  				receivedPacket{buffer: getPacketBuffer()},
  1194  				&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1195  			)
  1196  			Consistently(done).ShouldNot(BeClosed())
  1197  			close(ready)
  1198  			Eventually(done).Should(BeClosed())
  1199  		})
  1200  
  1201  		It("rejects new connection attempts if the accept queue is full", func() {
  1202  			connChan := make(chan *MockQUICConn, 1)
  1203  			var wg sync.WaitGroup // to make sure the test fully completes
  1204  			wg.Add(protocol.MaxAcceptQueueSize + 1)
  1205  			serv.baseServer.newConn = func(
  1206  				_ sendConn,
  1207  				runner connRunner,
  1208  				_ protocol.ConnectionID,
  1209  				_ *protocol.ConnectionID,
  1210  				_ protocol.ConnectionID,
  1211  				_ protocol.ConnectionID,
  1212  				_ protocol.ConnectionID,
  1213  				_ ConnectionIDGenerator,
  1214  				_ protocol.StatelessResetToken,
  1215  				_ *Config,
  1216  				_ *tls.Config,
  1217  				_ *handshake.TokenGenerator,
  1218  				_ bool,
  1219  				_ *logging.ConnectionTracer,
  1220  				_ uint64,
  1221  				_ utils.Logger,
  1222  				_ protocol.Version,
  1223  			) quicConn {
  1224  				defer wg.Done()
  1225  				ready := make(chan struct{})
  1226  				close(ready)
  1227  				conn := <-connChan
  1228  				conn.EXPECT().handlePacket(gomock.Any())
  1229  				conn.EXPECT().run()
  1230  				conn.EXPECT().earlyConnReady().Return(ready)
  1231  				conn.EXPECT().Context().Return(context.Background())
  1232  				return conn
  1233  			}
  1234  
  1235  			phm.EXPECT().Get(gomock.Any()).AnyTimes()
  1236  			phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
  1237  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
  1238  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
  1239  				conn := NewMockQUICConn(mockCtrl)
  1240  				connChan <- conn
  1241  				serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1242  			}
  1243  
  1244  			Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
  1245  
  1246  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1247  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1248  			conn := NewMockQUICConn(mockCtrl)
  1249  			conn.EXPECT().closeWithTransportError(ConnectionRefused)
  1250  			connChan <- conn
  1251  			serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1252  			wg.Wait()
  1253  		})
  1254  
  1255  		It("doesn't accept new connections if they were closed in the mean time", func() {
  1256  			p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
  1257  			ctx, cancel := context.WithCancel(context.Background())
  1258  			connCreated := make(chan struct{})
  1259  			conn := NewMockQUICConn(mockCtrl)
  1260  			serv.baseServer.newConn = func(
  1261  				_ sendConn,
  1262  				runner connRunner,
  1263  				_ protocol.ConnectionID,
  1264  				_ *protocol.ConnectionID,
  1265  				_ protocol.ConnectionID,
  1266  				_ protocol.ConnectionID,
  1267  				_ protocol.ConnectionID,
  1268  				_ ConnectionIDGenerator,
  1269  				_ protocol.StatelessResetToken,
  1270  				_ *Config,
  1271  				_ *tls.Config,
  1272  				_ *handshake.TokenGenerator,
  1273  				_ bool,
  1274  				_ *logging.ConnectionTracer,
  1275  				_ uint64,
  1276  				_ utils.Logger,
  1277  				_ protocol.Version,
  1278  			) quicConn {
  1279  				conn.EXPECT().handlePacket(p)
  1280  				conn.EXPECT().run()
  1281  				conn.EXPECT().earlyConnReady()
  1282  				conn.EXPECT().Context().Return(ctx)
  1283  				close(connCreated)
  1284  				return conn
  1285  			}
  1286  
  1287  			phm.EXPECT().Get(gomock.Any())
  1288  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1289  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1290  			serv.baseServer.handlePacket(p)
  1291  			// make sure there are no Write calls on the packet conn
  1292  			time.Sleep(50 * time.Millisecond)
  1293  			Eventually(connCreated).Should(BeClosed())
  1294  			cancel()
  1295  			time.Sleep(scaleDuration(200 * time.Millisecond))
  1296  
  1297  			done := make(chan struct{})
  1298  			go func() {
  1299  				defer GinkgoRecover()
  1300  				serv.Accept(context.Background())
  1301  				close(done)
  1302  			}()
  1303  			Consistently(done).ShouldNot(BeClosed())
  1304  
  1305  			// make the go routine return
  1306  			Expect(serv.Close()).To(Succeed())
  1307  			Eventually(done).Should(BeClosed())
  1308  		})
  1309  	})
  1310  
  1311  	Context("0-RTT", func() {
  1312  		var (
  1313  			tr     *Transport
  1314  			serv   *baseServer
  1315  			phm    *MockPacketHandlerManager
  1316  			tracer *mocklogging.MockTracer
  1317  		)
  1318  
  1319  		BeforeEach(func() {
  1320  			var t *logging.Tracer
  1321  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
  1322  			tr = &Transport{Conn: conn, Tracer: t}
  1323  			ln, err := tr.ListenEarly(tlsConf, nil)
  1324  			Expect(err).ToNot(HaveOccurred())
  1325  			phm = NewMockPacketHandlerManager(mockCtrl)
  1326  			serv = ln.baseServer
  1327  			serv.connHandler = phm
  1328  		})
  1329  
  1330  		AfterEach(func() {
  1331  			tracer.EXPECT().Close()
  1332  			Expect(tr.Close()).To(Succeed())
  1333  		})
  1334  
  1335  		It("passes packets to existing connections", func() {
  1336  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1337  			p := getPacket(&wire.Header{
  1338  				Type:             protocol.PacketType0RTT,
  1339  				DestConnectionID: connID,
  1340  				Version:          serv.config.Versions[0],
  1341  			}, make([]byte, 100))
  1342  			conn := NewMockPacketHandler(mockCtrl)
  1343  			phm.EXPECT().Get(connID).Return(conn, true)
  1344  			handled := make(chan struct{})
  1345  			conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
  1346  			serv.handlePacket(p)
  1347  			Eventually(handled).Should(BeClosed())
  1348  		})
  1349  
  1350  		It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
  1351  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1352  
  1353  			var zeroRTTPackets []receivedPacket
  1354  
  1355  			for i := 0; i < protocol.Max0RTTQueueLen; i++ {
  1356  				p := getPacket(&wire.Header{
  1357  					Type:             protocol.PacketType0RTT,
  1358  					DestConnectionID: connID,
  1359  					Version:          serv.config.Versions[0],
  1360  				}, make([]byte, 100+i))
  1361  				phm.EXPECT().Get(connID)
  1362  				serv.handlePacket(p)
  1363  				zeroRTTPackets = append(zeroRTTPackets, p)
  1364  			}
  1365  
  1366  			// send one more packet, this one should be dropped
  1367  			p := getPacket(&wire.Header{
  1368  				Type:             protocol.PacketType0RTT,
  1369  				DestConnectionID: connID,
  1370  				Version:          serv.config.Versions[0],
  1371  			}, make([]byte, 200))
  1372  			phm.EXPECT().Get(connID)
  1373  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
  1374  			serv.handlePacket(p)
  1375  
  1376  			initial := getPacket(&wire.Header{
  1377  				Type:             protocol.PacketTypeInitial,
  1378  				DestConnectionID: connID,
  1379  				Version:          serv.config.Versions[0],
  1380  			}, make([]byte, protocol.MinInitialPacketSize))
  1381  			called := make(chan struct{})
  1382  			serv.newConn = func(
  1383  				_ sendConn,
  1384  				_ connRunner,
  1385  				_ protocol.ConnectionID,
  1386  				_ *protocol.ConnectionID,
  1387  				_ protocol.ConnectionID,
  1388  				_ protocol.ConnectionID,
  1389  				_ protocol.ConnectionID,
  1390  				_ ConnectionIDGenerator,
  1391  				_ protocol.StatelessResetToken,
  1392  				_ *Config,
  1393  				_ *tls.Config,
  1394  				_ *handshake.TokenGenerator,
  1395  				_ bool,
  1396  				_ *logging.ConnectionTracer,
  1397  				_ uint64,
  1398  				_ utils.Logger,
  1399  				_ protocol.Version,
  1400  			) quicConn {
  1401  				conn := NewMockQUICConn(mockCtrl)
  1402  				var calls []any
  1403  				calls = append(calls, conn.EXPECT().handlePacket(initial))
  1404  				for _, p := range zeroRTTPackets {
  1405  					calls = append(calls, conn.EXPECT().handlePacket(p))
  1406  				}
  1407  				gomock.InOrder(calls...)
  1408  				conn.EXPECT().run()
  1409  				conn.EXPECT().earlyConnReady()
  1410  				conn.EXPECT().Context().Return(context.Background())
  1411  				close(called)
  1412  				// shutdown
  1413  				conn.EXPECT().closeWithTransportError(gomock.Any())
  1414  				return conn
  1415  			}
  1416  
  1417  			phm.EXPECT().Get(connID)
  1418  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1419  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1420  			serv.handlePacket(initial)
  1421  			Eventually(called).Should(BeClosed())
  1422  		})
  1423  
  1424  		It("limits the number of queues", func() {
  1425  			for i := 0; i < protocol.Max0RTTQueues; i++ {
  1426  				b := make([]byte, 16)
  1427  				rand.Read(b)
  1428  				connID := protocol.ParseConnectionID(b)
  1429  				p := getPacket(&wire.Header{
  1430  					Type:             protocol.PacketType0RTT,
  1431  					DestConnectionID: connID,
  1432  					Version:          serv.config.Versions[0],
  1433  				}, make([]byte, 100+i))
  1434  				phm.EXPECT().Get(connID)
  1435  				serv.handlePacket(p)
  1436  			}
  1437  
  1438  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1439  			p := getPacket(&wire.Header{
  1440  				Type:             protocol.PacketType0RTT,
  1441  				DestConnectionID: connID,
  1442  				Version:          serv.config.Versions[0],
  1443  			}, make([]byte, 200))
  1444  			phm.EXPECT().Get(connID)
  1445  			dropped := make(chan struct{})
  1446  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1447  				close(dropped)
  1448  			})
  1449  			serv.handlePacket(p)
  1450  			Eventually(dropped).Should(BeClosed())
  1451  		})
  1452  
  1453  		It("drops queues after a while", func() {
  1454  			now := time.Now()
  1455  
  1456  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1457  			p := getPacket(&wire.Header{
  1458  				Type:             protocol.PacketType0RTT,
  1459  				DestConnectionID: connID,
  1460  				Version:          serv.config.Versions[0],
  1461  			}, make([]byte, 200))
  1462  			p.rcvTime = now
  1463  
  1464  			connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
  1465  			p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
  1466  			p2 := getPacket(&wire.Header{
  1467  				Type:             protocol.PacketType0RTT,
  1468  				DestConnectionID: connID2,
  1469  				Version:          serv.config.Versions[0],
  1470  			}, make([]byte, 300))
  1471  			p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
  1472  
  1473  			dropped1 := make(chan struct{})
  1474  			dropped2 := make(chan struct{})
  1475  			// need to register the call before handling the packet to avoid race condition
  1476  			gomock.InOrder(
  1477  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1478  					close(dropped1)
  1479  				}),
  1480  				tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1481  					close(dropped2)
  1482  				}),
  1483  			)
  1484  
  1485  			phm.EXPECT().Get(connID)
  1486  			serv.handlePacket(p)
  1487  
  1488  			// There's no cleanup Go routine.
  1489  			// Cleanup is triggered when new packets are received.
  1490  
  1491  			phm.EXPECT().Get(connID2)
  1492  			serv.handlePacket(p2)
  1493  			// make sure no cleanup is executed
  1494  			Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
  1495  
  1496  			// There's no cleanup Go routine.
  1497  			// Cleanup is triggered when new packets are received.
  1498  			connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
  1499  			p3 := getPacket(&wire.Header{
  1500  				Type:             protocol.PacketType0RTT,
  1501  				DestConnectionID: connID3,
  1502  				Version:          serv.config.Versions[0],
  1503  			}, make([]byte, 200))
  1504  			p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1505  			phm.EXPECT().Get(connID3)
  1506  			serv.handlePacket(p3)
  1507  			Eventually(dropped1).Should(BeClosed())
  1508  			Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
  1509  
  1510  			// make sure the second packet is also cleaned up
  1511  			connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
  1512  			p4 := getPacket(&wire.Header{
  1513  				Type:             protocol.PacketType0RTT,
  1514  				DestConnectionID: connID4,
  1515  				Version:          serv.config.Versions[0],
  1516  			}, make([]byte, 200))
  1517  			p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1518  			phm.EXPECT().Get(connID4)
  1519  			serv.handlePacket(p4)
  1520  			Eventually(dropped2).Should(BeClosed())
  1521  		})
  1522  	})
  1523  })