github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/http3/roundtrip_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"io"
     9  	"net/http"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/TugasAkhir-QUIC/quic-go"
    14  	"github.com/TugasAkhir-QUIC/quic-go/internal/qerr"
    15  
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  	"go.uber.org/mock/gomock"
    19  )
    20  
    21  type mockBody struct {
    22  	reader   bytes.Reader
    23  	readErr  error
    24  	closeErr error
    25  	closed   bool
    26  }
    27  
    28  // make sure the mockBody can be used as a http.Request.Body
    29  var _ io.ReadCloser = &mockBody{}
    30  
    31  func (m *mockBody) Read(p []byte) (int, error) {
    32  	if m.readErr != nil {
    33  		return 0, m.readErr
    34  	}
    35  	return m.reader.Read(p)
    36  }
    37  
    38  func (m *mockBody) SetData(data []byte) {
    39  	m.reader = *bytes.NewReader(data)
    40  }
    41  
    42  func (m *mockBody) Close() error {
    43  	m.closed = true
    44  	return m.closeErr
    45  }
    46  
    47  var _ = Describe("RoundTripper", func() {
    48  	var (
    49  		rt  *RoundTripper
    50  		req *http.Request
    51  	)
    52  
    53  	BeforeEach(func() {
    54  		rt = &RoundTripper{}
    55  		var err error
    56  		req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
    57  		Expect(err).ToNot(HaveOccurred())
    58  	})
    59  
    60  	Context("dialing hosts", func() {
    61  		It("creates new clients", func() {
    62  			testErr := errors.New("test err")
    63  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
    64  			Expect(err).ToNot(HaveOccurred())
    65  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
    66  				cl := NewMockRoundTripCloser(mockCtrl)
    67  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
    68  				return cl, nil
    69  			}
    70  			_, err = rt.RoundTrip(req)
    71  			Expect(err).To(MatchError(testErr))
    72  		})
    73  
    74  		It("creates new clients with additional settings", func() {
    75  			testErr := errors.New("test err")
    76  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
    77  			Expect(err).ToNot(HaveOccurred())
    78  			rt.AdditionalSettings = map[uint64]uint64{1337: 42}
    79  			rt.newClient = func(_ string, _ *tls.Config, opts *roundTripperOpts, conf *quic.Config, _ dialFunc) (roundTripCloser, error) {
    80  				cl := NewMockRoundTripCloser(mockCtrl)
    81  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
    82  				Expect(opts.AdditionalSettings).To(HaveKeyWithValue(uint64(1337), uint64(42)))
    83  				return cl, nil
    84  			}
    85  			_, err = rt.RoundTrip(req)
    86  			Expect(err).To(MatchError(testErr))
    87  		})
    88  
    89  		It("uses the quic.Config, if provided", func() {
    90  			config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
    91  			var receivedConfig *quic.Config
    92  			rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
    93  				receivedConfig = config
    94  				return nil, errors.New("handshake error")
    95  			}
    96  			rt.QuicConfig = config
    97  			_, err := rt.RoundTrip(req)
    98  			Expect(err).To(MatchError("handshake error"))
    99  			Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
   100  		})
   101  
   102  		It("uses the custom dialer, if provided", func() {
   103  			var dialed bool
   104  			dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   105  				dialed = true
   106  				return nil, errors.New("handshake error")
   107  			}
   108  			rt.Dial = dialer
   109  			_, err := rt.RoundTrip(req)
   110  			Expect(err).To(MatchError("handshake error"))
   111  			Expect(dialed).To(BeTrue())
   112  		})
   113  	})
   114  
   115  	Context("reusing clients", func() {
   116  		var req1, req2 *http.Request
   117  
   118  		BeforeEach(func() {
   119  			var err error
   120  			req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
   121  			Expect(err).ToNot(HaveOccurred())
   122  			req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
   123  			Expect(err).ToNot(HaveOccurred())
   124  			Expect(req1.URL).ToNot(Equal(req2.URL))
   125  		})
   126  
   127  		It("reuses existing clients", func() {
   128  			var count int
   129  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   130  				count++
   131  				cl := NewMockRoundTripCloser(mockCtrl)
   132  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   133  					return &http.Response{Request: req}, nil
   134  				}).Times(2)
   135  				cl.EXPECT().HandshakeComplete().Return(true)
   136  				return cl, nil
   137  			}
   138  			rsp1, err := rt.RoundTrip(req1)
   139  			Expect(err).ToNot(HaveOccurred())
   140  			Expect(rsp1.Request.URL).To(Equal(req1.URL))
   141  			rsp2, err := rt.RoundTrip(req2)
   142  			Expect(err).ToNot(HaveOccurred())
   143  			Expect(rsp2.Request.URL).To(Equal(req2.URL))
   144  			Expect(count).To(Equal(1))
   145  		})
   146  
   147  		It("immediately removes a clients when a request errored", func() {
   148  			testErr := errors.New("test err")
   149  
   150  			var count int
   151  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   152  				count++
   153  				cl := NewMockRoundTripCloser(mockCtrl)
   154  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
   155  				return cl, nil
   156  			}
   157  			_, err := rt.RoundTrip(req1)
   158  			Expect(err).To(MatchError(testErr))
   159  			_, err = rt.RoundTrip(req2)
   160  			Expect(err).To(MatchError(testErr))
   161  			Expect(count).To(Equal(2))
   162  		})
   163  
   164  		It("recreates a client when a request times out", func() {
   165  			var reqCount int
   166  			cl1 := NewMockRoundTripCloser(mockCtrl)
   167  			cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   168  				reqCount++
   169  				if reqCount == 1 { // the first request is successful...
   170  					Expect(req.URL).To(Equal(req1.URL))
   171  					return &http.Response{Request: req}, nil
   172  				}
   173  				// ... after that, the connection timed out in the background
   174  				Expect(req.URL).To(Equal(req2.URL))
   175  				return nil, &qerr.IdleTimeoutError{}
   176  			}).Times(2)
   177  			cl1.EXPECT().HandshakeComplete().Return(true)
   178  			cl2 := NewMockRoundTripCloser(mockCtrl)
   179  			cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   180  				return &http.Response{Request: req}, nil
   181  			})
   182  
   183  			var count int
   184  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   185  				count++
   186  				if count == 1 {
   187  					return cl1, nil
   188  				}
   189  				return cl2, nil
   190  			}
   191  			rsp1, err := rt.RoundTrip(req1)
   192  			Expect(err).ToNot(HaveOccurred())
   193  			Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
   194  			rsp2, err := rt.RoundTrip(req2)
   195  			Expect(err).ToNot(HaveOccurred())
   196  			Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
   197  		})
   198  
   199  		It("only issues a request once, even if a timeout error occurs", func() {
   200  			var count int
   201  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   202  				count++
   203  				cl := NewMockRoundTripCloser(mockCtrl)
   204  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
   205  				return cl, nil
   206  			}
   207  			_, err := rt.RoundTrip(req1)
   208  			Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   209  			Expect(count).To(Equal(1))
   210  		})
   211  
   212  		It("handles a burst of requests", func() {
   213  			wait := make(chan struct{})
   214  			reqs := make(chan struct{}, 2)
   215  			var count int
   216  			rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
   217  				count++
   218  				cl := NewMockRoundTripCloser(mockCtrl)
   219  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
   220  					reqs <- struct{}{}
   221  					<-wait
   222  					return nil, &qerr.IdleTimeoutError{}
   223  				}).Times(2)
   224  				cl.EXPECT().HandshakeComplete()
   225  				return cl, nil
   226  			}
   227  			done := make(chan struct{}, 2)
   228  			go func() {
   229  				defer GinkgoRecover()
   230  				defer func() { done <- struct{}{} }()
   231  				_, err := rt.RoundTrip(req1)
   232  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   233  			}()
   234  			go func() {
   235  				defer GinkgoRecover()
   236  				defer func() { done <- struct{}{} }()
   237  				_, err := rt.RoundTrip(req2)
   238  				Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
   239  			}()
   240  			// wait for both requests to be issued
   241  			Eventually(reqs).Should(Receive())
   242  			Eventually(reqs).Should(Receive())
   243  			close(wait) // now return the requests
   244  			Eventually(done).Should(Receive())
   245  			Eventually(done).Should(Receive())
   246  			Expect(count).To(Equal(1))
   247  		})
   248  
   249  		It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
   250  			req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
   251  			Expect(err).ToNot(HaveOccurred())
   252  			_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
   253  			Expect(err).To(MatchError(ErrNoCachedConn))
   254  		})
   255  	})
   256  
   257  	Context("validating request", func() {
   258  		It("rejects plain HTTP requests", func() {
   259  			req, err := http.NewRequest("GET", "http://www.example.org/", nil)
   260  			req.Body = &mockBody{}
   261  			Expect(err).ToNot(HaveOccurred())
   262  			_, err = rt.RoundTrip(req)
   263  			Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
   264  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   265  		})
   266  
   267  		It("rejects requests without a URL", func() {
   268  			req.URL = nil
   269  			req.Body = &mockBody{}
   270  			_, err := rt.RoundTrip(req)
   271  			Expect(err).To(MatchError("http3: nil Request.URL"))
   272  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   273  		})
   274  
   275  		It("rejects request without a URL Host", func() {
   276  			req.URL.Host = ""
   277  			req.Body = &mockBody{}
   278  			_, err := rt.RoundTrip(req)
   279  			Expect(err).To(MatchError("http3: no Host in request URL"))
   280  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   281  		})
   282  
   283  		It("doesn't try to close the body if the request doesn't have one", func() {
   284  			req.URL = nil
   285  			Expect(req.Body).To(BeNil())
   286  			_, err := rt.RoundTrip(req)
   287  			Expect(err).To(MatchError("http3: nil Request.URL"))
   288  		})
   289  
   290  		It("rejects requests without a header", func() {
   291  			req.Header = nil
   292  			req.Body = &mockBody{}
   293  			_, err := rt.RoundTrip(req)
   294  			Expect(err).To(MatchError("http3: nil Request.Header"))
   295  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   296  		})
   297  
   298  		It("rejects requests with invalid header name fields", func() {
   299  			req.Header.Add("foobär", "value")
   300  			_, err := rt.RoundTrip(req)
   301  			Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
   302  		})
   303  
   304  		It("rejects requests with invalid header name values", func() {
   305  			req.Header.Add("foo", string([]byte{0x7}))
   306  			_, err := rt.RoundTrip(req)
   307  			Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
   308  		})
   309  
   310  		It("rejects requests with an invalid request method", func() {
   311  			req.Method = "foobär"
   312  			req.Body = &mockBody{}
   313  			_, err := rt.RoundTrip(req)
   314  			Expect(err).To(MatchError("http3: invalid method \"foobär\""))
   315  			Expect(req.Body.(*mockBody).closed).To(BeTrue())
   316  		})
   317  	})
   318  
   319  	Context("closing", func() {
   320  		It("closes", func() {
   321  			rt.clients = make(map[string]*roundTripCloserWithCount)
   322  			cl := NewMockRoundTripCloser(mockCtrl)
   323  			cl.EXPECT().Close()
   324  			rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}}
   325  			err := rt.Close()
   326  			Expect(err).ToNot(HaveOccurred())
   327  			Expect(len(rt.clients)).To(BeZero())
   328  		})
   329  
   330  		It("closes a RoundTripper that has never been used", func() {
   331  			Expect(len(rt.clients)).To(BeZero())
   332  			err := rt.Close()
   333  			Expect(err).ToNot(HaveOccurred())
   334  			Expect(len(rt.clients)).To(BeZero())
   335  		})
   336  
   337  		It("closes idle connections", func() {
   338  			Expect(len(rt.clients)).To(Equal(0))
   339  			req1, err := http.NewRequest("GET", "https://site1.com", nil)
   340  			Expect(err).ToNot(HaveOccurred())
   341  			req2, err := http.NewRequest("GET", "https://site2.com", nil)
   342  			Expect(err).ToNot(HaveOccurred())
   343  			Expect(req1.Host).ToNot(Equal(req2.Host))
   344  			ctx1, cancel1 := context.WithCancel(context.Background())
   345  			ctx2, cancel2 := context.WithCancel(context.Background())
   346  			req1 = req1.WithContext(ctx1)
   347  			req2 = req2.WithContext(ctx2)
   348  			roundTripCalled := make(chan struct{})
   349  			reqFinished := make(chan struct{})
   350  			rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
   351  				cl := NewMockRoundTripCloser(mockCtrl)
   352  				cl.EXPECT().Close()
   353  				cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) {
   354  					roundTripCalled <- struct{}{}
   355  					<-r.Context().Done()
   356  					return nil, nil
   357  				})
   358  				return cl, nil
   359  			}
   360  			go func() {
   361  				rt.RoundTrip(req1)
   362  				reqFinished <- struct{}{}
   363  			}()
   364  			go func() {
   365  				rt.RoundTrip(req2)
   366  				reqFinished <- struct{}{}
   367  			}()
   368  			<-roundTripCalled
   369  			<-roundTripCalled
   370  			// Both two requests are started.
   371  			Expect(len(rt.clients)).To(Equal(2))
   372  			cancel1()
   373  			<-reqFinished
   374  			// req1 is finished
   375  			rt.CloseIdleConnections()
   376  			Expect(len(rt.clients)).To(Equal(1))
   377  			cancel2()
   378  			<-reqFinished
   379  			// all requests are finished
   380  			rt.CloseIdleConnections()
   381  			Expect(len(rt.clients)).To(Equal(0))
   382  		})
   383  	})
   384  })