github.com/MerlinKodo/quic-go@v0.39.2/http3/client_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/MerlinKodo/quic-go"
    16  	mockquic "github.com/MerlinKodo/quic-go/internal/mocks/quic"
    17  	"github.com/MerlinKodo/quic-go/internal/protocol"
    18  	"github.com/MerlinKodo/quic-go/internal/utils"
    19  	"github.com/MerlinKodo/quic-go/quicvarint"
    20  
    21  	"github.com/quic-go/qpack"
    22  
    23  	. "github.com/onsi/ginkgo/v2"
    24  	. "github.com/onsi/gomega"
    25  	"go.uber.org/mock/gomock"
    26  )
    27  
    28  var _ = Describe("Client", func() {
    29  	var (
    30  		cl            *client
    31  		req           *http.Request
    32  		origDialAddr  = dialAddr
    33  		handshakeChan <-chan struct{} // a closed chan
    34  	)
    35  
    36  	BeforeEach(func() {
    37  		origDialAddr = dialAddr
    38  		hostname := "quic.clemente.io:1337"
    39  		c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
    40  		Expect(err).ToNot(HaveOccurred())
    41  		cl = c.(*client)
    42  		Expect(cl.hostname).To(Equal(hostname))
    43  
    44  		req, err = http.NewRequest("GET", "https://localhost:1337", nil)
    45  		Expect(err).ToNot(HaveOccurred())
    46  
    47  		ch := make(chan struct{})
    48  		close(ch)
    49  		handshakeChan = ch
    50  	})
    51  
    52  	AfterEach(func() {
    53  		dialAddr = origDialAddr
    54  	})
    55  
    56  	It("rejects quic.Configs that allow multiple QUIC versions", func() {
    57  		qconf := &quic.Config{
    58  			Versions: []quic.VersionNumber{protocol.Version2, protocol.Version1},
    59  		}
    60  		_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
    61  		Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
    62  	})
    63  
    64  	It("uses the default QUIC and TLS config if none is give", func() {
    65  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
    66  		Expect(err).ToNot(HaveOccurred())
    67  		var dialAddrCalled bool
    68  		dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
    69  			Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
    70  			Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
    71  			Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
    72  			dialAddrCalled = true
    73  			return nil, errors.New("test done")
    74  		}
    75  		client.RoundTripOpt(req, RoundTripOpt{})
    76  		Expect(dialAddrCalled).To(BeTrue())
    77  	})
    78  
    79  	It("adds the port to the hostname, if none is given", func() {
    80  		client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		var dialAddrCalled bool
    83  		dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
    84  			Expect(hostname).To(Equal("quic.clemente.io:443"))
    85  			dialAddrCalled = true
    86  			return nil, errors.New("test done")
    87  		}
    88  		req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
    89  		Expect(err).ToNot(HaveOccurred())
    90  		client.RoundTripOpt(req, RoundTripOpt{})
    91  		Expect(dialAddrCalled).To(BeTrue())
    92  	})
    93  
    94  	It("sets the ServerName in the tls.Config, if not set", func() {
    95  		const host = "foo.bar"
    96  		dialCalled := false
    97  		dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
    98  			Expect(tlsCfg.ServerName).To(Equal(host))
    99  			dialCalled = true
   100  			return nil, errors.New("test done")
   101  		}
   102  		client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
   103  		Expect(err).ToNot(HaveOccurred())
   104  		req, err := http.NewRequest("GET", "https://foo.bar", nil)
   105  		Expect(err).ToNot(HaveOccurred())
   106  		client.RoundTripOpt(req, RoundTripOpt{})
   107  		Expect(dialCalled).To(BeTrue())
   108  	})
   109  
   110  	It("uses the TLS config and QUIC config", func() {
   111  		tlsConf := &tls.Config{
   112  			ServerName: "foo.bar",
   113  			NextProtos: []string{"proto foo", "proto bar"},
   114  		}
   115  		quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
   116  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
   117  		Expect(err).ToNot(HaveOccurred())
   118  		var dialAddrCalled bool
   119  		dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   120  			Expect(host).To(Equal("localhost:1337"))
   121  			Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
   122  			Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
   123  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   124  			dialAddrCalled = true
   125  			return nil, errors.New("test done")
   126  		}
   127  		client.RoundTripOpt(req, RoundTripOpt{})
   128  		Expect(dialAddrCalled).To(BeTrue())
   129  		// make sure the original tls.Config was not modified
   130  		Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
   131  	})
   132  
   133  	It("uses the custom dialer, if provided", func() {
   134  		testErr := errors.New("test done")
   135  		tlsConf := &tls.Config{ServerName: "foo.bar"}
   136  		quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
   137  		ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
   138  		defer cancel()
   139  		var dialerCalled bool
   140  		dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   141  			Expect(ctxP).To(Equal(ctx))
   142  			Expect(address).To(Equal("localhost:1337"))
   143  			Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
   144  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   145  			dialerCalled = true
   146  			return nil, testErr
   147  		}
   148  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
   149  		Expect(err).ToNot(HaveOccurred())
   150  		_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
   151  		Expect(err).To(MatchError(testErr))
   152  		Expect(dialerCalled).To(BeTrue())
   153  	})
   154  
   155  	It("enables HTTP/3 Datagrams", func() {
   156  		testErr := errors.New("handshake error")
   157  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
   158  		Expect(err).ToNot(HaveOccurred())
   159  		dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   160  			Expect(quicConf.EnableDatagrams).To(BeTrue())
   161  			return nil, testErr
   162  		}
   163  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   164  		Expect(err).To(MatchError(testErr))
   165  	})
   166  
   167  	It("errors when dialing fails", func() {
   168  		testErr := errors.New("handshake error")
   169  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   170  		Expect(err).ToNot(HaveOccurred())
   171  		dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   172  			return nil, testErr
   173  		}
   174  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   175  		Expect(err).To(MatchError(testErr))
   176  	})
   177  
   178  	It("closes correctly if connection was not created", func() {
   179  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   180  		Expect(err).ToNot(HaveOccurred())
   181  		Expect(client.Close()).To(Succeed())
   182  	})
   183  
   184  	Context("validating the address", func() {
   185  		It("refuses to do requests for the wrong host", func() {
   186  			req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
   187  			Expect(err).ToNot(HaveOccurred())
   188  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   189  			Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
   190  		})
   191  
   192  		It("allows requests using a different scheme", func() {
   193  			testErr := errors.New("handshake error")
   194  			req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
   195  			Expect(err).ToNot(HaveOccurred())
   196  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   197  				return nil, testErr
   198  			}
   199  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   200  			Expect(err).To(MatchError(testErr))
   201  		})
   202  	})
   203  
   204  	Context("hijacking bidirectional streams", func() {
   205  		var (
   206  			request              *http.Request
   207  			conn                 *mockquic.MockEarlyConnection
   208  			settingsFrameWritten chan struct{}
   209  		)
   210  		testDone := make(chan struct{})
   211  
   212  		BeforeEach(func() {
   213  			testDone = make(chan struct{})
   214  			settingsFrameWritten = make(chan struct{})
   215  			controlStr := mockquic.NewMockStream(mockCtrl)
   216  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   217  				defer GinkgoRecover()
   218  				close(settingsFrameWritten)
   219  			})
   220  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   221  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   222  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   223  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   224  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
   225  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   226  				return conn, nil
   227  			}
   228  			var err error
   229  			request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   230  			Expect(err).ToNot(HaveOccurred())
   231  		})
   232  
   233  		AfterEach(func() {
   234  			testDone <- struct{}{}
   235  			Eventually(settingsFrameWritten).Should(BeClosed())
   236  		})
   237  
   238  		It("hijacks a bidirectional stream of unknown frame type", func() {
   239  			frameTypeChan := make(chan FrameType, 1)
   240  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   241  				Expect(e).ToNot(HaveOccurred())
   242  				frameTypeChan <- ft
   243  				return true, nil
   244  			}
   245  
   246  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   247  			unknownStr := mockquic.NewMockStream(mockCtrl)
   248  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   249  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   250  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   251  				<-testDone
   252  				return nil, errors.New("test done")
   253  			})
   254  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   255  			Expect(err).To(MatchError("done"))
   256  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   257  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   258  		})
   259  
   260  		It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
   261  			frameTypeChan := make(chan FrameType, 1)
   262  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   263  				Expect(e).ToNot(HaveOccurred())
   264  				frameTypeChan <- ft
   265  				return false, nil
   266  			}
   267  
   268  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   269  			unknownStr := mockquic.NewMockStream(mockCtrl)
   270  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   271  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   272  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   273  				<-testDone
   274  				return nil, errors.New("test done")
   275  			})
   276  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   277  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   278  			Expect(err).To(MatchError("done"))
   279  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   280  		})
   281  
   282  		It("closes the connection when hijacker returned error", func() {
   283  			frameTypeChan := make(chan FrameType, 1)
   284  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   285  				Expect(e).ToNot(HaveOccurred())
   286  				frameTypeChan <- ft
   287  				return false, errors.New("error in hijacker")
   288  			}
   289  
   290  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   291  			unknownStr := mockquic.NewMockStream(mockCtrl)
   292  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   293  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   294  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   295  				<-testDone
   296  				return nil, errors.New("test done")
   297  			})
   298  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   299  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   300  			Expect(err).To(MatchError("done"))
   301  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   302  		})
   303  
   304  		It("handles errors that occur when reading the frame type", func() {
   305  			testErr := errors.New("test error")
   306  			unknownStr := mockquic.NewMockStream(mockCtrl)
   307  			done := make(chan struct{})
   308  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
   309  				defer close(done)
   310  				Expect(e).To(MatchError(testErr))
   311  				Expect(ft).To(BeZero())
   312  				Expect(str).To(Equal(unknownStr))
   313  				return false, nil
   314  			}
   315  
   316  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   317  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   318  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   319  				<-testDone
   320  				return nil, errors.New("test done")
   321  			})
   322  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   323  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   324  			Expect(err).To(MatchError("done"))
   325  			Eventually(done).Should(BeClosed())
   326  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   327  		})
   328  	})
   329  
   330  	Context("hijacking unidirectional streams", func() {
   331  		var (
   332  			req                  *http.Request
   333  			conn                 *mockquic.MockEarlyConnection
   334  			settingsFrameWritten chan struct{}
   335  		)
   336  		testDone := make(chan struct{})
   337  
   338  		BeforeEach(func() {
   339  			testDone = make(chan struct{})
   340  			settingsFrameWritten = make(chan struct{})
   341  			controlStr := mockquic.NewMockStream(mockCtrl)
   342  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   343  				defer GinkgoRecover()
   344  				close(settingsFrameWritten)
   345  			})
   346  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   347  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   348  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   349  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   350  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   351  				return conn, nil
   352  			}
   353  			var err error
   354  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   355  			Expect(err).ToNot(HaveOccurred())
   356  		})
   357  
   358  		AfterEach(func() {
   359  			testDone <- struct{}{}
   360  			Eventually(settingsFrameWritten).Should(BeClosed())
   361  		})
   362  
   363  		It("hijacks an unidirectional stream of unknown stream type", func() {
   364  			streamTypeChan := make(chan StreamType, 1)
   365  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   366  				Expect(err).ToNot(HaveOccurred())
   367  				streamTypeChan <- st
   368  				return true
   369  			}
   370  
   371  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   372  			unknownStr := mockquic.NewMockStream(mockCtrl)
   373  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   374  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   375  				return unknownStr, nil
   376  			})
   377  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   378  				<-testDone
   379  				return nil, errors.New("test done")
   380  			})
   381  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   382  			Expect(err).To(MatchError("done"))
   383  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   384  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   385  		})
   386  
   387  		It("handles errors that occur when reading the stream type", func() {
   388  			testErr := errors.New("test error")
   389  			done := make(chan struct{})
   390  			unknownStr := mockquic.NewMockStream(mockCtrl)
   391  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   392  				defer close(done)
   393  				Expect(st).To(BeZero())
   394  				Expect(str).To(Equal(unknownStr))
   395  				Expect(err).To(MatchError(testErr))
   396  				return true
   397  			}
   398  
   399  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
   400  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   401  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   402  				<-testDone
   403  				return nil, errors.New("test done")
   404  			})
   405  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   406  			Expect(err).To(MatchError("done"))
   407  			Eventually(done).Should(BeClosed())
   408  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   409  		})
   410  
   411  		It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   412  			streamTypeChan := make(chan StreamType, 1)
   413  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   414  				Expect(err).ToNot(HaveOccurred())
   415  				streamTypeChan <- st
   416  				return false
   417  			}
   418  
   419  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   420  			unknownStr := mockquic.NewMockStream(mockCtrl)
   421  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   422  			unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   423  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   424  				return unknownStr, nil
   425  			})
   426  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   427  				<-testDone
   428  				return nil, errors.New("test done")
   429  			})
   430  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   431  			Expect(err).To(MatchError("done"))
   432  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   433  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   434  		})
   435  	})
   436  
   437  	Context("control stream handling", func() {
   438  		var (
   439  			req                  *http.Request
   440  			conn                 *mockquic.MockEarlyConnection
   441  			settingsFrameWritten chan struct{}
   442  		)
   443  		testDone := make(chan struct{})
   444  
   445  		BeforeEach(func() {
   446  			settingsFrameWritten = make(chan struct{})
   447  			controlStr := mockquic.NewMockStream(mockCtrl)
   448  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   449  				defer GinkgoRecover()
   450  				close(settingsFrameWritten)
   451  			})
   452  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   453  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   454  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   455  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   456  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   457  				return conn, nil
   458  			}
   459  			var err error
   460  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   461  			Expect(err).ToNot(HaveOccurred())
   462  		})
   463  
   464  		AfterEach(func() {
   465  			testDone <- struct{}{}
   466  			Eventually(settingsFrameWritten).Should(BeClosed())
   467  		})
   468  
   469  		It("parses the SETTINGS frame", func() {
   470  			b := quicvarint.Append(nil, streamTypeControlStream)
   471  			b = (&settingsFrame{}).Append(b)
   472  			r := bytes.NewReader(b)
   473  			controlStr := mockquic.NewMockStream(mockCtrl)
   474  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   475  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   476  				return controlStr, nil
   477  			})
   478  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   479  				<-testDone
   480  				return nil, errors.New("test done")
   481  			})
   482  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   483  			Expect(err).To(MatchError("done"))
   484  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   485  		})
   486  
   487  		for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   488  			streamType := t
   489  			name := "encoder"
   490  			if streamType == streamTypeQPACKDecoderStream {
   491  				name = "decoder"
   492  			}
   493  
   494  			It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   495  				buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   496  				str := mockquic.NewMockStream(mockCtrl)
   497  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   498  
   499  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   500  					return str, nil
   501  				})
   502  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   503  					<-testDone
   504  					return nil, errors.New("test done")
   505  				})
   506  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   507  				Expect(err).To(MatchError("done"))
   508  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   509  			})
   510  		}
   511  
   512  		It("resets streams Other than the control stream and the QPACK streams", func() {
   513  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
   514  			str := mockquic.NewMockStream(mockCtrl)
   515  			str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   516  			done := make(chan struct{})
   517  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) {
   518  				close(done)
   519  			})
   520  
   521  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   522  				return str, nil
   523  			})
   524  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   525  				<-testDone
   526  				return nil, errors.New("test done")
   527  			})
   528  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   529  			Expect(err).To(MatchError("done"))
   530  			Eventually(done).Should(BeClosed())
   531  		})
   532  
   533  		It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   534  			b := quicvarint.Append(nil, streamTypeControlStream)
   535  			b = (&dataFrame{}).Append(b)
   536  			r := bytes.NewReader(b)
   537  			controlStr := mockquic.NewMockStream(mockCtrl)
   538  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   539  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   540  				return controlStr, nil
   541  			})
   542  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   543  				<-testDone
   544  				return nil, errors.New("test done")
   545  			})
   546  			done := make(chan struct{})
   547  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   548  				defer GinkgoRecover()
   549  				Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings))
   550  				close(done)
   551  			})
   552  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   553  			Expect(err).To(MatchError("done"))
   554  			Eventually(done).Should(BeClosed())
   555  		})
   556  
   557  		It("errors when parsing the frame on the control stream fails", func() {
   558  			b := quicvarint.Append(nil, streamTypeControlStream)
   559  			b = (&settingsFrame{}).Append(b)
   560  			r := bytes.NewReader(b[:len(b)-1])
   561  			controlStr := mockquic.NewMockStream(mockCtrl)
   562  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   563  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   564  				return controlStr, nil
   565  			})
   566  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   567  				<-testDone
   568  				return nil, errors.New("test done")
   569  			})
   570  			done := make(chan struct{})
   571  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   572  				defer GinkgoRecover()
   573  				Expect(code).To(BeEquivalentTo(ErrCodeFrameError))
   574  				close(done)
   575  			})
   576  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   577  			Expect(err).To(MatchError("done"))
   578  			Eventually(done).Should(BeClosed())
   579  		})
   580  
   581  		It("errors when parsing the server opens a push stream", func() {
   582  			buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
   583  			controlStr := mockquic.NewMockStream(mockCtrl)
   584  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   585  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   586  				return controlStr, nil
   587  			})
   588  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   589  				<-testDone
   590  				return nil, errors.New("test done")
   591  			})
   592  			done := make(chan struct{})
   593  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   594  				defer GinkgoRecover()
   595  				Expect(code).To(BeEquivalentTo(ErrCodeIDError))
   596  				close(done)
   597  			})
   598  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   599  			Expect(err).To(MatchError("done"))
   600  			Eventually(done).Should(BeClosed())
   601  		})
   602  
   603  		It("errors when the server advertises datagram support (and we enabled support for it)", func() {
   604  			cl.opts.EnableDatagram = true
   605  			b := quicvarint.Append(nil, streamTypeControlStream)
   606  			b = (&settingsFrame{Datagram: true}).Append(b)
   607  			r := bytes.NewReader(b)
   608  			controlStr := mockquic.NewMockStream(mockCtrl)
   609  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   610  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   611  				return controlStr, nil
   612  			})
   613  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   614  				<-testDone
   615  				return nil, errors.New("test done")
   616  			})
   617  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   618  			done := make(chan struct{})
   619  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
   620  				defer GinkgoRecover()
   621  				Expect(code).To(BeEquivalentTo(ErrCodeSettingsError))
   622  				Expect(reason).To(Equal("missing QUIC Datagram support"))
   623  				close(done)
   624  			})
   625  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   626  			Expect(err).To(MatchError("done"))
   627  			Eventually(done).Should(BeClosed())
   628  		})
   629  	})
   630  
   631  	Context("Doing requests", func() {
   632  		var (
   633  			req                  *http.Request
   634  			str                  *mockquic.MockStream
   635  			conn                 *mockquic.MockEarlyConnection
   636  			settingsFrameWritten chan struct{}
   637  		)
   638  		testDone := make(chan struct{})
   639  
   640  		decodeHeader := func(str io.Reader) map[string]string {
   641  			fields := make(map[string]string)
   642  			decoder := qpack.NewDecoder(nil)
   643  
   644  			frame, err := parseNextFrame(str, nil)
   645  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   646  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   647  			headersFrame := frame.(*headersFrame)
   648  			data := make([]byte, headersFrame.Length)
   649  			_, err = io.ReadFull(str, data)
   650  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   651  			hfs, err := decoder.DecodeFull(data)
   652  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   653  			for _, p := range hfs {
   654  				fields[p.Name] = p.Value
   655  			}
   656  			return fields
   657  		}
   658  
   659  		getResponse := func(status int) []byte {
   660  			buf := &bytes.Buffer{}
   661  			rstr := mockquic.NewMockStream(mockCtrl)
   662  			rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   663  			rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   664  			rw.WriteHeader(status)
   665  			rw.Flush()
   666  			return buf.Bytes()
   667  		}
   668  
   669  		BeforeEach(func() {
   670  			settingsFrameWritten = make(chan struct{})
   671  			controlStr := mockquic.NewMockStream(mockCtrl)
   672  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
   673  				defer GinkgoRecover()
   674  				r := bytes.NewReader(b)
   675  				streamType, err := quicvarint.Read(r)
   676  				Expect(err).ToNot(HaveOccurred())
   677  				Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
   678  				close(settingsFrameWritten)
   679  			}) // SETTINGS frame
   680  			str = mockquic.NewMockStream(mockCtrl)
   681  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   682  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   683  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   684  				<-testDone
   685  				return nil, errors.New("test done")
   686  			})
   687  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   688  				return conn, nil
   689  			}
   690  			var err error
   691  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   692  			Expect(err).ToNot(HaveOccurred())
   693  		})
   694  
   695  		AfterEach(func() {
   696  			testDone <- struct{}{}
   697  			Eventually(settingsFrameWritten).Should(BeClosed())
   698  		})
   699  
   700  		It("errors if it can't open a stream", func() {
   701  			testErr := errors.New("stream open error")
   702  			conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
   703  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
   704  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   705  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   706  			Expect(err).To(MatchError(testErr))
   707  		})
   708  
   709  		It("performs a 0-RTT request", func() {
   710  			testErr := errors.New("stream open error")
   711  			req.Method = MethodGet0RTT
   712  			// don't EXPECT any calls to HandshakeComplete()
   713  			conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   714  			buf := &bytes.Buffer{}
   715  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   716  			str.EXPECT().Close()
   717  			str.EXPECT().CancelWrite(gomock.Any())
   718  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   719  				return 0, testErr
   720  			})
   721  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   722  			Expect(err).To(MatchError(testErr))
   723  			Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
   724  		})
   725  
   726  		It("returns a response", func() {
   727  			rspBuf := bytes.NewBuffer(getResponse(418))
   728  			gomock.InOrder(
   729  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   730  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   731  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   732  			)
   733  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   734  			str.EXPECT().Close()
   735  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   736  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
   737  			Expect(err).ToNot(HaveOccurred())
   738  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   739  			Expect(rsp.ProtoMajor).To(Equal(3))
   740  			Expect(rsp.StatusCode).To(Equal(418))
   741  			Expect(rsp.Request).ToNot(BeNil())
   742  		})
   743  
   744  		It("doesn't close the request stream, with DontCloseRequestStream set", func() {
   745  			rspBuf := bytes.NewBuffer(getResponse(418))
   746  			gomock.InOrder(
   747  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   748  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   749  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   750  			)
   751  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   752  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   753  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
   754  			Expect(err).ToNot(HaveOccurred())
   755  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   756  			Expect(rsp.ProtoMajor).To(Equal(3))
   757  			Expect(rsp.StatusCode).To(Equal(418))
   758  		})
   759  
   760  		Context("requests containing a Body", func() {
   761  			var strBuf *bytes.Buffer
   762  
   763  			BeforeEach(func() {
   764  				strBuf = &bytes.Buffer{}
   765  				gomock.InOrder(
   766  					conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   767  					conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   768  				)
   769  				body := &mockBody{}
   770  				body.SetData([]byte("request body"))
   771  				var err error
   772  				req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
   773  				Expect(err).ToNot(HaveOccurred())
   774  				str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
   775  			})
   776  
   777  			It("sends a request", func() {
   778  				done := make(chan struct{})
   779  				gomock.InOrder(
   780  					str.EXPECT().Close().Do(func() { close(done) }),
   781  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
   782  				)
   783  				// the response body is sent asynchronously, while already reading the response
   784  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   785  					<-done
   786  					return 0, errors.New("test done")
   787  				})
   788  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   789  				Expect(err).To(MatchError("test done"))
   790  				hfs := decodeHeader(strBuf)
   791  				Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
   792  				Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
   793  			})
   794  
   795  			It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
   796  				req.ContentLength = 7
   797  				var once sync.Once
   798  				done := make(chan struct{})
   799  				gomock.InOrder(
   800  					str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
   801  						once.Do(func() {
   802  							Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
   803  							close(done)
   804  						})
   805  					}).AnyTimes(),
   806  					str.EXPECT().Close().MaxTimes(1),
   807  					str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
   808  				)
   809  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   810  					<-done
   811  					return 0, errors.New("done")
   812  				})
   813  				cl.RoundTripOpt(req, RoundTripOpt{})
   814  				Expect(strBuf.String()).To(ContainSubstring("request"))
   815  				Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
   816  			})
   817  
   818  			It("returns the error that occurred when reading the body", func() {
   819  				req.Body.(*mockBody).readErr = errors.New("testErr")
   820  				done := make(chan struct{})
   821  				gomock.InOrder(
   822  					str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
   823  						close(done)
   824  					}),
   825  					str.EXPECT().CancelWrite(gomock.Any()),
   826  				)
   827  
   828  				// the response body is sent asynchronously, while already reading the response
   829  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   830  					<-done
   831  					return 0, errors.New("test done")
   832  				})
   833  				closed := make(chan struct{})
   834  				str.EXPECT().Close().Do(func() { close(closed) })
   835  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   836  				Expect(err).To(MatchError("test done"))
   837  				Eventually(closed).Should(BeClosed())
   838  			})
   839  
   840  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   841  				b := (&dataFrame{Length: 0x42}).Append(nil)
   842  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
   843  				closed := make(chan struct{})
   844  				r := bytes.NewReader(b)
   845  				str.EXPECT().Close().Do(func() { close(closed) })
   846  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   847  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   848  				Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
   849  				Eventually(closed).Should(BeClosed())
   850  			})
   851  
   852  			It("cancels the stream when parsing the headers fails", func() {
   853  				headerBuf := &bytes.Buffer{}
   854  				enc := qpack.NewEncoder(headerBuf)
   855  				Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header
   856  				Expect(enc.Close()).To(Succeed())
   857  				b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
   858  				b = append(b, headerBuf.Bytes()...)
   859  
   860  				r := bytes.NewReader(b)
   861  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   862  				closed := make(chan struct{})
   863  				str.EXPECT().Close().Do(func() { close(closed) })
   864  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   865  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   866  				Expect(err).To(HaveOccurred())
   867  				Eventually(closed).Should(BeClosed())
   868  			})
   869  
   870  			It("cancels the stream when the HEADERS frame is too large", func() {
   871  				b := (&headersFrame{Length: 1338}).Append(nil)
   872  				r := bytes.NewReader(b)
   873  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
   874  				closed := make(chan struct{})
   875  				str.EXPECT().Close().Do(func() { close(closed) })
   876  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   877  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   878  				Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
   879  				Eventually(closed).Should(BeClosed())
   880  			})
   881  		})
   882  
   883  		Context("request cancellations", func() {
   884  			for _, dontClose := range []bool{false, true} {
   885  				dontClose := dontClose
   886  
   887  				Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
   888  					roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
   889  
   890  					It("cancels a request while waiting for the handshake to complete", func() {
   891  						ctx, cancel := context.WithCancel(context.Background())
   892  						req := req.WithContext(ctx)
   893  						conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   894  
   895  						errChan := make(chan error)
   896  						go func() {
   897  							_, err := cl.RoundTripOpt(req, roundTripOpt)
   898  							errChan <- err
   899  						}()
   900  						Consistently(errChan).ShouldNot(Receive())
   901  						cancel()
   902  						Eventually(errChan).Should(Receive(MatchError("context canceled")))
   903  					})
   904  
   905  					It("cancels a request while the request is still in flight", func() {
   906  						ctx, cancel := context.WithCancel(context.Background())
   907  						req := req.WithContext(ctx)
   908  						conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   909  						conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   910  						buf := &bytes.Buffer{}
   911  						str.EXPECT().Close().MaxTimes(1)
   912  
   913  						str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   914  
   915  						done := make(chan struct{})
   916  						canceled := make(chan struct{})
   917  						gomock.InOrder(
   918  							str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
   919  							str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
   920  						)
   921  						str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
   922  						str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   923  							cancel()
   924  							<-canceled
   925  							return 0, errors.New("test done")
   926  						})
   927  						_, err := cl.RoundTripOpt(req, roundTripOpt)
   928  						Expect(err).To(MatchError("test done"))
   929  						Eventually(done).Should(BeClosed())
   930  					})
   931  				})
   932  			}
   933  
   934  			It("cancels a request after the response arrived", func() {
   935  				rspBuf := bytes.NewBuffer(getResponse(404))
   936  
   937  				ctx, cancel := context.WithCancel(context.Background())
   938  				req := req.WithContext(ctx)
   939  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   940  				conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   941  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   942  				buf := &bytes.Buffer{}
   943  				str.EXPECT().Close().MaxTimes(1)
   944  
   945  				done := make(chan struct{})
   946  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   947  				str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   948  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   949  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
   950  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   951  				Expect(err).ToNot(HaveOccurred())
   952  				cancel()
   953  				Eventually(done).Should(BeClosed())
   954  			})
   955  		})
   956  
   957  		Context("gzip compression", func() {
   958  			BeforeEach(func() {
   959  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   960  			})
   961  
   962  			It("adds the gzip header to requests", func() {
   963  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   964  				buf := &bytes.Buffer{}
   965  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   966  				gomock.InOrder(
   967  					str.EXPECT().Close(),
   968  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   969  				)
   970  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   971  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   972  				Expect(err).To(MatchError("test done"))
   973  				hfs := decodeHeader(buf)
   974  				Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
   975  			})
   976  
   977  			It("doesn't add gzip if the header disable it", func() {
   978  				client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
   979  				Expect(err).ToNot(HaveOccurred())
   980  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   981  				buf := &bytes.Buffer{}
   982  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   983  				gomock.InOrder(
   984  					str.EXPECT().Close(),
   985  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   986  				)
   987  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   988  				_, err = client.RoundTripOpt(req, RoundTripOpt{})
   989  				Expect(err).To(MatchError("test done"))
   990  				hfs := decodeHeader(buf)
   991  				Expect(hfs).ToNot(HaveKey("accept-encoding"))
   992  			})
   993  
   994  			It("decompresses the response", func() {
   995  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   996  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   997  				buf := &bytes.Buffer{}
   998  				rstr := mockquic.NewMockStream(mockCtrl)
   999  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1000  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1001  				rw.Header().Set("Content-Encoding", "gzip")
  1002  				gz := gzip.NewWriter(rw)
  1003  				gz.Write([]byte("gzipped response"))
  1004  				gz.Close()
  1005  				rw.Flush()
  1006  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1007  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1008  				str.EXPECT().Close()
  1009  
  1010  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1011  				Expect(err).ToNot(HaveOccurred())
  1012  				data, err := io.ReadAll(rsp.Body)
  1013  				Expect(err).ToNot(HaveOccurred())
  1014  				Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
  1015  				Expect(string(data)).To(Equal("gzipped response"))
  1016  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1017  				Expect(rsp.Uncompressed).To(BeTrue())
  1018  			})
  1019  
  1020  			It("only decompresses the response if the response contains the right content-encoding header", func() {
  1021  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1022  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1023  				buf := &bytes.Buffer{}
  1024  				rstr := mockquic.NewMockStream(mockCtrl)
  1025  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1026  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1027  				rw.Write([]byte("not gzipped"))
  1028  				rw.Flush()
  1029  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1030  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1031  				str.EXPECT().Close()
  1032  
  1033  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1034  				Expect(err).ToNot(HaveOccurred())
  1035  				data, err := io.ReadAll(rsp.Body)
  1036  				Expect(err).ToNot(HaveOccurred())
  1037  				Expect(string(data)).To(Equal("not gzipped"))
  1038  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1039  			})
  1040  		})
  1041  	})
  1042  })