github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/roundtrip_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"io"
     9  	"net/http"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/daeuniverse/quic-go"
    14  	"github.com/daeuniverse/quic-go/internal/qerr"
    15  
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  	"go.uber.org/mock/gomock"
    19  )
    20  
    21  type mockBody struct {
    22  	reader   bytes.Reader
    23  	readErr  error
    24  	closeErr error
    25  	closed   bool
    26  }
    27  
    28  // make sure the mockBody can be used as a http.Request.Body
    29  var _ io.ReadCloser = &mockBody{}
    30  
    31  func (m *mockBody) Read(p []byte) (int, error) {
    32  	if m.readErr != nil {
    33  		return 0, m.readErr
    34  	}
    35  	return m.reader.Read(p)
    36  }
    37  
    38  func (m *mockBody) SetData(data []byte) {
    39  	m.reader = *bytes.NewReader(data)
    40  }
    41  
    42  func (m *mockBody) Close() error {
    43  	m.closed = true
    44  	return m.closeErr
    45  }
    46  
    47  var _ = Describe("RoundTripper", func() {
    48  	var (
    49  		rt  *RoundTripper
    50  		req *http.Request
    51  	)
    52  
    53  	BeforeEach(func() {
    54  		rt = &RoundTripper{}
    55  		var err error
    56  		req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
    57  		Expect(err).ToNot(HaveOccurred())
    58  	})
    59  
    60  	Context("dialing hosts", func() {
    61  		It("creates new clients", func() {
    62  			testErr := errors.New("test err")
    63  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
    64  			Expect(err).ToNot(HaveOccurred())
    65  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
    66  				cl := NewMockRoundTripCloser(mockCtrl)
    67  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
    68  				return cl, nil
    69  			}
    70  			_, err = rt.RoundTrip(req)
    71  			Expect(err).To(MatchError(testErr))
    72  		})
    73  
    74  		It("creates new clients with additional settings", func() {
    75  			testErr := errors.New("test err")
    76  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
    77  			Expect(err).ToNot(HaveOccurred())
    78  			rt.AdditionalSettings = map[uint64]uint64{1337: 42}
    79  			rt.newClient = func(_ string, _ *tls.Config, opts *roundTripperOpts, conf *quic.Config, _ dialFunc) (roundTripCloser, error) {
    80  				cl := NewMockRoundTripCloser(mockCtrl)
    81  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
    82  				Expect(opts.AdditionalSettings).To(HaveKeyWithValue(uint64(1337), uint64(42)))
    83  				return cl, nil
    84  			}
    85  			_, err = rt.RoundTrip(req)
    86  			Expect(err).To(MatchError(testErr))
    87  		})
    88  
    89  		It("uses the quic.Config, if provided", func() {
    90  			config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
    91  			var receivedConfig *quic.Config
    92  			rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
    93  				receivedConfig = config
    94  				return nil, errors.New("handshake error")
    95  			}
    96  			rt.QuicConfig = config
    97  			_, err := rt.RoundTrip(req)
    98  			Expect(err).To(MatchError("handshake error"))
    99  			Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
   100  		})
   101  
   102  		It("requires quic.Config.EnableDatagram if HTTP datagrams are enabled", func() {
   103  			rt.QuicConfig = &quic.Config{EnableDatagrams: false}
   104  			rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
   105  				return nil, errors.New("handshake error")
   106  			}
   107  			rt.EnableDatagrams = true
   108  			_, err := rt.RoundTrip(req)
   109  			Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled"))
   110  			rt.QuicConfig.EnableDatagrams = true
   111  			_, err = rt.RoundTrip(req)
   112  			Expect(err).To(MatchError("handshake error"))
   113  		})
   114  
   115  		It("uses the custom dialer, if provided", func() {
   116  			var dialed bool
   117  			dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   118  				dialed = true
   119  				return nil, errors.New("handshake error")
   120  			}
   121  			rt.Dial = dialer
   122  			_, err := rt.RoundTrip(req)
   123  			Expect(err).To(MatchError("handshake error"))
   124  			Expect(dialed).To(BeTrue())
   125  		})
   126  	})
   127  
   128  	Context("reusing clients", func() {
   129  		var req1, req2 *http.Request
   130  
   131  		BeforeEach(func() {
   132  			var err error
   133  			req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
   134  			Expect(err).ToNot(HaveOccurred())
   135  			req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
   136  			Expect(err).ToNot(HaveOccurred())
   137  			Expect(req1.URL).ToNot(Equal(req2.URL))
   138  		})
   139  
   140  		It("reuses existing clients", func() {
   141  			var count int
   142  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   143  				count++
   144  				cl := NewMockRoundTripCloser(mockCtrl)
   145  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   146  					return &http.Response{Request: req}, nil
   147  				}).Times(2)
   148  				cl.EXPECT().HandshakeComplete().Return(true)
   149  				return cl, nil
   150  			}
   151  			rsp1, err := rt.RoundTrip(req1)
   152  			Expect(err).ToNot(HaveOccurred())
   153  			Expect(rsp1.Request.URL).To(Equal(req1.URL))
   154  			rsp2, err := rt.RoundTrip(req2)
   155  			Expect(err).ToNot(HaveOccurred())
   156  			Expect(rsp2.Request.URL).To(Equal(req2.URL))
   157  			Expect(count).To(Equal(1))
   158  		})
   159  
   160  		It("immediately removes a clients when a request errored", func() {
   161  			testErr := errors.New("test err")
   162  
   163  			var count int
   164  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   165  				count++
   166  				cl := NewMockRoundTripCloser(mockCtrl)
   167  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
   168  				return cl, nil
   169  			}
   170  			_, err := rt.RoundTrip(req1)
   171  			Expect(err).To(MatchError(testErr))
   172  			_, err = rt.RoundTrip(req2)
   173  			Expect(err).To(MatchError(testErr))
   174  			Expect(count).To(Equal(2))
   175  		})
   176  
   177  		It("recreates a client when a request times out", func() {
   178  			var reqCount int
   179  			cl1 := NewMockRoundTripCloser(mockCtrl)
   180  			cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   181  				reqCount++
   182  				if reqCount == 1 { // the first request is successful...
   183  					Expect(req.URL).To(Equal(req1.URL))
   184  					return &http.Response{Request: req}, nil
   185  				}
   186  				// ... after that, the connection timed out in the background
   187  				Expect(req.URL).To(Equal(req2.URL))
   188  				return nil, &qerr.IdleTimeoutError{}
   189  			}).Times(2)
   190  			cl1.EXPECT().HandshakeComplete().Return(true)
   191  			cl2 := NewMockRoundTripCloser(mockCtrl)
   192  			cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   193  				return &http.Response{Request: req}, nil
   194  			})
   195  
   196  			var count int
   197  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   198  				count++
   199  				if count == 1 {
   200  					return cl1, nil
   201  				}
   202  				return cl2, nil
   203  			}
   204  			rsp1, err := rt.RoundTrip(req1)
   205  			Expect(err).ToNot(HaveOccurred())
   206  			Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
   207  			rsp2, err := rt.RoundTrip(req2)
   208  			Expect(err).ToNot(HaveOccurred())
   209  			Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
   210  		})
   211  
   212  		It("only issues a request once, even if a timeout error occurs", func() {
   213  			var count int
   214  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   215  				count++
   216  				cl := NewMockRoundTripCloser(mockCtrl)
   217  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
   218  				return cl, nil
   219  			}
   220  			_, err := rt.RoundTrip(req1)
   221  			Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   222  			Expect(count).To(Equal(1))
   223  		})
   224  
   225  		It("handles a burst of requests", func() {
   226  			wait := make(chan struct{})
   227  			reqs := make(chan struct{}, 2)
   228  			var count int
   229  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   230  				count++
   231  				cl := NewMockRoundTripCloser(mockCtrl)
   232  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   233  					reqs <- struct{}{}
   234  					<-wait
   235  					return nil, &qerr.IdleTimeoutError{}
   236  				}).Times(2)
   237  				cl.EXPECT().HandshakeComplete()
   238  				return cl, nil
   239  			}
   240  			done := make(chan struct{}, 2)
   241  			go func() {
   242  				defer GinkgoRecover()
   243  				defer func() { done <- struct{}{} }()
   244  				_, err := rt.RoundTrip(req1)
   245  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   246  			}()
   247  			go func() {
   248  				defer GinkgoRecover()
   249  				defer func() { done <- struct{}{} }()
   250  				_, err := rt.RoundTrip(req2)
   251  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   252  			}()
   253  			// wait for both requests to be issued
   254  			Eventually(reqs).Should(Receive())
   255  			Eventually(reqs).Should(Receive())
   256  			close(wait) // now return the requests
   257  			Eventually(done).Should(Receive())
   258  			Eventually(done).Should(Receive())
   259  			Expect(count).To(Equal(1))
   260  		})
   261  
   262  		It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
   263  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
   264  			Expect(err).ToNot(HaveOccurred())
   265  			_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
   266  			Expect(err).To(MatchError(ErrNoCachedConn))
   267  		})
   268  	})
   269  
   270  	Context("validating request", func() {
   271  		It("rejects plain HTTP requests", func() {
   272  			req, err := http.NewRequest("GET", "http://www.example.org/", nil)
   273  			req.Body = &mockBody{}
   274  			Expect(err).ToNot(HaveOccurred())
   275  			_, err = rt.RoundTrip(req)
   276  			Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
   277  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   278  		})
   279  
   280  		It("rejects requests without a URL", func() {
   281  			req.URL = nil
   282  			req.Body = &mockBody{}
   283  			_, err := rt.RoundTrip(req)
   284  			Expect(err).To(MatchError("http3: nil Request.URL"))
   285  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   286  		})
   287  
   288  		It("rejects request without a URL Host", func() {
   289  			req.URL.Host = ""
   290  			req.Body = &mockBody{}
   291  			_, err := rt.RoundTrip(req)
   292  			Expect(err).To(MatchError("http3: no Host in request URL"))
   293  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   294  		})
   295  
   296  		It("doesn't try to close the body if the request doesn't have one", func() {
   297  			req.URL = nil
   298  			Expect(req.Body).To(BeNil())
   299  			_, err := rt.RoundTrip(req)
   300  			Expect(err).To(MatchError("http3: nil Request.URL"))
   301  		})
   302  
   303  		It("rejects requests without a header", func() {
   304  			req.Header = nil
   305  			req.Body = &mockBody{}
   306  			_, err := rt.RoundTrip(req)
   307  			Expect(err).To(MatchError("http3: nil Request.Header"))
   308  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   309  		})
   310  
   311  		It("rejects requests with invalid header name fields", func() {
   312  			req.Header.Add("foobär", "value")
   313  			_, err := rt.RoundTrip(req)
   314  			Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
   315  		})
   316  
   317  		It("rejects requests with invalid header name values", func() {
   318  			req.Header.Add("foo", string([]byte{0x7}))
   319  			_, err := rt.RoundTrip(req)
   320  			Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
   321  		})
   322  
   323  		It("rejects requests with an invalid request method", func() {
   324  			req.Method = "foobär"
   325  			req.Body = &mockBody{}
   326  			_, err := rt.RoundTrip(req)
   327  			Expect(err).To(MatchError("http3: invalid method \"foobär\""))
   328  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   329  		})
   330  	})
   331  
   332  	Context("closing", func() {
   333  		It("closes", func() {
   334  			rt.clients = make(map[string]*roundTripCloserWithCount)
   335  			cl := NewMockRoundTripCloser(mockCtrl)
   336  			cl.EXPECT().Close()
   337  			rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}}
   338  			err := rt.Close()
   339  			Expect(err).ToNot(HaveOccurred())
   340  			Expect(len(rt.clients)).To(BeZero())
   341  		})
   342  
   343  		It("closes a RoundTripper that has never been used", func() {
   344  			Expect(len(rt.clients)).To(BeZero())
   345  			err := rt.Close()
   346  			Expect(err).ToNot(HaveOccurred())
   347  			Expect(len(rt.clients)).To(BeZero())
   348  		})
   349  
   350  		It("closes idle connections", func() {
   351  			Expect(len(rt.clients)).To(Equal(0))
   352  			req1, err := http.NewRequest("GET", "https://site1.com", nil)
   353  			Expect(err).ToNot(HaveOccurred())
   354  			req2, err := http.NewRequest("GET", "https://site2.com", nil)
   355  			Expect(err).ToNot(HaveOccurred())
   356  			Expect(req1.Host).ToNot(Equal(req2.Host))
   357  			ctx1, cancel1 := context.WithCancel(context.Background())
   358  			ctx2, cancel2 := context.WithCancel(context.Background())
   359  			req1 = req1.WithContext(ctx1)
   360  			req2 = req2.WithContext(ctx2)
   361  			roundTripCalled := make(chan struct{})
   362  			reqFinished := make(chan struct{})
   363  			rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
   364  				cl := NewMockRoundTripCloser(mockCtrl)
   365  				cl.EXPECT().Close()
   366  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) {
   367  					roundTripCalled <- struct{}{}
   368  					<-r.Context().Done()
   369  					return nil, nil
   370  				})
   371  				return cl, nil
   372  			}
   373  			go func() {
   374  				rt.RoundTrip(req1)
   375  				reqFinished <- struct{}{}
   376  			}()
   377  			go func() {
   378  				rt.RoundTrip(req2)
   379  				reqFinished <- struct{}{}
   380  			}()
   381  			<-roundTripCalled
   382  			<-roundTripCalled
   383  			// Both two requests are started.
   384  			Expect(len(rt.clients)).To(Equal(2))
   385  			cancel1()
   386  			<-reqFinished
   387  			// req1 is finished
   388  			rt.CloseIdleConnections()
   389  			Expect(len(rt.clients)).To(Equal(1))
   390  			cancel2()
   391  			<-reqFinished
   392  			// all requests are finished
   393  			rt.CloseIdleConnections()
   394  			Expect(len(rt.clients)).To(Equal(0))
   395  		})
   396  	})
   397  })