github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/http3/client_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"time"
    13  
    14  	"github.com/mikelsr/quic-go"
    15  	mockquic "github.com/mikelsr/quic-go/internal/mocks/quic"
    16  	"github.com/mikelsr/quic-go/internal/protocol"
    17  	"github.com/mikelsr/quic-go/internal/utils"
    18  	"github.com/mikelsr/quic-go/quicvarint"
    19  
    20  	"github.com/golang/mock/gomock"
    21  	"github.com/quic-go/qpack"
    22  
    23  	. "github.com/onsi/ginkgo/v2"
    24  	. "github.com/onsi/gomega"
    25  )
    26  
    27  var _ = Describe("Client", func() {
    28  	var (
    29  		cl            *client
    30  		req           *http.Request
    31  		origDialAddr  = dialAddr
    32  		handshakeChan <-chan struct{} // a closed chan
    33  	)
    34  
    35  	BeforeEach(func() {
    36  		origDialAddr = dialAddr
    37  		hostname := "quic.clemente.io:1337"
    38  		c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
    39  		Expect(err).ToNot(HaveOccurred())
    40  		cl = c.(*client)
    41  		Expect(cl.hostname).To(Equal(hostname))
    42  
    43  		req, err = http.NewRequest("GET", "https://localhost:1337", nil)
    44  		Expect(err).ToNot(HaveOccurred())
    45  
    46  		ch := make(chan struct{})
    47  		close(ch)
    48  		handshakeChan = ch
    49  	})
    50  
    51  	AfterEach(func() {
    52  		dialAddr = origDialAddr
    53  	})
    54  
    55  	It("rejects quic.Configs that allow multiple QUIC versions", func() {
    56  		qconf := &quic.Config{
    57  			Versions: []quic.VersionNumber{protocol.Version2, protocol.Version1},
    58  		}
    59  		_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
    60  		Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
    61  	})
    62  
    63  	It("uses the default QUIC and TLS config if none is give", func() {
    64  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
    65  		Expect(err).ToNot(HaveOccurred())
    66  		var dialAddrCalled bool
    67  		dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
    68  			Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
    69  			Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
    70  			Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
    71  			dialAddrCalled = true
    72  			return nil, errors.New("test done")
    73  		}
    74  		client.RoundTripOpt(req, RoundTripOpt{})
    75  		Expect(dialAddrCalled).To(BeTrue())
    76  	})
    77  
    78  	It("adds the port to the hostname, if none is given", func() {
    79  		client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
    80  		Expect(err).ToNot(HaveOccurred())
    81  		var dialAddrCalled bool
    82  		dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
    83  			Expect(hostname).To(Equal("quic.clemente.io:443"))
    84  			dialAddrCalled = true
    85  			return nil, errors.New("test done")
    86  		}
    87  		req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
    88  		Expect(err).ToNot(HaveOccurred())
    89  		client.RoundTripOpt(req, RoundTripOpt{})
    90  		Expect(dialAddrCalled).To(BeTrue())
    91  	})
    92  
    93  	It("sets the ServerName in the tls.Config, if not set", func() {
    94  		const host = "foo.bar"
    95  		dialCalled := false
    96  		dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
    97  			Expect(tlsCfg.ServerName).To(Equal(host))
    98  			dialCalled = true
    99  			return nil, errors.New("test done")
   100  		}
   101  		client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
   102  		Expect(err).ToNot(HaveOccurred())
   103  		req, err := http.NewRequest("GET", "https://foo.bar", nil)
   104  		Expect(err).ToNot(HaveOccurred())
   105  		client.RoundTripOpt(req, RoundTripOpt{})
   106  		Expect(dialCalled).To(BeTrue())
   107  	})
   108  
   109  	It("uses the TLS config and QUIC config", func() {
   110  		tlsConf := &tls.Config{
   111  			ServerName: "foo.bar",
   112  			NextProtos: []string{"proto foo", "proto bar"},
   113  		}
   114  		quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
   115  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
   116  		Expect(err).ToNot(HaveOccurred())
   117  		var dialAddrCalled bool
   118  		dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   119  			Expect(host).To(Equal("localhost:1337"))
   120  			Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
   121  			Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
   122  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   123  			dialAddrCalled = true
   124  			return nil, errors.New("test done")
   125  		}
   126  		client.RoundTripOpt(req, RoundTripOpt{})
   127  		Expect(dialAddrCalled).To(BeTrue())
   128  		// make sure the original tls.Config was not modified
   129  		Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
   130  	})
   131  
   132  	It("uses the custom dialer, if provided", func() {
   133  		testErr := errors.New("test done")
   134  		tlsConf := &tls.Config{ServerName: "foo.bar"}
   135  		quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
   136  		ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
   137  		defer cancel()
   138  		var dialerCalled bool
   139  		dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   140  			Expect(ctxP).To(Equal(ctx))
   141  			Expect(address).To(Equal("localhost:1337"))
   142  			Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
   143  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   144  			dialerCalled = true
   145  			return nil, testErr
   146  		}
   147  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
   148  		Expect(err).ToNot(HaveOccurred())
   149  		_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
   150  		Expect(err).To(MatchError(testErr))
   151  		Expect(dialerCalled).To(BeTrue())
   152  	})
   153  
   154  	It("enables HTTP/3 Datagrams", func() {
   155  		testErr := errors.New("handshake error")
   156  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
   157  		Expect(err).ToNot(HaveOccurred())
   158  		dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   159  			Expect(quicConf.EnableDatagrams).To(BeTrue())
   160  			return nil, testErr
   161  		}
   162  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   163  		Expect(err).To(MatchError(testErr))
   164  	})
   165  
   166  	It("errors when dialing fails", func() {
   167  		testErr := errors.New("handshake error")
   168  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   169  		Expect(err).ToNot(HaveOccurred())
   170  		dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   171  			return nil, testErr
   172  		}
   173  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   174  		Expect(err).To(MatchError(testErr))
   175  	})
   176  
   177  	It("closes correctly if connection was not created", func() {
   178  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   179  		Expect(err).ToNot(HaveOccurred())
   180  		Expect(client.Close()).To(Succeed())
   181  	})
   182  
   183  	Context("validating the address", func() {
   184  		It("refuses to do requests for the wrong host", func() {
   185  			req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
   186  			Expect(err).ToNot(HaveOccurred())
   187  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   188  			Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
   189  		})
   190  
   191  		It("allows requests using a different scheme", func() {
   192  			testErr := errors.New("handshake error")
   193  			req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
   194  			Expect(err).ToNot(HaveOccurred())
   195  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   196  				return nil, testErr
   197  			}
   198  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   199  			Expect(err).To(MatchError(testErr))
   200  		})
   201  	})
   202  
   203  	Context("hijacking bidirectional streams", func() {
   204  		var (
   205  			request              *http.Request
   206  			conn                 *mockquic.MockEarlyConnection
   207  			settingsFrameWritten chan struct{}
   208  		)
   209  		testDone := make(chan struct{})
   210  
   211  		BeforeEach(func() {
   212  			testDone = make(chan struct{})
   213  			settingsFrameWritten = make(chan struct{})
   214  			controlStr := mockquic.NewMockStream(mockCtrl)
   215  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   216  				defer GinkgoRecover()
   217  				close(settingsFrameWritten)
   218  			})
   219  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   220  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   221  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   222  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   223  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
   224  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   225  				return conn, nil
   226  			}
   227  			var err error
   228  			request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   229  			Expect(err).ToNot(HaveOccurred())
   230  		})
   231  
   232  		AfterEach(func() {
   233  			testDone <- struct{}{}
   234  			Eventually(settingsFrameWritten).Should(BeClosed())
   235  		})
   236  
   237  		It("hijacks a bidirectional stream of unknown frame type", func() {
   238  			frameTypeChan := make(chan FrameType, 1)
   239  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   240  				Expect(e).ToNot(HaveOccurred())
   241  				frameTypeChan <- ft
   242  				return true, nil
   243  			}
   244  
   245  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   246  			unknownStr := mockquic.NewMockStream(mockCtrl)
   247  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   248  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   249  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   250  				<-testDone
   251  				return nil, errors.New("test done")
   252  			})
   253  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   254  			Expect(err).To(MatchError("done"))
   255  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   256  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   257  		})
   258  
   259  		It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
   260  			frameTypeChan := make(chan FrameType, 1)
   261  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   262  				Expect(e).ToNot(HaveOccurred())
   263  				frameTypeChan <- ft
   264  				return false, nil
   265  			}
   266  
   267  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   268  			unknownStr := mockquic.NewMockStream(mockCtrl)
   269  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   270  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   271  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   272  				<-testDone
   273  				return nil, errors.New("test done")
   274  			})
   275  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   276  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   277  			Expect(err).To(MatchError("done"))
   278  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   279  		})
   280  
   281  		It("closes the connection when hijacker returned error", func() {
   282  			frameTypeChan := make(chan FrameType, 1)
   283  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   284  				Expect(e).ToNot(HaveOccurred())
   285  				frameTypeChan <- ft
   286  				return false, errors.New("error in hijacker")
   287  			}
   288  
   289  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   290  			unknownStr := mockquic.NewMockStream(mockCtrl)
   291  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   292  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   293  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   294  				<-testDone
   295  				return nil, errors.New("test done")
   296  			})
   297  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   298  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   299  			Expect(err).To(MatchError("done"))
   300  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   301  		})
   302  
   303  		It("handles errors that occur when reading the frame type", func() {
   304  			testErr := errors.New("test error")
   305  			unknownStr := mockquic.NewMockStream(mockCtrl)
   306  			done := make(chan struct{})
   307  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
   308  				defer close(done)
   309  				Expect(e).To(MatchError(testErr))
   310  				Expect(ft).To(BeZero())
   311  				Expect(str).To(Equal(unknownStr))
   312  				return false, nil
   313  			}
   314  
   315  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   316  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   317  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   318  				<-testDone
   319  				return nil, errors.New("test done")
   320  			})
   321  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   322  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   323  			Expect(err).To(MatchError("done"))
   324  			Eventually(done).Should(BeClosed())
   325  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   326  		})
   327  	})
   328  
   329  	Context("hijacking unidirectional streams", func() {
   330  		var (
   331  			req                  *http.Request
   332  			conn                 *mockquic.MockEarlyConnection
   333  			settingsFrameWritten chan struct{}
   334  		)
   335  		testDone := make(chan struct{})
   336  
   337  		BeforeEach(func() {
   338  			testDone = make(chan struct{})
   339  			settingsFrameWritten = make(chan struct{})
   340  			controlStr := mockquic.NewMockStream(mockCtrl)
   341  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   342  				defer GinkgoRecover()
   343  				close(settingsFrameWritten)
   344  			})
   345  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   346  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   347  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   348  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   349  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   350  				return conn, nil
   351  			}
   352  			var err error
   353  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   354  			Expect(err).ToNot(HaveOccurred())
   355  		})
   356  
   357  		AfterEach(func() {
   358  			testDone <- struct{}{}
   359  			Eventually(settingsFrameWritten).Should(BeClosed())
   360  		})
   361  
   362  		It("hijacks an unidirectional stream of unknown stream type", func() {
   363  			streamTypeChan := make(chan StreamType, 1)
   364  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   365  				Expect(err).ToNot(HaveOccurred())
   366  				streamTypeChan <- st
   367  				return true
   368  			}
   369  
   370  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   371  			unknownStr := mockquic.NewMockStream(mockCtrl)
   372  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   373  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   374  				return unknownStr, nil
   375  			})
   376  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   377  				<-testDone
   378  				return nil, errors.New("test done")
   379  			})
   380  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   381  			Expect(err).To(MatchError("done"))
   382  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   383  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   384  		})
   385  
   386  		It("handles errors that occur when reading the stream type", func() {
   387  			testErr := errors.New("test error")
   388  			done := make(chan struct{})
   389  			unknownStr := mockquic.NewMockStream(mockCtrl)
   390  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   391  				defer close(done)
   392  				Expect(st).To(BeZero())
   393  				Expect(str).To(Equal(unknownStr))
   394  				Expect(err).To(MatchError(testErr))
   395  				return true
   396  			}
   397  
   398  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
   399  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   400  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   401  				<-testDone
   402  				return nil, errors.New("test done")
   403  			})
   404  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   405  			Expect(err).To(MatchError("done"))
   406  			Eventually(done).Should(BeClosed())
   407  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   408  		})
   409  
   410  		It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   411  			streamTypeChan := make(chan StreamType, 1)
   412  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   413  				Expect(err).ToNot(HaveOccurred())
   414  				streamTypeChan <- st
   415  				return false
   416  			}
   417  
   418  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   419  			unknownStr := mockquic.NewMockStream(mockCtrl)
   420  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   421  			unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   422  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   423  				return unknownStr, nil
   424  			})
   425  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   426  				<-testDone
   427  				return nil, errors.New("test done")
   428  			})
   429  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   430  			Expect(err).To(MatchError("done"))
   431  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   432  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   433  		})
   434  	})
   435  
   436  	Context("control stream handling", func() {
   437  		var (
   438  			req                  *http.Request
   439  			conn                 *mockquic.MockEarlyConnection
   440  			settingsFrameWritten chan struct{}
   441  		)
   442  		testDone := make(chan struct{})
   443  
   444  		BeforeEach(func() {
   445  			settingsFrameWritten = make(chan struct{})
   446  			controlStr := mockquic.NewMockStream(mockCtrl)
   447  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   448  				defer GinkgoRecover()
   449  				close(settingsFrameWritten)
   450  			})
   451  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   452  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   453  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   454  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   455  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   456  				return conn, nil
   457  			}
   458  			var err error
   459  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   460  			Expect(err).ToNot(HaveOccurred())
   461  		})
   462  
   463  		AfterEach(func() {
   464  			testDone <- struct{}{}
   465  			Eventually(settingsFrameWritten).Should(BeClosed())
   466  		})
   467  
   468  		It("parses the SETTINGS frame", func() {
   469  			b := quicvarint.Append(nil, streamTypeControlStream)
   470  			b = (&settingsFrame{}).Append(b)
   471  			r := bytes.NewReader(b)
   472  			controlStr := mockquic.NewMockStream(mockCtrl)
   473  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   474  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   475  				return controlStr, nil
   476  			})
   477  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   478  				<-testDone
   479  				return nil, errors.New("test done")
   480  			})
   481  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   482  			Expect(err).To(MatchError("done"))
   483  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   484  		})
   485  
   486  		for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   487  			streamType := t
   488  			name := "encoder"
   489  			if streamType == streamTypeQPACKDecoderStream {
   490  				name = "decoder"
   491  			}
   492  
   493  			It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   494  				buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   495  				str := mockquic.NewMockStream(mockCtrl)
   496  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   497  
   498  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   499  					return str, nil
   500  				})
   501  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   502  					<-testDone
   503  					return nil, errors.New("test done")
   504  				})
   505  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   506  				Expect(err).To(MatchError("done"))
   507  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   508  			})
   509  		}
   510  
   511  		It("resets streams Other than the control stream and the QPACK streams", func() {
   512  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
   513  			str := mockquic.NewMockStream(mockCtrl)
   514  			str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   515  			done := make(chan struct{})
   516  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) {
   517  				close(done)
   518  			})
   519  
   520  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   521  				return str, nil
   522  			})
   523  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   524  				<-testDone
   525  				return nil, errors.New("test done")
   526  			})
   527  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   528  			Expect(err).To(MatchError("done"))
   529  			Eventually(done).Should(BeClosed())
   530  		})
   531  
   532  		It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   533  			b := quicvarint.Append(nil, streamTypeControlStream)
   534  			b = (&dataFrame{}).Append(b)
   535  			r := bytes.NewReader(b)
   536  			controlStr := mockquic.NewMockStream(mockCtrl)
   537  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   538  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   539  				return controlStr, nil
   540  			})
   541  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   542  				<-testDone
   543  				return nil, errors.New("test done")
   544  			})
   545  			done := make(chan struct{})
   546  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   547  				defer GinkgoRecover()
   548  				Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings))
   549  				close(done)
   550  			})
   551  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   552  			Expect(err).To(MatchError("done"))
   553  			Eventually(done).Should(BeClosed())
   554  		})
   555  
   556  		It("errors when parsing the frame on the control stream fails", func() {
   557  			b := quicvarint.Append(nil, streamTypeControlStream)
   558  			b = (&settingsFrame{}).Append(b)
   559  			r := bytes.NewReader(b[:len(b)-1])
   560  			controlStr := mockquic.NewMockStream(mockCtrl)
   561  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   562  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   563  				return controlStr, nil
   564  			})
   565  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   566  				<-testDone
   567  				return nil, errors.New("test done")
   568  			})
   569  			done := make(chan struct{})
   570  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   571  				defer GinkgoRecover()
   572  				Expect(code).To(BeEquivalentTo(ErrCodeFrameError))
   573  				close(done)
   574  			})
   575  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   576  			Expect(err).To(MatchError("done"))
   577  			Eventually(done).Should(BeClosed())
   578  		})
   579  
   580  		It("errors when parsing the server opens a push stream", func() {
   581  			buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
   582  			controlStr := mockquic.NewMockStream(mockCtrl)
   583  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   584  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   585  				return controlStr, nil
   586  			})
   587  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   588  				<-testDone
   589  				return nil, errors.New("test done")
   590  			})
   591  			done := make(chan struct{})
   592  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   593  				defer GinkgoRecover()
   594  				Expect(code).To(BeEquivalentTo(ErrCodeIDError))
   595  				close(done)
   596  			})
   597  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   598  			Expect(err).To(MatchError("done"))
   599  			Eventually(done).Should(BeClosed())
   600  		})
   601  
   602  		It("errors when the server advertises datagram support (and we enabled support for it)", func() {
   603  			cl.opts.EnableDatagram = true
   604  			b := quicvarint.Append(nil, streamTypeControlStream)
   605  			b = (&settingsFrame{Datagram: true}).Append(b)
   606  			r := bytes.NewReader(b)
   607  			controlStr := mockquic.NewMockStream(mockCtrl)
   608  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   609  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   610  				return controlStr, nil
   611  			})
   612  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   613  				<-testDone
   614  				return nil, errors.New("test done")
   615  			})
   616  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   617  			done := make(chan struct{})
   618  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
   619  				defer GinkgoRecover()
   620  				Expect(code).To(BeEquivalentTo(ErrCodeSettingsError))
   621  				Expect(reason).To(Equal("missing QUIC Datagram support"))
   622  				close(done)
   623  			})
   624  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   625  			Expect(err).To(MatchError("done"))
   626  			Eventually(done).Should(BeClosed())
   627  		})
   628  	})
   629  
   630  	Context("Doing requests", func() {
   631  		var (
   632  			req                  *http.Request
   633  			str                  *mockquic.MockStream
   634  			conn                 *mockquic.MockEarlyConnection
   635  			settingsFrameWritten chan struct{}
   636  		)
   637  		testDone := make(chan struct{})
   638  
   639  		getHeadersFrame := func(headers map[string]string) []byte {
   640  			headerBuf := &bytes.Buffer{}
   641  			enc := qpack.NewEncoder(headerBuf)
   642  			for name, value := range headers {
   643  				Expect(enc.WriteField(qpack.HeaderField{Name: name, Value: value})).To(Succeed())
   644  			}
   645  			Expect(enc.Close()).To(Succeed())
   646  			b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
   647  			b = append(b, headerBuf.Bytes()...)
   648  			return b
   649  		}
   650  
   651  		decodeHeader := func(str io.Reader) map[string]string {
   652  			fields := make(map[string]string)
   653  			decoder := qpack.NewDecoder(nil)
   654  
   655  			frame, err := parseNextFrame(str, nil)
   656  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   657  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   658  			headersFrame := frame.(*headersFrame)
   659  			data := make([]byte, headersFrame.Length)
   660  			_, err = io.ReadFull(str, data)
   661  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   662  			hfs, err := decoder.DecodeFull(data)
   663  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   664  			for _, p := range hfs {
   665  				fields[p.Name] = p.Value
   666  			}
   667  			return fields
   668  		}
   669  
   670  		getResponse := func(status int) []byte {
   671  			buf := &bytes.Buffer{}
   672  			rstr := mockquic.NewMockStream(mockCtrl)
   673  			rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   674  			rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   675  			rw.WriteHeader(status)
   676  			rw.Flush()
   677  			return buf.Bytes()
   678  		}
   679  
   680  		BeforeEach(func() {
   681  			settingsFrameWritten = make(chan struct{})
   682  			controlStr := mockquic.NewMockStream(mockCtrl)
   683  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   684  				defer GinkgoRecover()
   685  				r := bytes.NewReader(b)
   686  				streamType, err := quicvarint.Read(r)
   687  				Expect(err).ToNot(HaveOccurred())
   688  				Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
   689  				close(settingsFrameWritten)
   690  			}) // SETTINGS frame
   691  			str = mockquic.NewMockStream(mockCtrl)
   692  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   693  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   694  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   695  				<-testDone
   696  				return nil, errors.New("test done")
   697  			})
   698  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   699  				return conn, nil
   700  			}
   701  			var err error
   702  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   703  			Expect(err).ToNot(HaveOccurred())
   704  		})
   705  
   706  		AfterEach(func() {
   707  			testDone <- struct{}{}
   708  			Eventually(settingsFrameWritten).Should(BeClosed())
   709  		})
   710  
   711  		It("errors if it can't open a stream", func() {
   712  			testErr := errors.New("stream open error")
   713  			conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
   714  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
   715  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   716  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   717  			Expect(err).To(MatchError(testErr))
   718  		})
   719  
   720  		It("performs a 0-RTT request", func() {
   721  			testErr := errors.New("stream open error")
   722  			req.Method = MethodGet0RTT
   723  			// don't EXPECT any calls to HandshakeComplete()
   724  			conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   725  			buf := &bytes.Buffer{}
   726  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   727  			str.EXPECT().Close()
   728  			str.EXPECT().CancelWrite(gomock.Any())
   729  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   730  				return 0, testErr
   731  			})
   732  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   733  			Expect(err).To(MatchError(testErr))
   734  			Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
   735  		})
   736  
   737  		It("returns a response", func() {
   738  			rspBuf := bytes.NewBuffer(getResponse(418))
   739  			gomock.InOrder(
   740  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   741  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   742  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   743  			)
   744  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   745  			str.EXPECT().Close()
   746  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   747  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
   748  			Expect(err).ToNot(HaveOccurred())
   749  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   750  			Expect(rsp.ProtoMajor).To(Equal(3))
   751  			Expect(rsp.StatusCode).To(Equal(418))
   752  			Expect(rsp.Request).ToNot(BeNil())
   753  		})
   754  
   755  		It("doesn't close the request stream, with DontCloseRequestStream set", func() {
   756  			rspBuf := bytes.NewBuffer(getResponse(418))
   757  			gomock.InOrder(
   758  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   759  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   760  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   761  			)
   762  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   763  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   764  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
   765  			Expect(err).ToNot(HaveOccurred())
   766  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   767  			Expect(rsp.ProtoMajor).To(Equal(3))
   768  			Expect(rsp.StatusCode).To(Equal(418))
   769  		})
   770  
   771  		Context("requests containing a Body", func() {
   772  			var strBuf *bytes.Buffer
   773  
   774  			BeforeEach(func() {
   775  				strBuf = &bytes.Buffer{}
   776  				gomock.InOrder(
   777  					conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   778  					conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   779  				)
   780  				body := &mockBody{}
   781  				body.SetData([]byte("request body"))
   782  				var err error
   783  				req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
   784  				Expect(err).ToNot(HaveOccurred())
   785  				str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
   786  			})
   787  
   788  			It("sends a request", func() {
   789  				done := make(chan struct{})
   790  				gomock.InOrder(
   791  					str.EXPECT().Close().Do(func() { close(done) }),
   792  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
   793  				)
   794  				// the response body is sent asynchronously, while already reading the response
   795  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   796  					<-done
   797  					return 0, errors.New("test done")
   798  				})
   799  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   800  				Expect(err).To(MatchError("test done"))
   801  				hfs := decodeHeader(strBuf)
   802  				Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
   803  				Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
   804  			})
   805  
   806  			It("returns the error that occurred when reading the body", func() {
   807  				req.Body.(*mockBody).readErr = errors.New("testErr")
   808  				done := make(chan struct{})
   809  				gomock.InOrder(
   810  					str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
   811  						close(done)
   812  					}),
   813  					str.EXPECT().CancelWrite(gomock.Any()),
   814  				)
   815  
   816  				// the response body is sent asynchronously, while already reading the response
   817  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   818  					<-done
   819  					return 0, errors.New("test done")
   820  				})
   821  				closed := make(chan struct{})
   822  				str.EXPECT().Close().Do(func() { close(closed) })
   823  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   824  				Expect(err).To(MatchError("test done"))
   825  				Eventually(closed).Should(BeClosed())
   826  			})
   827  
   828  			It("sets the Content-Length", func() {
   829  				done := make(chan struct{})
   830  				b := getHeadersFrame(map[string]string{
   831  					":status":        "200",
   832  					"Content-Length": "1337",
   833  				})
   834  				b = (&dataFrame{Length: 0x6}).Append(b)
   835  				b = append(b, []byte("foobar")...)
   836  				r := bytes.NewReader(b)
   837  				str.EXPECT().Close().Do(func() { close(done) })
   838  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   839  				str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
   840  				// the response body is sent asynchronously, while already reading the response
   841  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   842  				req, err := cl.RoundTripOpt(req, RoundTripOpt{})
   843  				Expect(err).ToNot(HaveOccurred())
   844  				Expect(req.ContentLength).To(BeEquivalentTo(1337))
   845  				Eventually(done).Should(BeClosed())
   846  			})
   847  
   848  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   849  				b := (&dataFrame{Length: 0x42}).Append(nil)
   850  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
   851  				closed := make(chan struct{})
   852  				r := bytes.NewReader(b)
   853  				str.EXPECT().Close().Do(func() { close(closed) })
   854  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   855  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   856  				Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
   857  				Eventually(closed).Should(BeClosed())
   858  			})
   859  
   860  			It("cancels the stream when the HEADERS frame is too large", func() {
   861  				b := (&headersFrame{Length: 1338}).Append(nil)
   862  				r := bytes.NewReader(b)
   863  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
   864  				closed := make(chan struct{})
   865  				str.EXPECT().Close().Do(func() { close(closed) })
   866  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   867  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   868  				Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
   869  				Eventually(closed).Should(BeClosed())
   870  			})
   871  		})
   872  
   873  		Context("request cancellations", func() {
   874  			for _, dontClose := range []bool{false, true} {
   875  				dontClose := dontClose
   876  
   877  				Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
   878  					roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
   879  
   880  					It("cancels a request while waiting for the handshake to complete", func() {
   881  						ctx, cancel := context.WithCancel(context.Background())
   882  						req := req.WithContext(ctx)
   883  						conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   884  
   885  						errChan := make(chan error)
   886  						go func() {
   887  							_, err := cl.RoundTripOpt(req, roundTripOpt)
   888  							errChan <- err
   889  						}()
   890  						Consistently(errChan).ShouldNot(Receive())
   891  						cancel()
   892  						Eventually(errChan).Should(Receive(MatchError("context canceled")))
   893  					})
   894  
   895  					It("cancels a request while the request is still in flight", func() {
   896  						ctx, cancel := context.WithCancel(context.Background())
   897  						req := req.WithContext(ctx)
   898  						conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   899  						conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   900  						buf := &bytes.Buffer{}
   901  						str.EXPECT().Close().MaxTimes(1)
   902  
   903  						str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   904  
   905  						done := make(chan struct{})
   906  						canceled := make(chan struct{})
   907  						gomock.InOrder(
   908  							str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
   909  							str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
   910  						)
   911  						str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
   912  						str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   913  							cancel()
   914  							<-canceled
   915  							return 0, errors.New("test done")
   916  						})
   917  						_, err := cl.RoundTripOpt(req, roundTripOpt)
   918  						Expect(err).To(MatchError("test done"))
   919  						Eventually(done).Should(BeClosed())
   920  					})
   921  				})
   922  			}
   923  
   924  			It("cancels a request after the response arrived", func() {
   925  				rspBuf := bytes.NewBuffer(getResponse(404))
   926  
   927  				ctx, cancel := context.WithCancel(context.Background())
   928  				req := req.WithContext(ctx)
   929  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   930  				conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   931  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   932  				buf := &bytes.Buffer{}
   933  				str.EXPECT().Close().MaxTimes(1)
   934  
   935  				done := make(chan struct{})
   936  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   937  				str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   938  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   939  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
   940  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   941  				Expect(err).ToNot(HaveOccurred())
   942  				cancel()
   943  				Eventually(done).Should(BeClosed())
   944  			})
   945  		})
   946  
   947  		Context("gzip compression", func() {
   948  			BeforeEach(func() {
   949  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   950  			})
   951  
   952  			It("adds the gzip header to requests", func() {
   953  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   954  				buf := &bytes.Buffer{}
   955  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   956  				gomock.InOrder(
   957  					str.EXPECT().Close(),
   958  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   959  				)
   960  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   961  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   962  				Expect(err).To(MatchError("test done"))
   963  				hfs := decodeHeader(buf)
   964  				Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
   965  			})
   966  
   967  			It("doesn't add gzip if the header disable it", func() {
   968  				client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
   969  				Expect(err).ToNot(HaveOccurred())
   970  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   971  				buf := &bytes.Buffer{}
   972  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   973  				gomock.InOrder(
   974  					str.EXPECT().Close(),
   975  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   976  				)
   977  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   978  				_, err = client.RoundTripOpt(req, RoundTripOpt{})
   979  				Expect(err).To(MatchError("test done"))
   980  				hfs := decodeHeader(buf)
   981  				Expect(hfs).ToNot(HaveKey("accept-encoding"))
   982  			})
   983  
   984  			It("decompresses the response", func() {
   985  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   986  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   987  				buf := &bytes.Buffer{}
   988  				rstr := mockquic.NewMockStream(mockCtrl)
   989  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   990  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   991  				rw.Header().Set("Content-Encoding", "gzip")
   992  				gz := gzip.NewWriter(rw)
   993  				gz.Write([]byte("gzipped response"))
   994  				gz.Close()
   995  				rw.Flush()
   996  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   997  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   998  				str.EXPECT().Close()
   999  
  1000  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1001  				Expect(err).ToNot(HaveOccurred())
  1002  				data, err := io.ReadAll(rsp.Body)
  1003  				Expect(err).ToNot(HaveOccurred())
  1004  				Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
  1005  				Expect(string(data)).To(Equal("gzipped response"))
  1006  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1007  				Expect(rsp.Uncompressed).To(BeTrue())
  1008  			})
  1009  
  1010  			It("only decompresses the response if the response contains the right content-encoding header", func() {
  1011  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1012  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1013  				buf := &bytes.Buffer{}
  1014  				rstr := mockquic.NewMockStream(mockCtrl)
  1015  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1016  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1017  				rw.Write([]byte("not gzipped"))
  1018  				rw.Flush()
  1019  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1020  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1021  				str.EXPECT().Close()
  1022  
  1023  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1024  				Expect(err).ToNot(HaveOccurred())
  1025  				data, err := io.ReadAll(rsp.Body)
  1026  				Expect(err).ToNot(HaveOccurred())
  1027  				Expect(string(data)).To(Equal("not gzipped"))
  1028  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1029  			})
  1030  		})
  1031  	})
  1032  })