github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/server_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/TugasAkhir-QUIC/quic-go/internal/handshake"
    14  	mocklogging "github.com/TugasAkhir-QUIC/quic-go/internal/mocks/logging"
    15  	"github.com/TugasAkhir-QUIC/quic-go/internal/protocol"
    16  	"github.com/TugasAkhir-QUIC/quic-go/internal/qerr"
    17  	"github.com/TugasAkhir-QUIC/quic-go/internal/testdata"
    18  	"github.com/TugasAkhir-QUIC/quic-go/internal/utils"
    19  	"github.com/TugasAkhir-QUIC/quic-go/internal/wire"
    20  	"github.com/TugasAkhir-QUIC/quic-go/logging"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  	"go.uber.org/mock/gomock"
    25  )
    26  
    27  var _ = Describe("Server", func() {
    28  	var (
    29  		conn    *MockPacketConn
    30  		tlsConf *tls.Config
    31  	)
    32  
    33  	getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
    34  		buf := getPacketBuffer()
    35  		hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
    36  		var err error
    37  		buf.Data, err = (&wire.ExtendedHeader{
    38  			Header:          *hdr,
    39  			PacketNumber:    0x42,
    40  			PacketNumberLen: protocol.PacketNumberLen4,
    41  		}).Append(buf.Data, protocol.Version1)
    42  		Expect(err).ToNot(HaveOccurred())
    43  		n := len(buf.Data)
    44  		buf.Data = append(buf.Data, p...)
    45  		data := buf.Data
    46  		sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
    47  		_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
    48  		data = data[:len(data)+16]
    49  		sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
    50  		return receivedPacket{
    51  			remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
    52  			data:       data,
    53  			buffer:     buf,
    54  		}
    55  	}
    56  
    57  	getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
    58  		senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
    59  		hdr := &wire.Header{
    60  			Type:             protocol.PacketTypeInitial,
    61  			SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
    62  			DestConnectionID: destConnID,
    63  			Version:          protocol.Version1,
    64  		}
    65  		p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
    66  		p.buffer = getPacketBuffer()
    67  		p.remoteAddr = senderAddr
    68  		return p
    69  	}
    70  
    71  	getInitialWithRandomDestConnID := func() receivedPacket {
    72  		b := make([]byte, 10)
    73  		_, err := rand.Read(b)
    74  		Expect(err).ToNot(HaveOccurred())
    75  
    76  		return getInitial(protocol.ParseConnectionID(b))
    77  	}
    78  
    79  	parseHeader := func(data []byte) *wire.Header {
    80  		hdr, _, _, err := wire.ParsePacket(data)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		return hdr
    83  	}
    84  
    85  	checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
    86  		replyHdr := parseHeader(b)
    87  		Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
    88  		Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
    89  		Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
    90  		_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
    91  		extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
    92  		Expect(err).ToNot(HaveOccurred())
    93  		data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
    94  		Expect(err).ToNot(HaveOccurred())
    95  		_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
    96  		Expect(err).ToNot(HaveOccurred())
    97  		Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
    98  		ccf := f.(*wire.ConnectionCloseFrame)
    99  		Expect(ccf.IsApplicationError).To(BeFalse())
   100  		Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
   101  		Expect(ccf.ReasonPhrase).To(BeEmpty())
   102  	}
   103  
   104  	BeforeEach(func() {
   105  		conn = NewMockPacketConn(mockCtrl)
   106  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
   107  		wait := make(chan struct{})
   108  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
   109  			<-wait
   110  			return 0, nil, errors.New("done")
   111  		}).MaxTimes(1)
   112  		conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error {
   113  			close(wait)
   114  			conn.EXPECT().SetReadDeadline(time.Time{})
   115  			return nil
   116  		}).MaxTimes(1)
   117  		tlsConf = testdata.GetTLSConfig()
   118  		tlsConf.NextProtos = []string{"proto1"}
   119  	})
   120  
   121  	It("errors when no tls.Config is given", func() {
   122  		_, err := ListenAddr("localhost:0", nil, nil)
   123  		Expect(err).To(HaveOccurred())
   124  		Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
   125  	})
   126  
   127  	It("errors when the Config contains an invalid version", func() {
   128  		version := protocol.Version(0x1234)
   129  		_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}})
   130  		Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
   131  	})
   132  
   133  	It("fills in default values if options are not set in the Config", func() {
   134  		ln, err := Listen(conn, tlsConf, &Config{})
   135  		Expect(err).ToNot(HaveOccurred())
   136  		server := ln.baseServer
   137  		Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
   138  		Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   139  		Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
   140  		Expect(server.config.KeepAlivePeriod).To(BeZero())
   141  		// stop the listener
   142  		Expect(ln.Close()).To(Succeed())
   143  	})
   144  
   145  	It("setups with the right values", func() {
   146  		supportedVersions := []protocol.Version{protocol.Version1}
   147  		config := Config{
   148  			Versions:             supportedVersions,
   149  			HandshakeIdleTimeout: 1337 * time.Hour,
   150  			MaxIdleTimeout:       42 * time.Minute,
   151  			KeepAlivePeriod:      5 * time.Second,
   152  		}
   153  		ln, err := Listen(conn, tlsConf, &config)
   154  		Expect(err).ToNot(HaveOccurred())
   155  		server := ln.baseServer
   156  		Expect(server.connHandler).ToNot(BeNil())
   157  		Expect(server.config.Versions).To(Equal(supportedVersions))
   158  		Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
   159  		Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
   160  		Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
   161  		// stop the listener
   162  		Expect(ln.Close()).To(Succeed())
   163  	})
   164  
   165  	It("listens on a given address", func() {
   166  		addr := "127.0.0.1:13579"
   167  		ln, err := ListenAddr(addr, tlsConf, &Config{})
   168  		Expect(err).ToNot(HaveOccurred())
   169  		Expect(ln.Addr().String()).To(Equal(addr))
   170  		// stop the listener
   171  		Expect(ln.Close()).To(Succeed())
   172  	})
   173  
   174  	It("errors if given an invalid address", func() {
   175  		addr := "127.0.0.1"
   176  		_, err := ListenAddr(addr, tlsConf, &Config{})
   177  		Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
   178  	})
   179  
   180  	It("errors if given an invalid address", func() {
   181  		addr := "1.1.1.1:1111"
   182  		_, err := ListenAddr(addr, tlsConf, &Config{})
   183  		Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
   184  	})
   185  
   186  	Context("server accepting connections that completed the handshake", func() {
   187  		var (
   188  			tr     *Transport
   189  			serv   *baseServer
   190  			phm    *MockPacketHandlerManager
   191  			tracer *mocklogging.MockTracer
   192  		)
   193  
   194  		BeforeEach(func() {
   195  			var t *logging.Tracer
   196  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
   197  			tr = &Transport{Conn: conn, Tracer: t}
   198  			ln, err := tr.Listen(tlsConf, nil)
   199  			Expect(err).ToNot(HaveOccurred())
   200  			serv = ln.baseServer
   201  			phm = NewMockPacketHandlerManager(mockCtrl)
   202  			serv.connHandler = phm
   203  		})
   204  
   205  		AfterEach(func() {
   206  			tracer.EXPECT().Close()
   207  			tr.Close()
   208  		})
   209  
   210  		Context("handling packets", func() {
   211  			It("drops Initial packets with a too short connection ID", func() {
   212  				p := getPacket(&wire.Header{
   213  					Type:             protocol.PacketTypeInitial,
   214  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
   215  					Version:          serv.config.Versions[0],
   216  				}, nil)
   217  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   218  				serv.handlePacket(p)
   219  				// make sure there are no Write calls on the packet conn
   220  				time.Sleep(50 * time.Millisecond)
   221  			})
   222  
   223  			It("drops too small Initial", func() {
   224  				p := getPacket(&wire.Header{
   225  					Type:             protocol.PacketTypeInitial,
   226  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
   227  					Version:          serv.config.Versions[0],
   228  				}, make([]byte, protocol.MinInitialPacketSize-100))
   229  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   230  				serv.handlePacket(p)
   231  				// make sure there are no Write calls on the packet conn
   232  				time.Sleep(50 * time.Millisecond)
   233  			})
   234  
   235  			It("drops non-Initial packets", func() {
   236  				p := getPacket(&wire.Header{
   237  					Type:    protocol.PacketTypeHandshake,
   238  					Version: serv.config.Versions[0],
   239  				}, []byte("invalid"))
   240  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
   241  				serv.handlePacket(p)
   242  				// make sure there are no Write calls on the packet conn
   243  				time.Sleep(50 * time.Millisecond)
   244  			})
   245  
   246  			It("passes packets to existing connections", func() {
   247  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
   248  				p := getPacket(&wire.Header{
   249  					Type:             protocol.PacketTypeInitial,
   250  					DestConnectionID: connID,
   251  					Version:          serv.config.Versions[0],
   252  				}, make([]byte, protocol.MinInitialPacketSize))
   253  				conn := NewMockPacketHandler(mockCtrl)
   254  				phm.EXPECT().Get(connID).Return(conn, true)
   255  				handled := make(chan struct{})
   256  				conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
   257  				serv.handlePacket(p)
   258  				Eventually(handled).Should(BeClosed())
   259  			})
   260  
   261  			It("creates a connection when the token is accepted", func() {
   262  				serv.maxNumHandshakesUnvalidated = 0
   263  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   264  				retryToken, err := serv.tokenGenerator.NewRetryToken(
   265  					raddr,
   266  					protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
   267  					protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
   268  				)
   269  				Expect(err).ToNot(HaveOccurred())
   270  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   271  				hdr := &wire.Header{
   272  					Type:             protocol.PacketTypeInitial,
   273  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   274  					DestConnectionID: connID,
   275  					Version:          protocol.Version1,
   276  					Token:            retryToken,
   277  				}
   278  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   279  				p.remoteAddr = raddr
   280  				run := make(chan struct{})
   281  				var token protocol.StatelessResetToken
   282  				rand.Read(token[:])
   283  
   284  				var newConnID protocol.ConnectionID
   285  				conn := NewMockQUICConn(mockCtrl)
   286  				serv.newConn = func(
   287  					_ sendConn,
   288  					_ connRunner,
   289  					origDestConnID protocol.ConnectionID,
   290  					retrySrcConnID *protocol.ConnectionID,
   291  					clientDestConnID protocol.ConnectionID,
   292  					destConnID protocol.ConnectionID,
   293  					srcConnID protocol.ConnectionID,
   294  					_ ConnectionIDGenerator,
   295  					tokenP protocol.StatelessResetToken,
   296  					_ *Config,
   297  					_ *tls.Config,
   298  					_ *handshake.TokenGenerator,
   299  					_ bool,
   300  					_ *logging.ConnectionTracer,
   301  					_ uint64,
   302  					_ utils.Logger,
   303  					_ protocol.Version,
   304  				) quicConn {
   305  					Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
   306  					Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
   307  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   308  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   309  					// make sure we're using a server-generated connection ID
   310  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   311  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   312  					newConnID = srcConnID
   313  					Expect(tokenP).To(Equal(token))
   314  					conn.EXPECT().handlePacket(p)
   315  					conn.EXPECT().run().Do(func() error { close(run); return nil })
   316  					conn.EXPECT().Context().Return(context.Background())
   317  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   318  					return conn
   319  				}
   320  				phm.EXPECT().Get(connID)
   321  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
   322  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
   323  					Expect(cid).To(Equal(newConnID))
   324  					return true
   325  				})
   326  
   327  				done := make(chan struct{})
   328  				go func() {
   329  					defer GinkgoRecover()
   330  					serv.handlePacket(p)
   331  					// the Handshake packet is written by the connection.
   332  					// Make sure there are no Write calls on the packet conn.
   333  					time.Sleep(50 * time.Millisecond)
   334  					close(done)
   335  				}()
   336  				// make sure we're using a server-generated connection ID
   337  				Eventually(run).Should(BeClosed())
   338  				Eventually(done).Should(BeClosed())
   339  				// shutdown
   340  				conn.EXPECT().closeWithTransportError(gomock.Any())
   341  			})
   342  
   343  			It("sends a Version Negotiation Packet for unsupported versions", func() {
   344  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   345  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   346  				packet := getPacket(&wire.Header{
   347  					Type:             protocol.PacketTypeHandshake,
   348  					SrcConnectionID:  srcConnID,
   349  					DestConnectionID: destConnID,
   350  					Version:          0x42,
   351  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   352  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   353  				packet.remoteAddr = raddr
   354  				tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) {
   355  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   356  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   357  				})
   358  				done := make(chan struct{})
   359  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   360  					defer close(done)
   361  					Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
   362  					dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
   363  					Expect(err).ToNot(HaveOccurred())
   364  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   365  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   366  					Expect(versions).ToNot(ContainElement(protocol.Version(0x42)))
   367  					return len(b), nil
   368  				})
   369  				serv.handlePacket(packet)
   370  				Eventually(done).Should(BeClosed())
   371  			})
   372  
   373  			It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
   374  				serv.disableVersionNegotiation = true
   375  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   376  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   377  				packet := getPacket(&wire.Header{
   378  					Type:             protocol.PacketTypeHandshake,
   379  					SrcConnectionID:  srcConnID,
   380  					DestConnectionID: destConnID,
   381  					Version:          0x42,
   382  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   383  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   384  				packet.remoteAddr = raddr
   385  				done := make(chan struct{})
   386  				serv.handlePacket(packet)
   387  				Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
   388  			})
   389  
   390  			It("ignores Version Negotiation packets", func() {
   391  				data := wire.ComposeVersionNegotiation(
   392  					protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
   393  					protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
   394  					[]protocol.Version{1, 2, 3},
   395  				)
   396  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   397  				done := make(chan struct{})
   398  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   399  					close(done)
   400  				})
   401  				serv.handlePacket(receivedPacket{
   402  					remoteAddr: raddr,
   403  					data:       data,
   404  					buffer:     getPacketBuffer(),
   405  				})
   406  				Eventually(done).Should(BeClosed())
   407  				// make sure no other packet is sent
   408  				time.Sleep(scaleDuration(20 * time.Millisecond))
   409  			})
   410  
   411  			It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
   412  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   413  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   414  				p := getPacket(&wire.Header{
   415  					Type:             protocol.PacketTypeHandshake,
   416  					SrcConnectionID:  srcConnID,
   417  					DestConnectionID: destConnID,
   418  					Version:          0x42,
   419  				}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
   420  				Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
   421  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   422  				p.remoteAddr = raddr
   423  				done := make(chan struct{})
   424  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   425  					close(done)
   426  				})
   427  				serv.handlePacket(p)
   428  				Eventually(done).Should(BeClosed())
   429  				// make sure no other packet is sent
   430  				time.Sleep(scaleDuration(20 * time.Millisecond))
   431  			})
   432  
   433  			It("replies with a Retry packet, if a token is required", func() {
   434  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   435  				serv.maxNumHandshakesUnvalidated = 0
   436  				hdr := &wire.Header{
   437  					Type:             protocol.PacketTypeInitial,
   438  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   439  					DestConnectionID: connID,
   440  					Version:          protocol.Version1,
   441  				}
   442  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   443  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   444  				packet.remoteAddr = raddr
   445  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
   446  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   447  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   448  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   449  					Expect(replyHdr.Token).ToNot(BeEmpty())
   450  				})
   451  				done := make(chan struct{})
   452  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   453  					defer close(done)
   454  					replyHdr := parseHeader(b)
   455  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   456  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   457  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   458  					Expect(replyHdr.Token).ToNot(BeEmpty())
   459  					Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
   460  					return len(b), nil
   461  				})
   462  				phm.EXPECT().Get(connID)
   463  				serv.handlePacket(packet)
   464  				Eventually(done).Should(BeClosed())
   465  			})
   466  
   467  			It("creates a connection, if no token is required", func() {
   468  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   469  				hdr := &wire.Header{
   470  					Type:             protocol.PacketTypeInitial,
   471  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   472  					DestConnectionID: connID,
   473  					Version:          protocol.Version1,
   474  				}
   475  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   476  				run := make(chan struct{})
   477  				var token protocol.StatelessResetToken
   478  				rand.Read(token[:])
   479  
   480  				var newConnID protocol.ConnectionID
   481  				conn := NewMockQUICConn(mockCtrl)
   482  				serv.newConn = func(
   483  					_ sendConn,
   484  					_ connRunner,
   485  					origDestConnID protocol.ConnectionID,
   486  					retrySrcConnID *protocol.ConnectionID,
   487  					clientDestConnID protocol.ConnectionID,
   488  					destConnID protocol.ConnectionID,
   489  					srcConnID protocol.ConnectionID,
   490  					_ ConnectionIDGenerator,
   491  					tokenP protocol.StatelessResetToken,
   492  					_ *Config,
   493  					_ *tls.Config,
   494  					_ *handshake.TokenGenerator,
   495  					_ bool,
   496  					_ *logging.ConnectionTracer,
   497  					_ uint64,
   498  					_ utils.Logger,
   499  					_ protocol.Version,
   500  				) quicConn {
   501  					Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
   502  					Expect(retrySrcConnID).To(BeNil())
   503  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   504  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   505  					// make sure we're using a server-generated connection ID
   506  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   507  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   508  					newConnID = srcConnID
   509  					Expect(tokenP).To(Equal(token))
   510  					conn.EXPECT().handlePacket(p)
   511  					conn.EXPECT().run().Do(func() error { close(run); return nil })
   512  					conn.EXPECT().Context().Return(context.Background())
   513  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   514  					return conn
   515  				}
   516  				gomock.InOrder(
   517  					phm.EXPECT().Get(connID),
   518  					phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
   519  					phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
   520  						Expect(c).To(Equal(newConnID))
   521  						return true
   522  					}),
   523  				)
   524  
   525  				done := make(chan struct{})
   526  				go func() {
   527  					defer GinkgoRecover()
   528  					serv.handlePacket(p)
   529  					// the Handshake packet is written by the connection
   530  					// make sure there are no Write calls on the packet conn
   531  					time.Sleep(50 * time.Millisecond)
   532  					close(done)
   533  				}()
   534  				// make sure we're using a server-generated connection ID
   535  				Eventually(run).Should(BeClosed())
   536  				Eventually(done).Should(BeClosed())
   537  				// shutdown
   538  				conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
   539  			})
   540  
   541  			It("drops packets if the receive queue is full", func() {
   542  				serv.maxNumHandshakesTotal = 10000
   543  				serv.maxNumHandshakesUnvalidated = 10000
   544  
   545  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   546  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   547  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   548  
   549  				acceptConn := make(chan struct{})
   550  				var counter atomic.Uint32
   551  				serv.newConn = func(
   552  					_ sendConn,
   553  					runner connRunner,
   554  					_ protocol.ConnectionID,
   555  					_ *protocol.ConnectionID,
   556  					_ protocol.ConnectionID,
   557  					_ protocol.ConnectionID,
   558  					_ protocol.ConnectionID,
   559  					_ ConnectionIDGenerator,
   560  					_ protocol.StatelessResetToken,
   561  					_ *Config,
   562  					_ *tls.Config,
   563  					_ *handshake.TokenGenerator,
   564  					_ bool,
   565  					_ *logging.ConnectionTracer,
   566  					_ uint64,
   567  					_ utils.Logger,
   568  					_ protocol.Version,
   569  				) quicConn {
   570  					<-acceptConn
   571  					counter.Add(1)
   572  					conn := NewMockQUICConn(mockCtrl)
   573  					conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
   574  					conn.EXPECT().run().MaxTimes(1)
   575  					conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
   576  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
   577  					// shutdown
   578  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
   579  					return conn
   580  				}
   581  
   582  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
   583  				serv.handlePacket(p)
   584  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
   585  				var wg sync.WaitGroup
   586  				for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
   587  					wg.Add(1)
   588  					go func() {
   589  						defer GinkgoRecover()
   590  						defer wg.Done()
   591  						serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
   592  					}()
   593  				}
   594  				wg.Wait()
   595  
   596  				close(acceptConn)
   597  				Eventually(
   598  					func() uint32 { return counter.Load() },
   599  					scaleDuration(100*time.Millisecond),
   600  				).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   601  				Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   602  			})
   603  
   604  			PIt("only creates a single connection for a duplicate Initial", func() {
   605  				var createdConn bool
   606  				serv.newConn = func(
   607  					_ sendConn,
   608  					runner connRunner,
   609  					_ protocol.ConnectionID,
   610  					_ *protocol.ConnectionID,
   611  					_ protocol.ConnectionID,
   612  					_ protocol.ConnectionID,
   613  					_ protocol.ConnectionID,
   614  					_ ConnectionIDGenerator,
   615  					_ protocol.StatelessResetToken,
   616  					_ *Config,
   617  					_ *tls.Config,
   618  					_ *handshake.TokenGenerator,
   619  					_ bool,
   620  					_ *logging.ConnectionTracer,
   621  					_ uint64,
   622  					_ utils.Logger,
   623  					_ protocol.Version,
   624  				) quicConn {
   625  					createdConn = true
   626  					return NewMockQUICConn(mockCtrl)
   627  				}
   628  
   629  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
   630  				p := getInitial(connID)
   631  				phm.EXPECT().Get(connID)
   632  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
   633  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   634  				done := make(chan struct{})
   635  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) (int, error) { close(done); return 0, nil })
   636  				Expect(serv.handlePacketImpl(p)).To(BeTrue())
   637  				Expect(createdConn).To(BeFalse())
   638  				Eventually(done).Should(BeClosed())
   639  			})
   640  
   641  			It("limits the number of unvalidated handshakes", func() {
   642  				const limit = 3
   643  				serv.maxNumHandshakesTotal = 10000
   644  				serv.maxNumHandshakesUnvalidated = limit
   645  
   646  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   647  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   648  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   649  
   650  				handshakeChan := make(chan struct{})
   651  				connChan := make(chan *MockQUICConn, 1)
   652  				var wg sync.WaitGroup
   653  				wg.Add(2 * limit)
   654  				serv.newConn = func(
   655  					_ sendConn,
   656  					runner connRunner,
   657  					_ protocol.ConnectionID,
   658  					_ *protocol.ConnectionID,
   659  					_ protocol.ConnectionID,
   660  					_ protocol.ConnectionID,
   661  					_ protocol.ConnectionID,
   662  					_ ConnectionIDGenerator,
   663  					_ protocol.StatelessResetToken,
   664  					_ *Config,
   665  					_ *tls.Config,
   666  					_ *handshake.TokenGenerator,
   667  					_ bool,
   668  					_ *logging.ConnectionTracer,
   669  					_ uint64,
   670  					_ utils.Logger,
   671  					_ protocol.Version,
   672  				) quicConn {
   673  					conn := <-connChan
   674  					conn.EXPECT().handlePacket(gomock.Any())
   675  					conn.EXPECT().run()
   676  					conn.EXPECT().Context().Return(context.Background())
   677  					conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil })
   678  					return conn
   679  				}
   680  
   681  				// Initiate the maximum number of allowed connection attempts.
   682  				for i := 0; i < limit; i++ {
   683  					conn := NewMockQUICConn(mockCtrl)
   684  					connChan <- conn
   685  					serv.handlePacket(getInitialWithRandomDestConnID())
   686  				}
   687  
   688  				// Now initiate another connection attempt.
   689  				p := getInitialWithRandomDestConnID()
   690  				done := make(chan struct{})
   691  				tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   692  					defer GinkgoRecover()
   693  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   694  				})
   695  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   696  					defer GinkgoRecover()
   697  					defer close(done)
   698  					hdr, _, _, err := wire.ParsePacket(b)
   699  					Expect(err).ToNot(HaveOccurred())
   700  					Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
   701  					return len(b), nil
   702  				})
   703  				serv.handlePacket(p)
   704  				Eventually(done).Should(BeClosed())
   705  
   706  				close(handshakeChan)
   707  				for i := 0; i < limit; i++ {
   708  					_, err := serv.Accept(context.Background())
   709  					Expect(err).ToNot(HaveOccurred())
   710  				}
   711  				for i := 0; i < limit; i++ {
   712  					conn := NewMockQUICConn(mockCtrl)
   713  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
   714  					connChan <- conn
   715  					serv.handlePacket(getInitialWithRandomDestConnID())
   716  				}
   717  				wg.Wait()
   718  			})
   719  
   720  			It("limits the number of total handshakes", func() {
   721  				const limit = 3
   722  				serv.maxNumHandshakesTotal = limit
   723  				serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
   724  
   725  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   726  				phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
   727  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
   728  
   729  				handshakeChan := make(chan struct{})
   730  				connChan := make(chan *MockQUICConn, 1)
   731  				serv.newConn = func(
   732  					_ sendConn,
   733  					runner connRunner,
   734  					_ protocol.ConnectionID,
   735  					_ *protocol.ConnectionID,
   736  					_ protocol.ConnectionID,
   737  					_ protocol.ConnectionID,
   738  					_ protocol.ConnectionID,
   739  					_ ConnectionIDGenerator,
   740  					_ protocol.StatelessResetToken,
   741  					_ *Config,
   742  					_ *tls.Config,
   743  					_ *handshake.TokenGenerator,
   744  					_ bool,
   745  					_ *logging.ConnectionTracer,
   746  					_ uint64,
   747  					_ utils.Logger,
   748  					_ protocol.Version,
   749  				) quicConn {
   750  					conn := <-connChan
   751  					conn.EXPECT().handlePacket(gomock.Any())
   752  					conn.EXPECT().run()
   753  					conn.EXPECT().Context().Return(context.Background())
   754  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   755  					return conn
   756  				}
   757  
   758  				for i := 0; i < limit; i++ {
   759  					conn := NewMockQUICConn(mockCtrl)
   760  					connChan <- conn
   761  					serv.handlePacket(getInitialWithRandomDestConnID())
   762  				}
   763  
   764  				p := getInitialWithRandomDestConnID()
   765  				done := make(chan struct{})
   766  				tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   767  					defer GinkgoRecover()
   768  					hdr, _, _, err := wire.ParsePacket(p.data)
   769  					Expect(err).ToNot(HaveOccurred())
   770  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   771  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   772  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   773  					Expect(frames).To(HaveLen(1))
   774  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   775  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   776  					Expect(ccf.IsApplicationError).To(BeFalse())
   777  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
   778  				})
   779  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   780  					defer GinkgoRecover()
   781  					defer close(done)
   782  					hdr, _, _, err := wire.ParsePacket(p.data)
   783  					Expect(err).ToNot(HaveOccurred())
   784  					checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
   785  					return len(b), nil
   786  				})
   787  				serv.handlePacket(p)
   788  				Eventually(done).Should(BeClosed())
   789  
   790  				close(handshakeChan)
   791  				for i := 0; i < limit; i++ {
   792  					_, err := serv.Accept(context.Background())
   793  					Expect(err).ToNot(HaveOccurred())
   794  				}
   795  				// make sure we can enqueue and accept more connections after that
   796  				for i := 0; i < limit; i++ {
   797  					conn := NewMockQUICConn(mockCtrl)
   798  					conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
   799  					connChan <- conn
   800  					serv.handlePacket(getInitialWithRandomDestConnID())
   801  				}
   802  				for i := 0; i < limit; i++ {
   803  					_, err := serv.Accept(context.Background())
   804  					Expect(err).ToNot(HaveOccurred())
   805  				}
   806  			})
   807  		})
   808  
   809  		Context("token validation", func() {
   810  			It("decodes the token from the token field", func() {
   811  				raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
   812  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   813  				Expect(err).ToNot(HaveOccurred())
   814  				packet := getPacket(&wire.Header{
   815  					Type:    protocol.PacketTypeInitial,
   816  					Token:   token,
   817  					Version: serv.config.Versions[0],
   818  				}, make([]byte, protocol.MinInitialPacketSize))
   819  				packet.remoteAddr = raddr
   820  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
   821  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   822  
   823  				done := make(chan struct{})
   824  				phm.EXPECT().Get(gomock.Any())
   825  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
   826  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
   827  					close(done)
   828  					return true
   829  				})
   830  				phm.EXPECT().Remove(gomock.Any()).AnyTimes()
   831  				serv.handlePacket(packet)
   832  				Eventually(done).Should(BeClosed())
   833  			})
   834  
   835  			It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
   836  				serv.maxNumHandshakesUnvalidated = 0
   837  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   838  				Expect(err).ToNot(HaveOccurred())
   839  				hdr := &wire.Header{
   840  					Type:             protocol.PacketTypeInitial,
   841  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   842  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   843  					Token:            token,
   844  					Version:          protocol.Version1,
   845  				}
   846  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   847  				packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
   848  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   849  				packet.remoteAddr = raddr
   850  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   851  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   852  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   853  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   854  					Expect(frames).To(HaveLen(1))
   855  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   856  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   857  					Expect(ccf.IsApplicationError).To(BeFalse())
   858  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   859  				})
   860  				done := make(chan struct{})
   861  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   862  					defer close(done)
   863  					checkConnectionCloseError(b, hdr, qerr.InvalidToken)
   864  					return len(b), nil
   865  				})
   866  				phm.EXPECT().Get(gomock.Any())
   867  				serv.handlePacket(packet)
   868  				Eventually(done).Should(BeClosed())
   869  			})
   870  
   871  			It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
   872  				serv.maxNumHandshakesUnvalidated = 0
   873  				serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
   874  				Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
   875  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   876  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   877  				Expect(err).ToNot(HaveOccurred())
   878  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   879  				hdr := &wire.Header{
   880  					Type:             protocol.PacketTypeInitial,
   881  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   882  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   883  					Token:            token,
   884  					Version:          protocol.Version1,
   885  				}
   886  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   887  				packet.remoteAddr = raddr
   888  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   889  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   890  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   891  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   892  					Expect(frames).To(HaveLen(1))
   893  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   894  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   895  					Expect(ccf.IsApplicationError).To(BeFalse())
   896  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   897  				})
   898  				done := make(chan struct{})
   899  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   900  					defer close(done)
   901  					checkConnectionCloseError(b, hdr, qerr.InvalidToken)
   902  					return len(b), nil
   903  				})
   904  				phm.EXPECT().Get(gomock.Any())
   905  				serv.handlePacket(packet)
   906  				Eventually(done).Should(BeClosed())
   907  			})
   908  
   909  			It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
   910  				serv.maxNumHandshakesUnvalidated = 0
   911  				token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
   912  				Expect(err).ToNot(HaveOccurred())
   913  				hdr := &wire.Header{
   914  					Type:             protocol.PacketTypeInitial,
   915  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   916  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   917  					Token:            token,
   918  					Version:          protocol.Version1,
   919  				}
   920  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   921  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   922  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   923  				packet.remoteAddr = raddr
   924  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   925  				done := make(chan struct{})
   926  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   927  					defer close(done)
   928  					replyHdr := parseHeader(b)
   929  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   930  					return len(b), nil
   931  				})
   932  				phm.EXPECT().Get(gomock.Any())
   933  				serv.handlePacket(packet)
   934  				// make sure there are no Write calls on the packet conn
   935  				Eventually(done).Should(BeClosed())
   936  			})
   937  
   938  			It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
   939  				serv.maxNumHandshakesUnvalidated = 0
   940  				serv.maxTokenAge = time.Millisecond
   941  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   942  				token, err := serv.tokenGenerator.NewToken(raddr)
   943  				Expect(err).ToNot(HaveOccurred())
   944  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   945  				hdr := &wire.Header{
   946  					Type:             protocol.PacketTypeInitial,
   947  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   948  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   949  					Token:            token,
   950  					Version:          protocol.Version1,
   951  				}
   952  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   953  				packet.remoteAddr = raddr
   954  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   955  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   956  				})
   957  				done := make(chan struct{})
   958  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   959  					defer close(done)
   960  					return len(b), nil
   961  				})
   962  				phm.EXPECT().Get(gomock.Any())
   963  				serv.handlePacket(packet)
   964  				Eventually(done).Should(BeClosed())
   965  			})
   966  
   967  			It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
   968  				serv.maxNumHandshakesUnvalidated = 0
   969  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   970  				Expect(err).ToNot(HaveOccurred())
   971  				hdr := &wire.Header{
   972  					Type:             protocol.PacketTypeInitial,
   973  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   974  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   975  					Token:            token,
   976  					Version:          protocol.Version1,
   977  				}
   978  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   979  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   980  				packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   981  				done := make(chan struct{})
   982  				tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
   983  				phm.EXPECT().Get(gomock.Any())
   984  				serv.handlePacket(packet)
   985  				// make sure there are no Write calls on the packet conn
   986  				time.Sleep(50 * time.Millisecond)
   987  				Eventually(done).Should(BeClosed())
   988  			})
   989  		})
   990  
   991  		Context("accepting connections", func() {
   992  			It("returns Accept when closed", func() {
   993  				done := make(chan struct{})
   994  				go func() {
   995  					defer GinkgoRecover()
   996  					_, err := serv.Accept(context.Background())
   997  					Expect(err).To(MatchError(ErrServerClosed))
   998  					close(done)
   999  				}()
  1000  
  1001  				serv.Close()
  1002  				Eventually(done).Should(BeClosed())
  1003  			})
  1004  
  1005  			It("returns immediately, if an error occurred before", func() {
  1006  				serv.Close()
  1007  				for i := 0; i < 3; i++ {
  1008  					_, err := serv.Accept(context.Background())
  1009  					Expect(err).To(MatchError(ErrServerClosed))
  1010  				}
  1011  			})
  1012  
  1013  			It("closes connection that are still handshaking after Close", func() {
  1014  				serv.Close()
  1015  
  1016  				destroyed := make(chan struct{})
  1017  				serv.newConn = func(
  1018  					_ sendConn,
  1019  					_ connRunner,
  1020  					_ protocol.ConnectionID,
  1021  					_ *protocol.ConnectionID,
  1022  					_ protocol.ConnectionID,
  1023  					_ protocol.ConnectionID,
  1024  					_ protocol.ConnectionID,
  1025  					_ ConnectionIDGenerator,
  1026  					_ protocol.StatelessResetToken,
  1027  					conf *Config,
  1028  					_ *tls.Config,
  1029  					_ *handshake.TokenGenerator,
  1030  					_ bool,
  1031  					_ *logging.ConnectionTracer,
  1032  					_ uint64,
  1033  					_ utils.Logger,
  1034  					_ protocol.Version,
  1035  				) quicConn {
  1036  					conn := NewMockQUICConn(mockCtrl)
  1037  					conn.EXPECT().handlePacket(gomock.Any())
  1038  					conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
  1039  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
  1040  					conn.EXPECT().run().MaxTimes(1)
  1041  					conn.EXPECT().Context().Return(context.Background())
  1042  					return conn
  1043  				}
  1044  				phm.EXPECT().Get(gomock.Any())
  1045  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1046  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1047  				serv.handleInitialImpl(
  1048  					receivedPacket{buffer: getPacketBuffer()},
  1049  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1050  				)
  1051  				Eventually(destroyed).Should(BeClosed())
  1052  			})
  1053  
  1054  			It("returns when the context is canceled", func() {
  1055  				ctx, cancel := context.WithCancel(context.Background())
  1056  				done := make(chan struct{})
  1057  				go func() {
  1058  					defer GinkgoRecover()
  1059  					_, err := serv.Accept(ctx)
  1060  					Expect(err).To(MatchError("context canceled"))
  1061  					close(done)
  1062  				}()
  1063  
  1064  				Consistently(done).ShouldNot(BeClosed())
  1065  				cancel()
  1066  				Eventually(done).Should(BeClosed())
  1067  			})
  1068  
  1069  			It("uses the config returned by GetConfigClient", func() {
  1070  				conn := NewMockQUICConn(mockCtrl)
  1071  
  1072  				conf := &Config{MaxIncomingStreams: 1234}
  1073  				serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
  1074  				done := make(chan struct{})
  1075  				go func() {
  1076  					defer GinkgoRecover()
  1077  					s, err := serv.Accept(context.Background())
  1078  					Expect(err).ToNot(HaveOccurred())
  1079  					Expect(s).To(Equal(conn))
  1080  					close(done)
  1081  				}()
  1082  
  1083  				handshakeChan := make(chan struct{})
  1084  				serv.newConn = func(
  1085  					_ sendConn,
  1086  					_ connRunner,
  1087  					_ protocol.ConnectionID,
  1088  					_ *protocol.ConnectionID,
  1089  					_ protocol.ConnectionID,
  1090  					_ protocol.ConnectionID,
  1091  					_ protocol.ConnectionID,
  1092  					_ ConnectionIDGenerator,
  1093  					_ protocol.StatelessResetToken,
  1094  					conf *Config,
  1095  					_ *tls.Config,
  1096  					_ *handshake.TokenGenerator,
  1097  					_ bool,
  1098  					_ *logging.ConnectionTracer,
  1099  					_ uint64,
  1100  					_ utils.Logger,
  1101  					_ protocol.Version,
  1102  				) quicConn {
  1103  					Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
  1104  					conn.EXPECT().handlePacket(gomock.Any())
  1105  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1106  					conn.EXPECT().run()
  1107  					conn.EXPECT().Context().Return(context.Background())
  1108  					return conn
  1109  				}
  1110  				phm.EXPECT().Get(gomock.Any())
  1111  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1112  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1113  				serv.handleInitialImpl(
  1114  					receivedPacket{buffer: getPacketBuffer()},
  1115  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1116  				)
  1117  				Consistently(done).ShouldNot(BeClosed())
  1118  				close(handshakeChan) // complete the handshake
  1119  				Eventually(done).Should(BeClosed())
  1120  			})
  1121  
  1122  			It("rejects a connection attempt when GetConfigClient returns an error", func() {
  1123  				serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
  1124  
  1125  				phm.EXPECT().Get(gomock.Any())
  1126  				done := make(chan struct{})
  1127  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
  1128  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1129  					defer close(done)
  1130  					rejectHdr := parseHeader(b)
  1131  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1132  					return len(b), nil
  1133  				})
  1134  				serv.handleInitialImpl(
  1135  					receivedPacket{buffer: getPacketBuffer()},
  1136  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
  1137  				)
  1138  				Eventually(done).Should(BeClosed())
  1139  			})
  1140  
  1141  			It("accepts new connections when the handshake completes", func() {
  1142  				conn := NewMockQUICConn(mockCtrl)
  1143  
  1144  				done := make(chan struct{})
  1145  				go func() {
  1146  					defer GinkgoRecover()
  1147  					s, err := serv.Accept(context.Background())
  1148  					Expect(err).ToNot(HaveOccurred())
  1149  					Expect(s).To(Equal(conn))
  1150  					close(done)
  1151  				}()
  1152  
  1153  				handshakeChan := make(chan struct{})
  1154  				serv.newConn = func(
  1155  					_ sendConn,
  1156  					runner connRunner,
  1157  					_ protocol.ConnectionID,
  1158  					_ *protocol.ConnectionID,
  1159  					_ protocol.ConnectionID,
  1160  					_ protocol.ConnectionID,
  1161  					_ protocol.ConnectionID,
  1162  					_ ConnectionIDGenerator,
  1163  					_ protocol.StatelessResetToken,
  1164  					_ *Config,
  1165  					_ *tls.Config,
  1166  					_ *handshake.TokenGenerator,
  1167  					_ bool,
  1168  					_ *logging.ConnectionTracer,
  1169  					_ uint64,
  1170  					_ utils.Logger,
  1171  					_ protocol.Version,
  1172  				) quicConn {
  1173  					conn.EXPECT().handlePacket(gomock.Any())
  1174  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1175  					conn.EXPECT().run()
  1176  					conn.EXPECT().Context().Return(context.Background())
  1177  					return conn
  1178  				}
  1179  				phm.EXPECT().Get(gomock.Any())
  1180  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1181  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1182  				serv.handleInitialImpl(
  1183  					receivedPacket{buffer: getPacketBuffer()},
  1184  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1185  				)
  1186  				Consistently(done).ShouldNot(BeClosed())
  1187  				close(handshakeChan) // complete the handshake
  1188  				Eventually(done).Should(BeClosed())
  1189  			})
  1190  		})
  1191  	})
  1192  
  1193  	Context("server accepting connections that haven't completed the handshake", func() {
  1194  		var (
  1195  			serv *EarlyListener
  1196  			phm  *MockPacketHandlerManager
  1197  		)
  1198  
  1199  		BeforeEach(func() {
  1200  			var err error
  1201  			serv, err = ListenEarly(conn, tlsConf, nil)
  1202  			Expect(err).ToNot(HaveOccurred())
  1203  			phm = NewMockPacketHandlerManager(mockCtrl)
  1204  			serv.baseServer.connHandler = phm
  1205  		})
  1206  
  1207  		AfterEach(func() {
  1208  			serv.Close()
  1209  		})
  1210  
  1211  		It("accepts new connections when they become ready", func() {
  1212  			conn := NewMockQUICConn(mockCtrl)
  1213  
  1214  			done := make(chan struct{})
  1215  			go func() {
  1216  				defer GinkgoRecover()
  1217  				s, err := serv.Accept(context.Background())
  1218  				Expect(err).ToNot(HaveOccurred())
  1219  				Expect(s).To(Equal(conn))
  1220  				close(done)
  1221  			}()
  1222  
  1223  			ready := make(chan struct{})
  1224  			serv.baseServer.newConn = func(
  1225  				_ sendConn,
  1226  				runner connRunner,
  1227  				_ protocol.ConnectionID,
  1228  				_ *protocol.ConnectionID,
  1229  				_ protocol.ConnectionID,
  1230  				_ protocol.ConnectionID,
  1231  				_ protocol.ConnectionID,
  1232  				_ ConnectionIDGenerator,
  1233  				_ protocol.StatelessResetToken,
  1234  				_ *Config,
  1235  				_ *tls.Config,
  1236  				_ *handshake.TokenGenerator,
  1237  				_ bool,
  1238  				_ *logging.ConnectionTracer,
  1239  				_ uint64,
  1240  				_ utils.Logger,
  1241  				_ protocol.Version,
  1242  			) quicConn {
  1243  				conn.EXPECT().handlePacket(gomock.Any())
  1244  				conn.EXPECT().run()
  1245  				conn.EXPECT().earlyConnReady().Return(ready)
  1246  				conn.EXPECT().Context().Return(context.Background())
  1247  				return conn
  1248  			}
  1249  			phm.EXPECT().Get(gomock.Any())
  1250  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1251  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1252  			serv.baseServer.handleInitialImpl(
  1253  				receivedPacket{buffer: getPacketBuffer()},
  1254  				&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1255  			)
  1256  			Consistently(done).ShouldNot(BeClosed())
  1257  			close(ready)
  1258  			Eventually(done).Should(BeClosed())
  1259  		})
  1260  
  1261  		It("rejects new connection attempts if the accept queue is full", func() {
  1262  			connChan := make(chan *MockQUICConn, 1)
  1263  			var wg sync.WaitGroup // to make sure the test fully completes
  1264  			wg.Add(protocol.MaxAcceptQueueSize + 1)
  1265  			serv.baseServer.newConn = func(
  1266  				_ sendConn,
  1267  				runner connRunner,
  1268  				_ protocol.ConnectionID,
  1269  				_ *protocol.ConnectionID,
  1270  				_ protocol.ConnectionID,
  1271  				_ protocol.ConnectionID,
  1272  				_ protocol.ConnectionID,
  1273  				_ ConnectionIDGenerator,
  1274  				_ protocol.StatelessResetToken,
  1275  				_ *Config,
  1276  				_ *tls.Config,
  1277  				_ *handshake.TokenGenerator,
  1278  				_ bool,
  1279  				_ *logging.ConnectionTracer,
  1280  				_ uint64,
  1281  				_ utils.Logger,
  1282  				_ protocol.Version,
  1283  			) quicConn {
  1284  				defer wg.Done()
  1285  				ready := make(chan struct{})
  1286  				close(ready)
  1287  				conn := <-connChan
  1288  				conn.EXPECT().handlePacket(gomock.Any())
  1289  				conn.EXPECT().run()
  1290  				conn.EXPECT().earlyConnReady().Return(ready)
  1291  				conn.EXPECT().Context().Return(context.Background())
  1292  				return conn
  1293  			}
  1294  
  1295  			phm.EXPECT().Get(gomock.Any()).AnyTimes()
  1296  			phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
  1297  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
  1298  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
  1299  				conn := NewMockQUICConn(mockCtrl)
  1300  				connChan <- conn
  1301  				serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1302  			}
  1303  
  1304  			Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
  1305  
  1306  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1307  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1308  			conn := NewMockQUICConn(mockCtrl)
  1309  			conn.EXPECT().closeWithTransportError(ConnectionRefused)
  1310  			connChan <- conn
  1311  			serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1312  			wg.Wait()
  1313  		})
  1314  
  1315  		It("doesn't accept new connections if they were closed in the mean time", func() {
  1316  			p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
  1317  			ctx, cancel := context.WithCancel(context.Background())
  1318  			connCreated := make(chan struct{})
  1319  			conn := NewMockQUICConn(mockCtrl)
  1320  			serv.baseServer.newConn = func(
  1321  				_ sendConn,
  1322  				runner connRunner,
  1323  				_ protocol.ConnectionID,
  1324  				_ *protocol.ConnectionID,
  1325  				_ protocol.ConnectionID,
  1326  				_ protocol.ConnectionID,
  1327  				_ protocol.ConnectionID,
  1328  				_ ConnectionIDGenerator,
  1329  				_ protocol.StatelessResetToken,
  1330  				_ *Config,
  1331  				_ *tls.Config,
  1332  				_ *handshake.TokenGenerator,
  1333  				_ bool,
  1334  				_ *logging.ConnectionTracer,
  1335  				_ uint64,
  1336  				_ utils.Logger,
  1337  				_ protocol.Version,
  1338  			) quicConn {
  1339  				conn.EXPECT().handlePacket(p)
  1340  				conn.EXPECT().run()
  1341  				conn.EXPECT().earlyConnReady()
  1342  				conn.EXPECT().Context().Return(ctx)
  1343  				close(connCreated)
  1344  				return conn
  1345  			}
  1346  
  1347  			phm.EXPECT().Get(gomock.Any())
  1348  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1349  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1350  			serv.baseServer.handlePacket(p)
  1351  			// make sure there are no Write calls on the packet conn
  1352  			time.Sleep(50 * time.Millisecond)
  1353  			Eventually(connCreated).Should(BeClosed())
  1354  			cancel()
  1355  			time.Sleep(scaleDuration(200 * time.Millisecond))
  1356  
  1357  			done := make(chan struct{})
  1358  			go func() {
  1359  				defer GinkgoRecover()
  1360  				serv.Accept(context.Background())
  1361  				close(done)
  1362  			}()
  1363  			Consistently(done).ShouldNot(BeClosed())
  1364  
  1365  			// make the go routine return
  1366  			conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
  1367  			Expect(serv.Close()).To(Succeed())
  1368  			Eventually(done).Should(BeClosed())
  1369  		})
  1370  	})
  1371  
  1372  	Context("0-RTT", func() {
  1373  		var (
  1374  			tr     *Transport
  1375  			serv   *baseServer
  1376  			phm    *MockPacketHandlerManager
  1377  			tracer *mocklogging.MockTracer
  1378  		)
  1379  
  1380  		BeforeEach(func() {
  1381  			var t *logging.Tracer
  1382  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
  1383  			tr = &Transport{Conn: conn, Tracer: t}
  1384  			ln, err := tr.ListenEarly(tlsConf, nil)
  1385  			Expect(err).ToNot(HaveOccurred())
  1386  			phm = NewMockPacketHandlerManager(mockCtrl)
  1387  			serv = ln.baseServer
  1388  			serv.connHandler = phm
  1389  		})
  1390  
  1391  		AfterEach(func() {
  1392  			tracer.EXPECT().Close()
  1393  			Expect(tr.Close()).To(Succeed())
  1394  		})
  1395  
  1396  		It("passes packets to existing connections", func() {
  1397  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1398  			p := getPacket(&wire.Header{
  1399  				Type:             protocol.PacketType0RTT,
  1400  				DestConnectionID: connID,
  1401  				Version:          serv.config.Versions[0],
  1402  			}, make([]byte, 100))
  1403  			conn := NewMockPacketHandler(mockCtrl)
  1404  			phm.EXPECT().Get(connID).Return(conn, true)
  1405  			handled := make(chan struct{})
  1406  			conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
  1407  			serv.handlePacket(p)
  1408  			Eventually(handled).Should(BeClosed())
  1409  		})
  1410  
  1411  		It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
  1412  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1413  
  1414  			var zeroRTTPackets []receivedPacket
  1415  
  1416  			for i := 0; i < protocol.Max0RTTQueueLen; i++ {
  1417  				p := getPacket(&wire.Header{
  1418  					Type:             protocol.PacketType0RTT,
  1419  					DestConnectionID: connID,
  1420  					Version:          serv.config.Versions[0],
  1421  				}, make([]byte, 100+i))
  1422  				phm.EXPECT().Get(connID)
  1423  				serv.handlePacket(p)
  1424  				zeroRTTPackets = append(zeroRTTPackets, p)
  1425  			}
  1426  
  1427  			// send one more packet, this one should be dropped
  1428  			p := getPacket(&wire.Header{
  1429  				Type:             protocol.PacketType0RTT,
  1430  				DestConnectionID: connID,
  1431  				Version:          serv.config.Versions[0],
  1432  			}, make([]byte, 200))
  1433  			phm.EXPECT().Get(connID)
  1434  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
  1435  			serv.handlePacket(p)
  1436  
  1437  			initial := getPacket(&wire.Header{
  1438  				Type:             protocol.PacketTypeInitial,
  1439  				DestConnectionID: connID,
  1440  				Version:          serv.config.Versions[0],
  1441  			}, make([]byte, protocol.MinInitialPacketSize))
  1442  			called := make(chan struct{})
  1443  			serv.newConn = func(
  1444  				_ sendConn,
  1445  				_ connRunner,
  1446  				_ protocol.ConnectionID,
  1447  				_ *protocol.ConnectionID,
  1448  				_ protocol.ConnectionID,
  1449  				_ protocol.ConnectionID,
  1450  				_ protocol.ConnectionID,
  1451  				_ ConnectionIDGenerator,
  1452  				_ protocol.StatelessResetToken,
  1453  				_ *Config,
  1454  				_ *tls.Config,
  1455  				_ *handshake.TokenGenerator,
  1456  				_ bool,
  1457  				_ *logging.ConnectionTracer,
  1458  				_ uint64,
  1459  				_ utils.Logger,
  1460  				_ protocol.Version,
  1461  			) quicConn {
  1462  				conn := NewMockQUICConn(mockCtrl)
  1463  				var calls []any
  1464  				calls = append(calls, conn.EXPECT().handlePacket(initial))
  1465  				for _, p := range zeroRTTPackets {
  1466  					calls = append(calls, conn.EXPECT().handlePacket(p))
  1467  				}
  1468  				gomock.InOrder(calls...)
  1469  				conn.EXPECT().run()
  1470  				conn.EXPECT().earlyConnReady()
  1471  				conn.EXPECT().Context().Return(context.Background())
  1472  				close(called)
  1473  				// shutdown
  1474  				conn.EXPECT().closeWithTransportError(gomock.Any())
  1475  				return conn
  1476  			}
  1477  
  1478  			phm.EXPECT().Get(connID)
  1479  			phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1480  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
  1481  			serv.handlePacket(initial)
  1482  			Eventually(called).Should(BeClosed())
  1483  		})
  1484  
  1485  		It("limits the number of queues", func() {
  1486  			for i := 0; i < protocol.Max0RTTQueues; i++ {
  1487  				b := make([]byte, 16)
  1488  				rand.Read(b)
  1489  				connID := protocol.ParseConnectionID(b)
  1490  				p := getPacket(&wire.Header{
  1491  					Type:             protocol.PacketType0RTT,
  1492  					DestConnectionID: connID,
  1493  					Version:          serv.config.Versions[0],
  1494  				}, make([]byte, 100+i))
  1495  				phm.EXPECT().Get(connID)
  1496  				serv.handlePacket(p)
  1497  			}
  1498  
  1499  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1500  			p := getPacket(&wire.Header{
  1501  				Type:             protocol.PacketType0RTT,
  1502  				DestConnectionID: connID,
  1503  				Version:          serv.config.Versions[0],
  1504  			}, make([]byte, 200))
  1505  			phm.EXPECT().Get(connID)
  1506  			dropped := make(chan struct{})
  1507  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1508  				close(dropped)
  1509  			})
  1510  			serv.handlePacket(p)
  1511  			Eventually(dropped).Should(BeClosed())
  1512  		})
  1513  
  1514  		It("drops queues after a while", func() {
  1515  			now := time.Now()
  1516  
  1517  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1518  			p := getPacket(&wire.Header{
  1519  				Type:             protocol.PacketType0RTT,
  1520  				DestConnectionID: connID,
  1521  				Version:          serv.config.Versions[0],
  1522  			}, make([]byte, 200))
  1523  			p.rcvTime = now
  1524  
  1525  			connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
  1526  			p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
  1527  			p2 := getPacket(&wire.Header{
  1528  				Type:             protocol.PacketType0RTT,
  1529  				DestConnectionID: connID2,
  1530  				Version:          serv.config.Versions[0],
  1531  			}, make([]byte, 300))
  1532  			p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
  1533  
  1534  			dropped1 := make(chan struct{})
  1535  			dropped2 := make(chan struct{})
  1536  			// need to register the call before handling the packet to avoid race condition
  1537  			gomock.InOrder(
  1538  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1539  					close(dropped1)
  1540  				}),
  1541  				tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1542  					close(dropped2)
  1543  				}),
  1544  			)
  1545  
  1546  			phm.EXPECT().Get(connID)
  1547  			serv.handlePacket(p)
  1548  
  1549  			// There's no cleanup Go routine.
  1550  			// Cleanup is triggered when new packets are received.
  1551  
  1552  			phm.EXPECT().Get(connID2)
  1553  			serv.handlePacket(p2)
  1554  			// make sure no cleanup is executed
  1555  			Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
  1556  
  1557  			// There's no cleanup Go routine.
  1558  			// Cleanup is triggered when new packets are received.
  1559  			connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
  1560  			p3 := getPacket(&wire.Header{
  1561  				Type:             protocol.PacketType0RTT,
  1562  				DestConnectionID: connID3,
  1563  				Version:          serv.config.Versions[0],
  1564  			}, make([]byte, 200))
  1565  			p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1566  			phm.EXPECT().Get(connID3)
  1567  			serv.handlePacket(p3)
  1568  			Eventually(dropped1).Should(BeClosed())
  1569  			Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
  1570  
  1571  			// make sure the second packet is also cleaned up
  1572  			connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
  1573  			p4 := getPacket(&wire.Header{
  1574  				Type:             protocol.PacketType0RTT,
  1575  				DestConnectionID: connID4,
  1576  				Version:          serv.config.Versions[0],
  1577  			}, make([]byte, 200))
  1578  			p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1579  			phm.EXPECT().Get(connID4)
  1580  			serv.handlePacket(p4)
  1581  			Eventually(dropped2).Should(BeClosed())
  1582  		})
  1583  	})
  1584  })