github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/http3/client_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/TugasAkhir-QUIC/quic-go"
    16  	mockquic "github.com/TugasAkhir-QUIC/quic-go/internal/mocks/quic"
    17  	"github.com/TugasAkhir-QUIC/quic-go/internal/protocol"
    18  	"github.com/TugasAkhir-QUIC/quic-go/internal/utils"
    19  	"github.com/TugasAkhir-QUIC/quic-go/quicvarint"
    20  
    21  	"github.com/quic-go/qpack"
    22  
    23  	. "github.com/onsi/ginkgo/v2"
    24  	. "github.com/onsi/gomega"
    25  	"go.uber.org/mock/gomock"
    26  )
    27  
    28  var _ = Describe("Client", func() {
    29  	var (
    30  		cl            *client
    31  		req           *http.Request
    32  		origDialAddr  = dialAddr
    33  		handshakeChan <-chan struct{} // a closed chan
    34  	)
    35  
    36  	BeforeEach(func() {
    37  		origDialAddr = dialAddr
    38  		hostname := "quic.clemente.io:1337"
    39  		c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
    40  		Expect(err).ToNot(HaveOccurred())
    41  		cl = c.(*client)
    42  		Expect(cl.hostname).To(Equal(hostname))
    43  
    44  		req, err = http.NewRequest("GET", "https://localhost:1337", nil)
    45  		Expect(err).ToNot(HaveOccurred())
    46  
    47  		ch := make(chan struct{})
    48  		close(ch)
    49  		handshakeChan = ch
    50  	})
    51  
    52  	AfterEach(func() {
    53  		dialAddr = origDialAddr
    54  	})
    55  
    56  	It("rejects quic.Configs that allow multiple QUIC versions", func() {
    57  		qconf := &quic.Config{
    58  			Versions: []quic.Version{protocol.Version2, protocol.Version1},
    59  		}
    60  		_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
    61  		Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
    62  	})
    63  
    64  	It("uses the default QUIC and TLS config if none is give", func() {
    65  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
    66  		Expect(err).ToNot(HaveOccurred())
    67  		var dialAddrCalled bool
    68  		dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
    69  			Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
    70  			Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
    71  			Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1}))
    72  			dialAddrCalled = true
    73  			return nil, errors.New("test done")
    74  		}
    75  		client.RoundTripOpt(req, RoundTripOpt{})
    76  		Expect(dialAddrCalled).To(BeTrue())
    77  	})
    78  
    79  	It("adds the port to the hostname, if none is given", func() {
    80  		client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		var dialAddrCalled bool
    83  		dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
    84  			Expect(hostname).To(Equal("quic.clemente.io:443"))
    85  			dialAddrCalled = true
    86  			return nil, errors.New("test done")
    87  		}
    88  		req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
    89  		Expect(err).ToNot(HaveOccurred())
    90  		client.RoundTripOpt(req, RoundTripOpt{})
    91  		Expect(dialAddrCalled).To(BeTrue())
    92  	})
    93  
    94  	It("sets the ServerName in the tls.Config, if not set", func() {
    95  		const host = "foo.bar"
    96  		dialCalled := false
    97  		dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
    98  			Expect(tlsCfg.ServerName).To(Equal(host))
    99  			dialCalled = true
   100  			return nil, errors.New("test done")
   101  		}
   102  		client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
   103  		Expect(err).ToNot(HaveOccurred())
   104  		req, err := http.NewRequest("GET", "https://foo.bar", nil)
   105  		Expect(err).ToNot(HaveOccurred())
   106  		client.RoundTripOpt(req, RoundTripOpt{})
   107  		Expect(dialCalled).To(BeTrue())
   108  	})
   109  
   110  	It("uses the TLS config and QUIC config", func() {
   111  		tlsConf := &tls.Config{
   112  			ServerName: "foo.bar",
   113  			NextProtos: []string{"proto foo", "proto bar"},
   114  		}
   115  		quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
   116  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
   117  		Expect(err).ToNot(HaveOccurred())
   118  		var dialAddrCalled bool
   119  		dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   120  			Expect(host).To(Equal("localhost:1337"))
   121  			Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
   122  			Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
   123  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   124  			dialAddrCalled = true
   125  			return nil, errors.New("test done")
   126  		}
   127  		client.RoundTripOpt(req, RoundTripOpt{})
   128  		Expect(dialAddrCalled).To(BeTrue())
   129  		// make sure the original tls.Config was not modified
   130  		Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
   131  	})
   132  
   133  	It("uses the custom dialer, if provided", func() {
   134  		testErr := errors.New("test done")
   135  		tlsConf := &tls.Config{ServerName: "foo.bar"}
   136  		quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
   137  		ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
   138  		defer cancel()
   139  		var dialerCalled bool
   140  		dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   141  			Expect(ctxP).To(Equal(ctx))
   142  			Expect(address).To(Equal("localhost:1337"))
   143  			Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
   144  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   145  			dialerCalled = true
   146  			return nil, testErr
   147  		}
   148  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
   149  		Expect(err).ToNot(HaveOccurred())
   150  		_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
   151  		Expect(err).To(MatchError(testErr))
   152  		Expect(dialerCalled).To(BeTrue())
   153  	})
   154  
   155  	It("enables HTTP/3 Datagrams", func() {
   156  		testErr := errors.New("handshake error")
   157  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
   158  		Expect(err).ToNot(HaveOccurred())
   159  		dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   160  			Expect(quicConf.EnableDatagrams).To(BeTrue())
   161  			return nil, testErr
   162  		}
   163  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   164  		Expect(err).To(MatchError(testErr))
   165  	})
   166  
   167  	It("errors when dialing fails", func() {
   168  		testErr := errors.New("handshake error")
   169  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   170  		Expect(err).ToNot(HaveOccurred())
   171  		dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   172  			return nil, testErr
   173  		}
   174  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   175  		Expect(err).To(MatchError(testErr))
   176  	})
   177  
   178  	It("closes correctly if connection was not created", func() {
   179  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   180  		Expect(err).ToNot(HaveOccurred())
   181  		Expect(client.Close()).To(Succeed())
   182  	})
   183  
   184  	Context("validating the address", func() {
   185  		It("refuses to do requests for the wrong host", func() {
   186  			req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
   187  			Expect(err).ToNot(HaveOccurred())
   188  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   189  			Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
   190  		})
   191  
   192  		It("allows requests using a different scheme", func() {
   193  			testErr := errors.New("handshake error")
   194  			req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
   195  			Expect(err).ToNot(HaveOccurred())
   196  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   197  				return nil, testErr
   198  			}
   199  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   200  			Expect(err).To(MatchError(testErr))
   201  		})
   202  	})
   203  
   204  	Context("hijacking bidirectional streams", func() {
   205  		var (
   206  			request              *http.Request
   207  			conn                 *mockquic.MockEarlyConnection
   208  			settingsFrameWritten chan struct{}
   209  		)
   210  		testDone := make(chan struct{})
   211  
   212  		BeforeEach(func() {
   213  			testDone = make(chan struct{})
   214  			settingsFrameWritten = make(chan struct{})
   215  			controlStr := mockquic.NewMockStream(mockCtrl)
   216  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   217  				defer GinkgoRecover()
   218  				close(settingsFrameWritten)
   219  				return len(b), nil
   220  			})
   221  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   222  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   223  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   224  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   225  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
   226  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   227  				return conn, nil
   228  			}
   229  			var err error
   230  			request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   231  			Expect(err).ToNot(HaveOccurred())
   232  		})
   233  
   234  		AfterEach(func() {
   235  			testDone <- struct{}{}
   236  			Eventually(settingsFrameWritten).Should(BeClosed())
   237  		})
   238  
   239  		It("hijacks a bidirectional stream of unknown frame type", func() {
   240  			frameTypeChan := make(chan FrameType, 1)
   241  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   242  				Expect(e).ToNot(HaveOccurred())
   243  				frameTypeChan <- ft
   244  				return true, nil
   245  			}
   246  
   247  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   248  			unknownStr := mockquic.NewMockStream(mockCtrl)
   249  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   250  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   251  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   252  				<-testDone
   253  				return nil, errors.New("test done")
   254  			})
   255  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   256  			Expect(err).To(MatchError("done"))
   257  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   258  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   259  		})
   260  
   261  		It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
   262  			frameTypeChan := make(chan FrameType, 1)
   263  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   264  				Expect(e).ToNot(HaveOccurred())
   265  				frameTypeChan <- ft
   266  				return false, nil
   267  			}
   268  
   269  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   270  			unknownStr := mockquic.NewMockStream(mockCtrl)
   271  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   272  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   273  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   274  				<-testDone
   275  				return nil, errors.New("test done")
   276  			})
   277  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   278  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   279  			Expect(err).To(MatchError("done"))
   280  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   281  		})
   282  
   283  		It("closes the connection when hijacker returned error", func() {
   284  			frameTypeChan := make(chan FrameType, 1)
   285  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   286  				Expect(e).ToNot(HaveOccurred())
   287  				frameTypeChan <- ft
   288  				return false, errors.New("error in hijacker")
   289  			}
   290  
   291  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   292  			unknownStr := mockquic.NewMockStream(mockCtrl)
   293  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   294  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   295  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   296  				<-testDone
   297  				return nil, errors.New("test done")
   298  			})
   299  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   300  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   301  			Expect(err).To(MatchError("done"))
   302  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   303  		})
   304  
   305  		It("handles errors that occur when reading the frame type", func() {
   306  			testErr := errors.New("test error")
   307  			unknownStr := mockquic.NewMockStream(mockCtrl)
   308  			done := make(chan struct{})
   309  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
   310  				defer close(done)
   311  				Expect(e).To(MatchError(testErr))
   312  				Expect(ft).To(BeZero())
   313  				Expect(str).To(Equal(unknownStr))
   314  				return false, nil
   315  			}
   316  
   317  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   318  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   319  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   320  				<-testDone
   321  				return nil, errors.New("test done")
   322  			})
   323  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   324  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   325  			Expect(err).To(MatchError("done"))
   326  			Eventually(done).Should(BeClosed())
   327  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   328  		})
   329  	})
   330  
   331  	Context("hijacking unidirectional streams", func() {
   332  		var (
   333  			req                  *http.Request
   334  			conn                 *mockquic.MockEarlyConnection
   335  			settingsFrameWritten chan struct{}
   336  		)
   337  		testDone := make(chan struct{})
   338  
   339  		BeforeEach(func() {
   340  			testDone = make(chan struct{})
   341  			settingsFrameWritten = make(chan struct{})
   342  			controlStr := mockquic.NewMockStream(mockCtrl)
   343  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   344  				defer GinkgoRecover()
   345  				close(settingsFrameWritten)
   346  				return len(b), nil
   347  			})
   348  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   349  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   350  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   351  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   352  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   353  				return conn, nil
   354  			}
   355  			var err error
   356  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   357  			Expect(err).ToNot(HaveOccurred())
   358  		})
   359  
   360  		AfterEach(func() {
   361  			testDone <- struct{}{}
   362  			Eventually(settingsFrameWritten).Should(BeClosed())
   363  		})
   364  
   365  		It("hijacks an unidirectional stream of unknown stream type", func() {
   366  			streamTypeChan := make(chan StreamType, 1)
   367  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   368  				Expect(err).ToNot(HaveOccurred())
   369  				streamTypeChan <- st
   370  				return true
   371  			}
   372  
   373  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   374  			unknownStr := mockquic.NewMockStream(mockCtrl)
   375  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   376  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   377  				return unknownStr, nil
   378  			})
   379  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   380  				<-testDone
   381  				return nil, errors.New("test done")
   382  			})
   383  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   384  			Expect(err).To(MatchError("done"))
   385  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   386  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   387  		})
   388  
   389  		It("handles errors that occur when reading the stream type", func() {
   390  			testErr := errors.New("test error")
   391  			done := make(chan struct{})
   392  			unknownStr := mockquic.NewMockStream(mockCtrl)
   393  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   394  				defer close(done)
   395  				Expect(st).To(BeZero())
   396  				Expect(str).To(Equal(unknownStr))
   397  				Expect(err).To(MatchError(testErr))
   398  				return true
   399  			}
   400  
   401  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
   402  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   403  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   404  				<-testDone
   405  				return nil, errors.New("test done")
   406  			})
   407  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   408  			Expect(err).To(MatchError("done"))
   409  			Eventually(done).Should(BeClosed())
   410  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   411  		})
   412  
   413  		It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   414  			streamTypeChan := make(chan StreamType, 1)
   415  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   416  				Expect(err).ToNot(HaveOccurred())
   417  				streamTypeChan <- st
   418  				return false
   419  			}
   420  
   421  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   422  			unknownStr := mockquic.NewMockStream(mockCtrl)
   423  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   424  			unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   425  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   426  				return unknownStr, nil
   427  			})
   428  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   429  				<-testDone
   430  				return nil, errors.New("test done")
   431  			})
   432  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   433  			Expect(err).To(MatchError("done"))
   434  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   435  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   436  		})
   437  	})
   438  
   439  	Context("control stream handling", func() {
   440  		var (
   441  			req                  *http.Request
   442  			conn                 *mockquic.MockEarlyConnection
   443  			settingsFrameWritten chan struct{}
   444  		)
   445  		testDone := make(chan struct{})
   446  
   447  		BeforeEach(func() {
   448  			settingsFrameWritten = make(chan struct{})
   449  			controlStr := mockquic.NewMockStream(mockCtrl)
   450  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   451  				defer GinkgoRecover()
   452  				close(settingsFrameWritten)
   453  				return len(b), nil
   454  			})
   455  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   456  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   457  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   458  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   459  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   460  				return conn, nil
   461  			}
   462  			var err error
   463  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   464  			Expect(err).ToNot(HaveOccurred())
   465  		})
   466  
   467  		AfterEach(func() {
   468  			testDone <- struct{}{}
   469  			Eventually(settingsFrameWritten).Should(BeClosed())
   470  		})
   471  
   472  		It("parses the SETTINGS frame", func() {
   473  			b := quicvarint.Append(nil, streamTypeControlStream)
   474  			b = (&settingsFrame{}).Append(b)
   475  			r := bytes.NewReader(b)
   476  			controlStr := mockquic.NewMockStream(mockCtrl)
   477  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   478  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   479  				return controlStr, nil
   480  			})
   481  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   482  				<-testDone
   483  				return nil, errors.New("test done")
   484  			})
   485  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   486  			Expect(err).To(MatchError("done"))
   487  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   488  		})
   489  
   490  		for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   491  			streamType := t
   492  			name := "encoder"
   493  			if streamType == streamTypeQPACKDecoderStream {
   494  				name = "decoder"
   495  			}
   496  
   497  			It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   498  				buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   499  				str := mockquic.NewMockStream(mockCtrl)
   500  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   501  
   502  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   503  					return str, nil
   504  				})
   505  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   506  					<-testDone
   507  					return nil, errors.New("test done")
   508  				})
   509  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   510  				Expect(err).To(MatchError("done"))
   511  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   512  			})
   513  		}
   514  
   515  		It("resets streams Other than the control stream and the QPACK streams", func() {
   516  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
   517  			str := mockquic.NewMockStream(mockCtrl)
   518  			str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   519  			done := make(chan struct{})
   520  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
   521  
   522  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   523  				return str, nil
   524  			})
   525  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   526  				<-testDone
   527  				return nil, errors.New("test done")
   528  			})
   529  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   530  			Expect(err).To(MatchError("done"))
   531  			Eventually(done).Should(BeClosed())
   532  		})
   533  
   534  		It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   535  			b := quicvarint.Append(nil, streamTypeControlStream)
   536  			b = (&dataFrame{}).Append(b)
   537  			r := bytes.NewReader(b)
   538  			controlStr := mockquic.NewMockStream(mockCtrl)
   539  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   540  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   541  				return controlStr, nil
   542  			})
   543  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   544  				<-testDone
   545  				return nil, errors.New("test done")
   546  			})
   547  			done := make(chan struct{})
   548  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   549  				close(done)
   550  				return nil
   551  			})
   552  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   553  			Expect(err).To(MatchError("done"))
   554  			Eventually(done).Should(BeClosed())
   555  		})
   556  
   557  		It("errors when parsing the frame on the control stream fails", func() {
   558  			b := quicvarint.Append(nil, streamTypeControlStream)
   559  			b = (&settingsFrame{}).Append(b)
   560  			r := bytes.NewReader(b[:len(b)-1])
   561  			controlStr := mockquic.NewMockStream(mockCtrl)
   562  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   563  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   564  				return controlStr, nil
   565  			})
   566  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   567  				<-testDone
   568  				return nil, errors.New("test done")
   569  			})
   570  			done := make(chan struct{})
   571  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error {
   572  				close(done)
   573  				return nil
   574  			})
   575  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   576  			Expect(err).To(MatchError("done"))
   577  			Eventually(done).Should(BeClosed())
   578  		})
   579  
   580  		It("errors when parsing the server opens a push stream", func() {
   581  			buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
   582  			controlStr := mockquic.NewMockStream(mockCtrl)
   583  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   584  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   585  				return controlStr, nil
   586  			})
   587  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   588  				<-testDone
   589  				return nil, errors.New("test done")
   590  			})
   591  			done := make(chan struct{})
   592  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   593  				close(done)
   594  				return nil
   595  			})
   596  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   597  			Expect(err).To(MatchError("done"))
   598  			Eventually(done).Should(BeClosed())
   599  		})
   600  
   601  		It("errors when the server advertises datagram support (and we enabled support for it)", func() {
   602  			cl.opts.EnableDatagram = true
   603  			b := quicvarint.Append(nil, streamTypeControlStream)
   604  			b = (&settingsFrame{Datagram: true}).Append(b)
   605  			r := bytes.NewReader(b)
   606  			controlStr := mockquic.NewMockStream(mockCtrl)
   607  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   608  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   609  				return controlStr, nil
   610  			})
   611  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   612  				<-testDone
   613  				return nil, errors.New("test done")
   614  			})
   615  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   616  			done := make(chan struct{})
   617  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
   618  				close(done)
   619  				return nil
   620  			})
   621  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   622  			Expect(err).To(MatchError("done"))
   623  			Eventually(done).Should(BeClosed())
   624  		})
   625  	})
   626  
   627  	Context("Doing requests", func() {
   628  		var (
   629  			req                  *http.Request
   630  			str                  *mockquic.MockStream
   631  			conn                 *mockquic.MockEarlyConnection
   632  			settingsFrameWritten chan struct{}
   633  		)
   634  		testDone := make(chan struct{})
   635  
   636  		decodeHeader := func(str io.Reader) map[string]string {
   637  			fields := make(map[string]string)
   638  			decoder := qpack.NewDecoder(nil)
   639  
   640  			frame, err := parseNextFrame(str, nil)
   641  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   642  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   643  			headersFrame := frame.(*headersFrame)
   644  			data := make([]byte, headersFrame.Length)
   645  			_, err = io.ReadFull(str, data)
   646  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   647  			hfs, err := decoder.DecodeFull(data)
   648  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   649  			for _, p := range hfs {
   650  				fields[p.Name] = p.Value
   651  			}
   652  			return fields
   653  		}
   654  
   655  		getResponse := func(status int) []byte {
   656  			buf := &bytes.Buffer{}
   657  			rstr := mockquic.NewMockStream(mockCtrl)
   658  			rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   659  			rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   660  			rw.WriteHeader(status)
   661  			rw.Flush()
   662  			return buf.Bytes()
   663  		}
   664  
   665  		BeforeEach(func() {
   666  			settingsFrameWritten = make(chan struct{})
   667  			controlStr := mockquic.NewMockStream(mockCtrl)
   668  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   669  				defer GinkgoRecover()
   670  				r := bytes.NewReader(b)
   671  				streamType, err := quicvarint.Read(r)
   672  				Expect(err).ToNot(HaveOccurred())
   673  				Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
   674  				close(settingsFrameWritten)
   675  				return len(b), nil
   676  			}) // SETTINGS frame
   677  			str = mockquic.NewMockStream(mockCtrl)
   678  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   679  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   680  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   681  				<-testDone
   682  				return nil, errors.New("test done")
   683  			})
   684  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   685  				return conn, nil
   686  			}
   687  			var err error
   688  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   689  			Expect(err).ToNot(HaveOccurred())
   690  		})
   691  
   692  		AfterEach(func() {
   693  			testDone <- struct{}{}
   694  			Eventually(settingsFrameWritten).Should(BeClosed())
   695  		})
   696  
   697  		It("errors if it can't open a stream", func() {
   698  			testErr := errors.New("stream open error")
   699  			conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
   700  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
   701  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   702  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   703  			Expect(err).To(MatchError(testErr))
   704  		})
   705  
   706  		It("performs a 0-RTT request", func() {
   707  			testErr := errors.New("stream open error")
   708  			req.Method = MethodGet0RTT
   709  			// don't EXPECT any calls to HandshakeComplete()
   710  			conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   711  			buf := &bytes.Buffer{}
   712  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   713  			str.EXPECT().Close()
   714  			str.EXPECT().CancelWrite(gomock.Any())
   715  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   716  				return 0, testErr
   717  			})
   718  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   719  			Expect(err).To(MatchError(testErr))
   720  			Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
   721  		})
   722  
   723  		It("returns a response", func() {
   724  			rspBuf := bytes.NewBuffer(getResponse(418))
   725  			gomock.InOrder(
   726  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   727  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   728  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   729  			)
   730  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   731  			str.EXPECT().Close()
   732  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   733  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
   734  			Expect(err).ToNot(HaveOccurred())
   735  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   736  			Expect(rsp.ProtoMajor).To(Equal(3))
   737  			Expect(rsp.StatusCode).To(Equal(418))
   738  			Expect(rsp.Request).ToNot(BeNil())
   739  		})
   740  
   741  		It("doesn't close the request stream, with DontCloseRequestStream set", func() {
   742  			rspBuf := bytes.NewBuffer(getResponse(418))
   743  			gomock.InOrder(
   744  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   745  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   746  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   747  			)
   748  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   749  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   750  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
   751  			Expect(err).ToNot(HaveOccurred())
   752  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   753  			Expect(rsp.ProtoMajor).To(Equal(3))
   754  			Expect(rsp.StatusCode).To(Equal(418))
   755  		})
   756  
   757  		Context("requests containing a Body", func() {
   758  			var strBuf *bytes.Buffer
   759  
   760  			BeforeEach(func() {
   761  				strBuf = &bytes.Buffer{}
   762  				gomock.InOrder(
   763  					conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   764  					conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   765  				)
   766  				body := &mockBody{}
   767  				body.SetData([]byte("request body"))
   768  				var err error
   769  				req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
   770  				Expect(err).ToNot(HaveOccurred())
   771  				str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
   772  			})
   773  
   774  			It("sends a request", func() {
   775  				done := make(chan struct{})
   776  				gomock.InOrder(
   777  					str.EXPECT().Close().Do(func() error { close(done); return nil }),
   778  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
   779  				)
   780  				// the response body is sent asynchronously, while already reading the response
   781  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   782  					<-done
   783  					return 0, errors.New("test done")
   784  				})
   785  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   786  				Expect(err).To(MatchError("test done"))
   787  				hfs := decodeHeader(strBuf)
   788  				Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
   789  				Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
   790  			})
   791  
   792  			It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
   793  				req.ContentLength = 7
   794  				var once sync.Once
   795  				done := make(chan struct{})
   796  				gomock.InOrder(
   797  					str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
   798  						once.Do(func() {
   799  							Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
   800  							close(done)
   801  						})
   802  					}).AnyTimes(),
   803  					str.EXPECT().Close().MaxTimes(1),
   804  					str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
   805  				)
   806  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   807  					<-done
   808  					return 0, errors.New("done")
   809  				})
   810  				cl.RoundTripOpt(req, RoundTripOpt{})
   811  				Expect(strBuf.String()).To(ContainSubstring("request"))
   812  				Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
   813  			})
   814  
   815  			It("returns the error that occurred when reading the body", func() {
   816  				req.Body.(*mockBody).readErr = errors.New("testErr")
   817  				done := make(chan struct{})
   818  				gomock.InOrder(
   819  					str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
   820  						close(done)
   821  					}),
   822  					str.EXPECT().CancelWrite(gomock.Any()),
   823  				)
   824  
   825  				// the response body is sent asynchronously, while already reading the response
   826  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   827  					<-done
   828  					return 0, errors.New("test done")
   829  				})
   830  				closed := make(chan struct{})
   831  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   832  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   833  				Expect(err).To(MatchError("test done"))
   834  				Eventually(closed).Should(BeClosed())
   835  			})
   836  
   837  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   838  				b := (&dataFrame{Length: 0x42}).Append(nil)
   839  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
   840  				closed := make(chan struct{})
   841  				r := bytes.NewReader(b)
   842  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   843  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   844  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   845  				Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
   846  				Eventually(closed).Should(BeClosed())
   847  			})
   848  
   849  			It("cancels the stream when parsing the headers fails", func() {
   850  				headerBuf := &bytes.Buffer{}
   851  				enc := qpack.NewEncoder(headerBuf)
   852  				Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header
   853  				Expect(enc.Close()).To(Succeed())
   854  				b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
   855  				b = append(b, headerBuf.Bytes()...)
   856  
   857  				r := bytes.NewReader(b)
   858  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   859  				closed := make(chan struct{})
   860  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   861  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   862  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   863  				Expect(err).To(HaveOccurred())
   864  				Eventually(closed).Should(BeClosed())
   865  			})
   866  
   867  			It("cancels the stream when the HEADERS frame is too large", func() {
   868  				b := (&headersFrame{Length: 1338}).Append(nil)
   869  				r := bytes.NewReader(b)
   870  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
   871  				closed := make(chan struct{})
   872  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   873  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   874  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   875  				Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
   876  				Eventually(closed).Should(BeClosed())
   877  			})
   878  		})
   879  
   880  		Context("request cancellations", func() {
   881  			for _, dontClose := range []bool{false, true} {
   882  				dontClose := dontClose
   883  
   884  				Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
   885  					roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
   886  
   887  					It("cancels a request while waiting for the handshake to complete", func() {
   888  						ctx, cancel := context.WithCancel(context.Background())
   889  						req := req.WithContext(ctx)
   890  						conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   891  
   892  						errChan := make(chan error)
   893  						go func() {
   894  							_, err := cl.RoundTripOpt(req, roundTripOpt)
   895  							errChan <- err
   896  						}()
   897  						Consistently(errChan).ShouldNot(Receive())
   898  						cancel()
   899  						Eventually(errChan).Should(Receive(MatchError("context canceled")))
   900  					})
   901  
   902  					It("cancels a request while the request is still in flight", func() {
   903  						ctx, cancel := context.WithCancel(context.Background())
   904  						req := req.WithContext(ctx)
   905  						conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   906  						conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   907  						buf := &bytes.Buffer{}
   908  						str.EXPECT().Close().MaxTimes(1)
   909  
   910  						str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   911  
   912  						done := make(chan struct{})
   913  						canceled := make(chan struct{})
   914  						gomock.InOrder(
   915  							str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
   916  							str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
   917  						)
   918  						str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
   919  						str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   920  							cancel()
   921  							<-canceled
   922  							return 0, errors.New("test done")
   923  						})
   924  						_, err := cl.RoundTripOpt(req, roundTripOpt)
   925  						Expect(err).To(MatchError(context.Canceled))
   926  						Eventually(done).Should(BeClosed())
   927  					})
   928  				})
   929  			}
   930  
   931  			It("cancels a request after the response arrived", func() {
   932  				rspBuf := bytes.NewBuffer(getResponse(404))
   933  
   934  				ctx, cancel := context.WithCancel(context.Background())
   935  				req := req.WithContext(ctx)
   936  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   937  				conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
   938  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   939  				buf := &bytes.Buffer{}
   940  				str.EXPECT().Close().MaxTimes(1)
   941  
   942  				done := make(chan struct{})
   943  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   944  				str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   945  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   946  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
   947  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   948  				Expect(err).ToNot(HaveOccurred())
   949  				cancel()
   950  				Eventually(done).Should(BeClosed())
   951  			})
   952  		})
   953  
   954  		Context("gzip compression", func() {
   955  			BeforeEach(func() {
   956  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   957  			})
   958  
   959  			It("adds the gzip header to requests", func() {
   960  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   961  				buf := &bytes.Buffer{}
   962  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   963  				gomock.InOrder(
   964  					str.EXPECT().Close(),
   965  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   966  				)
   967  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   968  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   969  				Expect(err).To(MatchError("test done"))
   970  				hfs := decodeHeader(buf)
   971  				Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
   972  			})
   973  
   974  			It("doesn't add gzip if the header disable it", func() {
   975  				client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
   976  				Expect(err).ToNot(HaveOccurred())
   977  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   978  				buf := &bytes.Buffer{}
   979  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
   980  				gomock.InOrder(
   981  					str.EXPECT().Close(),
   982  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
   983  				)
   984  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
   985  				_, err = client.RoundTripOpt(req, RoundTripOpt{})
   986  				Expect(err).To(MatchError("test done"))
   987  				hfs := decodeHeader(buf)
   988  				Expect(hfs).ToNot(HaveKey("accept-encoding"))
   989  			})
   990  
   991  			It("decompresses the response", func() {
   992  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   993  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
   994  				buf := &bytes.Buffer{}
   995  				rstr := mockquic.NewMockStream(mockCtrl)
   996  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   997  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   998  				rw.Header().Set("Content-Encoding", "gzip")
   999  				gz := gzip.NewWriter(rw)
  1000  				gz.Write([]byte("gzipped response"))
  1001  				gz.Close()
  1002  				rw.Flush()
  1003  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1004  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1005  				str.EXPECT().Close()
  1006  
  1007  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1008  				Expect(err).ToNot(HaveOccurred())
  1009  				data, err := io.ReadAll(rsp.Body)
  1010  				Expect(err).ToNot(HaveOccurred())
  1011  				Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
  1012  				Expect(string(data)).To(Equal("gzipped response"))
  1013  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1014  				Expect(rsp.Uncompressed).To(BeTrue())
  1015  			})
  1016  
  1017  			It("only decompresses the response if the response contains the right content-encoding header", func() {
  1018  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1019  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1020  				buf := &bytes.Buffer{}
  1021  				rstr := mockquic.NewMockStream(mockCtrl)
  1022  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1023  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1024  				rw.Write([]byte("not gzipped"))
  1025  				rw.Flush()
  1026  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1027  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1028  				str.EXPECT().Close()
  1029  
  1030  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1031  				Expect(err).ToNot(HaveOccurred())
  1032  				data, err := io.ReadAll(rsp.Body)
  1033  				Expect(err).ToNot(HaveOccurred())
  1034  				Expect(string(data)).To(Equal("not gzipped"))
  1035  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1036  			})
  1037  		})
  1038  	})
  1039  })