github.com/tumi8/quic-go@v0.37.4-tum/server_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"reflect"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/tumi8/quic-go/noninternal/handshake"
    15  	mocklogging "github.com/tumi8/quic-go/noninternal/mocks/logging"
    16  	"github.com/tumi8/quic-go/noninternal/protocol"
    17  	"github.com/tumi8/quic-go/noninternal/qerr"
    18  	"github.com/tumi8/quic-go/noninternal/testdata"
    19  	"github.com/tumi8/quic-go/noninternal/utils"
    20  	"github.com/tumi8/quic-go/noninternal/wire"
    21  	"github.com/tumi8/quic-go/logging"
    22  
    23  	"github.com/golang/mock/gomock"
    24  	. "github.com/onsi/ginkgo/v2"
    25  	. "github.com/onsi/gomega"
    26  )
    27  
    28  var _ = Describe("Server", func() {
    29  	var (
    30  		conn    *MockPacketConn
    31  		tlsConf *tls.Config
    32  	)
    33  
    34  	getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
    35  		buf := getPacketBuffer()
    36  		hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
    37  		var err error
    38  		buf.Data, err = (&wire.ExtendedHeader{
    39  			Header:          *hdr,
    40  			PacketNumber:    0x42,
    41  			PacketNumberLen: protocol.PacketNumberLen4,
    42  		}).Append(buf.Data, protocol.Version1)
    43  		Expect(err).ToNot(HaveOccurred())
    44  		n := len(buf.Data)
    45  		buf.Data = append(buf.Data, p...)
    46  		data := buf.Data
    47  		sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
    48  		_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
    49  		data = data[:len(data)+16]
    50  		sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
    51  		return receivedPacket{
    52  			remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
    53  			data:       data,
    54  			buffer:     buf,
    55  		}
    56  	}
    57  
    58  	getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
    59  		senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
    60  		hdr := &wire.Header{
    61  			Type:             protocol.PacketTypeInitial,
    62  			SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
    63  			DestConnectionID: destConnID,
    64  			Version:          protocol.Version1,
    65  		}
    66  		p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
    67  		p.buffer = getPacketBuffer()
    68  		p.remoteAddr = senderAddr
    69  		return p
    70  	}
    71  
    72  	getInitialWithRandomDestConnID := func() receivedPacket {
    73  		b := make([]byte, 10)
    74  		_, err := rand.Read(b)
    75  		Expect(err).ToNot(HaveOccurred())
    76  
    77  		return getInitial(protocol.ParseConnectionID(b))
    78  	}
    79  
    80  	parseHeader := func(data []byte) *wire.Header {
    81  		hdr, _, _, err := wire.ParsePacket(data)
    82  		Expect(err).ToNot(HaveOccurred())
    83  		return hdr
    84  	}
    85  
    86  	BeforeEach(func() {
    87  		conn = NewMockPacketConn(mockCtrl)
    88  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
    89  		wait := make(chan struct{})
    90  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
    91  			<-wait
    92  			return 0, nil, errors.New("done")
    93  		}).MaxTimes(1)
    94  		conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
    95  			close(wait)
    96  			conn.EXPECT().SetReadDeadline(time.Time{})
    97  		}).MaxTimes(1)
    98  		tlsConf = testdata.GetTLSConfig()
    99  		tlsConf.NextProtos = []string{"proto1"}
   100  	})
   101  
   102  	It("errors when no tls.Config is given", func() {
   103  		_, err := ListenAddr("localhost:0", nil, nil)
   104  		Expect(err).To(HaveOccurred())
   105  		Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
   106  	})
   107  
   108  	It("errors when the Config contains an invalid version", func() {
   109  		version := protocol.VersionNumber(0x1234)
   110  		_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
   111  		Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
   112  	})
   113  
   114  	It("fills in default values if options are not set in the Config", func() {
   115  		ln, err := Listen(conn, tlsConf, &Config{})
   116  		Expect(err).ToNot(HaveOccurred())
   117  		server := ln.baseServer
   118  		Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
   119  		Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   120  		Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
   121  		Expect(server.config.RequireAddressValidation).ToNot(BeNil())
   122  		Expect(server.config.KeepAlivePeriod).To(BeZero())
   123  		// stop the listener
   124  		Expect(ln.Close()).To(Succeed())
   125  	})
   126  
   127  	It("setups with the right values", func() {
   128  		supportedVersions := []protocol.VersionNumber{protocol.Version1}
   129  		requireAddrVal := func(net.Addr) bool { return true }
   130  		config := Config{
   131  			Versions:                 supportedVersions,
   132  			HandshakeIdleTimeout:     1337 * time.Hour,
   133  			MaxIdleTimeout:           42 * time.Minute,
   134  			KeepAlivePeriod:          5 * time.Second,
   135  			RequireAddressValidation: requireAddrVal,
   136  		}
   137  		ln, err := Listen(conn, tlsConf, &config)
   138  		Expect(err).ToNot(HaveOccurred())
   139  		server := ln.baseServer
   140  		Expect(server.connHandler).ToNot(BeNil())
   141  		Expect(server.config.Versions).To(Equal(supportedVersions))
   142  		Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
   143  		Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
   144  		Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
   145  		Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
   146  		// stop the listener
   147  		Expect(ln.Close()).To(Succeed())
   148  	})
   149  
   150  	It("listens on a given address", func() {
   151  		addr := "127.0.0.1:13579"
   152  		ln, err := ListenAddr(addr, tlsConf, &Config{})
   153  		Expect(err).ToNot(HaveOccurred())
   154  		Expect(ln.Addr().String()).To(Equal(addr))
   155  		// stop the listener
   156  		Expect(ln.Close()).To(Succeed())
   157  	})
   158  
   159  	It("errors if given an invalid address", func() {
   160  		addr := "127.0.0.1"
   161  		_, err := ListenAddr(addr, tlsConf, &Config{})
   162  		Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
   163  	})
   164  
   165  	It("errors if given an invalid address", func() {
   166  		addr := "1.1.1.1:1111"
   167  		_, err := ListenAddr(addr, tlsConf, &Config{})
   168  		Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
   169  	})
   170  
   171  	Context("server accepting connections that completed the handshake", func() {
   172  		var (
   173  			tr     *Transport
   174  			serv   *baseServer
   175  			phm    *MockPacketHandlerManager
   176  			tracer *mocklogging.MockTracer
   177  		)
   178  
   179  		BeforeEach(func() {
   180  			tracer = mocklogging.NewMockTracer(mockCtrl)
   181  			tr = &Transport{Conn: conn, Tracer: tracer}
   182  			ln, err := tr.Listen(tlsConf, nil)
   183  			Expect(err).ToNot(HaveOccurred())
   184  			serv = ln.baseServer
   185  			phm = NewMockPacketHandlerManager(mockCtrl)
   186  			serv.connHandler = phm
   187  		})
   188  
   189  		AfterEach(func() {
   190  			tr.Close()
   191  		})
   192  
   193  		Context("handling packets", func() {
   194  			It("drops Initial packets with a too short connection ID", func() {
   195  				p := getPacket(&wire.Header{
   196  					Type:             protocol.PacketTypeInitial,
   197  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
   198  					Version:          serv.config.Versions[0],
   199  				}, nil)
   200  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   201  				serv.handlePacket(p)
   202  				// make sure there are no Write calls on the packet conn
   203  				time.Sleep(50 * time.Millisecond)
   204  			})
   205  
   206  			It("drops too small Initial", func() {
   207  				p := getPacket(&wire.Header{
   208  					Type:             protocol.PacketTypeInitial,
   209  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
   210  					Version:          serv.config.Versions[0],
   211  				}, make([]byte, protocol.MinInitialPacketSize-100))
   212  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
   213  				serv.handlePacket(p)
   214  				// make sure there are no Write calls on the packet conn
   215  				time.Sleep(50 * time.Millisecond)
   216  			})
   217  
   218  			It("drops non-Initial packets", func() {
   219  				p := getPacket(&wire.Header{
   220  					Type:    protocol.PacketTypeHandshake,
   221  					Version: serv.config.Versions[0],
   222  				}, []byte("invalid"))
   223  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
   224  				serv.handlePacket(p)
   225  				// make sure there are no Write calls on the packet conn
   226  				time.Sleep(50 * time.Millisecond)
   227  			})
   228  
   229  			It("passes packets to existing connections", func() {
   230  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
   231  				p := getPacket(&wire.Header{
   232  					Type:             protocol.PacketTypeInitial,
   233  					DestConnectionID: connID,
   234  					Version:          serv.config.Versions[0],
   235  				}, make([]byte, protocol.MinInitialPacketSize))
   236  				conn := NewMockPacketHandler(mockCtrl)
   237  				phm.EXPECT().Get(connID).Return(conn, true)
   238  				handled := make(chan struct{})
   239  				conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
   240  				serv.handlePacket(p)
   241  				Eventually(handled).Should(BeClosed())
   242  			})
   243  
   244  			It("creates a connection when the token is accepted", func() {
   245  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   246  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   247  				retryToken, err := serv.tokenGenerator.NewRetryToken(
   248  					raddr,
   249  					protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
   250  					protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
   251  				)
   252  				Expect(err).ToNot(HaveOccurred())
   253  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   254  				hdr := &wire.Header{
   255  					Type:             protocol.PacketTypeInitial,
   256  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   257  					DestConnectionID: connID,
   258  					Version:          protocol.Version1,
   259  					Token:            retryToken,
   260  				}
   261  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   262  				p.remoteAddr = raddr
   263  				run := make(chan struct{})
   264  				var token protocol.StatelessResetToken
   265  				rand.Read(token[:])
   266  
   267  				var newConnID protocol.ConnectionID
   268  
   269  				phm.EXPECT().Get(connID)
   270  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   271  					newConnID = c
   272  					phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
   273  						newConnID = c
   274  						return token
   275  					})
   276  					_, ok := fn()
   277  					return ok
   278  				})
   279  				conn := NewMockQUICConn(mockCtrl)
   280  				serv.newConn = func(
   281  					_ sendConn,
   282  					_ connRunner,
   283  					origDestConnID protocol.ConnectionID,
   284  					retrySrcConnID *protocol.ConnectionID,
   285  					clientDestConnID protocol.ConnectionID,
   286  					destConnID protocol.ConnectionID,
   287  					srcConnID protocol.ConnectionID,
   288  					_ ConnectionIDGenerator,
   289  					tokenP protocol.StatelessResetToken,
   290  					_ *Config,
   291  					_ *tls.Config,
   292  					_ *handshake.TokenGenerator,
   293  					_ bool,
   294  					_ logging.ConnectionTracer,
   295  					_ uint64,
   296  					_ utils.Logger,
   297  					_ protocol.VersionNumber,
   298  				) quicConn {
   299  					Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
   300  					Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
   301  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   302  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   303  					// make sure we're using a server-generated connection ID
   304  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   305  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   306  					Expect(srcConnID).To(Equal(newConnID))
   307  					Expect(tokenP).To(Equal(token))
   308  					conn.EXPECT().handlePacket(p)
   309  					conn.EXPECT().run().Do(func() { close(run) })
   310  					conn.EXPECT().Context().Return(context.Background())
   311  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   312  					return conn
   313  				}
   314  
   315  				done := make(chan struct{})
   316  				go func() {
   317  					defer GinkgoRecover()
   318  					serv.handlePacket(p)
   319  					// the Handshake packet is written by the connection.
   320  					// Make sure there are no Write calls on the packet conn.
   321  					time.Sleep(50 * time.Millisecond)
   322  					close(done)
   323  				}()
   324  				// make sure we're using a server-generated connection ID
   325  				Eventually(run).Should(BeClosed())
   326  				Eventually(done).Should(BeClosed())
   327  			})
   328  
   329  			It("sends a Version Negotiation Packet for unsupported versions", func() {
   330  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   331  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   332  				packet := getPacket(&wire.Header{
   333  					Type:             protocol.PacketTypeHandshake,
   334  					SrcConnectionID:  srcConnID,
   335  					DestConnectionID: destConnID,
   336  					Version:          0x42,
   337  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   338  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   339  				packet.remoteAddr = raddr
   340  				tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber) {
   341  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   342  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   343  				})
   344  				done := make(chan struct{})
   345  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   346  					defer close(done)
   347  					Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
   348  					dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
   349  					Expect(err).ToNot(HaveOccurred())
   350  					Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
   351  					Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
   352  					Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
   353  					return len(b), nil
   354  				})
   355  				serv.handlePacket(packet)
   356  				Eventually(done).Should(BeClosed())
   357  			})
   358  
   359  			It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
   360  				serv.config.DisableVersionNegotiationPackets = true
   361  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   362  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   363  				packet := getPacket(&wire.Header{
   364  					Type:             protocol.PacketTypeHandshake,
   365  					SrcConnectionID:  srcConnID,
   366  					DestConnectionID: destConnID,
   367  					Version:          0x42,
   368  				}, make([]byte, protocol.MinUnknownVersionPacketSize))
   369  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   370  				packet.remoteAddr = raddr
   371  				done := make(chan struct{})
   372  				conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0)
   373  				serv.handlePacket(packet)
   374  				Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
   375  			})
   376  
   377  			It("ignores Version Negotiation packets", func() {
   378  				data := wire.ComposeVersionNegotiation(
   379  					protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
   380  					protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
   381  					[]protocol.VersionNumber{1, 2, 3},
   382  				)
   383  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   384  				done := make(chan struct{})
   385  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   386  					close(done)
   387  				})
   388  				serv.handlePacket(receivedPacket{
   389  					remoteAddr: raddr,
   390  					data:       data,
   391  					buffer:     getPacketBuffer(),
   392  				})
   393  				Eventually(done).Should(BeClosed())
   394  				// make sure no other packet is sent
   395  				time.Sleep(scaleDuration(20 * time.Millisecond))
   396  			})
   397  
   398  			It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
   399  				srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
   400  				destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   401  				p := getPacket(&wire.Header{
   402  					Type:             protocol.PacketTypeHandshake,
   403  					SrcConnectionID:  srcConnID,
   404  					DestConnectionID: destConnID,
   405  					Version:          0x42,
   406  				}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
   407  				Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
   408  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   409  				p.remoteAddr = raddr
   410  				done := make(chan struct{})
   411  				tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   412  					close(done)
   413  				})
   414  				serv.handlePacket(p)
   415  				Eventually(done).Should(BeClosed())
   416  				// make sure no other packet is sent
   417  				time.Sleep(scaleDuration(20 * time.Millisecond))
   418  			})
   419  
   420  			It("replies with a Retry packet, if a token is required", func() {
   421  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   422  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   423  				hdr := &wire.Header{
   424  					Type:             protocol.PacketTypeInitial,
   425  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   426  					DestConnectionID: connID,
   427  					Version:          protocol.Version1,
   428  				}
   429  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   430  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   431  				packet.remoteAddr = raddr
   432  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
   433  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   434  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   435  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   436  					Expect(replyHdr.Token).ToNot(BeEmpty())
   437  				})
   438  				done := make(chan struct{})
   439  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   440  					defer close(done)
   441  					replyHdr := parseHeader(b)
   442  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   443  					Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
   444  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   445  					Expect(replyHdr.Token).ToNot(BeEmpty())
   446  					Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
   447  					return len(b), nil
   448  				})
   449  				phm.EXPECT().Get(connID)
   450  				serv.handlePacket(packet)
   451  				Eventually(done).Should(BeClosed())
   452  			})
   453  
   454  			It("creates a connection, if no token is required", func() {
   455  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   456  				hdr := &wire.Header{
   457  					Type:             protocol.PacketTypeInitial,
   458  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   459  					DestConnectionID: connID,
   460  					Version:          protocol.Version1,
   461  				}
   462  				p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   463  				run := make(chan struct{})
   464  				var token protocol.StatelessResetToken
   465  				rand.Read(token[:])
   466  
   467  				var newConnID protocol.ConnectionID
   468  				gomock.InOrder(
   469  					phm.EXPECT().Get(connID),
   470  					phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   471  						newConnID = c
   472  						phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
   473  							newConnID = c
   474  							return token
   475  						})
   476  						_, ok := fn()
   477  						return ok
   478  					}),
   479  				)
   480  
   481  				conn := NewMockQUICConn(mockCtrl)
   482  				serv.newConn = func(
   483  					_ sendConn,
   484  					_ connRunner,
   485  					origDestConnID protocol.ConnectionID,
   486  					retrySrcConnID *protocol.ConnectionID,
   487  					clientDestConnID protocol.ConnectionID,
   488  					destConnID protocol.ConnectionID,
   489  					srcConnID protocol.ConnectionID,
   490  					_ ConnectionIDGenerator,
   491  					tokenP protocol.StatelessResetToken,
   492  					_ *Config,
   493  					_ *tls.Config,
   494  					_ *handshake.TokenGenerator,
   495  					_ bool,
   496  					_ logging.ConnectionTracer,
   497  					_ uint64,
   498  					_ utils.Logger,
   499  					_ protocol.VersionNumber,
   500  				) quicConn {
   501  					Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
   502  					Expect(retrySrcConnID).To(BeNil())
   503  					Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
   504  					Expect(destConnID).To(Equal(hdr.SrcConnectionID))
   505  					// make sure we're using a server-generated connection ID
   506  					Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
   507  					Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
   508  					Expect(srcConnID).To(Equal(newConnID))
   509  					Expect(tokenP).To(Equal(token))
   510  					conn.EXPECT().handlePacket(p)
   511  					conn.EXPECT().run().Do(func() { close(run) })
   512  					conn.EXPECT().Context().Return(context.Background())
   513  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   514  					return conn
   515  				}
   516  
   517  				done := make(chan struct{})
   518  				go func() {
   519  					defer GinkgoRecover()
   520  					serv.handlePacket(p)
   521  					// the Handshake packet is written by the connection
   522  					// make sure there are no Write calls on the packet conn
   523  					time.Sleep(50 * time.Millisecond)
   524  					close(done)
   525  				}()
   526  				// make sure we're using a server-generated connection ID
   527  				Eventually(run).Should(BeClosed())
   528  				Eventually(done).Should(BeClosed())
   529  			})
   530  
   531  			It("drops packets if the receive queue is full", func() {
   532  				phm.EXPECT().Get(gomock.Any()).AnyTimes()
   533  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   534  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
   535  					_, ok := fn()
   536  					return ok
   537  				}).AnyTimes()
   538  
   539  				acceptConn := make(chan struct{})
   540  				var counter uint32 // to be used as an atomic, so we query it in Eventually
   541  				serv.newConn = func(
   542  					_ sendConn,
   543  					runner connRunner,
   544  					_ protocol.ConnectionID,
   545  					_ *protocol.ConnectionID,
   546  					_ protocol.ConnectionID,
   547  					_ protocol.ConnectionID,
   548  					_ protocol.ConnectionID,
   549  					_ ConnectionIDGenerator,
   550  					_ protocol.StatelessResetToken,
   551  					_ *Config,
   552  					_ *tls.Config,
   553  					_ *handshake.TokenGenerator,
   554  					_ bool,
   555  					_ logging.ConnectionTracer,
   556  					_ uint64,
   557  					_ utils.Logger,
   558  					_ protocol.VersionNumber,
   559  				) quicConn {
   560  					<-acceptConn
   561  					atomic.AddUint32(&counter, 1)
   562  					conn := NewMockQUICConn(mockCtrl)
   563  					conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
   564  					conn.EXPECT().run().MaxTimes(1)
   565  					conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
   566  					conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
   567  					return conn
   568  				}
   569  
   570  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
   571  				serv.handlePacket(p)
   572  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
   573  				var wg sync.WaitGroup
   574  				for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
   575  					wg.Add(1)
   576  					go func() {
   577  						defer GinkgoRecover()
   578  						defer wg.Done()
   579  						serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
   580  					}()
   581  				}
   582  				wg.Wait()
   583  
   584  				close(acceptConn)
   585  				Eventually(
   586  					func() uint32 { return atomic.LoadUint32(&counter) },
   587  					scaleDuration(100*time.Millisecond),
   588  				).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   589  				Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
   590  			})
   591  
   592  			It("only creates a single connection for a duplicate Initial", func() {
   593  				var createdConn bool
   594  				serv.newConn = func(
   595  					_ sendConn,
   596  					runner connRunner,
   597  					_ protocol.ConnectionID,
   598  					_ *protocol.ConnectionID,
   599  					_ protocol.ConnectionID,
   600  					_ protocol.ConnectionID,
   601  					_ protocol.ConnectionID,
   602  					_ ConnectionIDGenerator,
   603  					_ protocol.StatelessResetToken,
   604  					_ *Config,
   605  					_ *tls.Config,
   606  					_ *handshake.TokenGenerator,
   607  					_ bool,
   608  					_ logging.ConnectionTracer,
   609  					_ uint64,
   610  					_ utils.Logger,
   611  					_ protocol.VersionNumber,
   612  				) quicConn {
   613  					createdConn = true
   614  					return NewMockQUICConn(mockCtrl)
   615  				}
   616  
   617  				connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
   618  				p := getInitial(connID)
   619  				phm.EXPECT().Get(connID)
   620  				phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
   621  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   622  				done := make(chan struct{})
   623  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
   624  				Expect(serv.handlePacketImpl(p)).To(BeTrue())
   625  				Expect(createdConn).To(BeFalse())
   626  				Eventually(done).Should(BeClosed())
   627  			})
   628  
   629  			It("rejects new connection attempts if the accept queue is full", func() {
   630  				serv.newConn = func(
   631  					_ sendConn,
   632  					runner connRunner,
   633  					_ protocol.ConnectionID,
   634  					_ *protocol.ConnectionID,
   635  					_ protocol.ConnectionID,
   636  					_ protocol.ConnectionID,
   637  					_ protocol.ConnectionID,
   638  					_ ConnectionIDGenerator,
   639  					_ protocol.StatelessResetToken,
   640  					_ *Config,
   641  					_ *tls.Config,
   642  					_ *handshake.TokenGenerator,
   643  					_ bool,
   644  					_ logging.ConnectionTracer,
   645  					_ uint64,
   646  					_ utils.Logger,
   647  					_ protocol.VersionNumber,
   648  				) quicConn {
   649  					conn := NewMockQUICConn(mockCtrl)
   650  					conn.EXPECT().handlePacket(gomock.Any())
   651  					conn.EXPECT().run()
   652  					conn.EXPECT().Context().Return(context.Background())
   653  					c := make(chan struct{})
   654  					close(c)
   655  					conn.EXPECT().HandshakeComplete().Return(c)
   656  					return conn
   657  				}
   658  
   659  				phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
   660  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   661  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
   662  					_, ok := fn()
   663  					return ok
   664  				}).Times(protocol.MaxAcceptQueueSize)
   665  
   666  				var wg sync.WaitGroup
   667  				wg.Add(protocol.MaxAcceptQueueSize)
   668  				for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
   669  					go func() {
   670  						defer GinkgoRecover()
   671  						defer wg.Done()
   672  						serv.handlePacket(getInitialWithRandomDestConnID())
   673  						// make sure there are no Write calls on the packet conn
   674  						time.Sleep(50 * time.Millisecond)
   675  					}()
   676  				}
   677  				wg.Wait()
   678  				p := getInitialWithRandomDestConnID()
   679  				hdr, _, _, err := wire.ParsePacket(p.data)
   680  				Expect(err).ToNot(HaveOccurred())
   681  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   682  				done := make(chan struct{})
   683  				conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   684  					defer close(done)
   685  					rejectHdr := parseHeader(b)
   686  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
   687  					Expect(rejectHdr.Version).To(Equal(hdr.Version))
   688  					Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   689  					Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   690  					return len(b), nil
   691  				})
   692  				serv.handlePacket(p)
   693  				Eventually(done).Should(BeClosed())
   694  			})
   695  
   696  			It("doesn't accept new connections if they were closed in the mean time", func() {
   697  				p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
   698  				ctx, cancel := context.WithCancel(context.Background())
   699  				connCreated := make(chan struct{})
   700  				conn := NewMockQUICConn(mockCtrl)
   701  				serv.newConn = func(
   702  					_ sendConn,
   703  					runner connRunner,
   704  					_ protocol.ConnectionID,
   705  					_ *protocol.ConnectionID,
   706  					_ protocol.ConnectionID,
   707  					_ protocol.ConnectionID,
   708  					_ protocol.ConnectionID,
   709  					_ ConnectionIDGenerator,
   710  					_ protocol.StatelessResetToken,
   711  					_ *Config,
   712  					_ *tls.Config,
   713  					_ *handshake.TokenGenerator,
   714  					_ bool,
   715  					_ logging.ConnectionTracer,
   716  					_ uint64,
   717  					_ utils.Logger,
   718  					_ protocol.VersionNumber,
   719  				) quicConn {
   720  					conn.EXPECT().handlePacket(p)
   721  					conn.EXPECT().run()
   722  					conn.EXPECT().Context().Return(ctx)
   723  					c := make(chan struct{})
   724  					close(c)
   725  					conn.EXPECT().HandshakeComplete().Return(c)
   726  					close(connCreated)
   727  					return conn
   728  				}
   729  
   730  				phm.EXPECT().Get(gomock.Any())
   731  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   732  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
   733  					_, ok := fn()
   734  					return ok
   735  				})
   736  
   737  				serv.handlePacket(p)
   738  				// make sure there are no Write calls on the packet conn
   739  				time.Sleep(50 * time.Millisecond)
   740  				Eventually(connCreated).Should(BeClosed())
   741  				cancel()
   742  				time.Sleep(scaleDuration(200 * time.Millisecond))
   743  
   744  				done := make(chan struct{})
   745  				go func() {
   746  					defer GinkgoRecover()
   747  					serv.Accept(context.Background())
   748  					close(done)
   749  				}()
   750  				Consistently(done).ShouldNot(BeClosed())
   751  
   752  				// make the go routine return
   753  				conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
   754  				Expect(serv.Close()).To(Succeed())
   755  				Eventually(done).Should(BeClosed())
   756  			})
   757  		})
   758  
   759  		Context("token validation", func() {
   760  			checkInvalidToken := func(b []byte, origHdr *wire.Header) {
   761  				replyHdr := parseHeader(b)
   762  				Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   763  				Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
   764  				Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
   765  				_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
   766  				extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
   767  				Expect(err).ToNot(HaveOccurred())
   768  				data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
   769  				Expect(err).ToNot(HaveOccurred())
   770  				_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
   771  				Expect(err).ToNot(HaveOccurred())
   772  				Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   773  				ccf := f.(*wire.ConnectionCloseFrame)
   774  				Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   775  				Expect(ccf.ReasonPhrase).To(BeEmpty())
   776  			}
   777  
   778  			It("decodes the token from the token field", func() {
   779  				raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
   780  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   781  				Expect(err).ToNot(HaveOccurred())
   782  				packet := getPacket(&wire.Header{
   783  					Type:    protocol.PacketTypeInitial,
   784  					Token:   token,
   785  					Version: serv.config.Versions[0],
   786  				}, make([]byte, protocol.MinInitialPacketSize))
   787  				packet.remoteAddr = raddr
   788  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
   789  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   790  
   791  				done := make(chan struct{})
   792  				phm.EXPECT().Get(gomock.Any())
   793  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
   794  				serv.handlePacket(packet)
   795  				Eventually(done).Should(BeClosed())
   796  			})
   797  
   798  			It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
   799  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   800  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   801  				Expect(err).ToNot(HaveOccurred())
   802  				hdr := &wire.Header{
   803  					Type:             protocol.PacketTypeInitial,
   804  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   805  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   806  					Token:            token,
   807  					Version:          protocol.Version1,
   808  				}
   809  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   810  				packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
   811  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   812  				packet.remoteAddr = raddr
   813  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   814  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   815  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   816  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   817  					Expect(frames).To(HaveLen(1))
   818  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   819  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   820  					Expect(ccf.IsApplicationError).To(BeFalse())
   821  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   822  				})
   823  				done := make(chan struct{})
   824  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   825  					defer close(done)
   826  					checkInvalidToken(b, hdr)
   827  					return len(b), nil
   828  				})
   829  				phm.EXPECT().Get(gomock.Any())
   830  				serv.handlePacket(packet)
   831  				Eventually(done).Should(BeClosed())
   832  			})
   833  
   834  			It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
   835  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   836  				serv.config.MaxRetryTokenAge = time.Millisecond
   837  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   838  				token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
   839  				Expect(err).ToNot(HaveOccurred())
   840  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   841  				hdr := &wire.Header{
   842  					Type:             protocol.PacketTypeInitial,
   843  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   844  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   845  					Token:            token,
   846  					Version:          protocol.Version1,
   847  				}
   848  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   849  				packet.remoteAddr = raddr
   850  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   851  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
   852  					Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
   853  					Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
   854  					Expect(frames).To(HaveLen(1))
   855  					Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
   856  					ccf := frames[0].(*logging.ConnectionCloseFrame)
   857  					Expect(ccf.IsApplicationError).To(BeFalse())
   858  					Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
   859  				})
   860  				done := make(chan struct{})
   861  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   862  					defer close(done)
   863  					checkInvalidToken(b, hdr)
   864  					return len(b), nil
   865  				})
   866  				phm.EXPECT().Get(gomock.Any())
   867  				serv.handlePacket(packet)
   868  				Eventually(done).Should(BeClosed())
   869  			})
   870  
   871  			It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
   872  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   873  				token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
   874  				Expect(err).ToNot(HaveOccurred())
   875  				hdr := &wire.Header{
   876  					Type:             protocol.PacketTypeInitial,
   877  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   878  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   879  					Token:            token,
   880  					Version:          protocol.Version1,
   881  				}
   882  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   883  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   884  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   885  				packet.remoteAddr = raddr
   886  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
   887  				done := make(chan struct{})
   888  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   889  					defer close(done)
   890  					replyHdr := parseHeader(b)
   891  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   892  					return len(b), nil
   893  				})
   894  				phm.EXPECT().Get(gomock.Any())
   895  				serv.handlePacket(packet)
   896  				// make sure there are no Write calls on the packet conn
   897  				Eventually(done).Should(BeClosed())
   898  			})
   899  
   900  			It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
   901  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   902  				serv.config.MaxTokenAge = time.Millisecond
   903  				raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   904  				token, err := serv.tokenGenerator.NewToken(raddr)
   905  				Expect(err).ToNot(HaveOccurred())
   906  				time.Sleep(2 * time.Millisecond) // make sure the token is expired
   907  				hdr := &wire.Header{
   908  					Type:             protocol.PacketTypeInitial,
   909  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   910  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   911  					Token:            token,
   912  					Version:          protocol.Version1,
   913  				}
   914  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   915  				packet.remoteAddr = raddr
   916  				tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
   917  					Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
   918  				})
   919  				done := make(chan struct{})
   920  				conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
   921  					defer close(done)
   922  					return len(b), nil
   923  				})
   924  				phm.EXPECT().Get(gomock.Any())
   925  				serv.handlePacket(packet)
   926  				Eventually(done).Should(BeClosed())
   927  			})
   928  
   929  			It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
   930  				serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
   931  				token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
   932  				Expect(err).ToNot(HaveOccurred())
   933  				hdr := &wire.Header{
   934  					Type:             protocol.PacketTypeInitial,
   935  					SrcConnectionID:  protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
   936  					DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
   937  					Token:            token,
   938  					Version:          protocol.Version1,
   939  				}
   940  				packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
   941  				packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
   942  				packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   943  				done := make(chan struct{})
   944  				tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
   945  				phm.EXPECT().Get(gomock.Any())
   946  				serv.handlePacket(packet)
   947  				// make sure there are no Write calls on the packet conn
   948  				time.Sleep(50 * time.Millisecond)
   949  				Eventually(done).Should(BeClosed())
   950  			})
   951  		})
   952  
   953  		Context("accepting connections", func() {
   954  			It("returns Accept when an error occurs", func() {
   955  				testErr := errors.New("test err")
   956  
   957  				done := make(chan struct{})
   958  				go func() {
   959  					defer GinkgoRecover()
   960  					_, err := serv.Accept(context.Background())
   961  					Expect(err).To(MatchError(testErr))
   962  					close(done)
   963  				}()
   964  
   965  				serv.setCloseError(testErr)
   966  				Eventually(done).Should(BeClosed())
   967  				serv.onClose() // shutdown
   968  			})
   969  
   970  			It("returns immediately, if an error occurred before", func() {
   971  				testErr := errors.New("test err")
   972  				serv.setCloseError(testErr)
   973  				for i := 0; i < 3; i++ {
   974  					_, err := serv.Accept(context.Background())
   975  					Expect(err).To(MatchError(testErr))
   976  				}
   977  				serv.onClose() // shutdown
   978  			})
   979  
   980  			It("returns when the context is canceled", func() {
   981  				ctx, cancel := context.WithCancel(context.Background())
   982  				done := make(chan struct{})
   983  				go func() {
   984  					defer GinkgoRecover()
   985  					_, err := serv.Accept(ctx)
   986  					Expect(err).To(MatchError("context canceled"))
   987  					close(done)
   988  				}()
   989  
   990  				Consistently(done).ShouldNot(BeClosed())
   991  				cancel()
   992  				Eventually(done).Should(BeClosed())
   993  			})
   994  
   995  			It("uses the config returned by GetConfigClient", func() {
   996  				conn := NewMockQUICConn(mockCtrl)
   997  
   998  				conf := &Config{MaxIncomingStreams: 1234}
   999  				serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
  1000  				done := make(chan struct{})
  1001  				go func() {
  1002  					defer GinkgoRecover()
  1003  					s, err := serv.Accept(context.Background())
  1004  					Expect(err).ToNot(HaveOccurred())
  1005  					Expect(s).To(Equal(conn))
  1006  					close(done)
  1007  				}()
  1008  
  1009  				handshakeChan := make(chan struct{})
  1010  				serv.newConn = func(
  1011  					_ sendConn,
  1012  					_ connRunner,
  1013  					_ protocol.ConnectionID,
  1014  					_ *protocol.ConnectionID,
  1015  					_ protocol.ConnectionID,
  1016  					_ protocol.ConnectionID,
  1017  					_ protocol.ConnectionID,
  1018  					_ ConnectionIDGenerator,
  1019  					_ protocol.StatelessResetToken,
  1020  					conf *Config,
  1021  					_ *tls.Config,
  1022  					_ *handshake.TokenGenerator,
  1023  					_ bool,
  1024  					_ logging.ConnectionTracer,
  1025  					_ uint64,
  1026  					_ utils.Logger,
  1027  					_ protocol.VersionNumber,
  1028  				) quicConn {
  1029  					Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
  1030  					conn.EXPECT().handlePacket(gomock.Any())
  1031  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1032  					conn.EXPECT().run().Do(func() {})
  1033  					conn.EXPECT().Context().Return(context.Background())
  1034  					return conn
  1035  				}
  1036  				phm.EXPECT().Get(gomock.Any())
  1037  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1038  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1039  					_, ok := fn()
  1040  					return ok
  1041  				})
  1042  				serv.handleInitialImpl(
  1043  					receivedPacket{buffer: getPacketBuffer()},
  1044  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1045  				)
  1046  				Consistently(done).ShouldNot(BeClosed())
  1047  				close(handshakeChan) // complete the handshake
  1048  				Eventually(done).Should(BeClosed())
  1049  			})
  1050  
  1051  			It("rejects a connection attempt when GetConfigClient returns an error", func() {
  1052  				serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
  1053  
  1054  				phm.EXPECT().Get(gomock.Any())
  1055  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1056  					_, ok := fn()
  1057  					return ok
  1058  				})
  1059  				done := make(chan struct{})
  1060  				tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
  1061  				conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1062  					defer close(done)
  1063  					rejectHdr := parseHeader(b)
  1064  					Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1065  					return len(b), nil
  1066  				})
  1067  				serv.handleInitialImpl(
  1068  					receivedPacket{buffer: getPacketBuffer()},
  1069  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
  1070  				)
  1071  				Eventually(done).Should(BeClosed())
  1072  			})
  1073  
  1074  			It("accepts new connections when the handshake completes", func() {
  1075  				conn := NewMockQUICConn(mockCtrl)
  1076  
  1077  				done := make(chan struct{})
  1078  				go func() {
  1079  					defer GinkgoRecover()
  1080  					s, err := serv.Accept(context.Background())
  1081  					Expect(err).ToNot(HaveOccurred())
  1082  					Expect(s).To(Equal(conn))
  1083  					close(done)
  1084  				}()
  1085  
  1086  				handshakeChan := make(chan struct{})
  1087  				serv.newConn = func(
  1088  					_ sendConn,
  1089  					runner connRunner,
  1090  					_ protocol.ConnectionID,
  1091  					_ *protocol.ConnectionID,
  1092  					_ protocol.ConnectionID,
  1093  					_ protocol.ConnectionID,
  1094  					_ protocol.ConnectionID,
  1095  					_ ConnectionIDGenerator,
  1096  					_ protocol.StatelessResetToken,
  1097  					_ *Config,
  1098  					_ *tls.Config,
  1099  					_ *handshake.TokenGenerator,
  1100  					_ bool,
  1101  					_ logging.ConnectionTracer,
  1102  					_ uint64,
  1103  					_ utils.Logger,
  1104  					_ protocol.VersionNumber,
  1105  				) quicConn {
  1106  					conn.EXPECT().handlePacket(gomock.Any())
  1107  					conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1108  					conn.EXPECT().run().Do(func() {})
  1109  					conn.EXPECT().Context().Return(context.Background())
  1110  					return conn
  1111  				}
  1112  				phm.EXPECT().Get(gomock.Any())
  1113  				phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1114  					phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1115  					_, ok := fn()
  1116  					return ok
  1117  				})
  1118  				serv.handleInitialImpl(
  1119  					receivedPacket{buffer: getPacketBuffer()},
  1120  					&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1121  				)
  1122  				Consistently(done).ShouldNot(BeClosed())
  1123  				close(handshakeChan) // complete the handshake
  1124  				Eventually(done).Should(BeClosed())
  1125  			})
  1126  		})
  1127  	})
  1128  
  1129  	Context("server accepting connections that haven't completed the handshake", func() {
  1130  		var (
  1131  			serv *EarlyListener
  1132  			phm  *MockPacketHandlerManager
  1133  		)
  1134  
  1135  		BeforeEach(func() {
  1136  			var err error
  1137  			serv, err = ListenEarly(conn, tlsConf, nil)
  1138  			Expect(err).ToNot(HaveOccurred())
  1139  			phm = NewMockPacketHandlerManager(mockCtrl)
  1140  			serv.baseServer.connHandler = phm
  1141  		})
  1142  
  1143  		AfterEach(func() {
  1144  			serv.Close()
  1145  		})
  1146  
  1147  		It("accepts new connections when they become ready", func() {
  1148  			conn := NewMockQUICConn(mockCtrl)
  1149  
  1150  			done := make(chan struct{})
  1151  			go func() {
  1152  				defer GinkgoRecover()
  1153  				s, err := serv.Accept(context.Background())
  1154  				Expect(err).ToNot(HaveOccurred())
  1155  				Expect(s).To(Equal(conn))
  1156  				close(done)
  1157  			}()
  1158  
  1159  			ready := make(chan struct{})
  1160  			serv.baseServer.newConn = func(
  1161  				_ sendConn,
  1162  				runner connRunner,
  1163  				_ protocol.ConnectionID,
  1164  				_ *protocol.ConnectionID,
  1165  				_ protocol.ConnectionID,
  1166  				_ protocol.ConnectionID,
  1167  				_ protocol.ConnectionID,
  1168  				_ ConnectionIDGenerator,
  1169  				_ protocol.StatelessResetToken,
  1170  				_ *Config,
  1171  				_ *tls.Config,
  1172  				_ *handshake.TokenGenerator,
  1173  				_ bool,
  1174  				_ logging.ConnectionTracer,
  1175  				_ uint64,
  1176  				_ utils.Logger,
  1177  				_ protocol.VersionNumber,
  1178  			) quicConn {
  1179  				conn.EXPECT().handlePacket(gomock.Any())
  1180  				conn.EXPECT().run().Do(func() {})
  1181  				conn.EXPECT().earlyConnReady().Return(ready)
  1182  				conn.EXPECT().Context().Return(context.Background())
  1183  				return conn
  1184  			}
  1185  			phm.EXPECT().Get(gomock.Any())
  1186  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1187  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1188  				_, ok := fn()
  1189  				return ok
  1190  			})
  1191  			serv.baseServer.handleInitialImpl(
  1192  				receivedPacket{buffer: getPacketBuffer()},
  1193  				&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
  1194  			)
  1195  			Consistently(done).ShouldNot(BeClosed())
  1196  			close(ready)
  1197  			Eventually(done).Should(BeClosed())
  1198  		})
  1199  
  1200  		It("rejects new connection attempts if the accept queue is full", func() {
  1201  			senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
  1202  
  1203  			serv.baseServer.newConn = func(
  1204  				_ sendConn,
  1205  				runner connRunner,
  1206  				_ protocol.ConnectionID,
  1207  				_ *protocol.ConnectionID,
  1208  				_ protocol.ConnectionID,
  1209  				_ protocol.ConnectionID,
  1210  				_ protocol.ConnectionID,
  1211  				_ ConnectionIDGenerator,
  1212  				_ protocol.StatelessResetToken,
  1213  				_ *Config,
  1214  				_ *tls.Config,
  1215  				_ *handshake.TokenGenerator,
  1216  				_ bool,
  1217  				_ logging.ConnectionTracer,
  1218  				_ uint64,
  1219  				_ utils.Logger,
  1220  				_ protocol.VersionNumber,
  1221  			) quicConn {
  1222  				ready := make(chan struct{})
  1223  				close(ready)
  1224  				conn := NewMockQUICConn(mockCtrl)
  1225  				conn.EXPECT().handlePacket(gomock.Any())
  1226  				conn.EXPECT().run()
  1227  				conn.EXPECT().earlyConnReady().Return(ready)
  1228  				conn.EXPECT().Context().Return(context.Background())
  1229  				return conn
  1230  			}
  1231  
  1232  			phm.EXPECT().Get(gomock.Any()).AnyTimes()
  1233  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1234  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1235  				_, ok := fn()
  1236  				return ok
  1237  			}).Times(protocol.MaxAcceptQueueSize)
  1238  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
  1239  				serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
  1240  			}
  1241  
  1242  			Eventually(func() int32 { return atomic.LoadInt32(&serv.baseServer.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
  1243  			// make sure there are no Write calls on the packet conn
  1244  			time.Sleep(50 * time.Millisecond)
  1245  
  1246  			p := getInitialWithRandomDestConnID()
  1247  			hdr := parseHeader(p.data)
  1248  			done := make(chan struct{})
  1249  			conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
  1250  				defer close(done)
  1251  				rejectHdr := parseHeader(b)
  1252  				Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
  1253  				Expect(rejectHdr.Version).To(Equal(hdr.Version))
  1254  				Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
  1255  				Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
  1256  				return len(b), nil
  1257  			})
  1258  			serv.baseServer.handlePacket(p)
  1259  			Eventually(done).Should(BeClosed())
  1260  		})
  1261  
  1262  		It("doesn't accept new connections if they were closed in the mean time", func() {
  1263  			p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
  1264  			ctx, cancel := context.WithCancel(context.Background())
  1265  			connCreated := make(chan struct{})
  1266  			conn := NewMockQUICConn(mockCtrl)
  1267  			serv.baseServer.newConn = func(
  1268  				_ sendConn,
  1269  				runner connRunner,
  1270  				_ protocol.ConnectionID,
  1271  				_ *protocol.ConnectionID,
  1272  				_ protocol.ConnectionID,
  1273  				_ protocol.ConnectionID,
  1274  				_ protocol.ConnectionID,
  1275  				_ ConnectionIDGenerator,
  1276  				_ protocol.StatelessResetToken,
  1277  				_ *Config,
  1278  				_ *tls.Config,
  1279  				_ *handshake.TokenGenerator,
  1280  				_ bool,
  1281  				_ logging.ConnectionTracer,
  1282  				_ uint64,
  1283  				_ utils.Logger,
  1284  				_ protocol.VersionNumber,
  1285  			) quicConn {
  1286  				conn.EXPECT().handlePacket(p)
  1287  				conn.EXPECT().run()
  1288  				conn.EXPECT().earlyConnReady()
  1289  				conn.EXPECT().Context().Return(ctx)
  1290  				close(connCreated)
  1291  				return conn
  1292  			}
  1293  
  1294  			phm.EXPECT().Get(gomock.Any())
  1295  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1296  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1297  				_, ok := fn()
  1298  				return ok
  1299  			})
  1300  			serv.baseServer.handlePacket(p)
  1301  			// make sure there are no Write calls on the packet conn
  1302  			time.Sleep(50 * time.Millisecond)
  1303  			Eventually(connCreated).Should(BeClosed())
  1304  			cancel()
  1305  			time.Sleep(scaleDuration(200 * time.Millisecond))
  1306  
  1307  			done := make(chan struct{})
  1308  			go func() {
  1309  				defer GinkgoRecover()
  1310  				serv.Accept(context.Background())
  1311  				close(done)
  1312  			}()
  1313  			Consistently(done).ShouldNot(BeClosed())
  1314  
  1315  			// make the go routine return
  1316  			conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
  1317  			Expect(serv.Close()).To(Succeed())
  1318  			Eventually(done).Should(BeClosed())
  1319  		})
  1320  	})
  1321  
  1322  	Context("0-RTT", func() {
  1323  		var (
  1324  			tr     *Transport
  1325  			serv   *baseServer
  1326  			phm    *MockPacketHandlerManager
  1327  			tracer *mocklogging.MockTracer
  1328  		)
  1329  
  1330  		BeforeEach(func() {
  1331  			tracer = mocklogging.NewMockTracer(mockCtrl)
  1332  			tr = &Transport{Conn: conn, Tracer: tracer}
  1333  			ln, err := tr.ListenEarly(tlsConf, nil)
  1334  			Expect(err).ToNot(HaveOccurred())
  1335  			phm = NewMockPacketHandlerManager(mockCtrl)
  1336  			serv = ln.baseServer
  1337  			serv.connHandler = phm
  1338  		})
  1339  
  1340  		AfterEach(func() {
  1341  			phm.EXPECT().CloseServer().MaxTimes(1)
  1342  			tr.Close()
  1343  		})
  1344  
  1345  		It("passes packets to existing connections", func() {
  1346  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1347  			p := getPacket(&wire.Header{
  1348  				Type:             protocol.PacketType0RTT,
  1349  				DestConnectionID: connID,
  1350  				Version:          serv.config.Versions[0],
  1351  			}, make([]byte, 100))
  1352  			conn := NewMockPacketHandler(mockCtrl)
  1353  			phm.EXPECT().Get(connID).Return(conn, true)
  1354  			handled := make(chan struct{})
  1355  			conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
  1356  			serv.handlePacket(p)
  1357  			Eventually(handled).Should(BeClosed())
  1358  		})
  1359  
  1360  		It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
  1361  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1362  
  1363  			var zeroRTTPackets []receivedPacket
  1364  
  1365  			for i := 0; i < protocol.Max0RTTQueueLen; i++ {
  1366  				p := getPacket(&wire.Header{
  1367  					Type:             protocol.PacketType0RTT,
  1368  					DestConnectionID: connID,
  1369  					Version:          serv.config.Versions[0],
  1370  				}, make([]byte, 100+i))
  1371  				phm.EXPECT().Get(connID)
  1372  				serv.handlePacket(p)
  1373  				zeroRTTPackets = append(zeroRTTPackets, p)
  1374  			}
  1375  
  1376  			// send one more packet, this one should be dropped
  1377  			p := getPacket(&wire.Header{
  1378  				Type:             protocol.PacketType0RTT,
  1379  				DestConnectionID: connID,
  1380  				Version:          serv.config.Versions[0],
  1381  			}, make([]byte, 200))
  1382  			phm.EXPECT().Get(connID)
  1383  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
  1384  			serv.handlePacket(p)
  1385  
  1386  			initial := getPacket(&wire.Header{
  1387  				Type:             protocol.PacketTypeInitial,
  1388  				DestConnectionID: connID,
  1389  				Version:          serv.config.Versions[0],
  1390  			}, make([]byte, protocol.MinInitialPacketSize))
  1391  			called := make(chan struct{})
  1392  			serv.newConn = func(
  1393  				_ sendConn,
  1394  				_ connRunner,
  1395  				_ protocol.ConnectionID,
  1396  				_ *protocol.ConnectionID,
  1397  				_ protocol.ConnectionID,
  1398  				_ protocol.ConnectionID,
  1399  				_ protocol.ConnectionID,
  1400  				_ ConnectionIDGenerator,
  1401  				_ protocol.StatelessResetToken,
  1402  				_ *Config,
  1403  				_ *tls.Config,
  1404  				_ *handshake.TokenGenerator,
  1405  				_ bool,
  1406  				_ logging.ConnectionTracer,
  1407  				_ uint64,
  1408  				_ utils.Logger,
  1409  				_ protocol.VersionNumber,
  1410  			) quicConn {
  1411  				conn := NewMockQUICConn(mockCtrl)
  1412  				var calls []*gomock.Call
  1413  				calls = append(calls, conn.EXPECT().handlePacket(initial))
  1414  				for _, p := range zeroRTTPackets {
  1415  					calls = append(calls, conn.EXPECT().handlePacket(p))
  1416  				}
  1417  				gomock.InOrder(calls...)
  1418  				conn.EXPECT().run()
  1419  				conn.EXPECT().earlyConnReady()
  1420  				conn.EXPECT().Context().Return(context.Background())
  1421  				close(called)
  1422  				return conn
  1423  			}
  1424  
  1425  			phm.EXPECT().Get(connID)
  1426  			phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
  1427  				phm.EXPECT().GetStatelessResetToken(gomock.Any())
  1428  				_, ok := fn()
  1429  				return ok
  1430  			})
  1431  			serv.handlePacket(initial)
  1432  			Eventually(called).Should(BeClosed())
  1433  		})
  1434  
  1435  		It("limits the number of queues", func() {
  1436  			for i := 0; i < protocol.Max0RTTQueues; i++ {
  1437  				b := make([]byte, 16)
  1438  				rand.Read(b)
  1439  				connID := protocol.ParseConnectionID(b)
  1440  				p := getPacket(&wire.Header{
  1441  					Type:             protocol.PacketType0RTT,
  1442  					DestConnectionID: connID,
  1443  					Version:          serv.config.Versions[0],
  1444  				}, make([]byte, 100+i))
  1445  				phm.EXPECT().Get(connID)
  1446  				serv.handlePacket(p)
  1447  			}
  1448  
  1449  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1450  			p := getPacket(&wire.Header{
  1451  				Type:             protocol.PacketType0RTT,
  1452  				DestConnectionID: connID,
  1453  				Version:          serv.config.Versions[0],
  1454  			}, make([]byte, 200))
  1455  			phm.EXPECT().Get(connID)
  1456  			dropped := make(chan struct{})
  1457  			tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1458  				close(dropped)
  1459  			})
  1460  			serv.handlePacket(p)
  1461  			Eventually(dropped).Should(BeClosed())
  1462  		})
  1463  
  1464  		It("drops queues after a while", func() {
  1465  			now := time.Now()
  1466  
  1467  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
  1468  			p := getPacket(&wire.Header{
  1469  				Type:             protocol.PacketType0RTT,
  1470  				DestConnectionID: connID,
  1471  				Version:          serv.config.Versions[0],
  1472  			}, make([]byte, 200))
  1473  			p.rcvTime = now
  1474  
  1475  			connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
  1476  			p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
  1477  			p2 := getPacket(&wire.Header{
  1478  				Type:             protocol.PacketType0RTT,
  1479  				DestConnectionID: connID2,
  1480  				Version:          serv.config.Versions[0],
  1481  			}, make([]byte, 300))
  1482  			p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
  1483  
  1484  			dropped1 := make(chan struct{})
  1485  			dropped2 := make(chan struct{})
  1486  			// need to register the call before handling the packet to avoid race condition
  1487  			gomock.InOrder(
  1488  				tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1489  					close(dropped1)
  1490  				}),
  1491  				tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
  1492  					close(dropped2)
  1493  				}),
  1494  			)
  1495  
  1496  			phm.EXPECT().Get(connID)
  1497  			serv.handlePacket(p)
  1498  
  1499  			// There's no cleanup Go routine.
  1500  			// Cleanup is triggered when new packets are received.
  1501  
  1502  			phm.EXPECT().Get(connID2)
  1503  			serv.handlePacket(p2)
  1504  			// make sure no cleanup is executed
  1505  			Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
  1506  
  1507  			// There's no cleanup Go routine.
  1508  			// Cleanup is triggered when new packets are received.
  1509  			connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
  1510  			p3 := getPacket(&wire.Header{
  1511  				Type:             protocol.PacketType0RTT,
  1512  				DestConnectionID: connID3,
  1513  				Version:          serv.config.Versions[0],
  1514  			}, make([]byte, 200))
  1515  			p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1516  			phm.EXPECT().Get(connID3)
  1517  			serv.handlePacket(p3)
  1518  			Eventually(dropped1).Should(BeClosed())
  1519  			Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
  1520  
  1521  			// make sure the second packet is also cleaned up
  1522  			connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
  1523  			p4 := getPacket(&wire.Header{
  1524  				Type:             protocol.PacketType0RTT,
  1525  				DestConnectionID: connID4,
  1526  				Version:          serv.config.Versions[0],
  1527  			}, make([]byte, 200))
  1528  			p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
  1529  			phm.EXPECT().Get(connID4)
  1530  			serv.handlePacket(p4)
  1531  			Eventually(dropped2).Should(BeClosed())
  1532  		})
  1533  	})
  1534  })