github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/integrationtests/self/http_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"compress/gzip"
     7  	"context"
     8  	"crypto/tls"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptrace"
    16  	"net/textproto"
    17  	"net/url"
    18  	"os"
    19  	"strconv"
    20  	"sync/atomic"
    21  	"time"
    22  
    23  	"golang.org/x/sync/errgroup"
    24  
    25  	"github.com/metacubex/quic-go"
    26  	"github.com/metacubex/quic-go/http3"
    27  	quicproxy "github.com/metacubex/quic-go/integrationtests/tools/proxy"
    28  
    29  	. "github.com/onsi/ginkgo/v2"
    30  	. "github.com/onsi/gomega"
    31  	"github.com/onsi/gomega/gbytes"
    32  )
    33  
    34  type neverEnding byte
    35  
    36  func (b neverEnding) Read(p []byte) (n int, err error) {
    37  	for i := range p {
    38  		p[i] = byte(b)
    39  	}
    40  	return len(p), nil
    41  }
    42  
    43  const deadlineDelay = 250 * time.Millisecond
    44  
    45  var _ = Describe("HTTP tests", func() {
    46  	var (
    47  		mux            *http.ServeMux
    48  		client         *http.Client
    49  		rt             *http3.RoundTripper
    50  		server         *http3.Server
    51  		stoppedServing chan struct{}
    52  		port           int
    53  	)
    54  
    55  	BeforeEach(func() {
    56  		mux = http.NewServeMux()
    57  		mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
    58  			defer GinkgoRecover()
    59  			io.WriteString(w, "Hello, World!\n") // don't check the error here. Stream may be reset.
    60  		})
    61  
    62  		mux.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
    63  			defer GinkgoRecover()
    64  			sl := r.URL.Query().Get("len")
    65  			if sl != "" {
    66  				var err error
    67  				l, err := strconv.Atoi(sl)
    68  				Expect(err).NotTo(HaveOccurred())
    69  				w.Write(GeneratePRData(l)) // don't check the error here. Stream may be reset.
    70  			} else {
    71  				w.Write(PRData) // don't check the error here. Stream may be reset.
    72  			}
    73  		})
    74  
    75  		mux.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
    76  			defer GinkgoRecover()
    77  			w.Write(PRDataLong) // don't check the error here. Stream may be reset.
    78  		})
    79  
    80  		mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
    81  			defer GinkgoRecover()
    82  			body, err := io.ReadAll(r.Body)
    83  			Expect(err).NotTo(HaveOccurred())
    84  			w.Write(body) // don't check the error here. Stream may be reset.
    85  		})
    86  
    87  		mux.HandleFunc("/remoteAddr", func(w http.ResponseWriter, r *http.Request) {
    88  			defer GinkgoRecover()
    89  			w.Header().Set("X-RemoteAddr", r.RemoteAddr)
    90  			w.WriteHeader(http.StatusOK)
    91  		})
    92  
    93  		server = &http3.Server{
    94  			Handler:    mux,
    95  			TLSConfig:  getTLSConfig(),
    96  			QUICConfig: getQuicConfig(&quic.Config{Allow0RTT: true, EnableDatagrams: true}),
    97  		}
    98  
    99  		addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
   100  		Expect(err).NotTo(HaveOccurred())
   101  		conn, err := net.ListenUDP("udp", addr)
   102  		Expect(err).NotTo(HaveOccurred())
   103  		port = conn.LocalAddr().(*net.UDPAddr).Port
   104  
   105  		stoppedServing = make(chan struct{})
   106  
   107  		go func() {
   108  			defer GinkgoRecover()
   109  			server.Serve(conn)
   110  			close(stoppedServing)
   111  		}()
   112  	})
   113  
   114  	AfterEach(func() {
   115  		Expect(rt.Close()).NotTo(HaveOccurred())
   116  		Expect(server.Close()).NotTo(HaveOccurred())
   117  		Eventually(stoppedServing).Should(BeClosed())
   118  	})
   119  
   120  	BeforeEach(func() {
   121  		rt = &http3.RoundTripper{
   122  			TLSClientConfig: getTLSClientConfigWithoutServerName(),
   123  			QUICConfig: getQuicConfig(&quic.Config{
   124  				MaxIdleTimeout: 10 * time.Second,
   125  			}),
   126  			DisableCompression: true,
   127  		}
   128  		client = &http.Client{Transport: rt}
   129  	})
   130  
   131  	It("downloads a hello", func() {
   132  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port))
   133  		Expect(err).ToNot(HaveOccurred())
   134  		Expect(resp.StatusCode).To(Equal(200))
   135  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   136  		Expect(err).ToNot(HaveOccurred())
   137  		Expect(string(body)).To(Equal("Hello, World!\n"))
   138  	})
   139  
   140  	It("sets content-length for small response", func() {
   141  		mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) {
   142  			defer GinkgoRecover()
   143  			w.Write([]byte("foo"))
   144  			w.Write([]byte("bar"))
   145  		})
   146  
   147  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port))
   148  		Expect(err).ToNot(HaveOccurred())
   149  		Expect(resp.StatusCode).To(Equal(200))
   150  		Expect(resp.Header.Get("Content-Length")).To(Equal("6"))
   151  	})
   152  
   153  	It("detects stream errors when server panics when writing response", func() {
   154  		respChan := make(chan struct{})
   155  		mux.HandleFunc("/writing_and_panicking", func(w http.ResponseWriter, r *http.Request) {
   156  			// no recover here as it will interfere with the handler
   157  			w.Write([]byte("foobar"))
   158  			w.(http.Flusher).Flush()
   159  			// wait for the client to receive the response
   160  			<-respChan
   161  			panic(http.ErrAbortHandler)
   162  		})
   163  
   164  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/writing_and_panicking", port))
   165  		close(respChan)
   166  		Expect(err).ToNot(HaveOccurred())
   167  		body, err := io.ReadAll(resp.Body)
   168  		Expect(err).To(HaveOccurred())
   169  		// the body will be a prefix of what's written
   170  		Expect(bytes.HasPrefix([]byte("foobar"), body)).To(BeTrue())
   171  	})
   172  
   173  	It("requests to different servers with the same udpconn", func() {
   174  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port))
   175  		Expect(err).ToNot(HaveOccurred())
   176  		Expect(resp.StatusCode).To(Equal(200))
   177  		addr1 := resp.Header.Get("X-RemoteAddr")
   178  		Expect(addr1).ToNot(Equal(""))
   179  		resp, err = client.Get(fmt.Sprintf("https://127.0.0.1:%d/remoteAddr", port))
   180  		Expect(err).ToNot(HaveOccurred())
   181  		Expect(resp.StatusCode).To(Equal(200))
   182  		addr2 := resp.Header.Get("X-RemoteAddr")
   183  		Expect(addr2).ToNot(Equal(""))
   184  		Expect(addr1).To(Equal(addr2))
   185  	})
   186  
   187  	It("downloads concurrently", func() {
   188  		group, ctx := errgroup.WithContext(context.Background())
   189  		for i := 0; i < 2; i++ {
   190  			group.Go(func() error {
   191  				defer GinkgoRecover()
   192  				req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil)
   193  				Expect(err).ToNot(HaveOccurred())
   194  				resp, err := client.Do(req)
   195  				Expect(err).ToNot(HaveOccurred())
   196  				Expect(resp.StatusCode).To(Equal(200))
   197  				body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   198  				Expect(err).ToNot(HaveOccurred())
   199  				Expect(string(body)).To(Equal("Hello, World!\n"))
   200  
   201  				return nil
   202  			})
   203  		}
   204  
   205  		err := group.Wait()
   206  		Expect(err).ToNot(HaveOccurred())
   207  	})
   208  
   209  	It("sets and gets request headers", func() {
   210  		handlerCalled := make(chan struct{})
   211  		mux.HandleFunc("/headers/request", func(w http.ResponseWriter, r *http.Request) {
   212  			defer GinkgoRecover()
   213  			Expect(r.Header.Get("foo")).To(Equal("bar"))
   214  			Expect(r.Header.Get("lorem")).To(Equal("ipsum"))
   215  			close(handlerCalled)
   216  		})
   217  
   218  		req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/headers/request", port), nil)
   219  		Expect(err).ToNot(HaveOccurred())
   220  		req.Header.Set("foo", "bar")
   221  		req.Header.Set("lorem", "ipsum")
   222  		resp, err := client.Do(req)
   223  		Expect(err).ToNot(HaveOccurred())
   224  		Expect(resp.StatusCode).To(Equal(200))
   225  		Eventually(handlerCalled).Should(BeClosed())
   226  	})
   227  
   228  	It("sets and gets response headers", func() {
   229  		mux.HandleFunc("/headers/response", func(w http.ResponseWriter, r *http.Request) {
   230  			defer GinkgoRecover()
   231  			w.Header().Set("foo", "bar")
   232  			w.Header().Set("lorem", "ipsum")
   233  		})
   234  
   235  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/headers/response", port))
   236  		Expect(err).ToNot(HaveOccurred())
   237  		Expect(resp.StatusCode).To(Equal(200))
   238  		Expect(resp.Header.Get("foo")).To(Equal("bar"))
   239  		Expect(resp.Header.Get("lorem")).To(Equal("ipsum"))
   240  	})
   241  
   242  	It("downloads a small file", func() {
   243  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdata", port))
   244  		Expect(err).ToNot(HaveOccurred())
   245  		Expect(resp.StatusCode).To(Equal(200))
   246  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
   247  		Expect(err).ToNot(HaveOccurred())
   248  		Expect(body).To(Equal(PRData))
   249  	})
   250  
   251  	It("downloads a large file", func() {
   252  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdatalong", port))
   253  		Expect(err).ToNot(HaveOccurred())
   254  		Expect(resp.StatusCode).To(Equal(200))
   255  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second))
   256  		Expect(err).ToNot(HaveOccurred())
   257  		Expect(body).To(Equal(PRDataLong))
   258  	})
   259  
   260  	It("downloads many hellos", func() {
   261  		const num = 150
   262  
   263  		for i := 0; i < num; i++ {
   264  			resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port))
   265  			Expect(err).ToNot(HaveOccurred())
   266  			Expect(resp.StatusCode).To(Equal(200))
   267  			body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   268  			Expect(err).ToNot(HaveOccurred())
   269  			Expect(string(body)).To(Equal("Hello, World!\n"))
   270  		}
   271  	})
   272  
   273  	It("downloads many files, if the response is not read", func() {
   274  		const num = 150
   275  
   276  		for i := 0; i < num; i++ {
   277  			resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdata", port))
   278  			Expect(err).ToNot(HaveOccurred())
   279  			Expect(resp.StatusCode).To(Equal(200))
   280  			Expect(resp.Body.Close()).To(Succeed())
   281  		}
   282  	})
   283  
   284  	It("posts a small message", func() {
   285  		resp, err := client.Post(
   286  			fmt.Sprintf("https://localhost:%d/echo", port),
   287  			"text/plain",
   288  			bytes.NewReader([]byte("Hello, world!")),
   289  		)
   290  		Expect(err).ToNot(HaveOccurred())
   291  		Expect(resp.StatusCode).To(Equal(200))
   292  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
   293  		Expect(err).ToNot(HaveOccurred())
   294  		Expect(body).To(Equal([]byte("Hello, world!")))
   295  	})
   296  
   297  	It("uploads a file", func() {
   298  		resp, err := client.Post(
   299  			fmt.Sprintf("https://localhost:%d/echo", port),
   300  			"text/plain",
   301  			bytes.NewReader(PRData),
   302  		)
   303  		Expect(err).ToNot(HaveOccurred())
   304  		Expect(resp.StatusCode).To(Equal(200))
   305  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
   306  		Expect(err).ToNot(HaveOccurred())
   307  		Expect(body).To(Equal(PRData))
   308  	})
   309  
   310  	It("uses gzip compression", func() {
   311  		mux.HandleFunc("/gzipped/hello", func(w http.ResponseWriter, r *http.Request) {
   312  			defer GinkgoRecover()
   313  			Expect(r.Header.Get("Accept-Encoding")).To(Equal("gzip"))
   314  			w.Header().Set("Content-Encoding", "gzip")
   315  			w.Header().Set("foo", "bar")
   316  
   317  			gw := gzip.NewWriter(w)
   318  			defer gw.Close()
   319  			gw.Write([]byte("Hello, World!\n"))
   320  		})
   321  
   322  		client.Transport.(*http3.RoundTripper).DisableCompression = false
   323  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/gzipped/hello", port))
   324  		Expect(err).ToNot(HaveOccurred())
   325  		Expect(resp.StatusCode).To(Equal(200))
   326  		Expect(resp.Uncompressed).To(BeTrue())
   327  
   328  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   329  		Expect(err).ToNot(HaveOccurred())
   330  		Expect(string(body)).To(Equal("Hello, World!\n"))
   331  	})
   332  
   333  	It("handles context cancellations", func() {
   334  		mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {
   335  			<-r.Context().Done()
   336  		})
   337  
   338  		ctx, cancel := context.WithCancel(context.Background())
   339  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil)
   340  		Expect(err).ToNot(HaveOccurred())
   341  		time.AfterFunc(50*time.Millisecond, cancel)
   342  
   343  		_, err = client.Do(req)
   344  		Expect(err).To(HaveOccurred())
   345  		Expect(err).To(MatchError(context.Canceled))
   346  	})
   347  
   348  	It("cancels requests", func() {
   349  		handlerCalled := make(chan struct{})
   350  		mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {
   351  			defer GinkgoRecover()
   352  			defer close(handlerCalled)
   353  			for {
   354  				if _, err := w.Write([]byte("foobar")); err != nil {
   355  					Expect(r.Context().Done()).To(BeClosed())
   356  					var http3Err *http3.Error
   357  					Expect(errors.As(err, &http3Err)).To(BeTrue())
   358  					Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c)))
   359  					Expect(http3Err.Error()).To(Equal("H3_REQUEST_CANCELLED"))
   360  					return
   361  				}
   362  			}
   363  		})
   364  
   365  		req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil)
   366  		Expect(err).ToNot(HaveOccurred())
   367  		ctx, cancel := context.WithCancel(context.Background())
   368  		req = req.WithContext(ctx)
   369  		resp, err := client.Do(req)
   370  		Expect(err).ToNot(HaveOccurred())
   371  		Expect(resp.StatusCode).To(Equal(200))
   372  		cancel()
   373  		Eventually(handlerCalled).Should(BeClosed())
   374  		_, err = resp.Body.Read([]byte{0})
   375  		var http3Err *http3.Error
   376  		Expect(errors.As(err, &http3Err)).To(BeTrue())
   377  		Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c)))
   378  		Expect(http3Err.Error()).To(Equal("H3_REQUEST_CANCELLED (local)"))
   379  	})
   380  
   381  	It("allows streamed HTTP requests", func() {
   382  		done := make(chan struct{})
   383  		mux.HandleFunc("/echoline", func(w http.ResponseWriter, r *http.Request) {
   384  			defer GinkgoRecover()
   385  			defer close(done)
   386  			w.WriteHeader(200)
   387  			w.(http.Flusher).Flush()
   388  			reader := bufio.NewReader(r.Body)
   389  			for {
   390  				msg, err := reader.ReadString('\n')
   391  				if err != nil {
   392  					return
   393  				}
   394  				_, err = w.Write([]byte(msg))
   395  				Expect(err).ToNot(HaveOccurred())
   396  				w.(http.Flusher).Flush()
   397  			}
   398  		})
   399  
   400  		r, w := io.Pipe()
   401  		req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("https://localhost:%d/echoline", port), r)
   402  		Expect(err).ToNot(HaveOccurred())
   403  		rsp, err := client.Do(req)
   404  		Expect(err).ToNot(HaveOccurred())
   405  		Expect(rsp.StatusCode).To(Equal(200))
   406  
   407  		reader := bufio.NewReader(rsp.Body)
   408  		for i := 0; i < 5; i++ {
   409  			msg := fmt.Sprintf("Hello world, %d!\n", i)
   410  			fmt.Fprint(w, msg)
   411  			msgRcvd, err := reader.ReadString('\n')
   412  			Expect(err).ToNot(HaveOccurred())
   413  			Expect(msgRcvd).To(Equal(msg))
   414  		}
   415  		Expect(req.Body.Close()).To(Succeed())
   416  		Eventually(done).Should(BeClosed())
   417  	})
   418  
   419  	It("allows taking over the stream", func() {
   420  		handlerCalled := make(chan struct{})
   421  		mux.HandleFunc("/httpstreamer", func(w http.ResponseWriter, r *http.Request) {
   422  			defer GinkgoRecover()
   423  			close(handlerCalled)
   424  			w.WriteHeader(http.StatusOK)
   425  
   426  			str := w.(http3.HTTPStreamer).HTTPStream()
   427  			str.Write([]byte("foobar"))
   428  
   429  			// Do this in a Go routine, so that the handler returns early.
   430  			// This way, we can also check that the HTTP/3 doesn't close the stream.
   431  			go func() {
   432  				defer GinkgoRecover()
   433  				_, err := io.Copy(str, str)
   434  				Expect(err).ToNot(HaveOccurred())
   435  				Expect(str.Close()).To(Succeed())
   436  			}()
   437  		})
   438  
   439  		req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/httpstreamer", port), nil)
   440  		Expect(err).ToNot(HaveOccurred())
   441  		tlsConf := getTLSClientConfigWithoutServerName()
   442  		tlsConf.NextProtos = []string{http3.NextProtoH3}
   443  		conn, err := quic.DialAddr(
   444  			context.Background(),
   445  			fmt.Sprintf("localhost:%d", port),
   446  			tlsConf,
   447  			getQuicConfig(nil),
   448  		)
   449  		Expect(err).ToNot(HaveOccurred())
   450  		defer conn.CloseWithError(0, "")
   451  		rt := http3.SingleDestinationRoundTripper{Connection: conn}
   452  		str, err := rt.OpenRequestStream(context.Background())
   453  		Expect(err).ToNot(HaveOccurred())
   454  		Expect(str.SendRequestHeader(req)).To(Succeed())
   455  		// make sure the request is received (and not stuck in some buffer, for example)
   456  		Eventually(handlerCalled).Should(BeClosed())
   457  
   458  		rsp, err := str.ReadResponse()
   459  		Expect(err).ToNot(HaveOccurred())
   460  		Expect(rsp.StatusCode).To(Equal(200))
   461  
   462  		b := make([]byte, 6)
   463  		_, err = io.ReadFull(str, b)
   464  		Expect(err).ToNot(HaveOccurred())
   465  		Expect(b).To(Equal([]byte("foobar")))
   466  
   467  		data := GeneratePRData(8 * 1024)
   468  		_, err = str.Write(data)
   469  		Expect(err).ToNot(HaveOccurred())
   470  		Expect(str.Close()).To(Succeed())
   471  		repl, err := io.ReadAll(str)
   472  		Expect(err).ToNot(HaveOccurred())
   473  		Expect(repl).To(Equal(data))
   474  	})
   475  
   476  	It("serves QUIC connections", func() {
   477  		tlsConf := getTLSConfig()
   478  		tlsConf.NextProtos = []string{http3.NextProtoH3}
   479  		ln, err := quic.ListenAddr("localhost:0", tlsConf, getQuicConfig(nil))
   480  		Expect(err).ToNot(HaveOccurred())
   481  		defer ln.Close()
   482  		done := make(chan struct{})
   483  		go func() {
   484  			defer GinkgoRecover()
   485  			defer close(done)
   486  			conn, err := ln.Accept(context.Background())
   487  			Expect(err).ToNot(HaveOccurred())
   488  			server.ServeQUICConn(conn) // returns once the client closes
   489  		}()
   490  
   491  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", ln.Addr().(*net.UDPAddr).Port))
   492  		Expect(err).ToNot(HaveOccurred())
   493  		Expect(resp.StatusCode).To(Equal(http.StatusOK))
   494  		client.Transport.(io.Closer).Close()
   495  		Eventually(done).Should(BeClosed())
   496  	})
   497  
   498  	It("supports read deadlines", func() {
   499  		mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
   500  			defer GinkgoRecover()
   501  			rc := http.NewResponseController(w)
   502  			Expect(rc.SetReadDeadline(time.Now().Add(deadlineDelay))).To(Succeed())
   503  
   504  			body, err := io.ReadAll(r.Body)
   505  			Expect(err).To(MatchError(os.ErrDeadlineExceeded))
   506  			Expect(body).To(ContainSubstring("aa"))
   507  
   508  			w.Write([]byte("ok"))
   509  		})
   510  
   511  		expectedEnd := time.Now().Add(deadlineDelay)
   512  		resp, err := client.Post(
   513  			fmt.Sprintf("https://localhost:%d/read-deadline", port),
   514  			"text/plain",
   515  			neverEnding('a'),
   516  		)
   517  		Expect(err).ToNot(HaveOccurred())
   518  		Expect(resp.StatusCode).To(Equal(200))
   519  
   520  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
   521  		Expect(err).ToNot(HaveOccurred())
   522  		Expect(time.Now().After(expectedEnd)).To(BeTrue())
   523  		Expect(string(body)).To(Equal("ok"))
   524  	})
   525  
   526  	It("supports write deadlines", func() {
   527  		mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
   528  			defer GinkgoRecover()
   529  			rc := http.NewResponseController(w)
   530  			Expect(rc.SetWriteDeadline(time.Now().Add(deadlineDelay))).To(Succeed())
   531  
   532  			_, err := io.Copy(w, neverEnding('a'))
   533  			Expect(err).To(MatchError(os.ErrDeadlineExceeded))
   534  		})
   535  
   536  		expectedEnd := time.Now().Add(deadlineDelay)
   537  
   538  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/write-deadline", port))
   539  		Expect(err).ToNot(HaveOccurred())
   540  		Expect(resp.StatusCode).To(Equal(200))
   541  
   542  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
   543  		Expect(err).ToNot(HaveOccurred())
   544  		Expect(time.Now().After(expectedEnd)).To(BeTrue())
   545  		Expect(string(body)).To(ContainSubstring("aa"))
   546  	})
   547  
   548  	It("sets remote address", func() {
   549  		mux.HandleFunc("/remote-addr", func(w http.ResponseWriter, r *http.Request) {
   550  			defer GinkgoRecover()
   551  			_, ok := r.Context().Value(http3.RemoteAddrContextKey).(net.Addr)
   552  			Expect(ok).To(BeTrue())
   553  		})
   554  
   555  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remote-addr", port))
   556  		Expect(err).ToNot(HaveOccurred())
   557  		Expect(resp.StatusCode).To(Equal(200))
   558  	})
   559  
   560  	It("sets conn context", func() {
   561  		type ctxKey int
   562  		var tracingID quic.ConnectionTracingID
   563  		server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context {
   564  			serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server)
   565  			Expect(ok).To(BeTrue())
   566  			Expect(serv).To(Equal(server))
   567  
   568  			ctx = context.WithValue(ctx, ctxKey(0), "Hello")
   569  			ctx = context.WithValue(ctx, ctxKey(1), c)
   570  			tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   571  			return ctx
   572  		}
   573  		mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
   574  			defer GinkgoRecover()
   575  			v, ok := r.Context().Value(ctxKey(0)).(string)
   576  			Expect(ok).To(BeTrue())
   577  			Expect(v).To(Equal("Hello"))
   578  
   579  			c, ok := r.Context().Value(ctxKey(1)).(quic.Connection)
   580  			Expect(ok).To(BeTrue())
   581  			Expect(c).ToNot(BeNil())
   582  
   583  			serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server)
   584  			Expect(ok).To(BeTrue())
   585  			Expect(serv).To(Equal(server))
   586  
   587  			id, ok := r.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   588  			Expect(ok).To(BeTrue())
   589  			Expect(id).To(Equal(tracingID))
   590  		})
   591  
   592  		resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))
   593  		Expect(err).ToNot(HaveOccurred())
   594  		Expect(resp.StatusCode).To(Equal(200))
   595  	})
   596  
   597  	It("checks the server's settings", func() {
   598  		tlsConf := tlsClientConfigWithoutServerName.Clone()
   599  		tlsConf.NextProtos = []string{http3.NextProtoH3}
   600  		conn, err := quic.DialAddr(
   601  			context.Background(),
   602  			fmt.Sprintf("localhost:%d", port),
   603  			tlsConf,
   604  			getQuicConfig(nil),
   605  		)
   606  		Expect(err).ToNot(HaveOccurred())
   607  		defer conn.CloseWithError(0, "")
   608  		rt := http3.SingleDestinationRoundTripper{Connection: conn}
   609  		hconn := rt.Start()
   610  		Eventually(hconn.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed())
   611  		settings := hconn.Settings()
   612  		Expect(settings.EnableExtendedConnect).To(BeTrue())
   613  		Expect(settings.EnableDatagrams).To(BeFalse())
   614  		Expect(settings.Other).To(BeEmpty())
   615  	})
   616  
   617  	It("receives the client's settings", func() {
   618  		settingsChan := make(chan *http3.Settings, 1)
   619  		mux.HandleFunc("/settings", func(w http.ResponseWriter, r *http.Request) {
   620  			defer GinkgoRecover()
   621  			conn := w.(http3.Hijacker).Connection()
   622  			Eventually(conn.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed())
   623  			settingsChan <- conn.Settings()
   624  			w.WriteHeader(http.StatusOK)
   625  		})
   626  
   627  		rt = &http3.RoundTripper{
   628  			TLSClientConfig: getTLSClientConfigWithoutServerName(),
   629  			QUICConfig: getQuicConfig(&quic.Config{
   630  				MaxIdleTimeout:  10 * time.Second,
   631  				EnableDatagrams: true,
   632  			}),
   633  			EnableDatagrams:    true,
   634  			AdditionalSettings: map[uint64]uint64{1337: 42},
   635  		}
   636  		req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/settings", port), nil)
   637  		Expect(err).ToNot(HaveOccurred())
   638  
   639  		_, err = rt.RoundTrip(req)
   640  		Expect(err).ToNot(HaveOccurred())
   641  		var settings *http3.Settings
   642  		Expect(settingsChan).To(Receive(&settings))
   643  		Expect(settings).ToNot(BeNil())
   644  		Expect(settings.EnableDatagrams).To(BeTrue())
   645  		Expect(settings.EnableExtendedConnect).To(BeFalse())
   646  		Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42)))
   647  	})
   648  
   649  	It("processes 1xx response", func() {
   650  		header1 := "</style.css>; rel=preload; as=style"
   651  		header2 := "</script.js>; rel=preload; as=script"
   652  		data := "1xx-test-data"
   653  		mux.HandleFunc("/103-early-data", func(w http.ResponseWriter, r *http.Request) {
   654  			defer GinkgoRecover()
   655  			w.Header().Add("Link", header1)
   656  			w.Header().Add("Link", header2)
   657  			w.WriteHeader(http.StatusEarlyHints)
   658  			n, err := w.Write([]byte(data))
   659  			Expect(err).NotTo(HaveOccurred())
   660  			Expect(n).To(Equal(len(data)))
   661  			w.WriteHeader(http.StatusOK)
   662  		})
   663  
   664  		var (
   665  			cnt    int
   666  			status int
   667  			hdr    textproto.MIMEHeader
   668  		)
   669  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
   670  			Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
   671  				hdr = header
   672  				status = code
   673  				cnt++
   674  				return nil
   675  			},
   676  		})
   677  
   678  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/103-early-data", port), nil)
   679  		Expect(err).ToNot(HaveOccurred())
   680  		resp, err := client.Do(req)
   681  		Expect(err).ToNot(HaveOccurred())
   682  		Expect(resp.StatusCode).To(Equal(http.StatusOK))
   683  		body, err := io.ReadAll(resp.Body)
   684  		Expect(err).ToNot(HaveOccurred())
   685  		Expect(string(body)).To(Equal(data))
   686  		Expect(status).To(Equal(http.StatusEarlyHints))
   687  		Expect(hdr).To(HaveKeyWithValue("Link", []string{header1, header2}))
   688  		Expect(cnt).To(Equal(1))
   689  		Expect(resp.Header).To(HaveKeyWithValue("Link", []string{header1, header2}))
   690  		Expect(resp.Body.Close()).To(Succeed())
   691  	})
   692  
   693  	It("processes 1xx terminal response", func() {
   694  		mux.HandleFunc("/101-switch-protocols", func(w http.ResponseWriter, r *http.Request) {
   695  			defer GinkgoRecover()
   696  			w.Header().Add("Connection", "upgrade")
   697  			w.Header().Add("Upgrade", "proto")
   698  			w.WriteHeader(http.StatusSwitchingProtocols)
   699  		})
   700  
   701  		var (
   702  			cnt    int
   703  			status int
   704  		)
   705  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
   706  			Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
   707  				status = code
   708  				cnt++
   709  				return nil
   710  			},
   711  		})
   712  
   713  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/101-switch-protocols", port), nil)
   714  		Expect(err).ToNot(HaveOccurred())
   715  		resp, err := client.Do(req)
   716  		Expect(err).ToNot(HaveOccurred())
   717  		Expect(resp.StatusCode).To(Equal(http.StatusSwitchingProtocols))
   718  		Expect(resp.Header).To(HaveKeyWithValue("Connection", []string{"upgrade"}))
   719  		Expect(resp.Header).To(HaveKeyWithValue("Upgrade", []string{"proto"}))
   720  		Expect(status).To(Equal(0))
   721  		Expect(cnt).To(Equal(0))
   722  	})
   723  
   724  	Context("HTTP datagrams", func() {
   725  		openDatagramStream := func(h string) (_ http3.RequestStream, closeFn func()) {
   726  			tlsConf := getTLSClientConfigWithoutServerName()
   727  			tlsConf.NextProtos = []string{http3.NextProtoH3}
   728  			conn, err := quic.DialAddr(
   729  				context.Background(),
   730  				fmt.Sprintf("localhost:%d", port),
   731  				tlsConf,
   732  				getQuicConfig(&quic.Config{EnableDatagrams: true}),
   733  			)
   734  			Expect(err).ToNot(HaveOccurred())
   735  
   736  			rt := &http3.SingleDestinationRoundTripper{
   737  				Connection:      conn,
   738  				EnableDatagrams: true,
   739  			}
   740  			str, err := rt.OpenRequestStream(context.Background())
   741  			Expect(err).ToNot(HaveOccurred())
   742  			u, err := url.Parse(h)
   743  			Expect(err).ToNot(HaveOccurred())
   744  			req := &http.Request{
   745  				Method: http.MethodConnect,
   746  				Proto:  "datagrams",
   747  				Host:   u.Host,
   748  				URL:    u,
   749  			}
   750  			Expect(str.SendRequestHeader(req)).To(Succeed())
   751  
   752  			rsp, err := str.ReadResponse()
   753  			Expect(err).ToNot(HaveOccurred())
   754  			Expect(rsp.StatusCode).To(Equal(http.StatusOK))
   755  			return str, func() { conn.CloseWithError(0, "") }
   756  		}
   757  
   758  		It("sends an receives HTTP datagrams", func() {
   759  			errChan := make(chan error, 1)
   760  			const num = 5
   761  			datagramChan := make(chan struct{}, num)
   762  			mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) {
   763  				defer GinkgoRecover()
   764  				Expect(r.Method).To(Equal(http.MethodConnect))
   765  				conn := w.(http3.Hijacker).Connection()
   766  				Eventually(conn.ReceivedSettings()).Should(BeClosed())
   767  				Expect(conn.Settings().EnableDatagrams).To(BeTrue())
   768  				w.WriteHeader(http.StatusOK)
   769  
   770  				str := w.(http3.HTTPStreamer).HTTPStream()
   771  				go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions
   772  
   773  				for {
   774  					if _, err := str.ReceiveDatagram(context.Background()); err != nil {
   775  						errChan <- err
   776  						return
   777  					}
   778  					datagramChan <- struct{}{}
   779  				}
   780  			})
   781  
   782  			str, closeFn := openDatagramStream(fmt.Sprintf("https://localhost:%d/datagrams", port))
   783  			defer closeFn()
   784  
   785  			for i := 0; i < num; i++ {
   786  				b := make([]byte, 8)
   787  				binary.BigEndian.PutUint64(b, uint64(i))
   788  				Expect(str.SendDatagram(bytes.Repeat(b, 100))).To(Succeed())
   789  			}
   790  			var count int
   791  		loop:
   792  			for {
   793  				select {
   794  				case <-datagramChan:
   795  					count++
   796  					if count >= num*4/5 {
   797  						break loop
   798  					}
   799  				case err := <-errChan:
   800  					Fail(fmt.Sprintf("receiving datagrams failed: %s", err))
   801  				}
   802  			}
   803  			str.CancelWrite(42)
   804  
   805  			var resetErr error
   806  			Eventually(errChan).Should(Receive(&resetErr))
   807  			Expect(resetErr.(*quic.StreamError).ErrorCode).To(BeEquivalentTo(42))
   808  		})
   809  
   810  		It("closes the send direction", func() {
   811  			errChan := make(chan error, 1)
   812  			datagramChan := make(chan []byte, 1)
   813  			mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) {
   814  				defer GinkgoRecover()
   815  				conn := w.(http3.Hijacker).Connection()
   816  				Eventually(conn.ReceivedSettings()).Should(BeClosed())
   817  				Expect(conn.Settings().EnableDatagrams).To(BeTrue())
   818  				w.WriteHeader(http.StatusOK)
   819  
   820  				str := w.(http3.HTTPStreamer).HTTPStream()
   821  				go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions
   822  
   823  				for {
   824  					data, err := str.ReceiveDatagram(context.Background())
   825  					if err != nil {
   826  						errChan <- err
   827  						return
   828  					}
   829  					datagramChan <- data
   830  				}
   831  			})
   832  
   833  			str, closeFn := openDatagramStream(fmt.Sprintf("https://localhost:%d/datagrams", port))
   834  			defer closeFn()
   835  			go str.Read([]byte{0})
   836  
   837  			Expect(str.SendDatagram([]byte("foo"))).To(Succeed())
   838  			Eventually(datagramChan).Should(Receive(Equal([]byte("foo"))))
   839  			// signal that we're done sending
   840  			str.Close()
   841  
   842  			var resetErr error
   843  			Eventually(errChan).Should(Receive(&resetErr))
   844  			Expect(resetErr).To(Equal(io.EOF))
   845  
   846  			// make sure we can't send anymore
   847  			Expect(str.SendDatagram([]byte("foo"))).ToNot(Succeed())
   848  		})
   849  
   850  		It("detecting a stream reset from the server", func() {
   851  			errChan := make(chan error, 1)
   852  			datagramChan := make(chan []byte, 1)
   853  			mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) {
   854  				defer GinkgoRecover()
   855  				conn := w.(http3.Hijacker).Connection()
   856  				Eventually(conn.ReceivedSettings()).Should(BeClosed())
   857  				Expect(conn.Settings().EnableDatagrams).To(BeTrue())
   858  				w.WriteHeader(http.StatusOK)
   859  
   860  				str := w.(http3.HTTPStreamer).HTTPStream()
   861  				go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions
   862  
   863  				for {
   864  					data, err := str.ReceiveDatagram(context.Background())
   865  					if err != nil {
   866  						errChan <- err
   867  						return
   868  					}
   869  					str.CancelRead(42)
   870  					datagramChan <- data
   871  				}
   872  			})
   873  
   874  			str, closeFn := openDatagramStream(fmt.Sprintf("https://localhost:%d/datagrams", port))
   875  			defer closeFn()
   876  			go str.Read([]byte{0})
   877  
   878  			Expect(str.SendDatagram([]byte("foo"))).To(Succeed())
   879  			Eventually(datagramChan).Should(Receive(Equal([]byte("foo"))))
   880  			// signal that we're done sending
   881  
   882  			var resetErr error
   883  			Eventually(errChan).Should(Receive(&resetErr))
   884  			Expect(resetErr).To(Equal(&quic.StreamError{ErrorCode: 42, Remote: false}))
   885  
   886  			// make sure we can't send anymore
   887  			Expect(str.SendDatagram([]byte("foo"))).To(Equal(&quic.StreamError{ErrorCode: 42, Remote: true}))
   888  		})
   889  	})
   890  
   891  	Context("0-RTT", func() {
   892  		runCountingProxy := func(serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) {
   893  			var num0RTTPackets atomic.Uint32
   894  			proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   895  				RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
   896  				DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
   897  					if contains0RTTPacket(data) {
   898  						num0RTTPackets.Add(1)
   899  					}
   900  					return rtt / 2
   901  				},
   902  			})
   903  			Expect(err).ToNot(HaveOccurred())
   904  			return proxy, &num0RTTPackets
   905  		}
   906  
   907  		It("sends 0-RTT GET requests", func() {
   908  			proxy, num0RTTPackets := runCountingProxy(port, scaleDuration(50*time.Millisecond))
   909  			defer proxy.Close()
   910  
   911  			tlsConf := getTLSClientConfigWithoutServerName()
   912  			puts := make(chan string, 10)
   913  			tlsConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(10), nil, puts)
   914  			rt := &http3.RoundTripper{
   915  				TLSClientConfig: tlsConf,
   916  				QUICConfig: getQuicConfig(&quic.Config{
   917  					MaxIdleTimeout: 10 * time.Second,
   918  				}),
   919  				DisableCompression: true,
   920  			}
   921  			defer rt.Close()
   922  
   923  			mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) {
   924  				w.Write([]byte(strconv.FormatBool(!r.TLS.HandshakeComplete)))
   925  			})
   926  			req, err := http.NewRequest(http3.MethodGet0RTT, fmt.Sprintf("https://localhost:%d/0rtt", proxy.LocalPort()), nil)
   927  			Expect(err).ToNot(HaveOccurred())
   928  			rsp, err := rt.RoundTrip(req)
   929  			Expect(err).ToNot(HaveOccurred())
   930  			Expect(rsp.StatusCode).To(BeEquivalentTo(200))
   931  			data, err := io.ReadAll(rsp.Body)
   932  			Expect(err).ToNot(HaveOccurred())
   933  			Expect(string(data)).To(Equal("false"))
   934  			Expect(num0RTTPackets.Load()).To(BeZero())
   935  			Eventually(puts).Should(Receive())
   936  
   937  			rt2 := &http3.RoundTripper{
   938  				TLSClientConfig:    rt.TLSClientConfig,
   939  				QUICConfig:         rt.QUICConfig,
   940  				DisableCompression: true,
   941  			}
   942  			defer rt2.Close()
   943  			rsp, err = rt2.RoundTrip(req)
   944  			Expect(err).ToNot(HaveOccurred())
   945  			Expect(rsp.StatusCode).To(BeEquivalentTo(200))
   946  			data, err = io.ReadAll(rsp.Body)
   947  			Expect(err).ToNot(HaveOccurred())
   948  			Expect(string(data)).To(Equal("true"))
   949  			Expect(num0RTTPackets.Load()).To(BeNumerically(">", 0))
   950  		})
   951  	})
   952  })