github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/client_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/danielpfeifer02/quic-go-prio-packs"
    16  	mockquic "github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks/quic"
    17  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    18  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    19  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    20  	"github.com/danielpfeifer02/quic-go-prio-packs/quicvarint"
    21  
    22  	"github.com/quic-go/qpack"
    23  
    24  	. "github.com/onsi/ginkgo/v2"
    25  	. "github.com/onsi/gomega"
    26  	"go.uber.org/mock/gomock"
    27  )
    28  
    29  var _ = Describe("Client", func() {
    30  	var (
    31  		cl            *client
    32  		req           *http.Request
    33  		origDialAddr  = dialAddr
    34  		handshakeChan <-chan struct{} // a closed chan
    35  	)
    36  
    37  	BeforeEach(func() {
    38  		origDialAddr = dialAddr
    39  		hostname := "quic.clemente.io:1337"
    40  		c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
    41  		Expect(err).ToNot(HaveOccurred())
    42  		cl = c.(*client)
    43  		Expect(cl.hostname).To(Equal(hostname))
    44  
    45  		req, err = http.NewRequest("GET", "https://localhost:1337", nil)
    46  		Expect(err).ToNot(HaveOccurred())
    47  
    48  		ch := make(chan struct{})
    49  		close(ch)
    50  		handshakeChan = ch
    51  	})
    52  
    53  	AfterEach(func() {
    54  		dialAddr = origDialAddr
    55  	})
    56  
    57  	It("rejects quic.Configs that allow multiple QUIC versions", func() {
    58  		qconf := &quic.Config{
    59  			Versions: []quic.Version{protocol.Version2, protocol.Version1},
    60  		}
    61  		_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
    62  		Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
    63  	})
    64  
    65  	It("uses the default QUIC and TLS config if none is give", func() {
    66  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
    67  		Expect(err).ToNot(HaveOccurred())
    68  		var dialAddrCalled bool
    69  		dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
    70  			Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
    71  			Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
    72  			Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1}))
    73  			dialAddrCalled = true
    74  			return nil, errors.New("test done")
    75  		}
    76  		client.RoundTripOpt(req, RoundTripOpt{})
    77  		Expect(dialAddrCalled).To(BeTrue())
    78  	})
    79  
    80  	It("adds the port to the hostname, if none is given", func() {
    81  		client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
    82  		Expect(err).ToNot(HaveOccurred())
    83  		var dialAddrCalled bool
    84  		dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
    85  			Expect(hostname).To(Equal("quic.clemente.io:443"))
    86  			dialAddrCalled = true
    87  			return nil, errors.New("test done")
    88  		}
    89  		req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
    90  		Expect(err).ToNot(HaveOccurred())
    91  		client.RoundTripOpt(req, RoundTripOpt{})
    92  		Expect(dialAddrCalled).To(BeTrue())
    93  	})
    94  
    95  	It("sets the ServerName in the tls.Config, if not set", func() {
    96  		const host = "foo.bar"
    97  		dialCalled := false
    98  		dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
    99  			Expect(tlsCfg.ServerName).To(Equal(host))
   100  			dialCalled = true
   101  			return nil, errors.New("test done")
   102  		}
   103  		client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
   104  		Expect(err).ToNot(HaveOccurred())
   105  		req, err := http.NewRequest("GET", "https://foo.bar", nil)
   106  		Expect(err).ToNot(HaveOccurred())
   107  		client.RoundTripOpt(req, RoundTripOpt{})
   108  		Expect(dialCalled).To(BeTrue())
   109  	})
   110  
   111  	It("uses the TLS config and QUIC config", func() {
   112  		tlsConf := &tls.Config{
   113  			ServerName: "foo.bar",
   114  			NextProtos: []string{"proto foo", "proto bar"},
   115  		}
   116  		quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
   117  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
   118  		Expect(err).ToNot(HaveOccurred())
   119  		var dialAddrCalled bool
   120  		dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   121  			Expect(host).To(Equal("localhost:1337"))
   122  			Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
   123  			Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
   124  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   125  			dialAddrCalled = true
   126  			return nil, errors.New("test done")
   127  		}
   128  		client.RoundTripOpt(req, RoundTripOpt{})
   129  		Expect(dialAddrCalled).To(BeTrue())
   130  		// make sure the original tls.Config was not modified
   131  		Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
   132  	})
   133  
   134  	It("uses the custom dialer, if provided", func() {
   135  		testErr := errors.New("test done")
   136  		tlsConf := &tls.Config{ServerName: "foo.bar"}
   137  		quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
   138  		ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
   139  		defer cancel()
   140  		var dialerCalled bool
   141  		dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   142  			Expect(ctxP).To(Equal(ctx))
   143  			Expect(address).To(Equal("localhost:1337"))
   144  			Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
   145  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   146  			dialerCalled = true
   147  			return nil, testErr
   148  		}
   149  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
   150  		Expect(err).ToNot(HaveOccurred())
   151  		_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
   152  		Expect(err).To(MatchError(testErr))
   153  		Expect(dialerCalled).To(BeTrue())
   154  	})
   155  
   156  	It("enables HTTP/3 Datagrams", func() {
   157  		testErr := errors.New("handshake error")
   158  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
   159  		Expect(err).ToNot(HaveOccurred())
   160  		dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   161  			Expect(quicConf.EnableDatagrams).To(BeTrue())
   162  			return nil, testErr
   163  		}
   164  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   165  		Expect(err).To(MatchError(testErr))
   166  	})
   167  
   168  	It("errors when dialing fails", func() {
   169  		testErr := errors.New("handshake error")
   170  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   171  		Expect(err).ToNot(HaveOccurred())
   172  		dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   173  			return nil, testErr
   174  		}
   175  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   176  		Expect(err).To(MatchError(testErr))
   177  	})
   178  
   179  	It("closes correctly if connection was not created", func() {
   180  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   181  		Expect(err).ToNot(HaveOccurred())
   182  		Expect(client.Close()).To(Succeed())
   183  	})
   184  
   185  	Context("validating the address", func() {
   186  		It("refuses to do requests for the wrong host", func() {
   187  			req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
   188  			Expect(err).ToNot(HaveOccurred())
   189  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   190  			Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
   191  		})
   192  
   193  		It("allows requests using a different scheme", func() {
   194  			testErr := errors.New("handshake error")
   195  			req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
   196  			Expect(err).ToNot(HaveOccurred())
   197  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   198  				return nil, testErr
   199  			}
   200  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   201  			Expect(err).To(MatchError(testErr))
   202  		})
   203  	})
   204  
   205  	Context("hijacking bidirectional streams", func() {
   206  		var (
   207  			request              *http.Request
   208  			conn                 *mockquic.MockEarlyConnection
   209  			settingsFrameWritten chan struct{}
   210  		)
   211  		testDone := make(chan struct{})
   212  
   213  		BeforeEach(func() {
   214  			testDone = make(chan struct{})
   215  			settingsFrameWritten = make(chan struct{})
   216  			controlStr := mockquic.NewMockStream(mockCtrl)
   217  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   218  				defer GinkgoRecover()
   219  				close(settingsFrameWritten)
   220  				return len(b), nil
   221  			})
   222  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   223  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   224  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   225  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   226  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
   227  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   228  				return conn, nil
   229  			}
   230  			var err error
   231  			request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   232  			Expect(err).ToNot(HaveOccurred())
   233  		})
   234  
   235  		AfterEach(func() {
   236  			testDone <- struct{}{}
   237  			Eventually(settingsFrameWritten).Should(BeClosed())
   238  		})
   239  
   240  		It("hijacks a bidirectional stream of unknown frame type", func() {
   241  			frameTypeChan := make(chan FrameType, 1)
   242  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   243  				Expect(e).ToNot(HaveOccurred())
   244  				frameTypeChan <- ft
   245  				return true, nil
   246  			}
   247  
   248  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   249  			unknownStr := mockquic.NewMockStream(mockCtrl)
   250  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   251  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   252  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   253  				<-testDone
   254  				return nil, errors.New("test done")
   255  			})
   256  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   257  			Expect(err).To(MatchError("done"))
   258  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   259  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   260  		})
   261  
   262  		It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
   263  			frameTypeChan := make(chan FrameType, 1)
   264  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   265  				Expect(e).ToNot(HaveOccurred())
   266  				frameTypeChan <- ft
   267  				return false, nil
   268  			}
   269  
   270  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   271  			unknownStr := mockquic.NewMockStream(mockCtrl)
   272  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   273  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   274  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   275  				<-testDone
   276  				return nil, errors.New("test done")
   277  			})
   278  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   279  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   280  			Expect(err).To(MatchError("done"))
   281  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   282  		})
   283  
   284  		It("closes the connection when hijacker returned error", func() {
   285  			frameTypeChan := make(chan FrameType, 1)
   286  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   287  				Expect(e).ToNot(HaveOccurred())
   288  				frameTypeChan <- ft
   289  				return false, errors.New("error in hijacker")
   290  			}
   291  
   292  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   293  			unknownStr := mockquic.NewMockStream(mockCtrl)
   294  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   295  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   296  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   297  				<-testDone
   298  				return nil, errors.New("test done")
   299  			})
   300  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   301  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   302  			Expect(err).To(MatchError("done"))
   303  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   304  		})
   305  
   306  		It("handles errors that occur when reading the frame type", func() {
   307  			testErr := errors.New("test error")
   308  			unknownStr := mockquic.NewMockStream(mockCtrl)
   309  			done := make(chan struct{})
   310  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
   311  				defer close(done)
   312  				Expect(e).To(MatchError(testErr))
   313  				Expect(ft).To(BeZero())
   314  				Expect(str).To(Equal(unknownStr))
   315  				return false, nil
   316  			}
   317  
   318  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   319  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   320  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   321  				<-testDone
   322  				return nil, errors.New("test done")
   323  			})
   324  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   325  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   326  			Expect(err).To(MatchError("done"))
   327  			Eventually(done).Should(BeClosed())
   328  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   329  		})
   330  	})
   331  
   332  	Context("hijacking unidirectional streams", func() {
   333  		var (
   334  			req                  *http.Request
   335  			conn                 *mockquic.MockEarlyConnection
   336  			settingsFrameWritten chan struct{}
   337  		)
   338  		testDone := make(chan struct{})
   339  
   340  		BeforeEach(func() {
   341  			testDone = make(chan struct{})
   342  			settingsFrameWritten = make(chan struct{})
   343  			controlStr := mockquic.NewMockStream(mockCtrl)
   344  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   345  				defer GinkgoRecover()
   346  				close(settingsFrameWritten)
   347  				return len(b), nil
   348  			})
   349  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   350  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   351  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   352  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   353  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   354  				return conn, nil
   355  			}
   356  			var err error
   357  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   358  			Expect(err).ToNot(HaveOccurred())
   359  		})
   360  
   361  		AfterEach(func() {
   362  			testDone <- struct{}{}
   363  			Eventually(settingsFrameWritten).Should(BeClosed())
   364  		})
   365  
   366  		It("hijacks an unidirectional stream of unknown stream type", func() {
   367  			streamTypeChan := make(chan StreamType, 1)
   368  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   369  				Expect(err).ToNot(HaveOccurred())
   370  				streamTypeChan <- st
   371  				return true
   372  			}
   373  
   374  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   375  			unknownStr := mockquic.NewMockStream(mockCtrl)
   376  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   377  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   378  				return unknownStr, nil
   379  			})
   380  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   381  				<-testDone
   382  				return nil, errors.New("test done")
   383  			})
   384  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   385  			Expect(err).To(MatchError("done"))
   386  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   387  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   388  		})
   389  
   390  		It("handles errors that occur when reading the stream type", func() {
   391  			testErr := errors.New("test error")
   392  			done := make(chan struct{})
   393  			unknownStr := mockquic.NewMockStream(mockCtrl)
   394  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   395  				defer close(done)
   396  				Expect(st).To(BeZero())
   397  				Expect(str).To(Equal(unknownStr))
   398  				Expect(err).To(MatchError(testErr))
   399  				return true
   400  			}
   401  
   402  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
   403  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   404  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   405  				<-testDone
   406  				return nil, errors.New("test done")
   407  			})
   408  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   409  			Expect(err).To(MatchError("done"))
   410  			Eventually(done).Should(BeClosed())
   411  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   412  		})
   413  
   414  		It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   415  			streamTypeChan := make(chan StreamType, 1)
   416  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   417  				Expect(err).ToNot(HaveOccurred())
   418  				streamTypeChan <- st
   419  				return false
   420  			}
   421  
   422  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   423  			unknownStr := mockquic.NewMockStream(mockCtrl)
   424  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   425  			unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   426  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   427  				return unknownStr, nil
   428  			})
   429  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   430  				<-testDone
   431  				return nil, errors.New("test done")
   432  			})
   433  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   434  			Expect(err).To(MatchError("done"))
   435  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   436  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   437  		})
   438  	})
   439  
   440  	Context("control stream handling", func() {
   441  		var (
   442  			req                  *http.Request
   443  			conn                 *mockquic.MockEarlyConnection
   444  			settingsFrameWritten chan struct{}
   445  		)
   446  		testDone := make(chan struct{}, 1)
   447  
   448  		BeforeEach(func() {
   449  			settingsFrameWritten = make(chan struct{})
   450  			controlStr := mockquic.NewMockStream(mockCtrl)
   451  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   452  				defer GinkgoRecover()
   453  				close(settingsFrameWritten)
   454  				return len(b), nil
   455  			})
   456  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   457  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   458  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   459  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   460  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   461  				return conn, nil
   462  			}
   463  			var err error
   464  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   465  			Expect(err).ToNot(HaveOccurred())
   466  		})
   467  
   468  		AfterEach(func() {
   469  			testDone <- struct{}{}
   470  			Eventually(settingsFrameWritten).Should(BeClosed())
   471  		})
   472  
   473  		It("parses the SETTINGS frame", func() {
   474  			b := quicvarint.Append(nil, streamTypeControlStream)
   475  			b = (&settingsFrame{}).Append(b)
   476  			r := bytes.NewReader(b)
   477  			controlStr := mockquic.NewMockStream(mockCtrl)
   478  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   479  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   480  				return controlStr, nil
   481  			})
   482  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   483  				<-testDone
   484  				return nil, errors.New("test done")
   485  			})
   486  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   487  			Expect(err).To(MatchError("done"))
   488  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   489  		})
   490  
   491  		It("rejects duplicate control streams", func() {
   492  			b := quicvarint.Append(nil, streamTypeControlStream)
   493  			b = (&settingsFrame{}).Append(b)
   494  			r1 := bytes.NewReader(b)
   495  			controlStr1 := mockquic.NewMockStream(mockCtrl)
   496  			controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes()
   497  			r2 := bytes.NewReader(b)
   498  			controlStr2 := mockquic.NewMockStream(mockCtrl)
   499  			controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes()
   500  			done := make(chan struct{})
   501  			conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error {
   502  				close(done)
   503  				return nil
   504  			})
   505  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   506  				return controlStr1, nil
   507  			})
   508  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   509  				return controlStr2, nil
   510  			})
   511  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   512  				<-done
   513  				return nil, errors.New("test done")
   514  			})
   515  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   516  			Expect(err).To(HaveOccurred())
   517  			Eventually(done).Should(BeClosed())
   518  		})
   519  
   520  		for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   521  			streamType := t
   522  			name := "encoder"
   523  			if streamType == streamTypeQPACKDecoderStream {
   524  				name = "decoder"
   525  			}
   526  
   527  			It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   528  				buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   529  				str := mockquic.NewMockStream(mockCtrl)
   530  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   531  
   532  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   533  					return str, nil
   534  				})
   535  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   536  					<-testDone
   537  					return nil, errors.New("test done")
   538  				})
   539  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   540  				Expect(err).To(MatchError("done"))
   541  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   542  			})
   543  		}
   544  
   545  		It("resets streams Other than the control stream and the QPACK streams", func() {
   546  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
   547  			str := mockquic.NewMockStream(mockCtrl)
   548  			str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   549  			done := make(chan struct{})
   550  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
   551  
   552  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   553  				return str, nil
   554  			})
   555  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   556  				<-testDone
   557  				return nil, errors.New("test done")
   558  			})
   559  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   560  			Expect(err).To(MatchError("done"))
   561  			Eventually(done).Should(BeClosed())
   562  		})
   563  
   564  		It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   565  			b := quicvarint.Append(nil, streamTypeControlStream)
   566  			b = (&dataFrame{}).Append(b)
   567  			r := bytes.NewReader(b)
   568  			controlStr := mockquic.NewMockStream(mockCtrl)
   569  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   570  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   571  				return controlStr, nil
   572  			})
   573  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   574  				<-testDone
   575  				return nil, errors.New("test done")
   576  			})
   577  			done := make(chan struct{})
   578  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   579  				close(done)
   580  				return nil
   581  			})
   582  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   583  			Expect(err).To(MatchError("done"))
   584  			Eventually(done).Should(BeClosed())
   585  		})
   586  
   587  		It("errors when parsing the frame on the control stream fails", func() {
   588  			b := quicvarint.Append(nil, streamTypeControlStream)
   589  			b = (&settingsFrame{}).Append(b)
   590  			r := bytes.NewReader(b[:len(b)-1])
   591  			controlStr := mockquic.NewMockStream(mockCtrl)
   592  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   593  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   594  				return controlStr, nil
   595  			})
   596  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   597  				<-testDone
   598  				return nil, errors.New("test done")
   599  			})
   600  			done := make(chan struct{})
   601  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error {
   602  				close(done)
   603  				return nil
   604  			})
   605  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   606  			Expect(err).To(MatchError("done"))
   607  			Eventually(done).Should(BeClosed())
   608  		})
   609  
   610  		It("errors when parsing the server opens a push stream", func() {
   611  			buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
   612  			controlStr := mockquic.NewMockStream(mockCtrl)
   613  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   614  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   615  				return controlStr, nil
   616  			})
   617  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   618  				<-testDone
   619  				return nil, errors.New("test done")
   620  			})
   621  			done := make(chan struct{})
   622  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   623  				close(done)
   624  				return nil
   625  			})
   626  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   627  			Expect(err).To(MatchError("done"))
   628  			Eventually(done).Should(BeClosed())
   629  		})
   630  
   631  		It("errors when the server advertises datagram support (and we enabled support for it)", func() {
   632  			cl.opts.EnableDatagram = true
   633  			b := quicvarint.Append(nil, streamTypeControlStream)
   634  			b = (&settingsFrame{Datagram: true}).Append(b)
   635  			r := bytes.NewReader(b)
   636  			controlStr := mockquic.NewMockStream(mockCtrl)
   637  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   638  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   639  				return controlStr, nil
   640  			})
   641  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   642  				<-testDone
   643  				return nil, errors.New("test done")
   644  			})
   645  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   646  			done := make(chan struct{})
   647  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
   648  				close(done)
   649  				return nil
   650  			})
   651  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   652  			Expect(err).To(MatchError("done"))
   653  			Eventually(done).Should(BeClosed())
   654  		})
   655  	})
   656  
   657  	Context("Doing requests", func() {
   658  		var (
   659  			req                  *http.Request
   660  			str                  *mockquic.MockStream
   661  			conn                 *mockquic.MockEarlyConnection
   662  			settingsFrameWritten chan struct{}
   663  		)
   664  		testDone := make(chan struct{})
   665  
   666  		decodeHeader := func(str io.Reader) map[string]string {
   667  			fields := make(map[string]string)
   668  			decoder := qpack.NewDecoder(nil)
   669  
   670  			frame, err := parseNextFrame(str, nil)
   671  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   672  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   673  			headersFrame := frame.(*headersFrame)
   674  			data := make([]byte, headersFrame.Length)
   675  			_, err = io.ReadFull(str, data)
   676  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   677  			hfs, err := decoder.DecodeFull(data)
   678  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   679  			for _, p := range hfs {
   680  				fields[p.Name] = p.Value
   681  			}
   682  			return fields
   683  		}
   684  
   685  		getResponse := func(status int) []byte {
   686  			buf := &bytes.Buffer{}
   687  			rstr := mockquic.NewMockStream(mockCtrl)
   688  			rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   689  			rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   690  			rw.WriteHeader(status)
   691  			rw.Flush()
   692  			return buf.Bytes()
   693  		}
   694  
   695  		BeforeEach(func() {
   696  			settingsFrameWritten = make(chan struct{})
   697  			controlStr := mockquic.NewMockStream(mockCtrl)
   698  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   699  				defer GinkgoRecover()
   700  				r := bytes.NewReader(b)
   701  				streamType, err := quicvarint.Read(r)
   702  				Expect(err).ToNot(HaveOccurred())
   703  				Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
   704  				close(settingsFrameWritten)
   705  				return len(b), nil
   706  			}) // SETTINGS frame
   707  			str = mockquic.NewMockStream(mockCtrl)
   708  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   709  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   710  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   711  				<-testDone
   712  				return nil, errors.New("test done")
   713  			})
   714  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   715  				return conn, nil
   716  			}
   717  			var err error
   718  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   719  			Expect(err).ToNot(HaveOccurred())
   720  		})
   721  
   722  		AfterEach(func() {
   723  			testDone <- struct{}{}
   724  			Eventually(settingsFrameWritten).Should(BeClosed())
   725  		})
   726  
   727  		It("errors if it can't open a stream", func() {
   728  			testErr := errors.New("stream open error")
   729  			conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
   730  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
   731  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   732  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   733  			Expect(err).To(MatchError(testErr))
   734  		})
   735  
   736  		It("performs a 0-RTT request", func() {
   737  			testErr := errors.New("stream open error")
   738  			req.Method = MethodGet0RTT
   739  			// don't EXPECT any calls to HandshakeComplete()
   740  			conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   741  			buf := &bytes.Buffer{}
   742  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   743  			str.EXPECT().Close()
   744  			str.EXPECT().CancelWrite(gomock.Any())
   745  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   746  				return 0, testErr
   747  			})
   748  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   749  			Expect(err).To(MatchError(testErr))
   750  			Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
   751  		})
   752  
   753  		It("returns a response", func() {
   754  			rspBuf := bytes.NewBuffer(getResponse(418))
   755  			gomock.InOrder(
   756  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   757  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   758  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   759  			)
   760  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   761  			str.EXPECT().Close()
   762  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   763  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
   764  			Expect(err).ToNot(HaveOccurred())
   765  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   766  			Expect(rsp.ProtoMajor).To(Equal(3))
   767  			Expect(rsp.StatusCode).To(Equal(418))
   768  			Expect(rsp.Request).ToNot(BeNil())
   769  		})
   770  
   771  		It("doesn't close the request stream, with DontCloseRequestStream set", func() {
   772  			rspBuf := bytes.NewBuffer(getResponse(418))
   773  			gomock.InOrder(
   774  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   775  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   776  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   777  			)
   778  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   779  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   780  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
   781  			Expect(err).ToNot(HaveOccurred())
   782  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   783  			Expect(rsp.ProtoMajor).To(Equal(3))
   784  			Expect(rsp.StatusCode).To(Equal(418))
   785  		})
   786  
   787  		Context("requests containing a Body", func() {
   788  			var strBuf *bytes.Buffer
   789  
   790  			BeforeEach(func() {
   791  				strBuf = &bytes.Buffer{}
   792  				gomock.InOrder(
   793  					conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   794  					conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   795  				)
   796  				body := &mockBody{}
   797  				body.SetData([]byte("request body"))
   798  				var err error
   799  				req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
   800  				Expect(err).ToNot(HaveOccurred())
   801  				str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
   802  			})
   803  
   804  			It("sends a request", func() {
   805  				done := make(chan struct{})
   806  				gomock.InOrder(
   807  					str.EXPECT().Close().Do(func() error { close(done); return nil }),
   808  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
   809  				)
   810  				// the response body is sent asynchronously, while already reading the response
   811  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   812  					<-done
   813  					return 0, errors.New("test done")
   814  				})
   815  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   816  				Expect(err).To(MatchError("test done"))
   817  				hfs := decodeHeader(strBuf)
   818  				Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
   819  				Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
   820  			})
   821  
   822  			It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
   823  				req.ContentLength = 7
   824  				var once sync.Once
   825  				done := make(chan struct{})
   826  				gomock.InOrder(
   827  					str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
   828  						once.Do(func() {
   829  							Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
   830  							close(done)
   831  						})
   832  					}).AnyTimes(),
   833  					str.EXPECT().Close().MaxTimes(1),
   834  					str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
   835  				)
   836  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   837  					<-done
   838  					return 0, errors.New("done")
   839  				})
   840  				cl.RoundTripOpt(req, RoundTripOpt{})
   841  				Expect(strBuf.String()).To(ContainSubstring("request"))
   842  				Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
   843  			})
   844  
   845  			It("returns the error that occurred when reading the body", func() {
   846  				req.Body.(*mockBody).readErr = errors.New("testErr")
   847  				done := make(chan struct{})
   848  				gomock.InOrder(
   849  					str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
   850  						close(done)
   851  					}),
   852  					str.EXPECT().CancelWrite(gomock.Any()),
   853  				)
   854  
   855  				// the response body is sent asynchronously, while already reading the response
   856  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   857  					<-done
   858  					return 0, errors.New("test done")
   859  				})
   860  				closed := make(chan struct{})
   861  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   862  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   863  				Expect(err).To(MatchError("test done"))
   864  				Eventually(closed).Should(BeClosed())
   865  			})
   866  
   867  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   868  				b := (&dataFrame{Length: 0x42}).Append(nil)
   869  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
   870  				closed := make(chan struct{})
   871  				r := bytes.NewReader(b)
   872  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   873  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   874  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   875  				Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
   876  				Eventually(closed).Should(BeClosed())
   877  			})
   878  
   879  			It("cancels the stream when parsing the headers fails", func() {
   880  				headerBuf := &bytes.Buffer{}
   881  				enc := qpack.NewEncoder(headerBuf)
   882  				Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header
   883  				Expect(enc.Close()).To(Succeed())
   884  				b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
   885  				b = append(b, headerBuf.Bytes()...)
   886  
   887  				r := bytes.NewReader(b)
   888  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   889  				closed := make(chan struct{})
   890  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   891  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   892  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   893  				Expect(err).To(HaveOccurred())
   894  				Eventually(closed).Should(BeClosed())
   895  			})
   896  
   897  			It("cancels the stream when the HEADERS frame is too large", func() {
   898  				b := (&headersFrame{Length: 1338}).Append(nil)
   899  				r := bytes.NewReader(b)
   900  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
   901  				closed := make(chan struct{})
   902  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   903  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   904  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   905  				Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
   906  				Eventually(closed).Should(BeClosed())
   907  			})
   908  		})
   909  
   910  		Context("request cancellations", func() {
   911  			for _, dontClose := range []bool{false, true} {
   912  				dontClose := dontClose
   913  
   914  				Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
   915  					roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
   916  
   917  					It("cancels a request while waiting for the handshake to complete", func() {
   918  						ctx, cancel := context.WithCancel(context.Background())
   919  						req := req.WithContext(ctx)
   920  						conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   921  
   922  						errChan := make(chan error)
   923  						go func() {
   924  							_, err := cl.RoundTripOpt(req, roundTripOpt)
   925  							errChan <- err
   926  						}()
   927  						Consistently(errChan).ShouldNot(Receive())
   928  						cancel()
   929  						Eventually(errChan).Should(Receive(MatchError("context canceled")))
   930  					})
   931  
   932  					It("cancels a request while the request is still in flight", func() {
   933  						ctx, cancel := context.WithCancel(context.Background())
   934  						req := req.WithContext(ctx)
   935  						conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   936  						conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   937  						buf := &bytes.Buffer{}
   938  						str.EXPECT().Close().MaxTimes(1)
   939  
   940  						str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   941  
   942  						done := make(chan struct{})
   943  						canceled := make(chan struct{})
   944  						gomock.InOrder(
   945  							str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
   946  							str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
   947  						)
   948  						str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
   949  						str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   950  							cancel()
   951  							<-canceled
   952  							return 0, errors.New("test done")
   953  						})
   954  						_, err := cl.RoundTripOpt(req, roundTripOpt)
   955  						Expect(err).To(MatchError(context.Canceled))
   956  						Eventually(done).Should(BeClosed())
   957  					})
   958  				})
   959  			}
   960  
   961  			It("cancels a request after the response arrived", func() {
   962  				rspBuf := bytes.NewBuffer(getResponse(404))
   963  
   964  				ctx, cancel := context.WithCancel(context.Background())
   965  				req := req.WithContext(ctx)
   966  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   967  				conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   968  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   969  				buf := &bytes.Buffer{}
   970  				str.EXPECT().Close().MaxTimes(1)
   971  
   972  				done := make(chan struct{})
   973  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   974  				str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   975  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   976  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
   977  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   978  				Expect(err).ToNot(HaveOccurred())
   979  				cancel()
   980  				Eventually(done).Should(BeClosed())
   981  			})
   982  		})
   983  
   984  		Context("gzip compression", func() {
   985  			BeforeEach(func() {
   986  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   987  			})
   988  
   989  			It("adds the gzip header to requests", func() {
   990  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   991  				buf := &bytes.Buffer{}
   992  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   993  				gomock.InOrder(
   994  					str.EXPECT().Close(),
   995  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   996  				)
   997  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   998  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   999  				Expect(err).To(MatchError("test done"))
  1000  				hfs := decodeHeader(buf)
  1001  				Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
  1002  			})
  1003  
  1004  			It("doesn't add gzip if the header disable it", func() {
  1005  				client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
  1006  				Expect(err).ToNot(HaveOccurred())
  1007  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1008  				buf := &bytes.Buffer{}
  1009  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
  1010  				gomock.InOrder(
  1011  					str.EXPECT().Close(),
  1012  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
  1013  				)
  1014  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
  1015  				_, err = client.RoundTripOpt(req, RoundTripOpt{})
  1016  				Expect(err).To(MatchError("test done"))
  1017  				hfs := decodeHeader(buf)
  1018  				Expect(hfs).ToNot(HaveKey("accept-encoding"))
  1019  			})
  1020  
  1021  			It("decompresses the response", func() {
  1022  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1023  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1024  				buf := &bytes.Buffer{}
  1025  				rstr := mockquic.NewMockStream(mockCtrl)
  1026  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1027  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1028  				rw.Header().Set("Content-Encoding", "gzip")
  1029  				gz := gzip.NewWriter(rw)
  1030  				gz.Write([]byte("gzipped response"))
  1031  				gz.Close()
  1032  				rw.Flush()
  1033  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1034  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1035  				str.EXPECT().Close()
  1036  
  1037  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1038  				Expect(err).ToNot(HaveOccurred())
  1039  				data, err := io.ReadAll(rsp.Body)
  1040  				Expect(err).ToNot(HaveOccurred())
  1041  				Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
  1042  				Expect(string(data)).To(Equal("gzipped response"))
  1043  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1044  				Expect(rsp.Uncompressed).To(BeTrue())
  1045  			})
  1046  
  1047  			It("only decompresses the response if the response contains the right content-encoding header", func() {
  1048  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1049  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1050  				buf := &bytes.Buffer{}
  1051  				rstr := mockquic.NewMockStream(mockCtrl)
  1052  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1053  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1054  				rw.Write([]byte("not gzipped"))
  1055  				rw.Flush()
  1056  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1057  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1058  				str.EXPECT().Close()
  1059  
  1060  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1061  				Expect(err).ToNot(HaveOccurred())
  1062  				data, err := io.ReadAll(rsp.Body)
  1063  				Expect(err).ToNot(HaveOccurred())
  1064  				Expect(string(data)).To(Equal("not gzipped"))
  1065  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1066  			})
  1067  		})
  1068  	})
  1069  })