github.com/MerlinKodo/quic-go@v0.39.2/http3/roundtrip_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"io"
     9  	"net/http"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/MerlinKodo/quic-go"
    14  	"github.com/MerlinKodo/quic-go/internal/qerr"
    15  
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  	"go.uber.org/mock/gomock"
    19  )
    20  
    21  type mockBody struct {
    22  	reader   bytes.Reader
    23  	readErr  error
    24  	closeErr error
    25  	closed   bool
    26  }
    27  
    28  // make sure the mockBody can be used as a http.Request.Body
    29  var _ io.ReadCloser = &mockBody{}
    30  
    31  func (m *mockBody) Read(p []byte) (int, error) {
    32  	if m.readErr != nil {
    33  		return 0, m.readErr
    34  	}
    35  	return m.reader.Read(p)
    36  }
    37  
    38  func (m *mockBody) SetData(data []byte) {
    39  	m.reader = *bytes.NewReader(data)
    40  }
    41  
    42  func (m *mockBody) Close() error {
    43  	m.closed = true
    44  	return m.closeErr
    45  }
    46  
    47  var _ = Describe("RoundTripper", func() {
    48  	var (
    49  		rt  *RoundTripper
    50  		req *http.Request
    51  	)
    52  
    53  	BeforeEach(func() {
    54  		rt = &RoundTripper{}
    55  		var err error
    56  		req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
    57  		Expect(err).ToNot(HaveOccurred())
    58  	})
    59  
    60  	Context("dialing hosts", func() {
    61  		It("creates new clients", func() {
    62  			testErr := errors.New("test err")
    63  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
    64  			Expect(err).ToNot(HaveOccurred())
    65  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
    66  				cl := NewMockRoundTripCloser(mockCtrl)
    67  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
    68  				return cl, nil
    69  			}
    70  			_, err = rt.RoundTrip(req)
    71  			Expect(err).To(MatchError(testErr))
    72  		})
    73  
    74  		It("uses the quic.Config, if provided", func() {
    75  			config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
    76  			var receivedConfig *quic.Config
    77  			rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
    78  				receivedConfig = config
    79  				return nil, errors.New("handshake error")
    80  			}
    81  			rt.QuicConfig = config
    82  			_, err := rt.RoundTrip(req)
    83  			Expect(err).To(MatchError("handshake error"))
    84  			Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
    85  		})
    86  
    87  		It("uses the custom dialer, if provided", func() {
    88  			var dialed bool
    89  			dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
    90  				dialed = true
    91  				return nil, errors.New("handshake error")
    92  			}
    93  			rt.Dial = dialer
    94  			_, err := rt.RoundTrip(req)
    95  			Expect(err).To(MatchError("handshake error"))
    96  			Expect(dialed).To(BeTrue())
    97  		})
    98  	})
    99  
   100  	Context("reusing clients", func() {
   101  		var req1, req2 *http.Request
   102  
   103  		BeforeEach(func() {
   104  			var err error
   105  			req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
   106  			Expect(err).ToNot(HaveOccurred())
   107  			req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
   108  			Expect(err).ToNot(HaveOccurred())
   109  			Expect(req1.URL).ToNot(Equal(req2.URL))
   110  		})
   111  
   112  		It("reuses existing clients", func() {
   113  			var count int
   114  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   115  				count++
   116  				cl := NewMockRoundTripCloser(mockCtrl)
   117  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   118  					return &http.Response{Request: req}, nil
   119  				}).Times(2)
   120  				cl.EXPECT().HandshakeComplete().Return(true)
   121  				return cl, nil
   122  			}
   123  			rsp1, err := rt.RoundTrip(req1)
   124  			Expect(err).ToNot(HaveOccurred())
   125  			Expect(rsp1.Request.URL).To(Equal(req1.URL))
   126  			rsp2, err := rt.RoundTrip(req2)
   127  			Expect(err).ToNot(HaveOccurred())
   128  			Expect(rsp2.Request.URL).To(Equal(req2.URL))
   129  			Expect(count).To(Equal(1))
   130  		})
   131  
   132  		It("immediately removes a clients when a request errored", func() {
   133  			testErr := errors.New("test err")
   134  
   135  			var count int
   136  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   137  				count++
   138  				cl := NewMockRoundTripCloser(mockCtrl)
   139  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
   140  				return cl, nil
   141  			}
   142  			_, err := rt.RoundTrip(req1)
   143  			Expect(err).To(MatchError(testErr))
   144  			_, err = rt.RoundTrip(req2)
   145  			Expect(err).To(MatchError(testErr))
   146  			Expect(count).To(Equal(2))
   147  		})
   148  
   149  		It("recreates a client when a request times out", func() {
   150  			var reqCount int
   151  			cl1 := NewMockRoundTripCloser(mockCtrl)
   152  			cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   153  				reqCount++
   154  				if reqCount == 1 { // the first request is successful...
   155  					Expect(req.URL).To(Equal(req1.URL))
   156  					return &http.Response{Request: req}, nil
   157  				}
   158  				// ... after that, the connection timed out in the background
   159  				Expect(req.URL).To(Equal(req2.URL))
   160  				return nil, &qerr.IdleTimeoutError{}
   161  			}).Times(2)
   162  			cl1.EXPECT().HandshakeComplete().Return(true)
   163  			cl2 := NewMockRoundTripCloser(mockCtrl)
   164  			cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   165  				return &http.Response{Request: req}, nil
   166  			})
   167  
   168  			var count int
   169  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   170  				count++
   171  				if count == 1 {
   172  					return cl1, nil
   173  				}
   174  				return cl2, nil
   175  			}
   176  			rsp1, err := rt.RoundTrip(req1)
   177  			Expect(err).ToNot(HaveOccurred())
   178  			Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
   179  			rsp2, err := rt.RoundTrip(req2)
   180  			Expect(err).ToNot(HaveOccurred())
   181  			Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
   182  		})
   183  
   184  		It("only issues a request once, even if a timeout error occurs", func() {
   185  			var count int
   186  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   187  				count++
   188  				cl := NewMockRoundTripCloser(mockCtrl)
   189  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
   190  				return cl, nil
   191  			}
   192  			_, err := rt.RoundTrip(req1)
   193  			Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   194  			Expect(count).To(Equal(1))
   195  		})
   196  
   197  		It("handles a burst of requests", func() {
   198  			wait := make(chan struct{})
   199  			reqs := make(chan struct{}, 2)
   200  			var count int
   201  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   202  				count++
   203  				cl := NewMockRoundTripCloser(mockCtrl)
   204  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   205  					reqs <- struct{}{}
   206  					<-wait
   207  					return nil, &qerr.IdleTimeoutError{}
   208  				}).Times(2)
   209  				cl.EXPECT().HandshakeComplete()
   210  				return cl, nil
   211  			}
   212  			done := make(chan struct{}, 2)
   213  			go func() {
   214  				defer GinkgoRecover()
   215  				defer func() { done <- struct{}{} }()
   216  				_, err := rt.RoundTrip(req1)
   217  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   218  			}()
   219  			go func() {
   220  				defer GinkgoRecover()
   221  				defer func() { done <- struct{}{} }()
   222  				_, err := rt.RoundTrip(req2)
   223  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   224  			}()
   225  			// wait for both requests to be issued
   226  			Eventually(reqs).Should(Receive())
   227  			Eventually(reqs).Should(Receive())
   228  			close(wait) // now return the requests
   229  			Eventually(done).Should(Receive())
   230  			Eventually(done).Should(Receive())
   231  			Expect(count).To(Equal(1))
   232  		})
   233  
   234  		It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
   235  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
   236  			Expect(err).ToNot(HaveOccurred())
   237  			_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
   238  			Expect(err).To(MatchError(ErrNoCachedConn))
   239  		})
   240  	})
   241  
   242  	Context("validating request", func() {
   243  		It("rejects plain HTTP requests", func() {
   244  			req, err := http.NewRequest("GET", "http://www.example.org/", nil)
   245  			req.Body = &mockBody{}
   246  			Expect(err).ToNot(HaveOccurred())
   247  			_, err = rt.RoundTrip(req)
   248  			Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
   249  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   250  		})
   251  
   252  		It("rejects requests without a URL", func() {
   253  			req.URL = nil
   254  			req.Body = &mockBody{}
   255  			_, err := rt.RoundTrip(req)
   256  			Expect(err).To(MatchError("http3: nil Request.URL"))
   257  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   258  		})
   259  
   260  		It("rejects request without a URL Host", func() {
   261  			req.URL.Host = ""
   262  			req.Body = &mockBody{}
   263  			_, err := rt.RoundTrip(req)
   264  			Expect(err).To(MatchError("http3: no Host in request URL"))
   265  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   266  		})
   267  
   268  		It("doesn't try to close the body if the request doesn't have one", func() {
   269  			req.URL = nil
   270  			Expect(req.Body).To(BeNil())
   271  			_, err := rt.RoundTrip(req)
   272  			Expect(err).To(MatchError("http3: nil Request.URL"))
   273  		})
   274  
   275  		It("rejects requests without a header", func() {
   276  			req.Header = nil
   277  			req.Body = &mockBody{}
   278  			_, err := rt.RoundTrip(req)
   279  			Expect(err).To(MatchError("http3: nil Request.Header"))
   280  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   281  		})
   282  
   283  		It("rejects requests with invalid header name fields", func() {
   284  			req.Header.Add("foobär", "value")
   285  			_, err := rt.RoundTrip(req)
   286  			Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
   287  		})
   288  
   289  		It("rejects requests with invalid header name values", func() {
   290  			req.Header.Add("foo", string([]byte{0x7}))
   291  			_, err := rt.RoundTrip(req)
   292  			Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
   293  		})
   294  
   295  		It("rejects requests with an invalid request method", func() {
   296  			req.Method = "foobär"
   297  			req.Body = &mockBody{}
   298  			_, err := rt.RoundTrip(req)
   299  			Expect(err).To(MatchError("http3: invalid method \"foobär\""))
   300  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   301  		})
   302  	})
   303  
   304  	Context("closing", func() {
   305  		It("closes", func() {
   306  			rt.clients = make(map[string]*roundTripCloserWithCount)
   307  			cl := NewMockRoundTripCloser(mockCtrl)
   308  			cl.EXPECT().Close()
   309  			rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}}
   310  			err := rt.Close()
   311  			Expect(err).ToNot(HaveOccurred())
   312  			Expect(len(rt.clients)).To(BeZero())
   313  		})
   314  
   315  		It("closes a RoundTripper that has never been used", func() {
   316  			Expect(len(rt.clients)).To(BeZero())
   317  			err := rt.Close()
   318  			Expect(err).ToNot(HaveOccurred())
   319  			Expect(len(rt.clients)).To(BeZero())
   320  		})
   321  
   322  		It("closes idle connections", func() {
   323  			Expect(len(rt.clients)).To(Equal(0))
   324  			req1, err := http.NewRequest("GET", "https://site1.com", nil)
   325  			Expect(err).ToNot(HaveOccurred())
   326  			req2, err := http.NewRequest("GET", "https://site2.com", nil)
   327  			Expect(err).ToNot(HaveOccurred())
   328  			Expect(req1.Host).ToNot(Equal(req2.Host))
   329  			ctx1, cancel1 := context.WithCancel(context.Background())
   330  			ctx2, cancel2 := context.WithCancel(context.Background())
   331  			req1 = req1.WithContext(ctx1)
   332  			req2 = req2.WithContext(ctx2)
   333  			roundTripCalled := make(chan struct{})
   334  			reqFinished := make(chan struct{})
   335  			rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
   336  				cl := NewMockRoundTripCloser(mockCtrl)
   337  				cl.EXPECT().Close()
   338  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) {
   339  					roundTripCalled <- struct{}{}
   340  					<-r.Context().Done()
   341  					return nil, nil
   342  				})
   343  				return cl, nil
   344  			}
   345  			go func() {
   346  				rt.RoundTrip(req1)
   347  				reqFinished <- struct{}{}
   348  			}()
   349  			go func() {
   350  				rt.RoundTrip(req2)
   351  				reqFinished <- struct{}{}
   352  			}()
   353  			<-roundTripCalled
   354  			<-roundTripCalled
   355  			// Both two requests are started.
   356  			Expect(len(rt.clients)).To(Equal(2))
   357  			cancel1()
   358  			<-reqFinished
   359  			// req1 is finished
   360  			rt.CloseIdleConnections()
   361  			Expect(len(rt.clients)).To(Equal(1))
   362  			cancel2()
   363  			<-reqFinished
   364  			// all requests are finished
   365  			rt.CloseIdleConnections()
   366  			Expect(len(rt.clients)).To(Equal(0))
   367  		})
   368  	})
   369  })