github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/zero_rtt_test.go (about)

     1  //go:build go1.21
     2  
     3  package self_test
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"fmt"
     9  	"io"
    10  	mrand "math/rand"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/MerlinKodo/quic-go"
    17  	quicproxy "github.com/MerlinKodo/quic-go/integrationtests/tools/proxy"
    18  	"github.com/MerlinKodo/quic-go/internal/protocol"
    19  	"github.com/MerlinKodo/quic-go/internal/wire"
    20  	"github.com/MerlinKodo/quic-go/logging"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  )
    25  
    26  type metadataClientSessionCache struct {
    27  	toAdd    []byte
    28  	restored func([]byte)
    29  
    30  	cache tls.ClientSessionCache
    31  }
    32  
    33  func (m metadataClientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
    34  	session, ok := m.cache.Get(key)
    35  	if !ok || session == nil {
    36  		return session, ok
    37  	}
    38  	ticket, state, err := session.ResumptionState()
    39  	Expect(err).ToNot(HaveOccurred())
    40  	Expect(state.Extra).To(HaveLen(2)) // ours, and the quic-go's
    41  	m.restored(state.Extra[1])
    42  	session, err = tls.NewResumptionState(ticket, state)
    43  	Expect(err).ToNot(HaveOccurred())
    44  	return session, true
    45  }
    46  
    47  func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionState) {
    48  	ticket, state, err := session.ResumptionState()
    49  	Expect(err).ToNot(HaveOccurred())
    50  	state.Extra = append(state.Extra, m.toAdd)
    51  	session, err = tls.NewResumptionState(ticket, state)
    52  	Expect(err).ToNot(HaveOccurred())
    53  	m.cache.Put(key, session)
    54  }
    55  
    56  var _ = Describe("0-RTT", func() {
    57  	rtt := scaleDuration(5 * time.Millisecond)
    58  
    59  	runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) {
    60  		var num0RTTPackets uint32 // to be used as an atomic
    61  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    62  			RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
    63  			DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
    64  				for len(data) > 0 {
    65  					if !wire.IsLongHeaderPacket(data[0]) {
    66  						break
    67  					}
    68  					hdr, _, rest, err := wire.ParsePacket(data)
    69  					Expect(err).ToNot(HaveOccurred())
    70  					if hdr.Type == protocol.PacketType0RTT {
    71  						atomic.AddUint32(&num0RTTPackets, 1)
    72  						break
    73  					}
    74  					data = rest
    75  				}
    76  				return rtt / 2
    77  			},
    78  		})
    79  		Expect(err).ToNot(HaveOccurred())
    80  
    81  		return proxy, &num0RTTPackets
    82  	}
    83  
    84  	dialAndReceiveSessionTicket := func(serverTLSConf *tls.Config, serverConf *quic.Config, clientTLSConf *tls.Config) {
    85  		if serverConf == nil {
    86  			serverConf = getQuicConfig(nil)
    87  		}
    88  		serverConf.Allow0RTT = true
    89  		ln, err := quic.ListenAddrEarly(
    90  			"localhost:0",
    91  			serverTLSConf,
    92  			serverConf,
    93  		)
    94  		Expect(err).ToNot(HaveOccurred())
    95  		defer ln.Close()
    96  
    97  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
    98  			RemoteAddr:  fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
    99  			DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 },
   100  		})
   101  		Expect(err).ToNot(HaveOccurred())
   102  		defer proxy.Close()
   103  
   104  		// dial the first connection in order to receive a session ticket
   105  		done := make(chan struct{})
   106  		go func() {
   107  			defer GinkgoRecover()
   108  			defer close(done)
   109  			conn, err := ln.Accept(context.Background())
   110  			Expect(err).ToNot(HaveOccurred())
   111  			<-conn.Context().Done()
   112  		}()
   113  
   114  		puts := make(chan string, 100)
   115  		cache := clientTLSConf.ClientSessionCache
   116  		if cache == nil {
   117  			cache = tls.NewLRUClientSessionCache(100)
   118  		}
   119  		clientTLSConf.ClientSessionCache = newClientSessionCache(cache, make(chan string, 100), puts)
   120  		conn, err := quic.DialAddr(
   121  			context.Background(),
   122  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   123  			clientTLSConf,
   124  			getQuicConfig(nil),
   125  		)
   126  		Expect(err).ToNot(HaveOccurred())
   127  		Eventually(puts).Should(Receive())
   128  		// received the session ticket. We're done here.
   129  		Expect(conn.CloseWithError(0, "")).To(Succeed())
   130  		Eventually(done).Should(BeClosed())
   131  	}
   132  
   133  	transfer0RTTData := func(
   134  		ln *quic.EarlyListener,
   135  		proxyPort int,
   136  		connIDLen int,
   137  		clientTLSConf *tls.Config,
   138  		clientConf *quic.Config,
   139  		testdata []byte, // data to transfer
   140  	) {
   141  		// accept the second connection, and receive the data sent in 0-RTT
   142  		done := make(chan struct{})
   143  		go func() {
   144  			defer GinkgoRecover()
   145  			conn, err := ln.Accept(context.Background())
   146  			Expect(err).ToNot(HaveOccurred())
   147  			str, err := conn.AcceptStream(context.Background())
   148  			Expect(err).ToNot(HaveOccurred())
   149  			data, err := io.ReadAll(str)
   150  			Expect(err).ToNot(HaveOccurred())
   151  			Expect(data).To(Equal(testdata))
   152  			Expect(str.Close()).To(Succeed())
   153  			Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   154  			<-conn.Context().Done()
   155  			close(done)
   156  		}()
   157  
   158  		if clientConf == nil {
   159  			clientConf = getQuicConfig(nil)
   160  		}
   161  		var conn quic.EarlyConnection
   162  		if connIDLen == 0 {
   163  			var err error
   164  			conn, err = quic.DialAddrEarly(
   165  				context.Background(),
   166  				fmt.Sprintf("localhost:%d", proxyPort),
   167  				clientTLSConf,
   168  				clientConf,
   169  			)
   170  			Expect(err).ToNot(HaveOccurred())
   171  		} else {
   172  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   173  			Expect(err).ToNot(HaveOccurred())
   174  			udpConn, err := net.ListenUDP("udp", addr)
   175  			Expect(err).ToNot(HaveOccurred())
   176  			defer udpConn.Close()
   177  			tr := &quic.Transport{
   178  				Conn:               udpConn,
   179  				ConnectionIDLength: connIDLen,
   180  			}
   181  			defer tr.Close()
   182  			conn, err = tr.DialEarly(
   183  				context.Background(),
   184  				&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
   185  				clientTLSConf,
   186  				clientConf,
   187  			)
   188  			Expect(err).ToNot(HaveOccurred())
   189  		}
   190  		defer conn.CloseWithError(0, "")
   191  		str, err := conn.OpenStream()
   192  		Expect(err).ToNot(HaveOccurred())
   193  		_, err = str.Write(testdata)
   194  		Expect(err).ToNot(HaveOccurred())
   195  		Expect(str.Close()).To(Succeed())
   196  		<-conn.HandshakeComplete()
   197  		Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   198  		io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
   199  		conn.CloseWithError(0, "")
   200  		Eventually(done).Should(BeClosed())
   201  		Eventually(conn.Context().Done()).Should(BeClosed())
   202  	}
   203  
   204  	check0RTTRejected := func(
   205  		ln *quic.EarlyListener,
   206  		proxyPort int,
   207  		clientConf *tls.Config,
   208  	) {
   209  		conn, err := quic.DialAddrEarly(
   210  			context.Background(),
   211  			fmt.Sprintf("localhost:%d", proxyPort),
   212  			clientConf,
   213  			getQuicConfig(nil),
   214  		)
   215  		Expect(err).ToNot(HaveOccurred())
   216  		str, err := conn.OpenUniStream()
   217  		Expect(err).ToNot(HaveOccurred())
   218  		_, err = str.Write(make([]byte, 3000))
   219  		Expect(err).ToNot(HaveOccurred())
   220  		Expect(str.Close()).To(Succeed())
   221  		Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
   222  
   223  		// make sure the server doesn't process the data
   224  		ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
   225  		defer cancel()
   226  		serverConn, err := ln.Accept(ctx)
   227  		Expect(err).ToNot(HaveOccurred())
   228  		Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
   229  		_, err = serverConn.AcceptUniStream(ctx)
   230  		Expect(err).To(Equal(context.DeadlineExceeded))
   231  		Expect(serverConn.CloseWithError(0, "")).To(Succeed())
   232  		Eventually(conn.Context().Done()).Should(BeClosed())
   233  	}
   234  
   235  	// can be used to extract 0-RTT from a packetCounter
   236  	get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
   237  		var zeroRTTPackets []protocol.PacketNumber
   238  		for _, p := range packets {
   239  			if p.hdr.Type == protocol.PacketType0RTT {
   240  				zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber)
   241  			}
   242  		}
   243  		return zeroRTTPackets
   244  	}
   245  
   246  	for _, l := range []int{0, 15} {
   247  		connIDLen := l
   248  
   249  		It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
   250  			tlsConf := getTLSConfig()
   251  			clientTLSConf := getTLSClientConfig()
   252  			dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
   253  
   254  			counter, tracer := newPacketTracer()
   255  			ln, err := quic.ListenAddrEarly(
   256  				"localhost:0",
   257  				tlsConf,
   258  				getQuicConfig(&quic.Config{
   259  					Allow0RTT: true,
   260  					Tracer:    newTracer(tracer),
   261  				}),
   262  			)
   263  			Expect(err).ToNot(HaveOccurred())
   264  			defer ln.Close()
   265  
   266  			proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   267  			defer proxy.Close()
   268  
   269  			transfer0RTTData(
   270  				ln,
   271  				proxy.LocalPort(),
   272  				connIDLen,
   273  				clientTLSConf,
   274  				getQuicConfig(nil),
   275  				PRData,
   276  			)
   277  
   278  			var numNewConnIDs int
   279  			for _, p := range counter.getRcvdLongHeaderPackets() {
   280  				for _, f := range p.frames {
   281  					if _, ok := f.(*logging.NewConnectionIDFrame); ok {
   282  						numNewConnIDs++
   283  					}
   284  				}
   285  			}
   286  			if connIDLen == 0 {
   287  				Expect(numNewConnIDs).To(BeZero())
   288  			} else {
   289  				Expect(numNewConnIDs).ToNot(BeZero())
   290  			}
   291  
   292  			num0RTT := atomic.LoadUint32(num0RTTPackets)
   293  			fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   294  			Expect(num0RTT).ToNot(BeZero())
   295  			zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
   296  			Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
   297  			Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
   298  		})
   299  	}
   300  
   301  	// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
   302  	It("waits for a connection until the handshake is done", func() {
   303  		tlsConf := getTLSConfig()
   304  		clientConf := getTLSClientConfig()
   305  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   306  
   307  		zeroRTTData := GeneratePRData(5 << 10)
   308  		oneRTTData := PRData
   309  
   310  		counter, tracer := newPacketTracer()
   311  		ln, err := quic.ListenAddrEarly(
   312  			"localhost:0",
   313  			tlsConf,
   314  			getQuicConfig(&quic.Config{
   315  				Allow0RTT: true,
   316  				Tracer:    newTracer(tracer),
   317  			}),
   318  		)
   319  		Expect(err).ToNot(HaveOccurred())
   320  		defer ln.Close()
   321  
   322  		// now accept the second connection, and receive the 0-RTT data
   323  		go func() {
   324  			defer GinkgoRecover()
   325  			conn, err := ln.Accept(context.Background())
   326  			Expect(err).ToNot(HaveOccurred())
   327  			str, err := conn.AcceptUniStream(context.Background())
   328  			Expect(err).ToNot(HaveOccurred())
   329  			data, err := io.ReadAll(str)
   330  			Expect(err).ToNot(HaveOccurred())
   331  			Expect(data).To(Equal(zeroRTTData))
   332  			str, err = conn.AcceptUniStream(context.Background())
   333  			Expect(err).ToNot(HaveOccurred())
   334  			data, err = io.ReadAll(str)
   335  			Expect(err).ToNot(HaveOccurred())
   336  			Expect(data).To(Equal(oneRTTData))
   337  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   338  		}()
   339  
   340  		proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   341  		defer proxy.Close()
   342  
   343  		conn, err := quic.DialAddrEarly(
   344  			context.Background(),
   345  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   346  			clientConf,
   347  			getQuicConfig(nil),
   348  		)
   349  		Expect(err).ToNot(HaveOccurred())
   350  		firstStr, err := conn.OpenUniStream()
   351  		Expect(err).ToNot(HaveOccurred())
   352  		_, err = firstStr.Write(zeroRTTData)
   353  		Expect(err).ToNot(HaveOccurred())
   354  		Expect(firstStr.Close()).To(Succeed())
   355  
   356  		// wait for the handshake to complete
   357  		Eventually(conn.HandshakeComplete()).Should(BeClosed())
   358  		str, err := conn.OpenUniStream()
   359  		Expect(err).ToNot(HaveOccurred())
   360  		_, err = str.Write(PRData)
   361  		Expect(err).ToNot(HaveOccurred())
   362  		Expect(str.Close()).To(Succeed())
   363  		<-conn.Context().Done()
   364  
   365  		// check that 0-RTT packets only contain STREAM frames for the first stream
   366  		var num0RTT int
   367  		for _, p := range counter.getRcvdLongHeaderPackets() {
   368  			if p.hdr.Header.Type != protocol.PacketType0RTT {
   369  				continue
   370  			}
   371  			for _, f := range p.frames {
   372  				sf, ok := f.(*logging.StreamFrame)
   373  				if !ok {
   374  					continue
   375  				}
   376  				num0RTT++
   377  				Expect(sf.StreamID).To(Equal(firstStr.StreamID()))
   378  			}
   379  		}
   380  		fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT)
   381  		Expect(num0RTT).ToNot(BeZero())
   382  	})
   383  
   384  	It("transfers 0-RTT data, when 0-RTT packets are lost", func() {
   385  		var (
   386  			num0RTTPackets uint32 // to be used as an atomic
   387  			num0RTTDropped uint32
   388  		)
   389  
   390  		tlsConf := getTLSConfig()
   391  		clientConf := getTLSClientConfig()
   392  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   393  
   394  		counter, tracer := newPacketTracer()
   395  		ln, err := quic.ListenAddrEarly(
   396  			"localhost:0",
   397  			tlsConf,
   398  			getQuicConfig(&quic.Config{
   399  				Allow0RTT: true,
   400  				Tracer:    newTracer(tracer),
   401  			}),
   402  		)
   403  		Expect(err).ToNot(HaveOccurred())
   404  		defer ln.Close()
   405  
   406  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   407  			RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   408  			DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
   409  				if wire.IsLongHeaderPacket(data[0]) {
   410  					hdr, _, _, err := wire.ParsePacket(data)
   411  					Expect(err).ToNot(HaveOccurred())
   412  					if hdr.Type == protocol.PacketType0RTT {
   413  						atomic.AddUint32(&num0RTTPackets, 1)
   414  					}
   415  				}
   416  				return rtt / 2
   417  			},
   418  			DropPacket: func(_ quicproxy.Direction, data []byte) bool {
   419  				if !wire.IsLongHeaderPacket(data[0]) {
   420  					return false
   421  				}
   422  				hdr, _, _, err := wire.ParsePacket(data)
   423  				Expect(err).ToNot(HaveOccurred())
   424  				if hdr.Type == protocol.PacketType0RTT {
   425  					// drop 25% of the 0-RTT packets
   426  					drop := mrand.Intn(4) == 0
   427  					if drop {
   428  						atomic.AddUint32(&num0RTTDropped, 1)
   429  					}
   430  					return drop
   431  				}
   432  				return false
   433  			},
   434  		})
   435  		Expect(err).ToNot(HaveOccurred())
   436  		defer proxy.Close()
   437  
   438  		transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
   439  
   440  		num0RTT := atomic.LoadUint32(&num0RTTPackets)
   441  		numDropped := atomic.LoadUint32(&num0RTTDropped)
   442  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
   443  		Expect(numDropped).ToNot(BeZero())
   444  		Expect(num0RTT).ToNot(BeZero())
   445  		Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
   446  	})
   447  
   448  	It("retransmits all 0-RTT data when the server performs a Retry", func() {
   449  		var mutex sync.Mutex
   450  		var firstConnID, secondConnID *protocol.ConnectionID
   451  		var firstCounter, secondCounter protocol.ByteCount
   452  
   453  		tlsConf := getTLSConfig()
   454  		clientConf := getTLSClientConfig()
   455  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   456  
   457  		countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
   458  			for len(data) > 0 {
   459  				hdr, _, rest, err := wire.ParsePacket(data)
   460  				if err != nil {
   461  					return
   462  				}
   463  				data = rest
   464  				if hdr.Type == protocol.PacketType0RTT {
   465  					n += hdr.Length - 16 /* AEAD tag */
   466  				}
   467  			}
   468  			return
   469  		}
   470  
   471  		counter, tracer := newPacketTracer()
   472  		ln, err := quic.ListenAddrEarly(
   473  			"localhost:0",
   474  			tlsConf,
   475  			getQuicConfig(&quic.Config{
   476  				RequireAddressValidation: func(net.Addr) bool { return true },
   477  				Allow0RTT:                true,
   478  				Tracer:                   newTracer(tracer),
   479  			}),
   480  		)
   481  		Expect(err).ToNot(HaveOccurred())
   482  		defer ln.Close()
   483  
   484  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   485  			RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   486  			DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
   487  				connID, err := wire.ParseConnectionID(data, 0)
   488  				Expect(err).ToNot(HaveOccurred())
   489  
   490  				mutex.Lock()
   491  				defer mutex.Unlock()
   492  
   493  				if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 {
   494  					if firstConnID == nil {
   495  						firstConnID = &connID
   496  						firstCounter += zeroRTTBytes
   497  					} else if firstConnID != nil && *firstConnID == connID {
   498  						Expect(secondConnID).To(BeNil())
   499  						firstCounter += zeroRTTBytes
   500  					} else if secondConnID == nil {
   501  						secondConnID = &connID
   502  						secondCounter += zeroRTTBytes
   503  					} else if secondConnID != nil && *secondConnID == connID {
   504  						secondCounter += zeroRTTBytes
   505  					} else {
   506  						Fail("received 3 connection IDs on 0-RTT packets")
   507  					}
   508  				}
   509  				return rtt / 2
   510  			},
   511  		})
   512  		Expect(err).ToNot(HaveOccurred())
   513  		defer proxy.Close()
   514  
   515  		transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
   516  
   517  		mutex.Lock()
   518  		defer mutex.Unlock()
   519  		Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
   520  		Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
   521  		zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
   522  		Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
   523  		Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
   524  	})
   525  
   526  	It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
   527  		const maxStreams = 1
   528  		tlsConf := getTLSConfig()
   529  		clientConf := getTLSClientConfig()
   530  		dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
   531  			MaxIncomingUniStreams: maxStreams,
   532  		}), clientConf)
   533  
   534  		ln, err := quic.ListenAddrEarly(
   535  			"localhost:0",
   536  			tlsConf,
   537  			getQuicConfig(&quic.Config{
   538  				MaxIncomingUniStreams: maxStreams + 1,
   539  				Allow0RTT:             true,
   540  			}),
   541  		)
   542  		Expect(err).ToNot(HaveOccurred())
   543  		defer ln.Close()
   544  		proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   545  		defer proxy.Close()
   546  
   547  		conn, err := quic.DialAddrEarly(
   548  			context.Background(),
   549  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   550  			clientConf,
   551  			getQuicConfig(nil),
   552  		)
   553  		Expect(err).ToNot(HaveOccurred())
   554  		str, err := conn.OpenUniStream()
   555  		Expect(err).ToNot(HaveOccurred())
   556  		_, err = str.Write([]byte("foobar"))
   557  		Expect(err).ToNot(HaveOccurred())
   558  		Expect(str.Close()).To(Succeed())
   559  		// The client remembers the old limit and refuses to open a new stream.
   560  		_, err = conn.OpenUniStream()
   561  		Expect(err).To(HaveOccurred())
   562  		Expect(err.Error()).To(ContainSubstring("too many open streams"))
   563  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   564  		defer cancel()
   565  		_, err = conn.OpenUniStreamSync(ctx)
   566  		Expect(err).ToNot(HaveOccurred())
   567  		Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   568  		Expect(conn.CloseWithError(0, "")).To(Succeed())
   569  	})
   570  
   571  	It("rejects 0-RTT when the server's stream limit decreased", func() {
   572  		const maxStreams = 42
   573  		tlsConf := getTLSConfig()
   574  		clientConf := getTLSClientConfig()
   575  		dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
   576  			MaxIncomingStreams: maxStreams,
   577  		}), clientConf)
   578  
   579  		counter, tracer := newPacketTracer()
   580  		ln, err := quic.ListenAddrEarly(
   581  			"localhost:0",
   582  			tlsConf,
   583  			getQuicConfig(&quic.Config{
   584  				MaxIncomingStreams: maxStreams - 1,
   585  				Allow0RTT:          true,
   586  				Tracer:             newTracer(tracer),
   587  			}),
   588  		)
   589  		Expect(err).ToNot(HaveOccurred())
   590  		defer ln.Close()
   591  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   592  		defer proxy.Close()
   593  
   594  		check0RTTRejected(ln, proxy.LocalPort(), clientConf)
   595  
   596  		// The client should send 0-RTT packets, but the server doesn't process them.
   597  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   598  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   599  		Expect(num0RTT).ToNot(BeZero())
   600  		Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
   601  	})
   602  
   603  	It("rejects 0-RTT when the ALPN changed", func() {
   604  		tlsConf := getTLSConfig()
   605  		clientConf := getTLSClientConfig()
   606  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   607  
   608  		// switch to different ALPN on the server side
   609  		tlsConf.NextProtos = []string{"new-alpn"}
   610  		// Append to the client's ALPN.
   611  		// crypto/tls will attempt to resume with the ALPN from the original connection
   612  		clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn")
   613  		counter, tracer := newPacketTracer()
   614  		ln, err := quic.ListenAddrEarly(
   615  			"localhost:0",
   616  			tlsConf,
   617  			getQuicConfig(&quic.Config{
   618  				Allow0RTT: true,
   619  				Tracer:    newTracer(tracer),
   620  			}),
   621  		)
   622  		Expect(err).ToNot(HaveOccurred())
   623  		defer ln.Close()
   624  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   625  		defer proxy.Close()
   626  
   627  		check0RTTRejected(ln, proxy.LocalPort(), clientConf)
   628  
   629  		// The client should send 0-RTT packets, but the server doesn't process them.
   630  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   631  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   632  		Expect(num0RTT).ToNot(BeZero())
   633  		Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
   634  	})
   635  
   636  	It("rejects 0-RTT when the application doesn't allow it", func() {
   637  		tlsConf := getTLSConfig()
   638  		clientConf := getTLSClientConfig()
   639  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   640  
   641  		// now close the listener and dial new connection with a different ALPN
   642  		counter, tracer := newPacketTracer()
   643  		ln, err := quic.ListenAddrEarly(
   644  			"localhost:0",
   645  			tlsConf,
   646  			getQuicConfig(&quic.Config{
   647  				Allow0RTT: false, // application rejects 0-RTT
   648  				Tracer:    newTracer(tracer),
   649  			}),
   650  		)
   651  		Expect(err).ToNot(HaveOccurred())
   652  		defer ln.Close()
   653  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   654  		defer proxy.Close()
   655  
   656  		check0RTTRejected(ln, proxy.LocalPort(), clientConf)
   657  
   658  		// The client should send 0-RTT packets, but the server doesn't process them.
   659  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   660  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   661  		Expect(num0RTT).ToNot(BeZero())
   662  		Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
   663  	})
   664  
   665  	DescribeTable("flow control limits",
   666  		func(addFlowControlLimit func(*quic.Config, uint64)) {
   667  			counter, tracer := newPacketTracer()
   668  			firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
   669  			addFlowControlLimit(firstConf, 3)
   670  			tlsConf := getTLSConfig()
   671  			clientConf := getTLSClientConfig()
   672  			dialAndReceiveSessionTicket(tlsConf, firstConf, clientConf)
   673  
   674  			secondConf := getQuicConfig(&quic.Config{
   675  				Allow0RTT: true,
   676  				Tracer:    newTracer(tracer),
   677  			})
   678  			addFlowControlLimit(secondConf, 100)
   679  			ln, err := quic.ListenAddrEarly(
   680  				"localhost:0",
   681  				tlsConf,
   682  				secondConf,
   683  			)
   684  			Expect(err).ToNot(HaveOccurred())
   685  			defer ln.Close()
   686  			proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   687  			defer proxy.Close()
   688  
   689  			conn, err := quic.DialAddrEarly(
   690  				context.Background(),
   691  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   692  				clientConf,
   693  				getQuicConfig(nil),
   694  			)
   695  			Expect(err).ToNot(HaveOccurred())
   696  			str, err := conn.OpenUniStream()
   697  			Expect(err).ToNot(HaveOccurred())
   698  			written := make(chan struct{})
   699  			go func() {
   700  				defer GinkgoRecover()
   701  				defer close(written)
   702  				_, err := str.Write([]byte("foobar"))
   703  				Expect(err).ToNot(HaveOccurred())
   704  				Expect(str.Close()).To(Succeed())
   705  			}()
   706  
   707  			Eventually(written).Should(BeClosed())
   708  
   709  			serverConn, err := ln.Accept(context.Background())
   710  			Expect(err).ToNot(HaveOccurred())
   711  			rstr, err := serverConn.AcceptUniStream(context.Background())
   712  			Expect(err).ToNot(HaveOccurred())
   713  			data, err := io.ReadAll(rstr)
   714  			Expect(err).ToNot(HaveOccurred())
   715  			Expect(data).To(Equal([]byte("foobar")))
   716  			Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue())
   717  			Expect(serverConn.CloseWithError(0, "")).To(Succeed())
   718  			Eventually(conn.Context().Done()).Should(BeClosed())
   719  
   720  			var processedFirst bool
   721  			for _, p := range counter.getRcvdLongHeaderPackets() {
   722  				for _, f := range p.frames {
   723  					if sf, ok := f.(*logging.StreamFrame); ok {
   724  						if !processedFirst {
   725  							// The first STREAM should have been sent in a 0-RTT packet.
   726  							// Due to the flow control limit, the STREAM frame was limit to the first 3 bytes.
   727  							Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT))
   728  							Expect(sf.Length).To(BeEquivalentTo(3))
   729  							processedFirst = true
   730  						} else {
   731  							Fail("STREAM was shouldn't have been sent in 0-RTT")
   732  						}
   733  					}
   734  				}
   735  			}
   736  		},
   737  		Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }),
   738  		Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }),
   739  	)
   740  
   741  	for _, l := range []int{0, 15} {
   742  		connIDLen := l
   743  
   744  		It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
   745  			tlsConf := getTLSConfig()
   746  			clientConf := getTLSClientConfig()
   747  			dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   748  			// now dial new connection with different transport parameters
   749  			counter, tracer := newPacketTracer()
   750  			ln, err := quic.ListenAddrEarly(
   751  				"localhost:0",
   752  				tlsConf,
   753  				getQuicConfig(&quic.Config{
   754  					MaxIncomingUniStreams: 1,
   755  					Tracer:                newTracer(tracer),
   756  				}),
   757  			)
   758  			Expect(err).ToNot(HaveOccurred())
   759  			defer ln.Close()
   760  			proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   761  			defer proxy.Close()
   762  
   763  			conn, err := quic.DialAddrEarly(
   764  				context.Background(),
   765  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   766  				clientConf,
   767  				getQuicConfig(nil),
   768  			)
   769  			Expect(err).ToNot(HaveOccurred())
   770  			// The client remembers that it was allowed to open 2 uni-directional streams.
   771  			firstStr, err := conn.OpenUniStream()
   772  			Expect(err).ToNot(HaveOccurred())
   773  			written := make(chan struct{}, 2)
   774  			go func() {
   775  				defer GinkgoRecover()
   776  				defer func() { written <- struct{}{} }()
   777  				_, err := firstStr.Write([]byte("first flight"))
   778  				Expect(err).ToNot(HaveOccurred())
   779  			}()
   780  			secondStr, err := conn.OpenUniStream()
   781  			Expect(err).ToNot(HaveOccurred())
   782  			go func() {
   783  				defer GinkgoRecover()
   784  				defer func() { written <- struct{}{} }()
   785  				_, err := secondStr.Write([]byte("first flight"))
   786  				Expect(err).ToNot(HaveOccurred())
   787  			}()
   788  
   789  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   790  			defer cancel()
   791  			_, err = conn.AcceptStream(ctx)
   792  			Expect(err).To(MatchError(quic.Err0RTTRejected))
   793  			Eventually(written).Should(Receive())
   794  			Eventually(written).Should(Receive())
   795  			_, err = firstStr.Write([]byte("foobar"))
   796  			Expect(err).To(MatchError(quic.Err0RTTRejected))
   797  			_, err = conn.OpenUniStream()
   798  			Expect(err).To(MatchError(quic.Err0RTTRejected))
   799  
   800  			_, err = conn.AcceptStream(ctx)
   801  			Expect(err).To(Equal(quic.Err0RTTRejected))
   802  
   803  			newConn := conn.NextConnection()
   804  			str, err := newConn.OpenUniStream()
   805  			Expect(err).ToNot(HaveOccurred())
   806  			_, err = newConn.OpenUniStream()
   807  			Expect(err).To(HaveOccurred())
   808  			Expect(err.Error()).To(ContainSubstring("too many open streams"))
   809  			_, err = str.Write([]byte("second flight"))
   810  			Expect(err).ToNot(HaveOccurred())
   811  			Expect(str.Close()).To(Succeed())
   812  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   813  
   814  			// The client should send 0-RTT packets, but the server doesn't process them.
   815  			num0RTT := atomic.LoadUint32(num0RTTPackets)
   816  			fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   817  			Expect(num0RTT).ToNot(BeZero())
   818  			Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
   819  		})
   820  	}
   821  
   822  	It("queues 0-RTT packets, if the Initial is delayed", func() {
   823  		tlsConf := getTLSConfig()
   824  		clientConf := getTLSClientConfig()
   825  		dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
   826  
   827  		counter, tracer := newPacketTracer()
   828  		ln, err := quic.ListenAddrEarly(
   829  			"localhost:0",
   830  			tlsConf,
   831  			getQuicConfig(&quic.Config{
   832  				Allow0RTT: true,
   833  				Tracer:    newTracer(tracer),
   834  			}),
   835  		)
   836  		Expect(err).ToNot(HaveOccurred())
   837  		defer ln.Close()
   838  		proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   839  			RemoteAddr: ln.Addr().String(),
   840  			DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
   841  				if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
   842  					return rtt/2 + rtt
   843  				}
   844  				return rtt / 2
   845  			},
   846  		})
   847  		Expect(err).ToNot(HaveOccurred())
   848  		defer proxy.Close()
   849  
   850  		transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
   851  
   852  		Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
   853  		zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
   854  		Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
   855  		Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
   856  	})
   857  
   858  	It("allows the application to attach data to the session ticket, for the server", func() {
   859  		tlsConf := getTLSConfig()
   860  		tlsConf.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) {
   861  			ss.Extra = append(ss.Extra, []byte("foobar"))
   862  			return tlsConf.EncryptTicket(cs, ss)
   863  		}
   864  		var unwrapped bool
   865  		tlsConf.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) {
   866  			defer GinkgoRecover()
   867  			state, err := tlsConf.DecryptTicket(identity, cs)
   868  			if err != nil {
   869  				return nil, err
   870  			}
   871  			Expect(state.Extra).To(HaveLen(2))
   872  			Expect(state.Extra[1]).To(Equal([]byte("foobar")))
   873  			unwrapped = true
   874  			return state, nil
   875  		}
   876  		clientTLSConf := getTLSClientConfig()
   877  		dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
   878  
   879  		ln, err := quic.ListenAddrEarly(
   880  			"localhost:0",
   881  			tlsConf,
   882  			getQuicConfig(&quic.Config{Allow0RTT: true}),
   883  		)
   884  		Expect(err).ToNot(HaveOccurred())
   885  		defer ln.Close()
   886  
   887  		transfer0RTTData(
   888  			ln,
   889  			ln.Addr().(*net.UDPAddr).Port,
   890  			10,
   891  			clientTLSConf,
   892  			getQuicConfig(nil),
   893  			PRData,
   894  		)
   895  		Expect(unwrapped).To(BeTrue())
   896  	})
   897  
   898  	It("allows the application to attach data to the session ticket, for the client", func() {
   899  		tlsConf := getTLSConfig()
   900  		clientTLSConf := getTLSClientConfig()
   901  		var restored bool
   902  		clientTLSConf.ClientSessionCache = &metadataClientSessionCache{
   903  			toAdd: []byte("foobar"),
   904  			restored: func(b []byte) {
   905  				defer GinkgoRecover()
   906  				Expect(b).To(Equal([]byte("foobar")))
   907  				restored = true
   908  			},
   909  			cache: tls.NewLRUClientSessionCache(100),
   910  		}
   911  		dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
   912  
   913  		ln, err := quic.ListenAddrEarly(
   914  			"localhost:0",
   915  			tlsConf,
   916  			getQuicConfig(&quic.Config{Allow0RTT: true}),
   917  		)
   918  		Expect(err).ToNot(HaveOccurred())
   919  		defer ln.Close()
   920  
   921  		transfer0RTTData(
   922  			ln,
   923  			ln.Addr().(*net.UDPAddr).Port,
   924  			10,
   925  			clientTLSConf,
   926  			getQuicConfig(nil),
   927  			PRData,
   928  		)
   929  		Expect(restored).To(BeTrue())
   930  	})
   931  
   932  	It("sends 0-RTT datagrams", func() {
   933  		tlsConf := getTLSConfig()
   934  		clientTLSConf := getTLSClientConfig()
   935  		dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
   936  			EnableDatagrams: true,
   937  		}), clientTLSConf)
   938  
   939  		counter, tracer := newPacketTracer()
   940  		ln, err := quic.ListenAddrEarly(
   941  			"localhost:0",
   942  			tlsConf,
   943  			getQuicConfig(&quic.Config{
   944  				Allow0RTT:       true,
   945  				EnableDatagrams: true,
   946  				Tracer:          newTracer(tracer),
   947  			}),
   948  		)
   949  		Expect(err).ToNot(HaveOccurred())
   950  		defer ln.Close()
   951  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
   952  		defer proxy.Close()
   953  
   954  		// second connection
   955  		sentMessage := GeneratePRData(100)
   956  		var receivedMessage []byte
   957  		received := make(chan struct{})
   958  		go func() {
   959  			defer GinkgoRecover()
   960  			defer close(received)
   961  			conn, err := ln.Accept(context.Background())
   962  			Expect(err).ToNot(HaveOccurred())
   963  			receivedMessage, err = conn.ReceiveMessage(context.Background())
   964  			Expect(err).ToNot(HaveOccurred())
   965  			Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   966  		}()
   967  		conn, err := quic.DialAddrEarly(
   968  			context.Background(),
   969  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   970  			clientTLSConf,
   971  			getQuicConfig(&quic.Config{
   972  				EnableDatagrams: true,
   973  			}),
   974  		)
   975  		Expect(err).ToNot(HaveOccurred())
   976  		Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue())
   977  		Expect(conn.SendMessage(sentMessage)).To(Succeed())
   978  		<-conn.HandshakeComplete()
   979  		<-received
   980  
   981  		Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
   982  		Expect(conn.CloseWithError(0, "")).To(Succeed())
   983  		Expect(receivedMessage).To(Equal(sentMessage))
   984  		num0RTT := atomic.LoadUint32(num0RTTPackets)
   985  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
   986  		Expect(num0RTT).ToNot(BeZero())
   987  		zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
   988  		Expect(zeroRTTPackets).To(HaveLen(1))
   989  	})
   990  
   991  	It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() {
   992  		tlsConf := getTLSConfig()
   993  		clientTLSConf := getTLSClientConfig()
   994  		dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
   995  			EnableDatagrams: true,
   996  		}), clientTLSConf)
   997  
   998  		counter, tracer := newPacketTracer()
   999  		ln, err := quic.ListenAddrEarly(
  1000  			"localhost:0",
  1001  			tlsConf,
  1002  			getQuicConfig(&quic.Config{
  1003  				Allow0RTT:       true,
  1004  				EnableDatagrams: false,
  1005  				Tracer:          newTracer(tracer),
  1006  			}),
  1007  		)
  1008  		Expect(err).ToNot(HaveOccurred())
  1009  		defer ln.Close()
  1010  
  1011  		proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
  1012  		defer proxy.Close()
  1013  
  1014  		// second connection
  1015  		go func() {
  1016  			defer GinkgoRecover()
  1017  			conn, err := ln.Accept(context.Background())
  1018  			Expect(err).ToNot(HaveOccurred())
  1019  			_, err = conn.ReceiveMessage(context.Background())
  1020  			Expect(err.Error()).To(Equal("datagram support disabled"))
  1021  			<-conn.HandshakeComplete()
  1022  			Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
  1023  		}()
  1024  		conn, err := quic.DialAddrEarly(
  1025  			context.Background(),
  1026  			fmt.Sprintf("localhost:%d", proxy.LocalPort()),
  1027  			clientTLSConf,
  1028  			getQuicConfig(&quic.Config{
  1029  				EnableDatagrams: true,
  1030  			}),
  1031  		)
  1032  		Expect(err).ToNot(HaveOccurred())
  1033  		// the client can temporarily send datagrams but the server doesn't process them.
  1034  		Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue())
  1035  		Expect(conn.SendMessage(make([]byte, 100))).To(Succeed())
  1036  		<-conn.HandshakeComplete()
  1037  
  1038  		Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
  1039  		Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
  1040  		Expect(conn.CloseWithError(0, "")).To(Succeed())
  1041  		num0RTT := atomic.LoadUint32(num0RTTPackets)
  1042  		fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
  1043  		Expect(num0RTT).ToNot(BeZero())
  1044  		Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
  1045  	})
  1046  })