github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/self/zero_rtt_test.go (about)

     1  //go:build go1.21
     2  
     3  package self_test
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"fmt"
     9  	"io"
    10  	mrand "math/rand"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/tumi8/quic-go"
    17  	quicproxy "github.com/tumi8/quic-go/integrationtests/tools/proxy"
    18  	"github.com/tumi8/quic-go/noninternal/protocol"
    19  	"github.com/tumi8/quic-go/noninternal/wire"
    20  	"github.com/tumi8/quic-go/logging"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  )
    25  
    26  type metadataClientSessionCache struct {
    27  	toAdd    []byte
    28  	restored func([]byte)
    29  
    30  	cache tls.ClientSessionCache
    31  }
    32  
    33  func (m metadataClientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
    34  	session, ok := m.cache.Get(key)
    35  	if !ok || session == nil {
    36  		return session, ok
    37  	}
    38  	ticket, state, err := session.ResumptionState()
    39  	Expect(err).ToNot(HaveOccurred())
    40  	Expect(state.Extra).To(HaveLen(2)) // ours, and the quic-go's
    41  	m.restored(state.Extra[1])
    42  	session, err = tls.NewResumptionState(ticket, state)
    43  	Expect(err).ToNot(HaveOccurred())
    44  	return session, true
    45  }
    46  
    47  func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionState) {
    48  	ticket, state, err := session.ResumptionState()
    49  	Expect(err).ToNot(HaveOccurred())
    50  	state.Extra = append(state.Extra, m.toAdd)
    51  	session, err = tls.NewResumptionState(ticket, state)
    52  	Expect(err).ToNot(HaveOccurred())
    53  	m.cache.Put(key, session)
    54  }
    55  
    56  var _ = Describe("0-RTT", func() {
    57  	rtt := scaleDuration(5 * time.Millisecond)
    58  
    59  	runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) {
    60  		var num0RTTPackets uint32 // to be used as an atomic
    61  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    62  			RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
    63  			DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
    64  				for len(data) > 0 {
    65  					if !wire.IsLongHeaderPacket(data[0]) {
    66  						break
    67  					}
    68  					hdr, _, rest, err := wire.ParsePacket(data)
    69  					Expect(err).ToNot(HaveOccurred())
    70  					if hdr.Type == protocol.PacketType0RTT {
    71  						atomic.AddUint32(&num0RTTPackets, 1)
    72  						break
    73  					}
    74  					data = rest
    75  				}
    76  				return rtt / 2
    77  			},
    78  		})
    79  		Expect(err).ToNot(HaveOccurred())
    80  
    81  		return proxy, &num0RTTPackets
    82  	}
    83  
    84  	dialAndReceiveSessionTicket := func(serverTLSConf *tls.Config, serverConf *quic.Config, clientTLSConf *tls.Config) {
    85  		if serverConf == nil {
    86  			serverConf = getQuicConfig(nil)
    87  		}
    88  		serverConf.Allow0RTT = true
    89  		ln, err := quic.ListenAddrEarly(
    90  			"localhost:0",
    91  			serverTLSConf,
    92  			serverConf,
    93  		)
    94  		Expect(err).ToNot(HaveOccurred())
    95  		defer ln.Close()
    96  
    97  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    98  			RemoteAddr:  fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
    99  			DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 },
   100  		})
   101  		Expect(err).ToNot(HaveOccurred())
   102  		defer proxy.Close()
   103  
   104  		// dial the first connection in order to receive a session ticket
   105  		done := make(chan struct{})
   106  		go func() {
   107  			defer GinkgoRecover()
   108  			defer close(done)
   109  			conn, err := ln.Accept(context.Background())
   110  			Expect(err).ToNot(HaveOccurred())
   111  			<-conn.Context().Done()
   112  		}()
   113  
   114  		puts := make(chan string, 100)
   115  		cache := clientTLSConf.ClientSessionCache
   116  		if cache == nil {
   117  			cache = tls.NewLRUClientSessionCache(100)
   118  		}
   119  		clientTLSConf.ClientSessionCache = newClientSessionCache(cache, make(chan string, 100), puts)
   120  		conn, err := quic.DialAddr(
   121  			context.Background(),
   122  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   123  			clientTLSConf,
   124  			getQuicConfig(nil),
   125  		)
   126  		Expect(err).ToNot(HaveOccurred())
   127  		Eventually(puts).Should(Receive())
   128  		// received the session ticket. We're done here.
   129  		Expect(conn.CloseWithError(0, "")).To(Succeed())
   130  		Eventually(done).Should(BeClosed())
   131  	}
   132  
   133  	transfer0RTTData := func(
   134  		ln *quic.EarlyListener,
   135  		proxyPort int,
   136  		connIDLen int,
   137  		clientTLSConf *tls.Config,
   138  		clientConf *quic.Config,
   139  		testdata []byte, // data to transfer
   140  	) {
   141  		// accept the second connection, and receive the data sent in 0-RTT
   142  		done := make(chan struct{})
   143  		go func() {
   144  			defer GinkgoRecover()
   145  			conn, err := ln.Accept(context.Background())
   146  			Expect(err).ToNot(HaveOccurred())
   147  			str, err := conn.AcceptStream(context.Background())
   148  			Expect(err).ToNot(HaveOccurred())
   149  			data, err := io.ReadAll(str)
   150  			Expect(err).ToNot(HaveOccurred())
   151  			Expect(data).To(Equal(testdata))
   152  			Expect(str.Close()).To(Succeed())
   153  			Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   154  			<-conn.Context().Done()
   155  			close(done)
   156  		}()
   157  
   158  		if clientConf == nil {
   159  			clientConf = getQuicConfig(nil)
   160  		}
   161  		var conn quic.EarlyConnection
   162  		if connIDLen == 0 {
   163  			var err error
   164  			conn, err = quic.DialAddrEarly(
   165  				context.Background(),
   166  				fmt.Sprintf("localhost:%d", proxyPort),
   167  				clientTLSConf,
   168  				clientConf,
   169  			)
   170  			Expect(err).ToNot(HaveOccurred())
   171  		} else {
   172  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   173  			Expect(err).ToNot(HaveOccurred())
   174  			udpConn, err := net.ListenUDP("udp", addr)
   175  			Expect(err).ToNot(HaveOccurred())
   176  			defer udpConn.Close()
   177  			tr := &quic.Transport{
   178  				Conn:               udpConn,
   179  				ConnectionIDLength: connIDLen,
   180  			}
   181  			defer tr.Close()
   182  			conn, err = tr.DialEarly(
   183  				context.Background(),
   184  				&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
   185  				clientTLSConf,
   186  				clientConf,
   187  			)
   188  			Expect(err).ToNot(HaveOccurred())
   189  		}
   190  		defer conn.CloseWithError(0, "")
   191  		str, err := conn.OpenStream()
   192  		Expect(err).ToNot(HaveOccurred())
   193  		_, err = str.Write(testdata)
   194  		Expect(err).ToNot(HaveOccurred())
   195  		Expect(str.Close()).To(Succeed())
   196  		<-conn.HandshakeComplete()
   197  		Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   198  		io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
   199  		conn.CloseWithError(0, "")
   200  		Eventually(done).Should(BeClosed())
   201  		Eventually(conn.Context().Done()).Should(BeClosed())
   202  	}
   203  
   204  	check0RTTRejected := func(
   205  		ln *quic.EarlyListener,
   206  		proxyPort int,
   207  		clientConf *tls.Config,
   208  	) {
   209  		conn, err := quic.DialAddrEarly(
   210  			context.Background(),
   211  			fmt.Sprintf("localhost:%d", proxyPort),
   212  			clientConf,
   213  			getQuicConfig(nil),
   214  		)
   215  		Expect(err).ToNot(HaveOccurred())
   216  		str, err := conn.OpenUniStream()
   217  		Expect(err).ToNot(HaveOccurred())
   218  		_, err = str.Write(make([]byte, 3000))
   219  		Expect(err).ToNot(HaveOccurred())
   220  		Expect(str.Close()).To(Succeed())
   221  		Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
   222  
   223  		// make sure the server doesn't process the data
   224  		ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
   225  		defer cancel()
   226  		serverConn, err := ln.Accept(ctx)
   227  		Expect(err).ToNot(HaveOccurred())
   228  		Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
   229  		_, err = serverConn.AcceptUniStream(ctx)
   230  		Expect(err).To(Equal(context.DeadlineExceeded))
   231  		Expect(serverConn.CloseWithError(0, "")).To(Succeed())
   232  		Eventually(conn.Context().Done()).Should(BeClosed())
   233  	}
   234  
   235  	// can be used to extract 0-RTT from a packetTracer
   236  	get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
   237  		var zeroRTTPackets []protocol.PacketNumber
   238  		for _, p := range packets {
   239  			if p.hdr.Type == protocol.PacketType0RTT {
   240  				zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber)
   241  			}
   242  		}
   243  		return zeroRTTPackets
   244  	}
   245  
   246  	for _, l := range []int{0, 15} {
   247  		connIDLen := l
   248  
   249  		It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
   250  			tlsConf := getTLSConfig()
   251  			clientTLSConf := getTLSClientConfig()
   252  			dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
   253  
   254  			tracer := newPacketTracer()
   255  			ln, err := quic.ListenAddrEarly(
   256  				"localhost:0",
   257  				tlsConf,
   258  				getQuicConfig(&quic.Config{
   259  					Allow0RTT: true,
   260  					Tracer:    newTracer(tracer),
   261  				}),
   262  			)
   263  			Expect(err).ToNot(HaveOccurred())
   264  			defer ln.Close()
   265  
   266  			proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   267  			defer proxy.Close()
   268  
   269  			transfer0RTTData(
   270  				ln,
   271  				proxy.LocalPort(),
   272  				connIDLen,
   273  				clientTLSConf,
   274  				getQuicConfig(nil),
   275  				PRData,
   276  			)
   277  
   278  			var numNewConnIDs int
   279  			for _, p := range tracer.getRcvdLongHeaderPackets() {
   280  				for _, f := range p.frames {
   281  					if _, ok := f.(*logging.NewConnectionIDFrame); ok {
   282  						numNewConnIDs++
   283  					}
   284  				}
   285  			}
   286  			if connIDLen == 0 {
   287  				Expect(numNewConnIDs).To(BeZero())
   288  			} else {
   289  				Expect(numNewConnIDs).ToNot(BeZero())
   290  			}
   291  
   292  			num0RTT := atomic.LoadUint32(num0RTTPackets)
   293  			fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   294  			Expect(num0RTT).ToNot(BeZero())
   295  			zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
   296  			Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
   297  			Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
   298  		})
   299  	}
   300  
   301  	// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
   302  	It("waits for a connection until the handshake is done", func() {
   303  		tlsConf := getTLSConfig()
   304  		clientConf := getTLSClientConfig()
   305  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   306  
   307  		zeroRTTData := GeneratePRData(5 << 10)
   308  		oneRTTData := PRData
   309  
   310  		tracer := newPacketTracer()
   311  		ln, err := quic.ListenAddrEarly(
   312  			"localhost:0",
   313  			tlsConf,
   314  			getQuicConfig(&quic.Config{
   315  				Allow0RTT: true,
   316  				Tracer:    newTracer(tracer),
   317  			}),
   318  		)
   319  		Expect(err).ToNot(HaveOccurred())
   320  		defer ln.Close()
   321  
   322  		// now accept the second connection, and receive the 0-RTT data
   323  		go func() {
   324  			defer GinkgoRecover()
   325  			conn, err := ln.Accept(context.Background())
   326  			Expect(err).ToNot(HaveOccurred())
   327  			str, err := conn.AcceptUniStream(context.Background())
   328  			Expect(err).ToNot(HaveOccurred())
   329  			data, err := io.ReadAll(str)
   330  			Expect(err).ToNot(HaveOccurred())
   331  			Expect(data).To(Equal(zeroRTTData))
   332  			str, err = conn.AcceptUniStream(context.Background())
   333  			Expect(err).ToNot(HaveOccurred())
   334  			data, err = io.ReadAll(str)
   335  			Expect(err).ToNot(HaveOccurred())
   336  			Expect(data).To(Equal(oneRTTData))
   337  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   338  		}()
   339  
   340  		proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   341  		defer proxy.Close()
   342  
   343  		conn, err := quic.DialAddrEarly(
   344  			context.Background(),
   345  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   346  			clientConf,
   347  			getQuicConfig(nil),
   348  		)
   349  		Expect(err).ToNot(HaveOccurred())
   350  		firstStr, err := conn.OpenUniStream()
   351  		Expect(err).ToNot(HaveOccurred())
   352  		_, err = firstStr.Write(zeroRTTData)
   353  		Expect(err).ToNot(HaveOccurred())
   354  		Expect(firstStr.Close()).To(Succeed())
   355  
   356  		// wait for the handshake to complete
   357  		Eventually(conn.HandshakeComplete()).Should(BeClosed())
   358  		str, err := conn.OpenUniStream()
   359  		Expect(err).ToNot(HaveOccurred())
   360  		_, err = str.Write(PRData)
   361  		Expect(err).ToNot(HaveOccurred())
   362  		Expect(str.Close()).To(Succeed())
   363  		<-conn.Context().Done()
   364  
   365  		// check that 0-RTT packets only contain STREAM frames for the first stream
   366  		var num0RTT int
   367  		for _, p := range tracer.getRcvdLongHeaderPackets() {
   368  			if p.hdr.Header.Type != protocol.PacketType0RTT {
   369  				continue
   370  			}
   371  			for _, f := range p.frames {
   372  				sf, ok := f.(*logging.StreamFrame)
   373  				if !ok {
   374  					continue
   375  				}
   376  				num0RTT++
   377  				Expect(sf.StreamID).To(Equal(firstStr.StreamID()))
   378  			}
   379  		}
   380  		fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT)
   381  		Expect(num0RTT).ToNot(BeZero())
   382  	})
   383  
   384  	It("transfers 0-RTT data, when 0-RTT packets are lost", func() {
   385  		var (
   386  			num0RTTPackets uint32 // to be used as an atomic
   387  			num0RTTDropped uint32
   388  		)
   389  
   390  		tlsConf := getTLSConfig()
   391  		clientConf := getTLSClientConfig()
   392  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   393  
   394  		tracer := newPacketTracer()
   395  		ln, err := quic.ListenAddrEarly(
   396  			"localhost:0",
   397  			tlsConf,
   398  			getQuicConfig(&quic.Config{
   399  				Allow0RTT: true,
   400  				Tracer:    newTracer(tracer),
   401  			}),
   402  		)
   403  		Expect(err).ToNot(HaveOccurred())
   404  		defer ln.Close()
   405  
   406  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   407  			RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   408  			DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
   409  				if wire.IsLongHeaderPacket(data[0]) {
   410  					hdr, _, _, err := wire.ParsePacket(data)
   411  					Expect(err).ToNot(HaveOccurred())
   412  					if hdr.Type == protocol.PacketType0RTT {
   413  						atomic.AddUint32(&num0RTTPackets, 1)
   414  					}
   415  				}
   416  				return rtt / 2
   417  			},
   418  			DropPacket: func(_ quicproxy.Direction, data []byte) bool {
   419  				if !wire.IsLongHeaderPacket(data[0]) {
   420  					return false
   421  				}
   422  				hdr, _, _, err := wire.ParsePacket(data)
   423  				Expect(err).ToNot(HaveOccurred())
   424  				if hdr.Type == protocol.PacketType0RTT {
   425  					// drop 25% of the 0-RTT packets
   426  					drop := mrand.Intn(4) == 0
   427  					if drop {
   428  						atomic.AddUint32(&num0RTTDropped, 1)
   429  					}
   430  					return drop
   431  				}
   432  				return false
   433  			},
   434  		})
   435  		Expect(err).ToNot(HaveOccurred())
   436  		defer proxy.Close()
   437  
   438  		transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
   439  
   440  		num0RTT := atomic.LoadUint32(&num0RTTPackets)
   441  		numDropped := atomic.LoadUint32(&num0RTTDropped)
   442  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
   443  		Expect(numDropped).ToNot(BeZero())
   444  		Expect(num0RTT).ToNot(BeZero())
   445  		Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
   446  	})
   447  
   448  	It("retransmits all 0-RTT data when the server performs a Retry", func() {
   449  		var mutex sync.Mutex
   450  		var firstConnID, secondConnID *protocol.ConnectionID
   451  		var firstCounter, secondCounter protocol.ByteCount
   452  
   453  		tlsConf := getTLSConfig()
   454  		clientConf := getTLSClientConfig()
   455  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   456  
   457  		countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
   458  			for len(data) > 0 {
   459  				hdr, _, rest, err := wire.ParsePacket(data)
   460  				if err != nil {
   461  					return
   462  				}
   463  				data = rest
   464  				if hdr.Type == protocol.PacketType0RTT {
   465  					n += hdr.Length - 16 /* AEAD tag */
   466  				}
   467  			}
   468  			return
   469  		}
   470  
   471  		tracer := newPacketTracer()
   472  		ln, err := quic.ListenAddrEarly(
   473  			"localhost:0",
   474  			tlsConf,
   475  			getQuicConfig(&quic.Config{
   476  				RequireAddressValidation: func(net.Addr) bool { return true },
   477  				Allow0RTT:                true,
   478  				Tracer:                   newTracer(tracer),
   479  			}),
   480  		)
   481  		Expect(err).ToNot(HaveOccurred())
   482  		defer ln.Close()
   483  
   484  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   485  			RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   486  			DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
   487  				connID, err := wire.ParseConnectionID(data, 0)
   488  				Expect(err).ToNot(HaveOccurred())
   489  
   490  				mutex.Lock()
   491  				defer mutex.Unlock()
   492  
   493  				if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 {
   494  					if firstConnID == nil {
   495  						firstConnID = &connID
   496  						firstCounter += zeroRTTBytes
   497  					} else if firstConnID != nil && *firstConnID == connID {
   498  						Expect(secondConnID).To(BeNil())
   499  						firstCounter += zeroRTTBytes
   500  					} else if secondConnID == nil {
   501  						secondConnID = &connID
   502  						secondCounter += zeroRTTBytes
   503  					} else if secondConnID != nil && *secondConnID == connID {
   504  						secondCounter += zeroRTTBytes
   505  					} else {
   506  						Fail("received 3 connection IDs on 0-RTT packets")
   507  					}
   508  				}
   509  				return rtt / 2
   510  			},
   511  		})
   512  		Expect(err).ToNot(HaveOccurred())
   513  		defer proxy.Close()
   514  
   515  		transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
   516  
   517  		mutex.Lock()
   518  		defer mutex.Unlock()
   519  		Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
   520  		Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
   521  		zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
   522  		Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
   523  		Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
   524  	})
   525  
   526  	It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
   527  		const maxStreams = 1
   528  		tlsConf := getTLSConfig()
   529  		clientConf := getTLSClientConfig()
   530  		dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
   531  			MaxIncomingUniStreams: maxStreams,
   532  		}), clientConf)
   533  
   534  		tracer := newPacketTracer()
   535  		ln, err := quic.ListenAddrEarly(
   536  			"localhost:0",
   537  			tlsConf,
   538  			getQuicConfig(&quic.Config{
   539  				MaxIncomingUniStreams: maxStreams + 1,
   540  				Allow0RTT:             true,
   541  				Tracer:                newTracer(tracer),
   542  			}),
   543  		)
   544  		Expect(err).ToNot(HaveOccurred())
   545  		defer ln.Close()
   546  		proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   547  		defer proxy.Close()
   548  
   549  		conn, err := quic.DialAddrEarly(
   550  			context.Background(),
   551  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   552  			clientConf,
   553  			getQuicConfig(nil),
   554  		)
   555  		Expect(err).ToNot(HaveOccurred())
   556  		str, err := conn.OpenUniStream()
   557  		Expect(err).ToNot(HaveOccurred())
   558  		_, err = str.Write([]byte("foobar"))
   559  		Expect(err).ToNot(HaveOccurred())
   560  		Expect(str.Close()).To(Succeed())
   561  		// The client remembers the old limit and refuses to open a new stream.
   562  		_, err = conn.OpenUniStream()
   563  		Expect(err).To(HaveOccurred())
   564  		Expect(err.Error()).To(ContainSubstring("too many open streams"))
   565  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   566  		defer cancel()
   567  		_, err = conn.OpenUniStreamSync(ctx)
   568  		Expect(err).ToNot(HaveOccurred())
   569  		Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   570  		Expect(conn.CloseWithError(0, "")).To(Succeed())
   571  	})
   572  
   573  	It("rejects 0-RTT when the server's stream limit decreased", func() {
   574  		const maxStreams = 42
   575  		tlsConf := getTLSConfig()
   576  		clientConf := getTLSClientConfig()
   577  		dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
   578  			MaxIncomingStreams: maxStreams,
   579  		}), clientConf)
   580  
   581  		tracer := newPacketTracer()
   582  		ln, err := quic.ListenAddrEarly(
   583  			"localhost:0",
   584  			tlsConf,
   585  			getQuicConfig(&quic.Config{
   586  				MaxIncomingStreams: maxStreams - 1,
   587  				Allow0RTT:          true,
   588  				Tracer:             newTracer(tracer),
   589  			}),
   590  		)
   591  		Expect(err).ToNot(HaveOccurred())
   592  		defer ln.Close()
   593  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   594  		defer proxy.Close()
   595  
   596  		check0RTTRejected(ln, proxy.LocalPort(), clientConf)
   597  
   598  		// The client should send 0-RTT packets, but the server doesn't process them.
   599  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   600  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   601  		Expect(num0RTT).ToNot(BeZero())
   602  		Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
   603  	})
   604  
   605  	It("rejects 0-RTT when the ALPN changed", func() {
   606  		tlsConf := getTLSConfig()
   607  		clientConf := getTLSClientConfig()
   608  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   609  
   610  		// switch to different ALPN on the server side
   611  		tlsConf.NextProtos = []string{"new-alpn"}
   612  		// Append to the client's ALPN.
   613  		// crypto/tls will attempt to resume with the ALPN from the original connection
   614  		clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn")
   615  		tracer := newPacketTracer()
   616  		ln, err := quic.ListenAddrEarly(
   617  			"localhost:0",
   618  			tlsConf,
   619  			getQuicConfig(&quic.Config{
   620  				Allow0RTT: true,
   621  				Tracer:    newTracer(tracer),
   622  			}),
   623  		)
   624  		Expect(err).ToNot(HaveOccurred())
   625  		defer ln.Close()
   626  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   627  		defer proxy.Close()
   628  
   629  		check0RTTRejected(ln, proxy.LocalPort(), clientConf)
   630  
   631  		// The client should send 0-RTT packets, but the server doesn't process them.
   632  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   633  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   634  		Expect(num0RTT).ToNot(BeZero())
   635  		Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
   636  	})
   637  
   638  	It("rejects 0-RTT when the application doesn't allow it", func() {
   639  		tlsConf := getTLSConfig()
   640  		clientConf := getTLSClientConfig()
   641  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   642  
   643  		// now close the listener and dial new connection with a different ALPN
   644  		tracer := newPacketTracer()
   645  		ln, err := quic.ListenAddrEarly(
   646  			"localhost:0",
   647  			tlsConf,
   648  			getQuicConfig(&quic.Config{
   649  				Allow0RTT: false, // application rejects 0-RTT
   650  				Tracer:    newTracer(tracer),
   651  			}),
   652  		)
   653  		Expect(err).ToNot(HaveOccurred())
   654  		defer ln.Close()
   655  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   656  		defer proxy.Close()
   657  
   658  		check0RTTRejected(ln, proxy.LocalPort(), clientConf)
   659  
   660  		// The client should send 0-RTT packets, but the server doesn't process them.
   661  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   662  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   663  		Expect(num0RTT).ToNot(BeZero())
   664  		Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
   665  	})
   666  
   667  	DescribeTable("flow control limits",
   668  		func(addFlowControlLimit func(*quic.Config, uint64)) {
   669  			tracer := newPacketTracer()
   670  			firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
   671  			addFlowControlLimit(firstConf, 3)
   672  			tlsConf := getTLSConfig()
   673  			clientConf := getTLSClientConfig()
   674  			dialAndReceiveSessionTicket(tlsConf, firstConf, clientConf)
   675  
   676  			secondConf := getQuicConfig(&quic.Config{
   677  				Allow0RTT: true,
   678  				Tracer:    newTracer(tracer),
   679  			})
   680  			addFlowControlLimit(secondConf, 100)
   681  			ln, err := quic.ListenAddrEarly(
   682  				"localhost:0",
   683  				tlsConf,
   684  				secondConf,
   685  			)
   686  			Expect(err).ToNot(HaveOccurred())
   687  			defer ln.Close()
   688  			proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   689  			defer proxy.Close()
   690  
   691  			conn, err := quic.DialAddrEarly(
   692  				context.Background(),
   693  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   694  				clientConf,
   695  				getQuicConfig(nil),
   696  			)
   697  			Expect(err).ToNot(HaveOccurred())
   698  			str, err := conn.OpenUniStream()
   699  			Expect(err).ToNot(HaveOccurred())
   700  			written := make(chan struct{})
   701  			go func() {
   702  				defer GinkgoRecover()
   703  				defer close(written)
   704  				_, err := str.Write([]byte("foobar"))
   705  				Expect(err).ToNot(HaveOccurred())
   706  				Expect(str.Close()).To(Succeed())
   707  			}()
   708  
   709  			Eventually(written).Should(BeClosed())
   710  
   711  			serverConn, err := ln.Accept(context.Background())
   712  			Expect(err).ToNot(HaveOccurred())
   713  			rstr, err := serverConn.AcceptUniStream(context.Background())
   714  			Expect(err).ToNot(HaveOccurred())
   715  			data, err := io.ReadAll(rstr)
   716  			Expect(err).ToNot(HaveOccurred())
   717  			Expect(data).To(Equal([]byte("foobar")))
   718  			Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue())
   719  			Expect(serverConn.CloseWithError(0, "")).To(Succeed())
   720  			Eventually(conn.Context().Done()).Should(BeClosed())
   721  
   722  			var processedFirst bool
   723  			for _, p := range tracer.getRcvdLongHeaderPackets() {
   724  				for _, f := range p.frames {
   725  					if sf, ok := f.(*logging.StreamFrame); ok {
   726  						if !processedFirst {
   727  							// The first STREAM should have been sent in a 0-RTT packet.
   728  							// Due to the flow control limit, the STREAM frame was limit to the first 3 bytes.
   729  							Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT))
   730  							Expect(sf.Length).To(BeEquivalentTo(3))
   731  							processedFirst = true
   732  						} else {
   733  							Fail("STREAM was shouldn't have been sent in 0-RTT")
   734  						}
   735  					}
   736  				}
   737  			}
   738  		},
   739  		Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }),
   740  		Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }),
   741  	)
   742  
   743  	for _, l := range []int{0, 15} {
   744  		connIDLen := l
   745  
   746  		It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
   747  			tlsConf := getTLSConfig()
   748  			clientConf := getTLSClientConfig()
   749  			dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   750  			// now dial new connection with different transport parameters
   751  			tracer := newPacketTracer()
   752  			ln, err := quic.ListenAddrEarly(
   753  				"localhost:0",
   754  				tlsConf,
   755  				getQuicConfig(&quic.Config{
   756  					MaxIncomingUniStreams: 1,
   757  					Tracer:                newTracer(tracer),
   758  				}),
   759  			)
   760  			Expect(err).ToNot(HaveOccurred())
   761  			defer ln.Close()
   762  			proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   763  			defer proxy.Close()
   764  
   765  			conn, err := quic.DialAddrEarly(
   766  				context.Background(),
   767  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   768  				clientConf,
   769  				getQuicConfig(nil),
   770  			)
   771  			Expect(err).ToNot(HaveOccurred())
   772  			// The client remembers that it was allowed to open 2 uni-directional streams.
   773  			firstStr, err := conn.OpenUniStream()
   774  			Expect(err).ToNot(HaveOccurred())
   775  			written := make(chan struct{}, 2)
   776  			go func() {
   777  				defer GinkgoRecover()
   778  				defer func() { written <- struct{}{} }()
   779  				_, err := firstStr.Write([]byte("first flight"))
   780  				Expect(err).ToNot(HaveOccurred())
   781  			}()
   782  			secondStr, err := conn.OpenUniStream()
   783  			Expect(err).ToNot(HaveOccurred())
   784  			go func() {
   785  				defer GinkgoRecover()
   786  				defer func() { written <- struct{}{} }()
   787  				_, err := secondStr.Write([]byte("first flight"))
   788  				Expect(err).ToNot(HaveOccurred())
   789  			}()
   790  
   791  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   792  			defer cancel()
   793  			_, err = conn.AcceptStream(ctx)
   794  			Expect(err).To(MatchError(quic.Err0RTTRejected))
   795  			Eventually(written).Should(Receive())
   796  			Eventually(written).Should(Receive())
   797  			_, err = firstStr.Write([]byte("foobar"))
   798  			Expect(err).To(MatchError(quic.Err0RTTRejected))
   799  			_, err = conn.OpenUniStream()
   800  			Expect(err).To(MatchError(quic.Err0RTTRejected))
   801  
   802  			_, err = conn.AcceptStream(ctx)
   803  			Expect(err).To(Equal(quic.Err0RTTRejected))
   804  
   805  			newConn := conn.NextConnection()
   806  			str, err := newConn.OpenUniStream()
   807  			Expect(err).ToNot(HaveOccurred())
   808  			_, err = newConn.OpenUniStream()
   809  			Expect(err).To(HaveOccurred())
   810  			Expect(err.Error()).To(ContainSubstring("too many open streams"))
   811  			_, err = str.Write([]byte("second flight"))
   812  			Expect(err).ToNot(HaveOccurred())
   813  			Expect(str.Close()).To(Succeed())
   814  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   815  
   816  			// The client should send 0-RTT packets, but the server doesn't process them.
   817  			num0RTT := atomic.LoadUint32(num0RTTPackets)
   818  			fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   819  			Expect(num0RTT).ToNot(BeZero())
   820  			Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
   821  		})
   822  	}
   823  
   824  	It("queues 0-RTT packets, if the Initial is delayed", func() {
   825  		tlsConf := getTLSConfig()
   826  		clientConf := getTLSClientConfig()
   827  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   828  
   829  		tracer := newPacketTracer()
   830  		ln, err := quic.ListenAddrEarly(
   831  			"localhost:0",
   832  			tlsConf,
   833  			getQuicConfig(&quic.Config{
   834  				Allow0RTT: true,
   835  				Tracer:    newTracer(tracer),
   836  			}),
   837  		)
   838  		Expect(err).ToNot(HaveOccurred())
   839  		defer ln.Close()
   840  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   841  			RemoteAddr: ln.Addr().String(),
   842  			DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
   843  				if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
   844  					return rtt/2 + rtt
   845  				}
   846  				return rtt / 2
   847  			},
   848  		})
   849  		Expect(err).ToNot(HaveOccurred())
   850  		defer proxy.Close()
   851  
   852  		transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
   853  
   854  		Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
   855  		zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
   856  		Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
   857  		Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
   858  	})
   859  
   860  	It("allows the application to attach data to the session ticket, for the server", func() {
   861  		tlsConf := getTLSConfig()
   862  		tlsConf.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) {
   863  			ss.Extra = append(ss.Extra, []byte("foobar"))
   864  			return tlsConf.EncryptTicket(cs, ss)
   865  		}
   866  		var unwrapped bool
   867  		tlsConf.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) {
   868  			defer GinkgoRecover()
   869  			state, err := tlsConf.DecryptTicket(identity, cs)
   870  			if err != nil {
   871  				return nil, err
   872  			}
   873  			Expect(state.Extra).To(HaveLen(2))
   874  			Expect(state.Extra[1]).To(Equal([]byte("foobar")))
   875  			unwrapped = true
   876  			return state, nil
   877  		}
   878  		clientTLSConf := getTLSClientConfig()
   879  		dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
   880  
   881  		tracer := newPacketTracer()
   882  		ln, err := quic.ListenAddrEarly(
   883  			"localhost:0",
   884  			tlsConf,
   885  			getQuicConfig(&quic.Config{
   886  				Allow0RTT: true,
   887  				Tracer:    newTracer(tracer),
   888  			}),
   889  		)
   890  		Expect(err).ToNot(HaveOccurred())
   891  		defer ln.Close()
   892  
   893  		transfer0RTTData(
   894  			ln,
   895  			ln.Addr().(*net.UDPAddr).Port,
   896  			10,
   897  			clientTLSConf,
   898  			getQuicConfig(nil),
   899  			PRData,
   900  		)
   901  		Expect(unwrapped).To(BeTrue())
   902  	})
   903  
   904  	It("allows the application to attach data to the session ticket, for the client", func() {
   905  		tlsConf := getTLSConfig()
   906  		clientTLSConf := getTLSClientConfig()
   907  		var restored bool
   908  		clientTLSConf.ClientSessionCache = &metadataClientSessionCache{
   909  			toAdd: []byte("foobar"),
   910  			restored: func(b []byte) {
   911  				defer GinkgoRecover()
   912  				Expect(b).To(Equal([]byte("foobar")))
   913  				restored = true
   914  			},
   915  			cache: tls.NewLRUClientSessionCache(100),
   916  		}
   917  		dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
   918  
   919  		tracer := newPacketTracer()
   920  		ln, err := quic.ListenAddrEarly(
   921  			"localhost:0",
   922  			tlsConf,
   923  			getQuicConfig(&quic.Config{
   924  				Allow0RTT: true,
   925  				Tracer:    newTracer(tracer),
   926  			}),
   927  		)
   928  		Expect(err).ToNot(HaveOccurred())
   929  		defer ln.Close()
   930  
   931  		transfer0RTTData(
   932  			ln,
   933  			ln.Addr().(*net.UDPAddr).Port,
   934  			10,
   935  			clientTLSConf,
   936  			getQuicConfig(nil),
   937  			PRData,
   938  		)
   939  		Expect(restored).To(BeTrue())
   940  	})
   941  })