github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/client_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/daeuniverse/quic-go"
    16  	mockquic "github.com/daeuniverse/quic-go/internal/mocks/quic"
    17  	"github.com/daeuniverse/quic-go/internal/protocol"
    18  	"github.com/daeuniverse/quic-go/internal/qerr"
    19  	"github.com/daeuniverse/quic-go/internal/utils"
    20  	"github.com/daeuniverse/quic-go/quicvarint"
    21  
    22  	"github.com/quic-go/qpack"
    23  
    24  	. "github.com/onsi/ginkgo/v2"
    25  	. "github.com/onsi/gomega"
    26  	"go.uber.org/mock/gomock"
    27  )
    28  
    29  var _ = Describe("Client", func() {
    30  	var (
    31  		cl            *client
    32  		req           *http.Request
    33  		origDialAddr  = dialAddr
    34  		handshakeChan <-chan struct{} // a closed chan
    35  	)
    36  
    37  	BeforeEach(func() {
    38  		origDialAddr = dialAddr
    39  		hostname := "quic.clemente.io:1337"
    40  		c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
    41  		Expect(err).ToNot(HaveOccurred())
    42  		cl = c.(*client)
    43  		Expect(cl.hostname).To(Equal(hostname))
    44  
    45  		req, err = http.NewRequest("GET", "https://localhost:1337", nil)
    46  		Expect(err).ToNot(HaveOccurred())
    47  
    48  		ch := make(chan struct{})
    49  		close(ch)
    50  		handshakeChan = ch
    51  	})
    52  
    53  	AfterEach(func() {
    54  		dialAddr = origDialAddr
    55  	})
    56  
    57  	It("rejects quic.Configs that allow multiple QUIC versions", func() {
    58  		qconf := &quic.Config{
    59  			Versions: []quic.Version{protocol.Version2, protocol.Version1},
    60  		}
    61  		_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
    62  		Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
    63  	})
    64  
    65  	It("uses the default QUIC and TLS config if none is give", func() {
    66  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
    67  		Expect(err).ToNot(HaveOccurred())
    68  		var dialAddrCalled bool
    69  		dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
    70  			Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
    71  			Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
    72  			Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1}))
    73  			dialAddrCalled = true
    74  			return nil, errors.New("test done")
    75  		}
    76  		client.RoundTripOpt(req, RoundTripOpt{})
    77  		Expect(dialAddrCalled).To(BeTrue())
    78  	})
    79  
    80  	It("adds the port to the hostname, if none is given", func() {
    81  		client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
    82  		Expect(err).ToNot(HaveOccurred())
    83  		var dialAddrCalled bool
    84  		dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
    85  			Expect(hostname).To(Equal("quic.clemente.io:443"))
    86  			dialAddrCalled = true
    87  			return nil, errors.New("test done")
    88  		}
    89  		req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
    90  		Expect(err).ToNot(HaveOccurred())
    91  		client.RoundTripOpt(req, RoundTripOpt{})
    92  		Expect(dialAddrCalled).To(BeTrue())
    93  	})
    94  
    95  	It("sets the ServerName in the tls.Config, if not set", func() {
    96  		const host = "foo.bar"
    97  		dialCalled := false
    98  		dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
    99  			Expect(tlsCfg.ServerName).To(Equal(host))
   100  			dialCalled = true
   101  			return nil, errors.New("test done")
   102  		}
   103  		client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
   104  		Expect(err).ToNot(HaveOccurred())
   105  		req, err := http.NewRequest("GET", "https://foo.bar", nil)
   106  		Expect(err).ToNot(HaveOccurred())
   107  		client.RoundTripOpt(req, RoundTripOpt{})
   108  		Expect(dialCalled).To(BeTrue())
   109  	})
   110  
   111  	It("uses the TLS config and QUIC config", func() {
   112  		tlsConf := &tls.Config{
   113  			ServerName: "foo.bar",
   114  			NextProtos: []string{"proto foo", "proto bar"},
   115  		}
   116  		quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
   117  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
   118  		Expect(err).ToNot(HaveOccurred())
   119  		var dialAddrCalled bool
   120  		dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   121  			Expect(host).To(Equal("localhost:1337"))
   122  			Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
   123  			Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
   124  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   125  			dialAddrCalled = true
   126  			return nil, errors.New("test done")
   127  		}
   128  		client.RoundTripOpt(req, RoundTripOpt{})
   129  		Expect(dialAddrCalled).To(BeTrue())
   130  		// make sure the original tls.Config was not modified
   131  		Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
   132  	})
   133  
   134  	It("uses the custom dialer, if provided", func() {
   135  		testErr := errors.New("test done")
   136  		tlsConf := &tls.Config{ServerName: "foo.bar"}
   137  		quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
   138  		ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
   139  		defer cancel()
   140  		var dialerCalled bool
   141  		dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   142  			Expect(ctxP).To(Equal(ctx))
   143  			Expect(address).To(Equal("localhost:1337"))
   144  			Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
   145  			Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   146  			dialerCalled = true
   147  			return nil, testErr
   148  		}
   149  		client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
   150  		Expect(err).ToNot(HaveOccurred())
   151  		_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
   152  		Expect(err).To(MatchError(testErr))
   153  		Expect(dialerCalled).To(BeTrue())
   154  	})
   155  
   156  	It("enables HTTP/3 Datagrams", func() {
   157  		testErr := errors.New("handshake error")
   158  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
   159  		Expect(err).ToNot(HaveOccurred())
   160  		dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   161  			Expect(quicConf.EnableDatagrams).To(BeTrue())
   162  			return nil, testErr
   163  		}
   164  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   165  		Expect(err).To(MatchError(testErr))
   166  	})
   167  
   168  	It("errors when dialing fails", func() {
   169  		testErr := errors.New("handshake error")
   170  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   171  		Expect(err).ToNot(HaveOccurred())
   172  		dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   173  			return nil, testErr
   174  		}
   175  		_, err = client.RoundTripOpt(req, RoundTripOpt{})
   176  		Expect(err).To(MatchError(testErr))
   177  	})
   178  
   179  	It("closes correctly if connection was not created", func() {
   180  		client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
   181  		Expect(err).ToNot(HaveOccurred())
   182  		Expect(client.Close()).To(Succeed())
   183  	})
   184  
   185  	Context("validating the address", func() {
   186  		It("refuses to do requests for the wrong host", func() {
   187  			req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
   188  			Expect(err).ToNot(HaveOccurred())
   189  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   190  			Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
   191  		})
   192  
   193  		It("allows requests using a different scheme", func() {
   194  			testErr := errors.New("handshake error")
   195  			req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
   196  			Expect(err).ToNot(HaveOccurred())
   197  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   198  				return nil, testErr
   199  			}
   200  			_, err = cl.RoundTripOpt(req, RoundTripOpt{})
   201  			Expect(err).To(MatchError(testErr))
   202  		})
   203  	})
   204  
   205  	Context("hijacking bidirectional streams", func() {
   206  		var (
   207  			request              *http.Request
   208  			conn                 *mockquic.MockEarlyConnection
   209  			settingsFrameWritten chan struct{}
   210  		)
   211  		testDone := make(chan struct{})
   212  
   213  		BeforeEach(func() {
   214  			testDone = make(chan struct{})
   215  			settingsFrameWritten = make(chan struct{})
   216  			controlStr := mockquic.NewMockStream(mockCtrl)
   217  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   218  				defer GinkgoRecover()
   219  				close(settingsFrameWritten)
   220  				return len(b), nil
   221  			})
   222  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   223  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   224  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   225  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   226  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
   227  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   228  				return conn, nil
   229  			}
   230  			var err error
   231  			request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   232  			Expect(err).ToNot(HaveOccurred())
   233  		})
   234  
   235  		AfterEach(func() {
   236  			testDone <- struct{}{}
   237  			Eventually(settingsFrameWritten).Should(BeClosed())
   238  		})
   239  
   240  		It("hijacks a bidirectional stream of unknown frame type", func() {
   241  			frameTypeChan := make(chan FrameType, 1)
   242  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   243  				Expect(e).ToNot(HaveOccurred())
   244  				frameTypeChan <- ft
   245  				return true, nil
   246  			}
   247  
   248  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   249  			unknownStr := mockquic.NewMockStream(mockCtrl)
   250  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   251  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   252  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   253  				<-testDone
   254  				return nil, errors.New("test done")
   255  			})
   256  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   257  			Expect(err).To(MatchError("done"))
   258  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   259  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   260  		})
   261  
   262  		It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
   263  			frameTypeChan := make(chan FrameType, 1)
   264  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   265  				Expect(e).ToNot(HaveOccurred())
   266  				frameTypeChan <- ft
   267  				return false, nil
   268  			}
   269  
   270  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   271  			unknownStr := mockquic.NewMockStream(mockCtrl)
   272  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   273  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   274  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   275  				<-testDone
   276  				return nil, errors.New("test done")
   277  			})
   278  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   279  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   280  			Expect(err).To(MatchError("done"))
   281  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   282  		})
   283  
   284  		It("closes the connection when hijacker returned error", func() {
   285  			frameTypeChan := make(chan FrameType, 1)
   286  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   287  				Expect(e).ToNot(HaveOccurred())
   288  				frameTypeChan <- ft
   289  				return false, errors.New("error in hijacker")
   290  			}
   291  
   292  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   293  			unknownStr := mockquic.NewMockStream(mockCtrl)
   294  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   295  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   296  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   297  				<-testDone
   298  				return nil, errors.New("test done")
   299  			})
   300  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   301  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   302  			Expect(err).To(MatchError("done"))
   303  			Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   304  		})
   305  
   306  		It("handles errors that occur when reading the frame type", func() {
   307  			testErr := errors.New("test error")
   308  			unknownStr := mockquic.NewMockStream(mockCtrl)
   309  			done := make(chan struct{})
   310  			cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
   311  				defer close(done)
   312  				Expect(e).To(MatchError(testErr))
   313  				Expect(ft).To(BeZero())
   314  				Expect(str).To(Equal(unknownStr))
   315  				return false, nil
   316  			}
   317  
   318  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   319  			conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   320  			conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
   321  				<-testDone
   322  				return nil, errors.New("test done")
   323  			})
   324  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
   325  			_, err := cl.RoundTripOpt(request, RoundTripOpt{})
   326  			Expect(err).To(MatchError("done"))
   327  			Eventually(done).Should(BeClosed())
   328  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   329  		})
   330  	})
   331  
   332  	Context("hijacking unidirectional streams", func() {
   333  		var (
   334  			req                  *http.Request
   335  			conn                 *mockquic.MockEarlyConnection
   336  			settingsFrameWritten chan struct{}
   337  		)
   338  		testDone := make(chan struct{})
   339  
   340  		BeforeEach(func() {
   341  			testDone = make(chan struct{})
   342  			settingsFrameWritten = make(chan struct{})
   343  			controlStr := mockquic.NewMockStream(mockCtrl)
   344  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   345  				defer GinkgoRecover()
   346  				close(settingsFrameWritten)
   347  				return len(b), nil
   348  			})
   349  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   350  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   351  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   352  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   353  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   354  				return conn, nil
   355  			}
   356  			var err error
   357  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   358  			Expect(err).ToNot(HaveOccurred())
   359  		})
   360  
   361  		AfterEach(func() {
   362  			testDone <- struct{}{}
   363  			Eventually(settingsFrameWritten).Should(BeClosed())
   364  		})
   365  
   366  		It("hijacks an unidirectional stream of unknown stream type", func() {
   367  			streamTypeChan := make(chan StreamType, 1)
   368  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   369  				Expect(err).ToNot(HaveOccurred())
   370  				streamTypeChan <- st
   371  				return true
   372  			}
   373  
   374  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   375  			unknownStr := mockquic.NewMockStream(mockCtrl)
   376  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   377  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   378  				return unknownStr, nil
   379  			})
   380  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   381  				<-testDone
   382  				return nil, errors.New("test done")
   383  			})
   384  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   385  			Expect(err).To(MatchError("done"))
   386  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   387  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   388  		})
   389  
   390  		It("handles errors that occur when reading the stream type", func() {
   391  			testErr := errors.New("test error")
   392  			done := make(chan struct{})
   393  			unknownStr := mockquic.NewMockStream(mockCtrl)
   394  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   395  				defer close(done)
   396  				Expect(st).To(BeZero())
   397  				Expect(str).To(Equal(unknownStr))
   398  				Expect(err).To(MatchError(testErr))
   399  				return true
   400  			}
   401  
   402  			unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
   403  			conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   404  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   405  				<-testDone
   406  				return nil, errors.New("test done")
   407  			})
   408  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   409  			Expect(err).To(MatchError("done"))
   410  			Eventually(done).Should(BeClosed())
   411  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   412  		})
   413  
   414  		It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   415  			streamTypeChan := make(chan StreamType, 1)
   416  			cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   417  				Expect(err).ToNot(HaveOccurred())
   418  				streamTypeChan <- st
   419  				return false
   420  			}
   421  
   422  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   423  			unknownStr := mockquic.NewMockStream(mockCtrl)
   424  			unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   425  			unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   426  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   427  				return unknownStr, nil
   428  			})
   429  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   430  				<-testDone
   431  				return nil, errors.New("test done")
   432  			})
   433  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   434  			Expect(err).To(MatchError("done"))
   435  			Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   436  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   437  		})
   438  	})
   439  
   440  	Context("control stream handling", func() {
   441  		var (
   442  			req                  *http.Request
   443  			conn                 *mockquic.MockEarlyConnection
   444  			settingsFrameWritten chan struct{}
   445  		)
   446  		testDone := make(chan struct{}, 1)
   447  
   448  		BeforeEach(func() {
   449  			settingsFrameWritten = make(chan struct{})
   450  			controlStr := mockquic.NewMockStream(mockCtrl)
   451  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   452  				defer GinkgoRecover()
   453  				close(settingsFrameWritten)
   454  				return len(b), nil
   455  			})
   456  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   457  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   458  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   459  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   460  				return conn, nil
   461  			}
   462  			var err error
   463  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   464  			Expect(err).ToNot(HaveOccurred())
   465  		})
   466  
   467  		AfterEach(func() {
   468  			testDone <- struct{}{}
   469  			Eventually(settingsFrameWritten).Should(BeClosed())
   470  		})
   471  
   472  		It("parses the SETTINGS frame", func() {
   473  			b := quicvarint.Append(nil, streamTypeControlStream)
   474  			b = (&settingsFrame{
   475  				Datagram:        true,
   476  				ExtendedConnect: true,
   477  				Other:           map[uint64]uint64{1337: 42},
   478  			}).Append(b)
   479  			r := bytes.NewReader(b)
   480  			controlStr := mockquic.NewMockStream(mockCtrl)
   481  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   482  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   483  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   484  				return controlStr, nil
   485  			})
   486  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   487  				<-testDone
   488  				return nil, errors.New("test done")
   489  			})
   490  			conn.EXPECT().Context().Return(context.Background())
   491  			_, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error {
   492  				defer GinkgoRecover()
   493  				Expect(settings.EnableDatagram).To(BeTrue())
   494  				Expect(settings.EnableExtendedConnect).To(BeTrue())
   495  				Expect(settings.Other).To(HaveLen(1))
   496  				Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42)))
   497  				return nil
   498  			}})
   499  			Expect(err).To(MatchError("done"))
   500  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   501  		})
   502  
   503  		It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() {
   504  			b := quicvarint.Append(nil, streamTypeControlStream)
   505  			b = (&settingsFrame{}).Append(b)
   506  			r := bytes.NewReader(b)
   507  			controlStr := mockquic.NewMockStream(mockCtrl)
   508  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   509  			// Don't EXPECT any call to OpenStreamSync.
   510  			// When the SETTINGS are rejected, we don't even open the request stream.
   511  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   512  				return controlStr, nil
   513  			})
   514  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   515  				<-testDone
   516  				return nil, errors.New("test done")
   517  			})
   518  			conn.EXPECT().Context().Return(context.Background())
   519  			_, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error {
   520  				return errors.New("wrong settings")
   521  			}})
   522  			Expect(err).To(MatchError("wrong settings"))
   523  			time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   524  		})
   525  
   526  		It("rejects duplicate control streams", func() {
   527  			b := quicvarint.Append(nil, streamTypeControlStream)
   528  			b = (&settingsFrame{}).Append(b)
   529  			r1 := bytes.NewReader(b)
   530  			controlStr1 := mockquic.NewMockStream(mockCtrl)
   531  			controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes()
   532  			r2 := bytes.NewReader(b)
   533  			controlStr2 := mockquic.NewMockStream(mockCtrl)
   534  			controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes()
   535  			done := make(chan struct{})
   536  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   537  			conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error {
   538  				close(done)
   539  				return nil
   540  			})
   541  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   542  				return controlStr1, nil
   543  			})
   544  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   545  				return controlStr2, nil
   546  			})
   547  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   548  				<-done
   549  				return nil, errors.New("test done")
   550  			})
   551  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   552  			Expect(err).To(HaveOccurred())
   553  			Eventually(done).Should(BeClosed())
   554  		})
   555  
   556  		for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   557  			streamType := t
   558  			name := "encoder"
   559  			if streamType == streamTypeQPACKDecoderStream {
   560  				name = "decoder"
   561  			}
   562  
   563  			It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   564  				buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   565  				str := mockquic.NewMockStream(mockCtrl)
   566  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   567  
   568  				conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   569  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   570  					return str, nil
   571  				})
   572  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   573  					<-testDone
   574  					return nil, errors.New("test done")
   575  				})
   576  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   577  				Expect(err).To(MatchError("done"))
   578  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   579  			})
   580  		}
   581  
   582  		It("resets streams other than the control stream and the QPACK streams", func() {
   583  			buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
   584  			str := mockquic.NewMockStream(mockCtrl)
   585  			str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   586  			done := make(chan struct{})
   587  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
   588  
   589  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   590  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   591  				return str, nil
   592  			})
   593  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   594  				<-testDone
   595  				return nil, errors.New("test done")
   596  			})
   597  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   598  			Expect(err).To(MatchError("done"))
   599  			Eventually(done).Should(BeClosed())
   600  		})
   601  
   602  		It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   603  			b := quicvarint.Append(nil, streamTypeControlStream)
   604  			b = (&dataFrame{}).Append(b)
   605  			r := bytes.NewReader(b)
   606  			controlStr := mockquic.NewMockStream(mockCtrl)
   607  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   608  
   609  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   610  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   611  				return controlStr, nil
   612  			})
   613  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   614  				<-testDone
   615  				return nil, errors.New("test done")
   616  			})
   617  			done := make(chan struct{})
   618  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   619  				close(done)
   620  				return nil
   621  			})
   622  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   623  			Expect(err).To(MatchError("done"))
   624  			Eventually(done).Should(BeClosed())
   625  		})
   626  
   627  		It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() {
   628  			b := quicvarint.Append(nil, streamTypeControlStream)
   629  			b = (&dataFrame{}).Append(b)
   630  			r := bytes.NewReader(b)
   631  			controlStr := mockquic.NewMockStream(mockCtrl)
   632  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   633  
   634  			// Don't EXPECT any calls to OpenStreamSync.
   635  			// We fail before we even get the chance to open the request stream.
   636  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   637  				return controlStr, nil
   638  			})
   639  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   640  				<-testDone
   641  				return nil, errors.New("test done")
   642  			})
   643  			doneCtx, doneCancel := context.WithCancelCause(context.Background())
   644  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   645  				doneCancel(errors.New("done"))
   646  				return nil
   647  			})
   648  			conn.EXPECT().Context().Return(doneCtx).Times(2)
   649  			var checked bool
   650  			_, err := cl.RoundTripOpt(req, RoundTripOpt{
   651  				CheckSettings: func(Settings) error { checked = true; return nil },
   652  			})
   653  			Expect(checked).To(BeFalse())
   654  			Expect(err).To(MatchError("done"))
   655  			Eventually(doneCtx.Done()).Should(BeClosed())
   656  		})
   657  
   658  		It("errors when parsing the frame on the control stream fails", func() {
   659  			b := quicvarint.Append(nil, streamTypeControlStream)
   660  			b = (&settingsFrame{}).Append(b)
   661  			r := bytes.NewReader(b[:len(b)-1])
   662  			controlStr := mockquic.NewMockStream(mockCtrl)
   663  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   664  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   665  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   666  				return controlStr, nil
   667  			})
   668  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   669  				<-testDone
   670  				return nil, errors.New("test done")
   671  			})
   672  			done := make(chan struct{})
   673  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error {
   674  				close(done)
   675  				return nil
   676  			})
   677  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   678  			Expect(err).To(MatchError("done"))
   679  			Eventually(done).Should(BeClosed())
   680  		})
   681  
   682  		It("errors when parsing the server opens a push stream", func() {
   683  			buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
   684  			controlStr := mockquic.NewMockStream(mockCtrl)
   685  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   686  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   687  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   688  				return controlStr, nil
   689  			})
   690  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   691  				<-testDone
   692  				return nil, errors.New("test done")
   693  			})
   694  			done := make(chan struct{})
   695  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   696  				close(done)
   697  				return nil
   698  			})
   699  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   700  			Expect(err).To(MatchError("done"))
   701  			Eventually(done).Should(BeClosed())
   702  		})
   703  
   704  		It("errors when the server advertises datagram support (and we enabled support for it)", func() {
   705  			cl.opts.EnableDatagram = true
   706  			b := quicvarint.Append(nil, streamTypeControlStream)
   707  			b = (&settingsFrame{Datagram: true}).Append(b)
   708  			r := bytes.NewReader(b)
   709  			controlStr := mockquic.NewMockStream(mockCtrl)
   710  			controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   711  			conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
   712  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   713  				return controlStr, nil
   714  			})
   715  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   716  				<-testDone
   717  				return nil, errors.New("test done")
   718  			})
   719  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   720  			done := make(chan struct{})
   721  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
   722  				close(done)
   723  				return nil
   724  			})
   725  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   726  			Expect(err).To(MatchError("done"))
   727  			Eventually(done).Should(BeClosed())
   728  		})
   729  	})
   730  
   731  	Context("Doing requests", func() {
   732  		var (
   733  			req                  *http.Request
   734  			str                  *mockquic.MockStream
   735  			conn                 *mockquic.MockEarlyConnection
   736  			settingsFrameWritten chan struct{}
   737  		)
   738  		testDone := make(chan struct{})
   739  
   740  		decodeHeader := func(str io.Reader) map[string]string {
   741  			fields := make(map[string]string)
   742  			decoder := qpack.NewDecoder(nil)
   743  
   744  			frame, err := parseNextFrame(str, nil)
   745  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   746  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   747  			headersFrame := frame.(*headersFrame)
   748  			data := make([]byte, headersFrame.Length)
   749  			_, err = io.ReadFull(str, data)
   750  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   751  			hfs, err := decoder.DecodeFull(data)
   752  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   753  			for _, p := range hfs {
   754  				fields[p.Name] = p.Value
   755  			}
   756  			return fields
   757  		}
   758  
   759  		getResponse := func(status int) []byte {
   760  			buf := &bytes.Buffer{}
   761  			rstr := mockquic.NewMockStream(mockCtrl)
   762  			rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
   763  			rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
   764  			rw.WriteHeader(status)
   765  			rw.Flush()
   766  			return buf.Bytes()
   767  		}
   768  
   769  		BeforeEach(func() {
   770  			settingsFrameWritten = make(chan struct{})
   771  			controlStr := mockquic.NewMockStream(mockCtrl)
   772  			controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
   773  				defer GinkgoRecover()
   774  				r := bytes.NewReader(b)
   775  				streamType, err := quicvarint.Read(r)
   776  				Expect(err).ToNot(HaveOccurred())
   777  				Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
   778  				close(settingsFrameWritten)
   779  				return len(b), nil
   780  			}) // SETTINGS frame
   781  			str = mockquic.NewMockStream(mockCtrl)
   782  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   783  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   784  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   785  				<-testDone
   786  				return nil, errors.New("test done")
   787  			})
   788  			dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   789  				return conn, nil
   790  			}
   791  			var err error
   792  			req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
   793  			Expect(err).ToNot(HaveOccurred())
   794  		})
   795  
   796  		AfterEach(func() {
   797  			testDone <- struct{}{}
   798  			Eventually(settingsFrameWritten).Should(BeClosed())
   799  		})
   800  
   801  		It("errors if it can't open a stream", func() {
   802  			testErr := errors.New("stream open error")
   803  			conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
   804  			conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
   805  			conn.EXPECT().HandshakeComplete().Return(handshakeChan)
   806  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   807  			Expect(err).To(MatchError(testErr))
   808  		})
   809  
   810  		It("performs a 0-RTT request", func() {
   811  			testErr := errors.New("stream open error")
   812  			req.Method = MethodGet0RTT
   813  			// don't EXPECT any calls to HandshakeComplete()
   814  			conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
   815  			buf := &bytes.Buffer{}
   816  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   817  			str.EXPECT().Close()
   818  			str.EXPECT().CancelWrite(gomock.Any())
   819  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   820  				return 0, testErr
   821  			})
   822  			_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   823  			Expect(err).To(MatchError(testErr))
   824  			Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
   825  		})
   826  
   827  		It("returns a response", func() {
   828  			rspBuf := bytes.NewBuffer(getResponse(418))
   829  			gomock.InOrder(
   830  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   831  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   832  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   833  			)
   834  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   835  			str.EXPECT().Close()
   836  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   837  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
   838  			Expect(err).ToNot(HaveOccurred())
   839  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   840  			Expect(rsp.ProtoMajor).To(Equal(3))
   841  			Expect(rsp.StatusCode).To(Equal(418))
   842  			Expect(rsp.Request).ToNot(BeNil())
   843  		})
   844  
   845  		It("doesn't close the request stream, with DontCloseRequestStream set", func() {
   846  			rspBuf := bytes.NewBuffer(getResponse(418))
   847  			gomock.InOrder(
   848  				conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   849  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   850  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
   851  			)
   852  			str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
   853  			str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
   854  			rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
   855  			Expect(err).ToNot(HaveOccurred())
   856  			Expect(rsp.Proto).To(Equal("HTTP/3.0"))
   857  			Expect(rsp.ProtoMajor).To(Equal(3))
   858  			Expect(rsp.StatusCode).To(Equal(418))
   859  		})
   860  
   861  		Context("requests containing a Body", func() {
   862  			var strBuf *bytes.Buffer
   863  
   864  			BeforeEach(func() {
   865  				strBuf = &bytes.Buffer{}
   866  				gomock.InOrder(
   867  					conn.EXPECT().HandshakeComplete().Return(handshakeChan),
   868  					conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
   869  				)
   870  				body := &mockBody{}
   871  				body.SetData([]byte("request body"))
   872  				var err error
   873  				req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
   874  				Expect(err).ToNot(HaveOccurred())
   875  				str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
   876  			})
   877  
   878  			It("sends a request", func() {
   879  				done := make(chan struct{})
   880  				gomock.InOrder(
   881  					str.EXPECT().Close().Do(func() error { close(done); return nil }),
   882  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
   883  				)
   884  				// the response body is sent asynchronously, while already reading the response
   885  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   886  					<-done
   887  					return 0, errors.New("test done")
   888  				})
   889  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   890  				Expect(err).To(MatchError("test done"))
   891  				hfs := decodeHeader(strBuf)
   892  				Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
   893  				Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
   894  			})
   895  
   896  			It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
   897  				req.ContentLength = 7
   898  				var once sync.Once
   899  				done := make(chan struct{})
   900  				gomock.InOrder(
   901  					str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
   902  						once.Do(func() {
   903  							Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
   904  							close(done)
   905  						})
   906  					}).AnyTimes(),
   907  					str.EXPECT().Close().MaxTimes(1),
   908  					str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
   909  				)
   910  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   911  					<-done
   912  					return 0, errors.New("done")
   913  				})
   914  				cl.RoundTripOpt(req, RoundTripOpt{})
   915  				Expect(strBuf.String()).To(ContainSubstring("request"))
   916  				Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
   917  			})
   918  
   919  			It("returns the error that occurred when reading the body", func() {
   920  				req.Body.(*mockBody).readErr = errors.New("testErr")
   921  				done := make(chan struct{})
   922  				gomock.InOrder(
   923  					str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
   924  						close(done)
   925  					}),
   926  					str.EXPECT().CancelWrite(gomock.Any()),
   927  				)
   928  
   929  				// the response body is sent asynchronously, while already reading the response
   930  				str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
   931  					<-done
   932  					return 0, errors.New("test done")
   933  				})
   934  				closed := make(chan struct{})
   935  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   936  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   937  				Expect(err).To(MatchError("test done"))
   938  				Eventually(closed).Should(BeClosed())
   939  			})
   940  
   941  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   942  				b := (&dataFrame{Length: 0x42}).Append(nil)
   943  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
   944  				closed := make(chan struct{})
   945  				r := bytes.NewReader(b)
   946  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   947  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   948  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   949  				Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
   950  				Eventually(closed).Should(BeClosed())
   951  			})
   952  
   953  			It("cancels the stream when parsing the headers fails", func() {
   954  				headerBuf := &bytes.Buffer{}
   955  				enc := qpack.NewEncoder(headerBuf)
   956  				Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header
   957  				Expect(enc.Close()).To(Succeed())
   958  				b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
   959  				b = append(b, headerBuf.Bytes()...)
   960  
   961  				r := bytes.NewReader(b)
   962  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
   963  				closed := make(chan struct{})
   964  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   965  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   966  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   967  				Expect(err).To(HaveOccurred())
   968  				Eventually(closed).Should(BeClosed())
   969  			})
   970  
   971  			It("cancels the stream when the HEADERS frame is too large", func() {
   972  				b := (&headersFrame{Length: 1338}).Append(nil)
   973  				r := bytes.NewReader(b)
   974  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
   975  				closed := make(chan struct{})
   976  				str.EXPECT().Close().Do(func() error { close(closed); return nil })
   977  				str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   978  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
   979  				Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
   980  				Eventually(closed).Should(BeClosed())
   981  			})
   982  		})
   983  
   984  		Context("request cancellations", func() {
   985  			for _, dontClose := range []bool{false, true} {
   986  				dontClose := dontClose
   987  
   988  				Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
   989  					roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
   990  
   991  					It("cancels a request while waiting for the handshake to complete", func() {
   992  						ctx, cancel := context.WithCancel(context.Background())
   993  						req := req.WithContext(ctx)
   994  						conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   995  
   996  						errChan := make(chan error)
   997  						go func() {
   998  							_, err := cl.RoundTripOpt(req, roundTripOpt)
   999  							errChan <- err
  1000  						}()
  1001  						Consistently(errChan).ShouldNot(Receive())
  1002  						cancel()
  1003  						Eventually(errChan).Should(Receive(MatchError("context canceled")))
  1004  					})
  1005  
  1006  					It("cancels a request while the request is still in flight", func() {
  1007  						ctx, cancel := context.WithCancel(context.Background())
  1008  						req := req.WithContext(ctx)
  1009  						conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1010  						conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
  1011  						buf := &bytes.Buffer{}
  1012  						str.EXPECT().Close().MaxTimes(1)
  1013  
  1014  						str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
  1015  
  1016  						done := make(chan struct{})
  1017  						canceled := make(chan struct{})
  1018  						gomock.InOrder(
  1019  							str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
  1020  							str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
  1021  						)
  1022  						str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
  1023  						str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
  1024  							cancel()
  1025  							<-canceled
  1026  							return 0, errors.New("test done")
  1027  						})
  1028  						_, err := cl.RoundTripOpt(req, roundTripOpt)
  1029  						Expect(err).To(MatchError(context.Canceled))
  1030  						Eventually(done).Should(BeClosed())
  1031  					})
  1032  				})
  1033  			}
  1034  
  1035  			It("cancels a request after the response arrived", func() {
  1036  				rspBuf := bytes.NewBuffer(getResponse(404))
  1037  
  1038  				ctx, cancel := context.WithCancel(context.Background())
  1039  				req := req.WithContext(ctx)
  1040  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1041  				conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
  1042  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1043  				buf := &bytes.Buffer{}
  1044  				str.EXPECT().Close().MaxTimes(1)
  1045  
  1046  				done := make(chan struct{})
  1047  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
  1048  				str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
  1049  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  1050  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
  1051  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1052  				Expect(err).ToNot(HaveOccurred())
  1053  				cancel()
  1054  				Eventually(done).Should(BeClosed())
  1055  			})
  1056  		})
  1057  
  1058  		Context("gzip compression", func() {
  1059  			BeforeEach(func() {
  1060  				conn.EXPECT().HandshakeComplete().Return(handshakeChan)
  1061  			})
  1062  
  1063  			It("adds the gzip header to requests", func() {
  1064  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1065  				buf := &bytes.Buffer{}
  1066  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
  1067  				gomock.InOrder(
  1068  					str.EXPECT().Close(),
  1069  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
  1070  				)
  1071  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
  1072  				_, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1073  				Expect(err).To(MatchError("test done"))
  1074  				hfs := decodeHeader(buf)
  1075  				Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
  1076  			})
  1077  
  1078  			It("doesn't add gzip if the header disable it", func() {
  1079  				client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
  1080  				Expect(err).ToNot(HaveOccurred())
  1081  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1082  				buf := &bytes.Buffer{}
  1083  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
  1084  				gomock.InOrder(
  1085  					str.EXPECT().Close(),
  1086  					str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
  1087  				)
  1088  				str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
  1089  				_, err = client.RoundTripOpt(req, RoundTripOpt{})
  1090  				Expect(err).To(MatchError("test done"))
  1091  				hfs := decodeHeader(buf)
  1092  				Expect(hfs).ToNot(HaveKey("accept-encoding"))
  1093  			})
  1094  
  1095  			It("decompresses the response", func() {
  1096  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1097  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1098  				buf := &bytes.Buffer{}
  1099  				rstr := mockquic.NewMockStream(mockCtrl)
  1100  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1101  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1102  				rw.Header().Set("Content-Encoding", "gzip")
  1103  				gz := gzip.NewWriter(rw)
  1104  				gz.Write([]byte("gzipped response"))
  1105  				gz.Close()
  1106  				rw.Flush()
  1107  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1108  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1109  				str.EXPECT().Close()
  1110  
  1111  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1112  				Expect(err).ToNot(HaveOccurred())
  1113  				data, err := io.ReadAll(rsp.Body)
  1114  				Expect(err).ToNot(HaveOccurred())
  1115  				Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
  1116  				Expect(string(data)).To(Equal("gzipped response"))
  1117  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1118  				Expect(rsp.Uncompressed).To(BeTrue())
  1119  			})
  1120  
  1121  			It("only decompresses the response if the response contains the right content-encoding header", func() {
  1122  				conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
  1123  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
  1124  				buf := &bytes.Buffer{}
  1125  				rstr := mockquic.NewMockStream(mockCtrl)
  1126  				rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
  1127  				rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
  1128  				rw.Write([]byte("not gzipped"))
  1129  				rw.Flush()
  1130  				str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
  1131  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
  1132  				str.EXPECT().Close()
  1133  
  1134  				rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
  1135  				Expect(err).ToNot(HaveOccurred())
  1136  				data, err := io.ReadAll(rsp.Body)
  1137  				Expect(err).ToNot(HaveOccurred())
  1138  				Expect(string(data)).To(Equal("not gzipped"))
  1139  				Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
  1140  			})
  1141  		})
  1142  	})
  1143  })