github.com/MerlinKodo/quic-go@v0.39.2/server_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"reflect"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/MerlinKodo/quic-go/internal/handshake"
    15  	mocklogging "github.com/MerlinKodo/quic-go/internal/mocks/logging"
    16  	"github.com/MerlinKodo/quic-go/internal/protocol"
    17  	"github.com/MerlinKodo/quic-go/internal/qerr"
    18  	"github.com/MerlinKodo/quic-go/internal/testdata"
    19  	"github.com/MerlinKodo/quic-go/internal/utils"
    20  	"github.com/MerlinKodo/quic-go/internal/wire"
    21  	"github.com/MerlinKodo/quic-go/logging"
    22  
    23  	. "github.com/onsi/ginkgo/v2"
    24  	. "github.com/onsi/gomega"
    25  	"go.uber.org/mock/gomock"
    26  )
    27  
    28  var _ = Describe("Server", func() {
    29  	var (
    30  		conn    *MockPacketConn
    31  		tlsConf *tls.Config
    32  	)
    33  
    34  	getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
    35  		buf := getPacketBuffer()
    36  		hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
    37  		var err error
    38  		buf.Data, err = (&wire.ExtendedHeader{
    39  			Header:          *hdr,
    40  			PacketNumber:    0x42,
    41  			PacketNumberLen: protocol.PacketNumberLen4,
    42  		}).Append(buf.Data, protocol.Version1)
    43  		Expect(err).ToNot(HaveOccurred())
    44  		n := len(buf.Data)
    45  		buf.Data = append(buf.Data, p...)
    46  		data := buf.Data
    47  		sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
    48  		_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
    49  		data = data[:len(data)+16]
    50  		sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
    51  		return receivedPacket{
    52  			remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
    53  			data:       data,
    54  			buffer:     buf,
    55  		}
    56  	}
    57  
    58  	getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
    59  		senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
    60  		hdr := &wire.Header{
    61  			Type:             protocol.PacketTypeInitial,
    62  			SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
    63  			DestConnectionID: destConnID,
    64  			Version:          protocol.Version1,
    65  		}
    66  		p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
    67  		p.buffer = getPacketBuffer()
    68  		p.remoteAddr = senderAddr
    69  		return p
    70  	}
    71  
    72  	getInitialWithRandomDestConnID := func() receivedPacket {
    73  		b := make([]byte, 10)
    74  		_, err := rand.Read(b)
    75  		Expect(err).ToNot(HaveOccurred())
    76  
    77  		return getInitial(protocol.ParseConnectionID(b))
    78  	}
    79  
    80  	parseHeader := func(data []byte) *wire.Header {
    81  		hdr, _, _, err := wire.ParsePacket(data)
    82  		Expect(err).ToNot(HaveOccurred())
    83  		return hdr
    84  	}
    85  
    86  	BeforeEach(func() {
    87  		conn = NewMockPacketConn(mockCtrl)
    88  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
    89  		wait := make(chan struct{})
    90  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
    91  			<-wait
    92  			return 0, nil, errors.New("done")
    93  		}).MaxTimes(1)
    94  		conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
    95  			close(wait)
    96  			conn.EXPECT().SetReadDeadline(time.Time{})
    97  		}).MaxTimes(1)
    98  		tlsConf = testdata.GetTLSConfig()
    99  		tlsConf.NextProtos = []string{"proto1"}
   100  	})
   101  
   102  	It("errors when no tls.Config is given", func() {
   103  		_, err := ListenAddr("localhost:0", nil, nil)
   104  		Expect(err).To(HaveOccurred())
   105  		Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
   106  	})
   107  
   108  	It("errors when the Config contains an invalid version", func() {
   109  		version := protocol.VersionNumber(0x1234)
   110  		_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
   111  		Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
   112  	})
   113  
   114  	It("fills in default values if options are not set in the Config", func() {
   115  		ln, err := Listen(conn, tlsConf, &Config{})
   116  		Expect(err).ToNot(HaveOccurred())
   117  		server := ln.baseServer
   118  		Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
   119  		Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   120  		Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
   121  		Expect(server.config.RequireAddressValidation).ToNot(BeNil())
   122  		Expect(server.config.KeepAlivePeriod).To(BeZero())
   123  		// stop the listener
   124  		Expect(ln.Close()).To(Succeed())
   125  	})
   126  
   127  	It("setups with the right values", func() {
   128  		supportedVersions := []protocol.VersionNumber{protocol.Version1}
   129  		requireAddrVal := func(net.Addr) bool { return true }
   130  		config := Config{
   131  			Versions:                 supportedVersions,
   132  			HandshakeIdleTimeout:     1337 * time.Hour,
   133  			MaxIdleTimeout:           42 * time.Minute,
   134  			KeepAlivePeriod:          5 * time.Second,
   135  			RequireAddressValidation: requireAddrVal,
   136  		}
   137  		ln, err := Listen(conn, tlsConf, &config)
   138  		Expect(err).ToNot(HaveOccurred())
   139  		server := ln.baseServer
   140  		Expect(server.connHandler).ToNot(BeNil())
   141  		Expect(server.config.Versions).To(Equal(supportedVersions))
   142  		Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
   143  		Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
   144  		Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
   145  		Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
   146  		// stop the listener
   147  		Expect(ln.Close()).To(Succeed())
   148  	})
   149  
   150  	It("listens on a given address", func() {
   151  		addr := "127.0.0.1:13579"
   152  		ln, err := ListenAddr(addr, tlsConf, &Config{})
   153  		Expect(err).ToNot(HaveOccurred())
   154  		Expect(ln.Addr().String()).To(Equal(addr))
   155  		// stop the listener
   156  		Expect(ln.Close()).To(Succeed())
   157  	})
   158  
   159  	It("errors if given an invalid address", func() {
   160  		addr := "127.0.0.1"
   161  		_, err := ListenAddr(addr, tlsConf, &Config{})
   162  		Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
   163  	})
   164  
   165  	It("errors if given an invalid address", func() {
   166  		addr := "1.1.1.1:1111"
   167  		_, err := ListenAddr(addr, tlsConf, &Config{})
   168  		Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
   169  	})
   170  
   171  	Context("server accepting connections that completed the handshake", func() {
   172  		var (
   173  			tr     *Transport
   174  			serv   *baseServer
   175  			phm    *MockPacketHandlerManager
   176  			tracer *mocklogging.MockTracer
   177  		)
   178  
   179  		BeforeEach(func() {
   180  			var t *logging.Tracer
   181  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
   182  			tr = &Transport{Conn: conn, Tracer: t}
   183  			ln, err := tr.Listen(tlsConf, nil)
   184  			Expect(err).ToNot(HaveOccurred())
   185  			serv = ln.baseServer
   186  			phm = NewMockPacketHandlerManager(mockCtrl)
   187  			serv.connHandler = phm
   188  		})
   189  
   190  		AfterEach(func() {
   191  			tr.Close()
   192  		})
   193  
   194  		Context("handling packets", func() {
   195  			It("drops Initial packets with a too short connection ID", func() {
   196  				p := getPacket(&wire.Header{
   197  					Type:             protocol.PacketTypeInitial,
   198  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
   199  					Version:          serv.config.Versions[0],
   200  				}, nil)
   201  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   202  				serv.handlePacket(p)
   203  				// make sure there are no Write calls on the packet conn
   204  				time.Sleep(50 * time.Millisecond)
   205  			})
   206  
   207  			It("drops too small Initial", func() {
   208  				p := getPacket(&wire.Header{
   209  					Type:             protocol.PacketTypeInitial,
   210  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
   211  					Version:          serv.config.Versions[0],
   212  				}, make([]byte, protocol.MinInitialPacketSize-100))
   213  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   214  				serv.handlePacket(p)
   215  				// make sure there are no Write calls on the packet conn
   216  				time.Sleep(50 * time.Millisecond)
   217  			})
   218  
   219  			It("drops non-Initial packets", func() {
   220  				p := getPacket(&wire.Header{
   221  					Type:    protocol.PacketTypeHandshake,
   222  					Version: serv.config.Versions[0],
   223  				}, []byte("invalid"))
   224  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
   225  				serv.handlePacket(p)
   226  				// make sure there are no Write calls on the packet conn
   227  				time.Sleep(50 * time.Millisecond)
   228  			})
   229  
   230  			It("passes packets to existing connections", func() {
   231  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
   232  				p := getPacket(&wire.Header{
   233  					Type:             protocol.PacketTypeInitial,
   234  					DestConnectionID: connID,
   235  					Version:          serv.config.Versions[0],
   236  				}, make([]byte, protocol.MinInitialPacketSize))
   237  				conn := NewMockPacketHandler(mockCtrl)
   238  				phm.EXPECT().Get(connID).Return(conn, true)
   239  				handled := make(chan struct{})
   240  				conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
   241  				serv.handlePacket(p)
   242  				Eventually(handled).Should(BeClosed())
   243  			})
   244  
   245  			It("creates a connection when the token is accepted", func() {
   246  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   247  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   248  				retryToken, err := serv.tokenGenerator.NewRetryToken(
   249  					raddr,
   250  					protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
   251  					protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
   252  				)
   253  				Expect(err).ToNot(HaveOccurred())
   254  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   255  				hdr := &wire.Header{
   256  					Type:             protocol.PacketTypeInitial,
   257  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   258  					DestConnectionID: connID,
   259  					Version:          protocol.Version1,
   260  					Token:            retryToken,
   261  				}
   262  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   263  				p.remoteAddr = raddr
   264  				run := make(chan struct{})
   265  				var token protocol.StatelessResetToken
   266  				rand.Read(token[:])
   267  
   268  				var newConnID protocol.ConnectionID
   269  
   270  				phm.EXPECT().Get(connID)
   271  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   272  					newConnID = c
   273  					phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
   274  						newConnID = c
   275  						return token
   276  					})
   277  					_, ok := fn()
   278  					return ok
   279  				})
   280  				conn := NewMockQUICConn(mockCtrl)
   281  				serv.newConn = func(
   282  					_ sendConn,
   283  					_ connRunner,
   284  					origDestConnID protocol.ConnectionID,
   285  					retrySrcConnID *protocol.ConnectionID,
   286  					clientDestConnID protocol.ConnectionID,
   287  					destConnID protocol.ConnectionID,
   288  					srcConnID protocol.ConnectionID,
   289  					_ ConnectionIDGenerator,
   290  					tokenP protocol.StatelessResetToken,
   291  					_ *Config,
   292  					_ *tls.Config,
   293  					_ *handshake.TokenGenerator,
   294  					_ bool,
   295  					_ *logging.ConnectionTracer,
   296  					_ uint64,
   297  					_ utils.Logger,
   298  					_ protocol.VersionNumber,
   299  				) quicConn {
   300  					Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
   301  					Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
   302  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   303  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   304  					// make sure we're using a server-generated connection ID
   305  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   306  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   307  					Expect(srcConnID).To(Equal(newConnID))
   308  					Expect(tokenP).To(Equal(token))
   309  					conn.EXPECT().handlePacket(p)
   310  					conn.EXPECT().run().Do(func() { close(run) })
   311  					conn.EXPECT().Context().Return(context.Background())
   312  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   313  					return conn
   314  				}
   315  
   316  				done := make(chan struct{})
   317  				go func() {
   318  					defer GinkgoRecover()
   319  					serv.handlePacket(p)
   320  					// the Handshake packet is written by the connection.
   321  					// Make sure there are no Write calls on the packet conn.
   322  					time.Sleep(50 * time.Millisecond)
   323  					close(done)
   324  				}()
   325  				// make sure we're using a server-generated connection ID
   326  				Eventually(run).Should(BeClosed())
   327  				Eventually(done).Should(BeClosed())
   328  			})
   329  
   330  			It("sends a Version Negotiation Packet for unsupported versions", func() {
   331  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   332  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   333  				packet := getPacket(&wire.Header{
   334  					Type:             protocol.PacketTypeHandshake,
   335  					SrcConnectionID:  srcConnID,
   336  					DestConnectionID: destConnID,
   337  					Version:          0x42,
   338  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   339  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   340  				packet.remoteAddr = raddr
   341  				tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber) {
   342  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   343  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   344  				})
   345  				done := make(chan struct{})
   346  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   347  					defer close(done)
   348  					Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
   349  					dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
   350  					Expect(err).ToNot(HaveOccurred())
   351  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   352  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   353  					Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
   354  					return len(b), nil
   355  				})
   356  				serv.handlePacket(packet)
   357  				Eventually(done).Should(BeClosed())
   358  			})
   359  
   360  			It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
   361  				serv.disableVersionNegotiation = true
   362  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   363  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   364  				packet := getPacket(&wire.Header{
   365  					Type:             protocol.PacketTypeHandshake,
   366  					SrcConnectionID:  srcConnID,
   367  					DestConnectionID: destConnID,
   368  					Version:          0x42,
   369  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   370  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   371  				packet.remoteAddr = raddr
   372  				done := make(chan struct{})
   373  				conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0)
   374  				serv.handlePacket(packet)
   375  				Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
   376  			})
   377  
   378  			It("ignores Version Negotiation packets", func() {
   379  				data := wire.ComposeVersionNegotiation(
   380  					protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
   381  					protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
   382  					[]protocol.VersionNumber{1, 2, 3},
   383  				)
   384  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   385  				done := make(chan struct{})
   386  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   387  					close(done)
   388  				})
   389  				serv.handlePacket(receivedPacket{
   390  					remoteAddr: raddr,
   391  					data:       data,
   392  					buffer:     getPacketBuffer(),
   393  				})
   394  				Eventually(done).Should(BeClosed())
   395  				// make sure no other packet is sent
   396  				time.Sleep(scaleDuration(20 * time.Millisecond))
   397  			})
   398  
   399  			It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
   400  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   401  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   402  				p := getPacket(&wire.Header{
   403  					Type:             protocol.PacketTypeHandshake,
   404  					SrcConnectionID:  srcConnID,
   405  					DestConnectionID: destConnID,
   406  					Version:          0x42,
   407  				}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
   408  				Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
   409  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   410  				p.remoteAddr = raddr
   411  				done := make(chan struct{})
   412  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   413  					close(done)
   414  				})
   415  				serv.handlePacket(p)
   416  				Eventually(done).Should(BeClosed())
   417  				// make sure no other packet is sent
   418  				time.Sleep(scaleDuration(20 * time.Millisecond))
   419  			})
   420  
   421  			It("replies with a Retry packet, if a token is required", func() {
   422  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   423  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   424  				hdr := &wire.Header{
   425  					Type:             protocol.PacketTypeInitial,
   426  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   427  					DestConnectionID: connID,
   428  					Version:          protocol.Version1,
   429  				}
   430  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   431  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   432  				packet.remoteAddr = raddr
   433  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
   434  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   435  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   436  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   437  					Expect(replyHdr.Token).ToNot(BeEmpty())
   438  				})
   439  				done := make(chan struct{})
   440  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   441  					defer close(done)
   442  					replyHdr := parseHeader(b)
   443  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   444  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   445  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   446  					Expect(replyHdr.Token).ToNot(BeEmpty())
   447  					Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
   448  					return len(b), nil
   449  				})
   450  				phm.EXPECT().Get(connID)
   451  				serv.handlePacket(packet)
   452  				Eventually(done).Should(BeClosed())
   453  			})
   454  
   455  			It("creates a connection, if no token is required", func() {
   456  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   457  				hdr := &wire.Header{
   458  					Type:             protocol.PacketTypeInitial,
   459  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   460  					DestConnectionID: connID,
   461  					Version:          protocol.Version1,
   462  				}
   463  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   464  				run := make(chan struct{})
   465  				var token protocol.StatelessResetToken
   466  				rand.Read(token[:])
   467  
   468  				var newConnID protocol.ConnectionID
   469  				gomock.InOrder(
   470  					phm.EXPECT().Get(connID),
   471  					phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   472  						newConnID = c
   473  						phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
   474  							newConnID = c
   475  							return token
   476  						})
   477  						_, ok := fn()
   478  						return ok
   479  					}),
   480  				)
   481  
   482  				conn := NewMockQUICConn(mockCtrl)
   483  				serv.newConn = func(
   484  					_ sendConn,
   485  					_ connRunner,
   486  					origDestConnID protocol.ConnectionID,
   487  					retrySrcConnID *protocol.ConnectionID,
   488  					clientDestConnID protocol.ConnectionID,
   489  					destConnID protocol.ConnectionID,
   490  					srcConnID protocol.ConnectionID,
   491  					_ ConnectionIDGenerator,
   492  					tokenP protocol.StatelessResetToken,
   493  					_ *Config,
   494  					_ *tls.Config,
   495  					_ *handshake.TokenGenerator,
   496  					_ bool,
   497  					_ *logging.ConnectionTracer,
   498  					_ uint64,
   499  					_ utils.Logger,
   500  					_ protocol.VersionNumber,
   501  				) quicConn {
   502  					Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
   503  					Expect(retrySrcConnID).To(BeNil())
   504  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   505  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   506  					// make sure we're using a server-generated connection ID
   507  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   508  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   509  					Expect(srcConnID).To(Equal(newConnID))
   510  					Expect(tokenP).To(Equal(token))
   511  					conn.EXPECT().handlePacket(p)
   512  					conn.EXPECT().run().Do(func() { close(run) })
   513  					conn.EXPECT().Context().Return(context.Background())
   514  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   515  					return conn
   516  				}
   517  
   518  				done := make(chan struct{})
   519  				go func() {
   520  					defer GinkgoRecover()
   521  					serv.handlePacket(p)
   522  					// the Handshake packet is written by the connection
   523  					// make sure there are no Write calls on the packet conn
   524  					time.Sleep(50 * time.Millisecond)
   525  					close(done)
   526  				}()
   527  				// make sure we're using a server-generated connection ID
   528  				Eventually(run).Should(BeClosed())
   529  				Eventually(done).Should(BeClosed())
   530  			})
   531  
   532  			It("drops packets if the receive queue is full", func() {
   533  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   534  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   535  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
   536  					_, ok := fn()
   537  					return ok
   538  				}).AnyTimes()
   539  
   540  				acceptConn := make(chan struct{})
   541  				var counter uint32 // to be used as an atomic, so we query it in Eventually
   542  				serv.newConn = func(
   543  					_ sendConn,
   544  					runner connRunner,
   545  					_ protocol.ConnectionID,
   546  					_ *protocol.ConnectionID,
   547  					_ protocol.ConnectionID,
   548  					_ protocol.ConnectionID,
   549  					_ protocol.ConnectionID,
   550  					_ ConnectionIDGenerator,
   551  					_ protocol.StatelessResetToken,
   552  					_ *Config,
   553  					_ *tls.Config,
   554  					_ *handshake.TokenGenerator,
   555  					_ bool,
   556  					_ *logging.ConnectionTracer,
   557  					_ uint64,
   558  					_ utils.Logger,
   559  					_ protocol.VersionNumber,
   560  				) quicConn {
   561  					<-acceptConn
   562  					atomic.AddUint32(&counter, 1)
   563  					conn := NewMockQUICConn(mockCtrl)
   564  					conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
   565  					conn.EXPECT().run().MaxTimes(1)
   566  					conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
   567  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
   568  					return conn
   569  				}
   570  
   571  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
   572  				serv.handlePacket(p)
   573  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
   574  				var wg sync.WaitGroup
   575  				for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
   576  					wg.Add(1)
   577  					go func() {
   578  						defer GinkgoRecover()
   579  						defer wg.Done()
   580  						serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
   581  					}()
   582  				}
   583  				wg.Wait()
   584  
   585  				close(acceptConn)
   586  				Eventually(
   587  					func() uint32 { return atomic.LoadUint32(&counter) },
   588  					scaleDuration(100*time.Millisecond),
   589  				).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   590  				Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   591  			})
   592  
   593  			It("only creates a single connection for a duplicate Initial", func() {
   594  				var createdConn bool
   595  				serv.newConn = func(
   596  					_ sendConn,
   597  					runner connRunner,
   598  					_ protocol.ConnectionID,
   599  					_ *protocol.ConnectionID,
   600  					_ protocol.ConnectionID,
   601  					_ protocol.ConnectionID,
   602  					_ protocol.ConnectionID,
   603  					_ ConnectionIDGenerator,
   604  					_ protocol.StatelessResetToken,
   605  					_ *Config,
   606  					_ *tls.Config,
   607  					_ *handshake.TokenGenerator,
   608  					_ bool,
   609  					_ *logging.ConnectionTracer,
   610  					_ uint64,
   611  					_ utils.Logger,
   612  					_ protocol.VersionNumber,
   613  				) quicConn {
   614  					createdConn = true
   615  					return NewMockQUICConn(mockCtrl)
   616  				}
   617  
   618  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
   619  				p := getInitial(connID)
   620  				phm.EXPECT().Get(connID)
   621  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
   622  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   623  				done := make(chan struct{})
   624  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
   625  				Expect(serv.handlePacketImpl(p)).To(BeTrue())
   626  				Expect(createdConn).To(BeFalse())
   627  				Eventually(done).Should(BeClosed())
   628  			})
   629  
   630  			It("rejects new connection attempts if the accept queue is full", func() {
   631  				serv.newConn = func(
   632  					_ sendConn,
   633  					runner connRunner,
   634  					_ protocol.ConnectionID,
   635  					_ *protocol.ConnectionID,
   636  					_ protocol.ConnectionID,
   637  					_ protocol.ConnectionID,
   638  					_ protocol.ConnectionID,
   639  					_ ConnectionIDGenerator,
   640  					_ protocol.StatelessResetToken,
   641  					_ *Config,
   642  					_ *tls.Config,
   643  					_ *handshake.TokenGenerator,
   644  					_ bool,
   645  					_ *logging.ConnectionTracer,
   646  					_ uint64,
   647  					_ utils.Logger,
   648  					_ protocol.VersionNumber,
   649  				) quicConn {
   650  					conn := NewMockQUICConn(mockCtrl)
   651  					conn.EXPECT().handlePacket(gomock.Any())
   652  					conn.EXPECT().run()
   653  					conn.EXPECT().Context().Return(context.Background())
   654  					c := make(chan struct{})
   655  					close(c)
   656  					conn.EXPECT().HandshakeComplete().Return(c)
   657  					return conn
   658  				}
   659  
   660  				phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
   661  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   662  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
   663  					_, ok := fn()
   664  					return ok
   665  				}).Times(protocol.MaxAcceptQueueSize)
   666  
   667  				var wg sync.WaitGroup
   668  				wg.Add(protocol.MaxAcceptQueueSize)
   669  				for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
   670  					go func() {
   671  						defer GinkgoRecover()
   672  						defer wg.Done()
   673  						serv.handlePacket(getInitialWithRandomDestConnID())
   674  						// make sure there are no Write calls on the packet conn
   675  						time.Sleep(50 * time.Millisecond)
   676  					}()
   677  				}
   678  				wg.Wait()
   679  				p := getInitialWithRandomDestConnID()
   680  				hdr, _, _, err := wire.ParsePacket(p.data)
   681  				Expect(err).ToNot(HaveOccurred())
   682  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   683  				done := make(chan struct{})
   684  				conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   685  					defer close(done)
   686  					rejectHdr := parseHeader(b)
   687  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
   688  					Expect(rejectHdr.Version).To(Equal(hdr.Version))
   689  					Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   690  					Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   691  					return len(b), nil
   692  				})
   693  				serv.handlePacket(p)
   694  				Eventually(done).Should(BeClosed())
   695  			})
   696  
   697  			It("doesn't accept new connections if they were closed in the mean time", func() {
   698  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
   699  				ctx, cancel := context.WithCancel(context.Background())
   700  				connCreated := make(chan struct{})
   701  				conn := NewMockQUICConn(mockCtrl)
   702  				serv.newConn = func(
   703  					_ sendConn,
   704  					runner connRunner,
   705  					_ protocol.ConnectionID,
   706  					_ *protocol.ConnectionID,
   707  					_ protocol.ConnectionID,
   708  					_ protocol.ConnectionID,
   709  					_ protocol.ConnectionID,
   710  					_ ConnectionIDGenerator,
   711  					_ protocol.StatelessResetToken,
   712  					_ *Config,
   713  					_ *tls.Config,
   714  					_ *handshake.TokenGenerator,
   715  					_ bool,
   716  					_ *logging.ConnectionTracer,
   717  					_ uint64,
   718  					_ utils.Logger,
   719  					_ protocol.VersionNumber,
   720  				) quicConn {
   721  					conn.EXPECT().handlePacket(p)
   722  					conn.EXPECT().run()
   723  					conn.EXPECT().Context().Return(ctx)
   724  					c := make(chan struct{})
   725  					close(c)
   726  					conn.EXPECT().HandshakeComplete().Return(c)
   727  					close(connCreated)
   728  					return conn
   729  				}
   730  
   731  				phm.EXPECT().Get(gomock.Any())
   732  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   733  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
   734  					_, ok := fn()
   735  					return ok
   736  				})
   737  
   738  				serv.handlePacket(p)
   739  				// make sure there are no Write calls on the packet conn
   740  				time.Sleep(50 * time.Millisecond)
   741  				Eventually(connCreated).Should(BeClosed())
   742  				cancel()
   743  				time.Sleep(scaleDuration(200 * time.Millisecond))
   744  
   745  				done := make(chan struct{})
   746  				go func() {
   747  					defer GinkgoRecover()
   748  					serv.Accept(context.Background())
   749  					close(done)
   750  				}()
   751  				Consistently(done).ShouldNot(BeClosed())
   752  
   753  				// make the go routine return
   754  				conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
   755  				Expect(serv.Close()).To(Succeed())
   756  				Eventually(done).Should(BeClosed())
   757  			})
   758  		})
   759  
   760  		Context("token validation", func() {
   761  			checkInvalidToken := func(b []byte, origHdr *wire.Header) {
   762  				replyHdr := parseHeader(b)
   763  				Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   764  				Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
   765  				Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
   766  				_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
   767  				extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
   768  				Expect(err).ToNot(HaveOccurred())
   769  				data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
   770  				Expect(err).ToNot(HaveOccurred())
   771  				_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
   772  				Expect(err).ToNot(HaveOccurred())
   773  				Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   774  				ccf := f.(*wire.ConnectionCloseFrame)
   775  				Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   776  				Expect(ccf.ReasonPhrase).To(BeEmpty())
   777  			}
   778  
   779  			It("decodes the token from the token field", func() {
   780  				raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
   781  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   782  				Expect(err).ToNot(HaveOccurred())
   783  				packet := getPacket(&wire.Header{
   784  					Type:    protocol.PacketTypeInitial,
   785  					Token:   token,
   786  					Version: serv.config.Versions[0],
   787  				}, make([]byte, protocol.MinInitialPacketSize))
   788  				packet.remoteAddr = raddr
   789  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
   790  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   791  
   792  				done := make(chan struct{})
   793  				phm.EXPECT().Get(gomock.Any())
   794  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
   795  				serv.handlePacket(packet)
   796  				Eventually(done).Should(BeClosed())
   797  			})
   798  
   799  			It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
   800  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   801  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   802  				Expect(err).ToNot(HaveOccurred())
   803  				hdr := &wire.Header{
   804  					Type:             protocol.PacketTypeInitial,
   805  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   806  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   807  					Token:            token,
   808  					Version:          protocol.Version1,
   809  				}
   810  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   811  				packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
   812  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   813  				packet.remoteAddr = raddr
   814  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   815  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   816  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   817  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   818  					Expect(frames).To(HaveLen(1))
   819  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   820  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   821  					Expect(ccf.IsApplicationError).To(BeFalse())
   822  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   823  				})
   824  				done := make(chan struct{})
   825  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   826  					defer close(done)
   827  					checkInvalidToken(b, hdr)
   828  					return len(b), nil
   829  				})
   830  				phm.EXPECT().Get(gomock.Any())
   831  				serv.handlePacket(packet)
   832  				Eventually(done).Should(BeClosed())
   833  			})
   834  
   835  			It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
   836  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   837  				serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
   838  				Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
   839  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   840  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   841  				Expect(err).ToNot(HaveOccurred())
   842  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   843  				hdr := &wire.Header{
   844  					Type:             protocol.PacketTypeInitial,
   845  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   846  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   847  					Token:            token,
   848  					Version:          protocol.Version1,
   849  				}
   850  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   851  				packet.remoteAddr = raddr
   852  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   853  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   854  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   855  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   856  					Expect(frames).To(HaveLen(1))
   857  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   858  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   859  					Expect(ccf.IsApplicationError).To(BeFalse())
   860  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   861  				})
   862  				done := make(chan struct{})
   863  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   864  					defer close(done)
   865  					checkInvalidToken(b, hdr)
   866  					return len(b), nil
   867  				})
   868  				phm.EXPECT().Get(gomock.Any())
   869  				serv.handlePacket(packet)
   870  				Eventually(done).Should(BeClosed())
   871  			})
   872  
   873  			It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
   874  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   875  				token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
   876  				Expect(err).ToNot(HaveOccurred())
   877  				hdr := &wire.Header{
   878  					Type:             protocol.PacketTypeInitial,
   879  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   880  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   881  					Token:            token,
   882  					Version:          protocol.Version1,
   883  				}
   884  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   885  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   886  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   887  				packet.remoteAddr = raddr
   888  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   889  				done := make(chan struct{})
   890  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   891  					defer close(done)
   892  					replyHdr := parseHeader(b)
   893  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   894  					return len(b), nil
   895  				})
   896  				phm.EXPECT().Get(gomock.Any())
   897  				serv.handlePacket(packet)
   898  				// make sure there are no Write calls on the packet conn
   899  				Eventually(done).Should(BeClosed())
   900  			})
   901  
   902  			It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
   903  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   904  				serv.maxTokenAge = time.Millisecond
   905  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   906  				token, err := serv.tokenGenerator.NewToken(raddr)
   907  				Expect(err).ToNot(HaveOccurred())
   908  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   909  				hdr := &wire.Header{
   910  					Type:             protocol.PacketTypeInitial,
   911  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   912  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   913  					Token:            token,
   914  					Version:          protocol.Version1,
   915  				}
   916  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   917  				packet.remoteAddr = raddr
   918  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   919  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   920  				})
   921  				done := make(chan struct{})
   922  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   923  					defer close(done)
   924  					return len(b), nil
   925  				})
   926  				phm.EXPECT().Get(gomock.Any())
   927  				serv.handlePacket(packet)
   928  				Eventually(done).Should(BeClosed())
   929  			})
   930  
   931  			It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
   932  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   933  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   934  				Expect(err).ToNot(HaveOccurred())
   935  				hdr := &wire.Header{
   936  					Type:             protocol.PacketTypeInitial,
   937  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   938  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   939  					Token:            token,
   940  					Version:          protocol.Version1,
   941  				}
   942  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   943  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   944  				packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   945  				done := make(chan struct{})
   946  				tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
   947  				phm.EXPECT().Get(gomock.Any())
   948  				serv.handlePacket(packet)
   949  				// make sure there are no Write calls on the packet conn
   950  				time.Sleep(50 * time.Millisecond)
   951  				Eventually(done).Should(BeClosed())
   952  			})
   953  		})
   954  
   955  		Context("accepting connections", func() {
   956  			It("returns Accept when an error occurs", func() {
   957  				testErr := errors.New("test err")
   958  
   959  				done := make(chan struct{})
   960  				go func() {
   961  					defer GinkgoRecover()
   962  					_, err := serv.Accept(context.Background())
   963  					Expect(err).To(MatchError(testErr))
   964  					close(done)
   965  				}()
   966  
   967  				serv.setCloseError(testErr)
   968  				Eventually(done).Should(BeClosed())
   969  				serv.onClose() // shutdown
   970  			})
   971  
   972  			It("returns immediately, if an error occurred before", func() {
   973  				testErr := errors.New("test err")
   974  				serv.setCloseError(testErr)
   975  				for i := 0; i < 3; i++ {
   976  					_, err := serv.Accept(context.Background())
   977  					Expect(err).To(MatchError(testErr))
   978  				}
   979  				serv.onClose() // shutdown
   980  			})
   981  
   982  			It("returns when the context is canceled", func() {
   983  				ctx, cancel := context.WithCancel(context.Background())
   984  				done := make(chan struct{})
   985  				go func() {
   986  					defer GinkgoRecover()
   987  					_, err := serv.Accept(ctx)
   988  					Expect(err).To(MatchError("context canceled"))
   989  					close(done)
   990  				}()
   991  
   992  				Consistently(done).ShouldNot(BeClosed())
   993  				cancel()
   994  				Eventually(done).Should(BeClosed())
   995  			})
   996  
   997  			It("uses the config returned by GetConfigClient", func() {
   998  				conn := NewMockQUICConn(mockCtrl)
   999  
  1000  				conf := &Config{MaxIncomingStreams: 1234}
  1001  				serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
  1002  				done := make(chan struct{})
  1003  				go func() {
  1004  					defer GinkgoRecover()
  1005  					s, err := serv.Accept(context.Background())
  1006  					Expect(err).ToNot(HaveOccurred())
  1007  					Expect(s).To(Equal(conn))
  1008  					close(done)
  1009  				}()
  1010  
  1011  				handshakeChan := make(chan struct{})
  1012  				serv.newConn = func(
  1013  					_ sendConn,
  1014  					_ connRunner,
  1015  					_ protocol.ConnectionID,
  1016  					_ *protocol.ConnectionID,
  1017  					_ protocol.ConnectionID,
  1018  					_ protocol.ConnectionID,
  1019  					_ protocol.ConnectionID,
  1020  					_ ConnectionIDGenerator,
  1021  					_ protocol.StatelessResetToken,
  1022  					conf *Config,
  1023  					_ *tls.Config,
  1024  					_ *handshake.TokenGenerator,
  1025  					_ bool,
  1026  					_ *logging.ConnectionTracer,
  1027  					_ uint64,
  1028  					_ utils.Logger,
  1029  					_ protocol.VersionNumber,
  1030  				) quicConn {
  1031  					Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
  1032  					conn.EXPECT().handlePacket(gomock.Any())
  1033  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1034  					conn.EXPECT().run().Do(func() {})
  1035  					conn.EXPECT().Context().Return(context.Background())
  1036  					return conn
  1037  				}
  1038  				phm.EXPECT().Get(gomock.Any())
  1039  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1040  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1041  					_, ok := fn()
  1042  					return ok
  1043  				})
  1044  				serv.handleInitialImpl(
  1045  					receivedPacket{buffer: getPacketBuffer()},
  1046  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1047  				)
  1048  				Consistently(done).ShouldNot(BeClosed())
  1049  				close(handshakeChan) // complete the handshake
  1050  				Eventually(done).Should(BeClosed())
  1051  			})
  1052  
  1053  			It("rejects a connection attempt when GetConfigClient returns an error", func() {
  1054  				serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
  1055  
  1056  				phm.EXPECT().Get(gomock.Any())
  1057  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1058  					_, ok := fn()
  1059  					return ok
  1060  				})
  1061  				done := make(chan struct{})
  1062  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
  1063  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1064  					defer close(done)
  1065  					rejectHdr := parseHeader(b)
  1066  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1067  					return len(b), nil
  1068  				})
  1069  				serv.handleInitialImpl(
  1070  					receivedPacket{buffer: getPacketBuffer()},
  1071  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
  1072  				)
  1073  				Eventually(done).Should(BeClosed())
  1074  			})
  1075  
  1076  			It("accepts new connections when the handshake completes", func() {
  1077  				conn := NewMockQUICConn(mockCtrl)
  1078  
  1079  				done := make(chan struct{})
  1080  				go func() {
  1081  					defer GinkgoRecover()
  1082  					s, err := serv.Accept(context.Background())
  1083  					Expect(err).ToNot(HaveOccurred())
  1084  					Expect(s).To(Equal(conn))
  1085  					close(done)
  1086  				}()
  1087  
  1088  				handshakeChan := make(chan struct{})
  1089  				serv.newConn = func(
  1090  					_ sendConn,
  1091  					runner connRunner,
  1092  					_ protocol.ConnectionID,
  1093  					_ *protocol.ConnectionID,
  1094  					_ protocol.ConnectionID,
  1095  					_ protocol.ConnectionID,
  1096  					_ protocol.ConnectionID,
  1097  					_ ConnectionIDGenerator,
  1098  					_ protocol.StatelessResetToken,
  1099  					_ *Config,
  1100  					_ *tls.Config,
  1101  					_ *handshake.TokenGenerator,
  1102  					_ bool,
  1103  					_ *logging.ConnectionTracer,
  1104  					_ uint64,
  1105  					_ utils.Logger,
  1106  					_ protocol.VersionNumber,
  1107  				) quicConn {
  1108  					conn.EXPECT().handlePacket(gomock.Any())
  1109  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1110  					conn.EXPECT().run().Do(func() {})
  1111  					conn.EXPECT().Context().Return(context.Background())
  1112  					return conn
  1113  				}
  1114  				phm.EXPECT().Get(gomock.Any())
  1115  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1116  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1117  					_, ok := fn()
  1118  					return ok
  1119  				})
  1120  				serv.handleInitialImpl(
  1121  					receivedPacket{buffer: getPacketBuffer()},
  1122  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1123  				)
  1124  				Consistently(done).ShouldNot(BeClosed())
  1125  				close(handshakeChan) // complete the handshake
  1126  				Eventually(done).Should(BeClosed())
  1127  			})
  1128  		})
  1129  	})
  1130  
  1131  	Context("server accepting connections that haven't completed the handshake", func() {
  1132  		var (
  1133  			serv *EarlyListener
  1134  			phm  *MockPacketHandlerManager
  1135  		)
  1136  
  1137  		BeforeEach(func() {
  1138  			var err error
  1139  			serv, err = ListenEarly(conn, tlsConf, nil)
  1140  			Expect(err).ToNot(HaveOccurred())
  1141  			phm = NewMockPacketHandlerManager(mockCtrl)
  1142  			serv.baseServer.connHandler = phm
  1143  		})
  1144  
  1145  		AfterEach(func() {
  1146  			serv.Close()
  1147  		})
  1148  
  1149  		It("accepts new connections when they become ready", func() {
  1150  			conn := NewMockQUICConn(mockCtrl)
  1151  
  1152  			done := make(chan struct{})
  1153  			go func() {
  1154  				defer GinkgoRecover()
  1155  				s, err := serv.Accept(context.Background())
  1156  				Expect(err).ToNot(HaveOccurred())
  1157  				Expect(s).To(Equal(conn))
  1158  				close(done)
  1159  			}()
  1160  
  1161  			ready := make(chan struct{})
  1162  			serv.baseServer.newConn = func(
  1163  				_ sendConn,
  1164  				runner connRunner,
  1165  				_ protocol.ConnectionID,
  1166  				_ *protocol.ConnectionID,
  1167  				_ protocol.ConnectionID,
  1168  				_ protocol.ConnectionID,
  1169  				_ protocol.ConnectionID,
  1170  				_ ConnectionIDGenerator,
  1171  				_ protocol.StatelessResetToken,
  1172  				_ *Config,
  1173  				_ *tls.Config,
  1174  				_ *handshake.TokenGenerator,
  1175  				_ bool,
  1176  				_ *logging.ConnectionTracer,
  1177  				_ uint64,
  1178  				_ utils.Logger,
  1179  				_ protocol.VersionNumber,
  1180  			) quicConn {
  1181  				conn.EXPECT().handlePacket(gomock.Any())
  1182  				conn.EXPECT().run().Do(func() {})
  1183  				conn.EXPECT().earlyConnReady().Return(ready)
  1184  				conn.EXPECT().Context().Return(context.Background())
  1185  				return conn
  1186  			}
  1187  			phm.EXPECT().Get(gomock.Any())
  1188  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1189  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1190  				_, ok := fn()
  1191  				return ok
  1192  			})
  1193  			serv.baseServer.handleInitialImpl(
  1194  				receivedPacket{buffer: getPacketBuffer()},
  1195  				&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1196  			)
  1197  			Consistently(done).ShouldNot(BeClosed())
  1198  			close(ready)
  1199  			Eventually(done).Should(BeClosed())
  1200  		})
  1201  
  1202  		It("rejects new connection attempts if the accept queue is full", func() {
  1203  			senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
  1204  
  1205  			serv.baseServer.newConn = func(
  1206  				_ sendConn,
  1207  				runner connRunner,
  1208  				_ protocol.ConnectionID,
  1209  				_ *protocol.ConnectionID,
  1210  				_ protocol.ConnectionID,
  1211  				_ protocol.ConnectionID,
  1212  				_ protocol.ConnectionID,
  1213  				_ ConnectionIDGenerator,
  1214  				_ protocol.StatelessResetToken,
  1215  				_ *Config,
  1216  				_ *tls.Config,
  1217  				_ *handshake.TokenGenerator,
  1218  				_ bool,
  1219  				_ *logging.ConnectionTracer,
  1220  				_ uint64,
  1221  				_ utils.Logger,
  1222  				_ protocol.VersionNumber,
  1223  			) quicConn {
  1224  				ready := make(chan struct{})
  1225  				close(ready)
  1226  				conn := NewMockQUICConn(mockCtrl)
  1227  				conn.EXPECT().handlePacket(gomock.Any())
  1228  				conn.EXPECT().run()
  1229  				conn.EXPECT().earlyConnReady().Return(ready)
  1230  				conn.EXPECT().Context().Return(context.Background())
  1231  				return conn
  1232  			}
  1233  
  1234  			phm.EXPECT().Get(gomock.Any()).AnyTimes()
  1235  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1236  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1237  				_, ok := fn()
  1238  				return ok
  1239  			}).Times(protocol.MaxAcceptQueueSize)
  1240  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
  1241  				serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1242  			}
  1243  
  1244  			Eventually(func() int32 { return atomic.LoadInt32(&serv.baseServer.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
  1245  			// make sure there are no Write calls on the packet conn
  1246  			time.Sleep(50 * time.Millisecond)
  1247  
  1248  			p := getInitialWithRandomDestConnID()
  1249  			hdr := parseHeader(p.data)
  1250  			done := make(chan struct{})
  1251  			conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1252  				defer close(done)
  1253  				rejectHdr := parseHeader(b)
  1254  				Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1255  				Expect(rejectHdr.Version).To(Equal(hdr.Version))
  1256  				Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
  1257  				Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
  1258  				return len(b), nil
  1259  			})
  1260  			serv.baseServer.handlePacket(p)
  1261  			Eventually(done).Should(BeClosed())
  1262  		})
  1263  
  1264  		It("doesn't accept new connections if they were closed in the mean time", func() {
  1265  			p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
  1266  			ctx, cancel := context.WithCancel(context.Background())
  1267  			connCreated := make(chan struct{})
  1268  			conn := NewMockQUICConn(mockCtrl)
  1269  			serv.baseServer.newConn = func(
  1270  				_ sendConn,
  1271  				runner connRunner,
  1272  				_ protocol.ConnectionID,
  1273  				_ *protocol.ConnectionID,
  1274  				_ protocol.ConnectionID,
  1275  				_ protocol.ConnectionID,
  1276  				_ protocol.ConnectionID,
  1277  				_ ConnectionIDGenerator,
  1278  				_ protocol.StatelessResetToken,
  1279  				_ *Config,
  1280  				_ *tls.Config,
  1281  				_ *handshake.TokenGenerator,
  1282  				_ bool,
  1283  				_ *logging.ConnectionTracer,
  1284  				_ uint64,
  1285  				_ utils.Logger,
  1286  				_ protocol.VersionNumber,
  1287  			) quicConn {
  1288  				conn.EXPECT().handlePacket(p)
  1289  				conn.EXPECT().run()
  1290  				conn.EXPECT().earlyConnReady()
  1291  				conn.EXPECT().Context().Return(ctx)
  1292  				close(connCreated)
  1293  				return conn
  1294  			}
  1295  
  1296  			phm.EXPECT().Get(gomock.Any())
  1297  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1298  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1299  				_, ok := fn()
  1300  				return ok
  1301  			})
  1302  			serv.baseServer.handlePacket(p)
  1303  			// make sure there are no Write calls on the packet conn
  1304  			time.Sleep(50 * time.Millisecond)
  1305  			Eventually(connCreated).Should(BeClosed())
  1306  			cancel()
  1307  			time.Sleep(scaleDuration(200 * time.Millisecond))
  1308  
  1309  			done := make(chan struct{})
  1310  			go func() {
  1311  				defer GinkgoRecover()
  1312  				serv.Accept(context.Background())
  1313  				close(done)
  1314  			}()
  1315  			Consistently(done).ShouldNot(BeClosed())
  1316  
  1317  			// make the go routine return
  1318  			conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
  1319  			Expect(serv.Close()).To(Succeed())
  1320  			Eventually(done).Should(BeClosed())
  1321  		})
  1322  	})
  1323  
  1324  	Context("0-RTT", func() {
  1325  		var (
  1326  			tr     *Transport
  1327  			serv   *baseServer
  1328  			phm    *MockPacketHandlerManager
  1329  			tracer *mocklogging.MockTracer
  1330  		)
  1331  
  1332  		BeforeEach(func() {
  1333  			var t *logging.Tracer
  1334  			t, tracer = mocklogging.NewMockTracer(mockCtrl)
  1335  			tr = &Transport{Conn: conn, Tracer: t}
  1336  			ln, err := tr.ListenEarly(tlsConf, nil)
  1337  			Expect(err).ToNot(HaveOccurred())
  1338  			phm = NewMockPacketHandlerManager(mockCtrl)
  1339  			serv = ln.baseServer
  1340  			serv.connHandler = phm
  1341  		})
  1342  
  1343  		AfterEach(func() {
  1344  			phm.EXPECT().CloseServer().MaxTimes(1)
  1345  			tr.Close()
  1346  		})
  1347  
  1348  		It("passes packets to existing connections", func() {
  1349  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1350  			p := getPacket(&wire.Header{
  1351  				Type:             protocol.PacketType0RTT,
  1352  				DestConnectionID: connID,
  1353  				Version:          serv.config.Versions[0],
  1354  			}, make([]byte, 100))
  1355  			conn := NewMockPacketHandler(mockCtrl)
  1356  			phm.EXPECT().Get(connID).Return(conn, true)
  1357  			handled := make(chan struct{})
  1358  			conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
  1359  			serv.handlePacket(p)
  1360  			Eventually(handled).Should(BeClosed())
  1361  		})
  1362  
  1363  		It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
  1364  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1365  
  1366  			var zeroRTTPackets []receivedPacket
  1367  
  1368  			for i := 0; i < protocol.Max0RTTQueueLen; i++ {
  1369  				p := getPacket(&wire.Header{
  1370  					Type:             protocol.PacketType0RTT,
  1371  					DestConnectionID: connID,
  1372  					Version:          serv.config.Versions[0],
  1373  				}, make([]byte, 100+i))
  1374  				phm.EXPECT().Get(connID)
  1375  				serv.handlePacket(p)
  1376  				zeroRTTPackets = append(zeroRTTPackets, p)
  1377  			}
  1378  
  1379  			// send one more packet, this one should be dropped
  1380  			p := getPacket(&wire.Header{
  1381  				Type:             protocol.PacketType0RTT,
  1382  				DestConnectionID: connID,
  1383  				Version:          serv.config.Versions[0],
  1384  			}, make([]byte, 200))
  1385  			phm.EXPECT().Get(connID)
  1386  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
  1387  			serv.handlePacket(p)
  1388  
  1389  			initial := getPacket(&wire.Header{
  1390  				Type:             protocol.PacketTypeInitial,
  1391  				DestConnectionID: connID,
  1392  				Version:          serv.config.Versions[0],
  1393  			}, make([]byte, protocol.MinInitialPacketSize))
  1394  			called := make(chan struct{})
  1395  			serv.newConn = func(
  1396  				_ sendConn,
  1397  				_ connRunner,
  1398  				_ protocol.ConnectionID,
  1399  				_ *protocol.ConnectionID,
  1400  				_ protocol.ConnectionID,
  1401  				_ protocol.ConnectionID,
  1402  				_ protocol.ConnectionID,
  1403  				_ ConnectionIDGenerator,
  1404  				_ protocol.StatelessResetToken,
  1405  				_ *Config,
  1406  				_ *tls.Config,
  1407  				_ *handshake.TokenGenerator,
  1408  				_ bool,
  1409  				_ *logging.ConnectionTracer,
  1410  				_ uint64,
  1411  				_ utils.Logger,
  1412  				_ protocol.VersionNumber,
  1413  			) quicConn {
  1414  				conn := NewMockQUICConn(mockCtrl)
  1415  				var calls []any
  1416  				calls = append(calls, conn.EXPECT().handlePacket(initial))
  1417  				for _, p := range zeroRTTPackets {
  1418  					calls = append(calls, conn.EXPECT().handlePacket(p))
  1419  				}
  1420  				gomock.InOrder(calls...)
  1421  				conn.EXPECT().run()
  1422  				conn.EXPECT().earlyConnReady()
  1423  				conn.EXPECT().Context().Return(context.Background())
  1424  				close(called)
  1425  				return conn
  1426  			}
  1427  
  1428  			phm.EXPECT().Get(connID)
  1429  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1430  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1431  				_, ok := fn()
  1432  				return ok
  1433  			})
  1434  			serv.handlePacket(initial)
  1435  			Eventually(called).Should(BeClosed())
  1436  		})
  1437  
  1438  		It("limits the number of queues", func() {
  1439  			for i := 0; i < protocol.Max0RTTQueues; i++ {
  1440  				b := make([]byte, 16)
  1441  				rand.Read(b)
  1442  				connID := protocol.ParseConnectionID(b)
  1443  				p := getPacket(&wire.Header{
  1444  					Type:             protocol.PacketType0RTT,
  1445  					DestConnectionID: connID,
  1446  					Version:          serv.config.Versions[0],
  1447  				}, make([]byte, 100+i))
  1448  				phm.EXPECT().Get(connID)
  1449  				serv.handlePacket(p)
  1450  			}
  1451  
  1452  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1453  			p := getPacket(&wire.Header{
  1454  				Type:             protocol.PacketType0RTT,
  1455  				DestConnectionID: connID,
  1456  				Version:          serv.config.Versions[0],
  1457  			}, make([]byte, 200))
  1458  			phm.EXPECT().Get(connID)
  1459  			dropped := make(chan struct{})
  1460  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1461  				close(dropped)
  1462  			})
  1463  			serv.handlePacket(p)
  1464  			Eventually(dropped).Should(BeClosed())
  1465  		})
  1466  
  1467  		It("drops queues after a while", func() {
  1468  			now := time.Now()
  1469  
  1470  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1471  			p := getPacket(&wire.Header{
  1472  				Type:             protocol.PacketType0RTT,
  1473  				DestConnectionID: connID,
  1474  				Version:          serv.config.Versions[0],
  1475  			}, make([]byte, 200))
  1476  			p.rcvTime = now
  1477  
  1478  			connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
  1479  			p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
  1480  			p2 := getPacket(&wire.Header{
  1481  				Type:             protocol.PacketType0RTT,
  1482  				DestConnectionID: connID2,
  1483  				Version:          serv.config.Versions[0],
  1484  			}, make([]byte, 300))
  1485  			p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
  1486  
  1487  			dropped1 := make(chan struct{})
  1488  			dropped2 := make(chan struct{})
  1489  			// need to register the call before handling the packet to avoid race condition
  1490  			gomock.InOrder(
  1491  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1492  					close(dropped1)
  1493  				}),
  1494  				tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1495  					close(dropped2)
  1496  				}),
  1497  			)
  1498  
  1499  			phm.EXPECT().Get(connID)
  1500  			serv.handlePacket(p)
  1501  
  1502  			// There's no cleanup Go routine.
  1503  			// Cleanup is triggered when new packets are received.
  1504  
  1505  			phm.EXPECT().Get(connID2)
  1506  			serv.handlePacket(p2)
  1507  			// make sure no cleanup is executed
  1508  			Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
  1509  
  1510  			// There's no cleanup Go routine.
  1511  			// Cleanup is triggered when new packets are received.
  1512  			connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
  1513  			p3 := getPacket(&wire.Header{
  1514  				Type:             protocol.PacketType0RTT,
  1515  				DestConnectionID: connID3,
  1516  				Version:          serv.config.Versions[0],
  1517  			}, make([]byte, 200))
  1518  			p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1519  			phm.EXPECT().Get(connID3)
  1520  			serv.handlePacket(p3)
  1521  			Eventually(dropped1).Should(BeClosed())
  1522  			Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
  1523  
  1524  			// make sure the second packet is also cleaned up
  1525  			connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
  1526  			p4 := getPacket(&wire.Header{
  1527  				Type:             protocol.PacketType0RTT,
  1528  				DestConnectionID: connID4,
  1529  				Version:          serv.config.Versions[0],
  1530  			}, make([]byte, 200))
  1531  			p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1532  			phm.EXPECT().Get(connID4)
  1533  			serv.handlePacket(p4)
  1534  			Eventually(dropped2).Should(BeClosed())
  1535  		})
  1536  	})
  1537  })