github.com/quic-go/quic-go@v0.44.0/http3/roundtrip_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"io"
     9  	"net/http"
    10  	"time"
    11  
    12  	"github.com/quic-go/quic-go"
    13  	mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
    14  	"github.com/quic-go/quic-go/internal/protocol"
    15  	"github.com/quic-go/quic-go/internal/qerr"
    16  
    17  	. "github.com/onsi/ginkgo/v2"
    18  	. "github.com/onsi/gomega"
    19  	"go.uber.org/mock/gomock"
    20  )
    21  
    22  type mockBody struct {
    23  	reader   bytes.Reader
    24  	readErr  error
    25  	closeErr error
    26  	closed   bool
    27  }
    28  
    29  // make sure the mockBody can be used as a http.Request.Body
    30  var _ io.ReadCloser = &mockBody{}
    31  
    32  func (m *mockBody) Read(p []byte) (int, error) {
    33  	if m.readErr != nil {
    34  		return 0, m.readErr
    35  	}
    36  	return m.reader.Read(p)
    37  }
    38  
    39  func (m *mockBody) SetData(data []byte) {
    40  	m.reader = *bytes.NewReader(data)
    41  }
    42  
    43  func (m *mockBody) Close() error {
    44  	m.closed = true
    45  	return m.closeErr
    46  }
    47  
    48  var _ = Describe("RoundTripper", func() {
    49  	var req *http.Request
    50  
    51  	BeforeEach(func() {
    52  		var err error
    53  		req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
    54  		Expect(err).ToNot(HaveOccurred())
    55  	})
    56  
    57  	It("rejects quic.Configs that allow multiple QUIC versions", func() {
    58  		qconf := &quic.Config{
    59  			Versions: []quic.Version{protocol.Version2, protocol.Version1},
    60  		}
    61  		rt := &RoundTripper{QUICConfig: qconf}
    62  		_, err := rt.RoundTrip(req)
    63  		Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
    64  	})
    65  
    66  	It("uses the default QUIC and TLS config if none is give", func() {
    67  		var dialAddrCalled bool
    68  		rt := &RoundTripper{
    69  			Dial: func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
    70  				defer GinkgoRecover()
    71  				Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
    72  				Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
    73  				Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1}))
    74  				dialAddrCalled = true
    75  				return nil, errors.New("test done")
    76  			},
    77  		}
    78  		_, err := rt.RoundTripOpt(req, RoundTripOpt{})
    79  		Expect(err).To(MatchError("test done"))
    80  		Expect(dialAddrCalled).To(BeTrue())
    81  	})
    82  
    83  	It("adds the port to the hostname, if none is given", func() {
    84  		var dialAddrCalled bool
    85  		rt := &RoundTripper{
    86  			Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
    87  				defer GinkgoRecover()
    88  				Expect(hostname).To(Equal("quic.clemente.io:443"))
    89  				dialAddrCalled = true
    90  				return nil, errors.New("test done")
    91  			},
    92  		}
    93  		req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
    94  		Expect(err).ToNot(HaveOccurred())
    95  		_, err = rt.RoundTripOpt(req, RoundTripOpt{})
    96  		Expect(err).To(MatchError("test done"))
    97  		Expect(dialAddrCalled).To(BeTrue())
    98  	})
    99  
   100  	It("sets the ServerName in the tls.Config, if not set", func() {
   101  		const host = "foo.bar"
   102  		var dialCalled bool
   103  		rt := &RoundTripper{
   104  			Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   105  				defer GinkgoRecover()
   106  				Expect(tlsCfg.ServerName).To(Equal(host))
   107  				dialCalled = true
   108  				return nil, errors.New("test done")
   109  			},
   110  		}
   111  		req, err := http.NewRequest("GET", "https://foo.bar", nil)
   112  		Expect(err).ToNot(HaveOccurred())
   113  		_, err = rt.RoundTripOpt(req, RoundTripOpt{})
   114  		Expect(err).To(MatchError("test done"))
   115  		Expect(dialCalled).To(BeTrue())
   116  	})
   117  
   118  	It("uses the TLS config and QUIC config", func() {
   119  		tlsConf := &tls.Config{
   120  			ServerName: "foo.bar",
   121  			NextProtos: []string{"proto foo", "proto bar"},
   122  		}
   123  		quicConf := &quic.Config{MaxIdleTimeout: 3 * time.Nanosecond}
   124  		var dialAddrCalled bool
   125  		rt := &RoundTripper{
   126  			Dial: func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   127  				defer GinkgoRecover()
   128  				Expect(host).To(Equal("www.example.org:443"))
   129  				Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
   130  				Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
   131  				Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   132  				dialAddrCalled = true
   133  				return nil, errors.New("test done")
   134  			},
   135  			QUICConfig:      quicConf,
   136  			TLSClientConfig: tlsConf,
   137  		}
   138  		_, err := rt.RoundTripOpt(req, RoundTripOpt{})
   139  		Expect(err).To(MatchError("test done"))
   140  		Expect(dialAddrCalled).To(BeTrue())
   141  		// make sure the original tls.Config was not modified
   142  		Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
   143  	})
   144  
   145  	It("uses the custom dialer, if provided", func() {
   146  		testErr := errors.New("test done")
   147  		tlsConf := &tls.Config{ServerName: "foo.bar"}
   148  		quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
   149  		// nolint:staticcheck // This is a test.
   150  		ctx := context.WithValue(context.Background(), "foo", "bar")
   151  		var dialerCalled bool
   152  		rt := &RoundTripper{
   153  			Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
   154  				defer GinkgoRecover()
   155  				Expect(ctx.Value("foo").(string)).To(Equal("bar"))
   156  				Expect(address).To(Equal("www.example.org:443"))
   157  				Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
   158  				Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
   159  				dialerCalled = true
   160  				return nil, testErr
   161  			},
   162  			TLSClientConfig: tlsConf,
   163  			QUICConfig:      quicConf,
   164  		}
   165  		_, err := rt.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
   166  		Expect(err).To(MatchError(testErr))
   167  		Expect(dialerCalled).To(BeTrue())
   168  	})
   169  
   170  	It("enables HTTP/3 Datagrams", func() {
   171  		testErr := errors.New("handshake error")
   172  		rt := &RoundTripper{
   173  			EnableDatagrams: true,
   174  			Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   175  				defer GinkgoRecover()
   176  				Expect(quicConf.EnableDatagrams).To(BeTrue())
   177  				return nil, testErr
   178  			},
   179  		}
   180  		_, err := rt.RoundTripOpt(req, RoundTripOpt{})
   181  		Expect(err).To(MatchError(testErr))
   182  	})
   183  
   184  	It("requires quic.Config.EnableDatagrams if HTTP/3 datagrams are enabled", func() {
   185  		rt := &RoundTripper{
   186  			QUICConfig:      &quic.Config{EnableDatagrams: false},
   187  			EnableDatagrams: true,
   188  			Dial: func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
   189  				return nil, errors.New("handshake error")
   190  			},
   191  		}
   192  		_, err := rt.RoundTrip(req)
   193  		Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled"))
   194  	})
   195  
   196  	It("creates new clients", func() {
   197  		testErr := errors.New("test err")
   198  		req1, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
   199  		Expect(err).ToNot(HaveOccurred())
   200  		req2, err := http.NewRequest("GET", "https://example.com/foobar.html", nil)
   201  		Expect(err).ToNot(HaveOccurred())
   202  		var hostsDialed []string
   203  		rt := &RoundTripper{
   204  			Dial: func(_ context.Context, host string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
   205  				hostsDialed = append(hostsDialed, host)
   206  				return nil, testErr
   207  			},
   208  		}
   209  		_, err = rt.RoundTrip(req1)
   210  		Expect(err).To(MatchError(testErr))
   211  		_, err = rt.RoundTrip(req2)
   212  		Expect(err).To(MatchError(testErr))
   213  		Expect(hostsDialed).To(Equal([]string{"quic-go.net:443", "example.com:443"}))
   214  	})
   215  
   216  	Context("reusing clients", func() {
   217  		var (
   218  			rt         *RoundTripper
   219  			req1, req2 *http.Request
   220  			clientChan chan *MockSingleRoundTripper
   221  		)
   222  
   223  		BeforeEach(func() {
   224  			clientChan = make(chan *MockSingleRoundTripper, 16)
   225  			rt = &RoundTripper{
   226  				newClient: func(quic.EarlyConnection) singleRoundTripper {
   227  					select {
   228  					case c := <-clientChan:
   229  						return c
   230  					default:
   231  						Fail("no client")
   232  						return nil
   233  					}
   234  				},
   235  			}
   236  			var err error
   237  			req1, err = http.NewRequest("GET", "https://quic-go.net/file1.html", nil)
   238  			Expect(err).ToNot(HaveOccurred())
   239  			req2, err = http.NewRequest("GET", "https://quic-go.net/file2.html", nil)
   240  			Expect(err).ToNot(HaveOccurred())
   241  			Expect(req1.URL).ToNot(Equal(req2.URL))
   242  		})
   243  
   244  		It("reuses existing clients", func() {
   245  			cl := NewMockSingleRoundTripper(mockCtrl)
   246  			clientChan <- cl
   247  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
   248  			handshakeChan := make(chan struct{})
   249  			close(handshakeChan)
   250  			conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
   251  
   252  			cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil)
   253  			cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
   254  			var count int
   255  			rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   256  				count++
   257  				return conn, nil
   258  			}
   259  			rsp, err := rt.RoundTrip(req1)
   260  			Expect(err).ToNot(HaveOccurred())
   261  			Expect(rsp.Request).To(Equal(req1))
   262  			rsp, err = rt.RoundTrip(req2)
   263  			Expect(err).ToNot(HaveOccurred())
   264  			Expect(rsp.Request).To(Equal(req2))
   265  			Expect(count).To(Equal(1))
   266  		})
   267  
   268  		It("immediately removes a clients when a request errored", func() {
   269  			cl1 := NewMockSingleRoundTripper(mockCtrl)
   270  			clientChan <- cl1
   271  			cl2 := NewMockSingleRoundTripper(mockCtrl)
   272  			clientChan <- cl2
   273  
   274  			req1, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
   275  			Expect(err).ToNot(HaveOccurred())
   276  			req2, err := http.NewRequest("GET", "https://quic-go.net/bar.html", nil)
   277  			Expect(err).ToNot(HaveOccurred())
   278  
   279  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
   280  			var count int
   281  			rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   282  				count++
   283  				return conn, nil
   284  			}
   285  			testErr := errors.New("test err")
   286  			handshakeChan := make(chan struct{})
   287  			close(handshakeChan)
   288  			conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
   289  			cl1.EXPECT().RoundTrip(req1).Return(nil, testErr)
   290  			cl2.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
   291  			_, err = rt.RoundTrip(req1)
   292  			Expect(err).To(MatchError(testErr))
   293  			rsp, err := rt.RoundTrip(req2)
   294  			Expect(err).ToNot(HaveOccurred())
   295  			Expect(rsp.Request).To(Equal(req2))
   296  			Expect(count).To(Equal(2))
   297  		})
   298  
   299  		It("does not remove a client when a request returns context canceled error", func() {
   300  			cl1 := NewMockSingleRoundTripper(mockCtrl)
   301  			clientChan <- cl1
   302  			cl2 := NewMockSingleRoundTripper(mockCtrl)
   303  			clientChan <- cl2
   304  
   305  			req1, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
   306  			Expect(err).ToNot(HaveOccurred())
   307  			req2, err := http.NewRequest("GET", "https://quic-go.net/bar.html", nil)
   308  			Expect(err).ToNot(HaveOccurred())
   309  
   310  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
   311  			var count int
   312  			rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   313  				count++
   314  				return conn, nil
   315  			}
   316  			testErr := context.Canceled
   317  			handshakeChan := make(chan struct{})
   318  			close(handshakeChan)
   319  			conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
   320  			cl1.EXPECT().RoundTrip(req1).Return(nil, testErr)
   321  			cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
   322  			_, err = rt.RoundTrip(req1)
   323  			Expect(err).To(MatchError(testErr))
   324  			rsp, err := rt.RoundTrip(req2)
   325  			Expect(err).ToNot(HaveOccurred())
   326  			Expect(rsp.Request).To(Equal(req2))
   327  			Expect(count).To(Equal(1))
   328  		})
   329  
   330  		It("recreates a client when a request times out", func() {
   331  			var reqCount int
   332  			cl1 := NewMockSingleRoundTripper(mockCtrl)
   333  			cl1.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(req *http.Request) (*http.Response, error) {
   334  				reqCount++
   335  				if reqCount == 1 { // the first request is successful...
   336  					Expect(req.URL).To(Equal(req1.URL))
   337  					return &http.Response{Request: req}, nil
   338  				}
   339  				// ... after that, the connection timed out in the background
   340  				Expect(req.URL).To(Equal(req2.URL))
   341  				return nil, &qerr.IdleTimeoutError{}
   342  			}).Times(2)
   343  			cl2 := NewMockSingleRoundTripper(mockCtrl)
   344  			cl2.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(req *http.Request) (*http.Response, error) {
   345  				return &http.Response{Request: req}, nil
   346  			})
   347  			clientChan <- cl1
   348  			clientChan <- cl2
   349  
   350  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
   351  			handshakeChan := make(chan struct{})
   352  			close(handshakeChan)
   353  			conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
   354  			var count int
   355  			rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   356  				count++
   357  				return conn, nil
   358  			}
   359  			rsp1, err := rt.RoundTrip(req1)
   360  			Expect(err).ToNot(HaveOccurred())
   361  			Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
   362  			rsp2, err := rt.RoundTrip(req2)
   363  			Expect(err).ToNot(HaveOccurred())
   364  			Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
   365  		})
   366  
   367  		It("only issues a request once, even if a timeout error occurs", func() {
   368  			var count int
   369  			rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   370  				count++
   371  				return mockquic.NewMockEarlyConnection(mockCtrl), nil
   372  			}
   373  			rt.newClient = func(quic.EarlyConnection) singleRoundTripper {
   374  				cl := NewMockSingleRoundTripper(mockCtrl)
   375  				cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
   376  				return cl
   377  			}
   378  			_, err := rt.RoundTrip(req1)
   379  			Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   380  			Expect(count).To(Equal(1))
   381  		})
   382  
   383  		It("handles a burst of requests", func() {
   384  			wait := make(chan struct{})
   385  			reqs := make(chan struct{}, 2)
   386  
   387  			cl := NewMockSingleRoundTripper(mockCtrl)
   388  			cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(req *http.Request) (*http.Response, error) {
   389  				reqs <- struct{}{}
   390  				<-wait
   391  				return nil, &qerr.IdleTimeoutError{}
   392  			}).Times(2)
   393  			clientChan <- cl
   394  
   395  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
   396  			conn.EXPECT().HandshakeComplete().Return(wait).AnyTimes()
   397  			var count int
   398  			rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   399  				count++
   400  				return conn, nil
   401  			}
   402  
   403  			done := make(chan struct{}, 2)
   404  			go func() {
   405  				defer GinkgoRecover()
   406  				defer func() { done <- struct{}{} }()
   407  				_, err := rt.RoundTrip(req1)
   408  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   409  			}()
   410  			// wait for the first requests to be issued
   411  			Eventually(reqs).Should(Receive())
   412  			go func() {
   413  				defer GinkgoRecover()
   414  				defer func() { done <- struct{}{} }()
   415  				_, err := rt.RoundTrip(req2)
   416  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   417  			}()
   418  			Eventually(reqs).Should(Receive())
   419  			close(wait) // now return the requests
   420  			Eventually(done).Should(Receive())
   421  			Eventually(done).Should(Receive())
   422  			Expect(count).To(Equal(1))
   423  		})
   424  
   425  		It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
   426  			req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
   427  			Expect(err).ToNot(HaveOccurred())
   428  			_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
   429  			Expect(err).To(MatchError(ErrNoCachedConn))
   430  		})
   431  	})
   432  
   433  	Context("validating request", func() {
   434  		var rt RoundTripper
   435  
   436  		It("rejects plain HTTP requests", func() {
   437  			req, err := http.NewRequest("GET", "http://www.example.org/", nil)
   438  			req.Body = &mockBody{}
   439  			Expect(err).ToNot(HaveOccurred())
   440  			_, err = rt.RoundTrip(req)
   441  			Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
   442  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   443  		})
   444  
   445  		It("rejects requests without a URL", func() {
   446  			req.URL = nil
   447  			req.Body = &mockBody{}
   448  			_, err := rt.RoundTrip(req)
   449  			Expect(err).To(MatchError("http3: nil Request.URL"))
   450  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   451  		})
   452  
   453  		It("rejects request without a URL Host", func() {
   454  			req.URL.Host = ""
   455  			req.Body = &mockBody{}
   456  			_, err := rt.RoundTrip(req)
   457  			Expect(err).To(MatchError("http3: no Host in request URL"))
   458  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   459  		})
   460  
   461  		It("doesn't try to close the body if the request doesn't have one", func() {
   462  			req.URL = nil
   463  			Expect(req.Body).To(BeNil())
   464  			_, err := rt.RoundTrip(req)
   465  			Expect(err).To(MatchError("http3: nil Request.URL"))
   466  		})
   467  
   468  		It("rejects requests without a header", func() {
   469  			req.Header = nil
   470  			req.Body = &mockBody{}
   471  			_, err := rt.RoundTrip(req)
   472  			Expect(err).To(MatchError("http3: nil Request.Header"))
   473  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   474  		})
   475  
   476  		It("rejects requests with invalid header name fields", func() {
   477  			req.Header.Add("foobär", "value")
   478  			_, err := rt.RoundTrip(req)
   479  			Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
   480  		})
   481  
   482  		It("rejects requests with invalid header name values", func() {
   483  			req.Header.Add("foo", string([]byte{0x7}))
   484  			_, err := rt.RoundTrip(req)
   485  			Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
   486  		})
   487  
   488  		It("rejects requests with an invalid request method", func() {
   489  			req.Method = "foobär"
   490  			req.Body = &mockBody{}
   491  			_, err := rt.RoundTrip(req)
   492  			Expect(err).To(MatchError("http3: invalid method \"foobär\""))
   493  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   494  		})
   495  	})
   496  
   497  	Context("closing", func() {
   498  		It("closes", func() {
   499  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
   500  			rt := &RoundTripper{
   501  				Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
   502  					return conn, nil
   503  				},
   504  				newClient: func(quic.EarlyConnection) singleRoundTripper {
   505  					cl := NewMockSingleRoundTripper(mockCtrl)
   506  					cl.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{}, nil)
   507  					return cl
   508  				},
   509  			}
   510  			req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
   511  			Expect(err).ToNot(HaveOccurred())
   512  			_, err = rt.RoundTrip(req)
   513  			Expect(err).ToNot(HaveOccurred())
   514  			conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(0), "")
   515  			Expect(rt.Close()).To(Succeed())
   516  		})
   517  
   518  		It("closes while dialing", func() {
   519  			rt := &RoundTripper{
   520  				Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
   521  					defer GinkgoRecover()
   522  					Eventually(ctx.Done()).Should(BeClosed())
   523  					return nil, errors.New("cancelled")
   524  				},
   525  			}
   526  			req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
   527  			Expect(err).ToNot(HaveOccurred())
   528  
   529  			errChan := make(chan error, 1)
   530  			go func() {
   531  				defer GinkgoRecover()
   532  				_, err := rt.RoundTrip(req)
   533  				errChan <- err
   534  			}()
   535  
   536  			Consistently(errChan, scaleDuration(30*time.Millisecond)).ShouldNot(Receive())
   537  			Expect(rt.Close()).To(Succeed())
   538  			var rtErr error
   539  			Eventually(errChan).Should(Receive(&rtErr))
   540  			Expect(rtErr).To(MatchError("cancelled"))
   541  		})
   542  
   543  		It("closes idle connections", func() {
   544  			conn1 := mockquic.NewMockEarlyConnection(mockCtrl)
   545  			conn2 := mockquic.NewMockEarlyConnection(mockCtrl)
   546  			rt := &RoundTripper{
   547  				Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
   548  					switch hostname {
   549  					case "site1.com:443":
   550  						return conn1, nil
   551  					case "site2.com:443":
   552  						return conn2, nil
   553  					default:
   554  						Fail("unexpected hostname")
   555  						return nil, errors.New("unexpected hostname")
   556  					}
   557  				},
   558  			}
   559  			req1, err := http.NewRequest("GET", "https://site1.com", nil)
   560  			Expect(err).ToNot(HaveOccurred())
   561  			req2, err := http.NewRequest("GET", "https://site2.com", nil)
   562  			Expect(err).ToNot(HaveOccurred())
   563  			Expect(req1.Host).ToNot(Equal(req2.Host))
   564  			ctx1, cancel1 := context.WithCancel(context.Background())
   565  			ctx2, cancel2 := context.WithCancel(context.Background())
   566  			req1 = req1.WithContext(ctx1)
   567  			req2 = req2.WithContext(ctx2)
   568  			roundTripCalled := make(chan struct{})
   569  			reqFinished := make(chan struct{})
   570  			rt.newClient = func(quic.EarlyConnection) singleRoundTripper {
   571  				cl := NewMockSingleRoundTripper(mockCtrl)
   572  				cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) {
   573  					roundTripCalled <- struct{}{}
   574  					<-r.Context().Done()
   575  					return nil, nil
   576  				})
   577  				return cl
   578  			}
   579  			go func() {
   580  				rt.RoundTrip(req1)
   581  				reqFinished <- struct{}{}
   582  			}()
   583  			go func() {
   584  				rt.RoundTrip(req2)
   585  				reqFinished <- struct{}{}
   586  			}()
   587  			<-roundTripCalled
   588  			<-roundTripCalled
   589  			// Both two requests are started.
   590  			cancel1()
   591  			<-reqFinished
   592  			// req1 is finished
   593  			conn1.EXPECT().CloseWithError(gomock.Any(), gomock.Any())
   594  			rt.CloseIdleConnections()
   595  			cancel2()
   596  			<-reqFinished
   597  			// all requests are finished
   598  			conn2.EXPECT().CloseWithError(gomock.Any(), gomock.Any())
   599  			rt.CloseIdleConnections()
   600  		})
   601  	})
   602  })